LCOV - code coverage report
Current view: top level - sys/classes/fn/impls/invsqrt - fninvsqrt.c (source / functions) Hit Total Coverage
Test: SLEPc Lines: 133 136 97.8 %
Date: 2024-11-23 00:39:48 Functions: 9 9 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*
       2             :    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
       3             :    SLEPc - Scalable Library for Eigenvalue Problem Computations
       4             :    Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain
       5             : 
       6             :    This file is part of SLEPc.
       7             :    SLEPc is distributed under a 2-clause BSD license (see LICENSE).
       8             :    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
       9             : */
      10             : /*
      11             :    Inverse square root function  x^(-1/2)
      12             : */
      13             : 
      14             : #include <slepc/private/fnimpl.h>      /*I "slepcfn.h" I*/
      15             : #include <slepcblaslapack.h>
      16             : 
      17          48 : static PetscErrorCode FNEvaluateFunction_Invsqrt(FN fn,PetscScalar x,PetscScalar *y)
      18             : {
      19          48 :   PetscFunctionBegin;
      20          48 :   PetscCheck(x!=0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Function not defined in the requested value");
      21             : #if !defined(PETSC_USE_COMPLEX)
      22             :   PetscCheck(x>0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Function not defined in the requested value");
      23             : #endif
      24          48 :   *y = 1.0/PetscSqrtScalar(x);
      25          48 :   PetscFunctionReturn(PETSC_SUCCESS);
      26             : }
      27             : 
      28           8 : static PetscErrorCode FNEvaluateDerivative_Invsqrt(FN fn,PetscScalar x,PetscScalar *y)
      29             : {
      30           8 :   PetscFunctionBegin;
      31           8 :   PetscCheck(x!=0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Derivative not defined in the requested value");
      32             : #if !defined(PETSC_USE_COMPLEX)
      33             :   PetscCheck(x>0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Derivative not defined in the requested value");
      34             : #endif
      35           8 :   *y = -1.0/(2.0*PetscPowScalarReal(x,1.5));
      36           8 :   PetscFunctionReturn(PETSC_SUCCESS);
      37             : }
      38             : 
      39           4 : static PetscErrorCode FNEvaluateFunctionMat_Invsqrt_Schur(FN fn,Mat A,Mat B)
      40             : {
      41           4 :   PetscBLASInt   n=0,ld,*ipiv,info;
      42           4 :   PetscScalar    *Ba,*Wa;
      43           4 :   PetscInt       m;
      44           4 :   Mat            W;
      45             : 
      46           4 :   PetscFunctionBegin;
      47           4 :   PetscCall(FN_AllocateWorkMat(fn,A,&W));
      48           4 :   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
      49           4 :   PetscCall(MatDenseGetArray(B,&Ba));
      50           4 :   PetscCall(MatDenseGetArray(W,&Wa));
      51             :   /* compute B = sqrtm(A) */
      52           4 :   PetscCall(MatGetSize(A,&m,NULL));
      53           4 :   PetscCall(PetscBLASIntCast(m,&n));
      54           4 :   ld = n;
      55           4 :   PetscCall(FNSqrtmSchur(fn,n,Ba,n,PETSC_FALSE));
      56             :   /* compute B = A\B */
      57           4 :   PetscCall(PetscMalloc1(ld,&ipiv));
      58           4 :   PetscCallBLAS("LAPACKgesv",LAPACKgesv_(&n,&n,Wa,&ld,ipiv,Ba,&ld,&info));
      59           4 :   SlepcCheckLapackInfo("gesv",info);
      60           4 :   PetscCall(PetscLogFlops(2.0*n*n*n/3.0+2.0*n*n*n));
      61           4 :   PetscCall(PetscFree(ipiv));
      62           4 :   PetscCall(MatDenseRestoreArray(W,&Wa));
      63           4 :   PetscCall(MatDenseRestoreArray(B,&Ba));
      64           4 :   PetscCall(FN_FreeWorkMat(fn,&W));
      65           4 :   PetscFunctionReturn(PETSC_SUCCESS);
      66             : }
      67             : 
      68           4 : static PetscErrorCode FNEvaluateFunctionMatVec_Invsqrt_Schur(FN fn,Mat A,Vec v)
      69             : {
      70           4 :   PetscBLASInt   n=0,ld,*ipiv,info,one=1;
      71           4 :   PetscScalar    *Ba,*Wa;
      72           4 :   PetscInt       m;
      73           4 :   Mat            B,W;
      74             : 
      75           4 :   PetscFunctionBegin;
      76           4 :   PetscCall(FN_AllocateWorkMat(fn,A,&B));
      77           4 :   PetscCall(FN_AllocateWorkMat(fn,A,&W));
      78           4 :   PetscCall(MatDenseGetArray(B,&Ba));
      79           4 :   PetscCall(MatDenseGetArray(W,&Wa));
      80             :   /* compute B_1 = sqrtm(A)*e_1 */
      81           4 :   PetscCall(MatGetSize(A,&m,NULL));
      82           4 :   PetscCall(PetscBLASIntCast(m,&n));
      83           4 :   ld = n;
      84           4 :   PetscCall(FNSqrtmSchur(fn,n,Ba,n,PETSC_TRUE));
      85             :   /* compute B_1 = A\B_1 */
      86           4 :   PetscCall(PetscMalloc1(ld,&ipiv));
      87           4 :   PetscCallBLAS("LAPACKgesv",LAPACKgesv_(&n,&one,Wa,&ld,ipiv,Ba,&ld,&info));
      88           4 :   SlepcCheckLapackInfo("gesv",info);
      89           4 :   PetscCall(PetscFree(ipiv));
      90           4 :   PetscCall(MatDenseRestoreArray(W,&Wa));
      91           4 :   PetscCall(MatDenseRestoreArray(B,&Ba));
      92           4 :   PetscCall(MatGetColumnVector(B,v,0));
      93           4 :   PetscCall(FN_FreeWorkMat(fn,&W));
      94           4 :   PetscCall(FN_FreeWorkMat(fn,&B));
      95           4 :   PetscFunctionReturn(PETSC_SUCCESS);
      96             : }
      97             : 
      98          12 : static PetscErrorCode FNEvaluateFunctionMat_Invsqrt_DBP(FN fn,Mat A,Mat B)
      99             : {
     100          12 :   PetscBLASInt   n=0;
     101          12 :   PetscScalar    *T;
     102          12 :   PetscInt       m;
     103             : 
     104          12 :   PetscFunctionBegin;
     105          12 :   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
     106          12 :   PetscCall(MatDenseGetArray(B,&T));
     107          12 :   PetscCall(MatGetSize(A,&m,NULL));
     108          12 :   PetscCall(PetscBLASIntCast(m,&n));
     109          12 :   PetscCall(FNSqrtmDenmanBeavers(fn,n,T,n,PETSC_TRUE));
     110          12 :   PetscCall(MatDenseRestoreArray(B,&T));
     111          12 :   PetscFunctionReturn(PETSC_SUCCESS);
     112             : }
     113             : 
     114          12 : static PetscErrorCode FNEvaluateFunctionMat_Invsqrt_NS(FN fn,Mat A,Mat B)
     115             : {
     116          12 :   PetscBLASInt   n=0;
     117          12 :   PetscScalar    *T;
     118          12 :   PetscInt       m;
     119             : 
     120          12 :   PetscFunctionBegin;
     121          12 :   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
     122          12 :   PetscCall(MatDenseGetArray(B,&T));
     123          12 :   PetscCall(MatGetSize(A,&m,NULL));
     124          12 :   PetscCall(PetscBLASIntCast(m,&n));
     125          12 :   PetscCall(FNSqrtmNewtonSchulz(fn,n,T,n,PETSC_TRUE));
     126          12 :   PetscCall(MatDenseRestoreArray(B,&T));
     127          12 :   PetscFunctionReturn(PETSC_SUCCESS);
     128             : }
     129             : 
     130          12 : static PetscErrorCode FNEvaluateFunctionMat_Invsqrt_Sadeghi(FN fn,Mat A,Mat B)
     131             : {
     132          12 :   PetscBLASInt   n=0,ld,*ipiv,info;
     133          12 :   PetscScalar    *Ba,*Wa;
     134          12 :   PetscInt       m;
     135          12 :   Mat            W;
     136             : 
     137          12 :   PetscFunctionBegin;
     138          12 :   PetscCall(FN_AllocateWorkMat(fn,A,&W));
     139          12 :   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
     140          12 :   PetscCall(MatDenseGetArray(B,&Ba));
     141          12 :   PetscCall(MatDenseGetArray(W,&Wa));
     142             :   /* compute B = sqrtm(A) */
     143          12 :   PetscCall(MatGetSize(A,&m,NULL));
     144          12 :   PetscCall(PetscBLASIntCast(m,&n));
     145          12 :   ld = n;
     146          12 :   PetscCall(FNSqrtmSadeghi(fn,n,Ba,n));
     147             :   /* compute B = A\B */
     148          12 :   PetscCall(PetscMalloc1(ld,&ipiv));
     149          12 :   PetscCallBLAS("LAPACKgesv",LAPACKgesv_(&n,&n,Wa,&ld,ipiv,Ba,&ld,&info));
     150          12 :   SlepcCheckLapackInfo("gesv",info);
     151          12 :   PetscCall(PetscLogFlops(2.0*n*n*n/3.0+2.0*n*n*n));
     152          12 :   PetscCall(PetscFree(ipiv));
     153          12 :   PetscCall(MatDenseRestoreArray(W,&Wa));
     154          12 :   PetscCall(MatDenseRestoreArray(B,&Ba));
     155          12 :   PetscCall(FN_FreeWorkMat(fn,&W));
     156          12 :   PetscFunctionReturn(PETSC_SUCCESS);
     157             : }
     158             : 
     159             : #if defined(PETSC_HAVE_CUDA)
     160             : PetscErrorCode FNEvaluateFunctionMat_Invsqrt_NS_CUDA(FN fn,Mat A,Mat B)
     161             : {
     162             :   PetscBLASInt   n=0;
     163             :   PetscScalar    *Ba;
     164             :   PetscInt       m;
     165             : 
     166             :   PetscFunctionBegin;
     167             :   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
     168             :   PetscCall(MatDenseCUDAGetArray(B,&Ba));
     169             :   PetscCall(MatGetSize(A,&m,NULL));
     170             :   PetscCall(PetscBLASIntCast(m,&n));
     171             :   PetscCall(FNSqrtmNewtonSchulz_CUDA(fn,n,Ba,n,PETSC_TRUE));
     172             :   PetscCall(MatDenseCUDARestoreArray(B,&Ba));
     173             :   PetscFunctionReturn(PETSC_SUCCESS);
     174             : }
     175             : 
     176             : #if defined(PETSC_HAVE_MAGMA)
     177             : #include <slepcmagma.h>
     178             : 
     179             : PetscErrorCode FNEvaluateFunctionMat_Invsqrt_DBP_CUDAm(FN fn,Mat A,Mat B)
     180             : {
     181             :   PetscBLASInt   n=0;
     182             :   PetscScalar    *T;
     183             :   PetscInt       m;
     184             : 
     185             :   PetscFunctionBegin;
     186             :   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
     187             :   PetscCall(MatDenseCUDAGetArray(B,&T));
     188             :   PetscCall(MatGetSize(A,&m,NULL));
     189             :   PetscCall(PetscBLASIntCast(m,&n));
     190             :   PetscCall(FNSqrtmDenmanBeavers_CUDAm(fn,n,T,n,PETSC_TRUE));
     191             :   PetscCall(MatDenseCUDARestoreArray(B,&T));
     192             :   PetscFunctionReturn(PETSC_SUCCESS);
     193             : }
     194             : 
     195             : PetscErrorCode FNEvaluateFunctionMat_Invsqrt_Sadeghi_CUDAm(FN fn,Mat A,Mat B)
     196             : {
     197             :   PetscBLASInt   n=0,ld,*ipiv;
     198             :   PetscScalar    *Ba,*Wa;
     199             :   PetscInt       m;
     200             :   Mat            W;
     201             : 
     202             :   PetscFunctionBegin;
     203             :   PetscCall(FN_AllocateWorkMat(fn,A,&W));
     204             :   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
     205             :   PetscCall(MatDenseCUDAGetArray(B,&Ba));
     206             :   PetscCall(MatDenseCUDAGetArray(W,&Wa));
     207             :   /* compute B = sqrtm(A) */
     208             :   PetscCall(MatGetSize(A,&m,NULL));
     209             :   PetscCall(PetscBLASIntCast(m,&n));
     210             :   ld = n;
     211             :   PetscCall(FNSqrtmSadeghi_CUDAm(fn,n,Ba,n));
     212             :   /* compute B = A\B */
     213             :   PetscCall(SlepcMagmaInit());
     214             :   PetscCall(PetscMalloc1(ld,&ipiv));
     215             :   PetscCallMAGMA(magma_xgesv_gpu,n,n,Wa,ld,ipiv,Ba,ld);
     216             :   PetscCall(PetscLogFlops(2.0*n*n*n/3.0+2.0*n*n*n));
     217             :   PetscCall(PetscFree(ipiv));
     218             :   PetscCall(MatDenseCUDARestoreArray(W,&Wa));
     219             :   PetscCall(MatDenseCUDARestoreArray(B,&Ba));
     220             :   PetscCall(FN_FreeWorkMat(fn,&W));
     221             :   PetscFunctionReturn(PETSC_SUCCESS);
     222             : }
     223             : #endif /* PETSC_HAVE_MAGMA */
     224             : #endif /* PETSC_HAVE_CUDA */
     225             : 
     226           8 : static PetscErrorCode FNView_Invsqrt(FN fn,PetscViewer viewer)
     227             : {
     228           8 :   PetscBool      isascii;
     229           8 :   char           str[50];
     230           8 :   const char     *methodname[] = {
     231             :                   "Schur method for inv(A)*sqrtm(A)",
     232             :                   "Denman-Beavers (product form)",
     233             :                   "Newton-Schulz iteration",
     234             :                   "Sadeghi iteration"
     235             :   };
     236           8 :   const int      nmeth=PETSC_STATIC_ARRAY_LENGTH(methodname);
     237             : 
     238           8 :   PetscFunctionBegin;
     239           8 :   PetscCall(PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&isascii));
     240           8 :   if (isascii) {
     241           8 :     if (fn->beta==(PetscScalar)1.0) {
     242           0 :       if (fn->alpha==(PetscScalar)1.0) PetscCall(PetscViewerASCIIPrintf(viewer,"  inverse square root: x^(-1/2)\n"));
     243             :       else {
     244           0 :         PetscCall(SlepcSNPrintfScalar(str,sizeof(str),fn->alpha,PETSC_TRUE));
     245           0 :         PetscCall(PetscViewerASCIIPrintf(viewer,"  inverse square root: (%s*x)^(-1/2)\n",str));
     246             :       }
     247             :     } else {
     248           8 :       PetscCall(SlepcSNPrintfScalar(str,sizeof(str),fn->beta,PETSC_TRUE));
     249           8 :       if (fn->alpha==(PetscScalar)1.0) PetscCall(PetscViewerASCIIPrintf(viewer,"  inverse square root: %s*x^(-1/2)\n",str));
     250             :       else {
     251           8 :         PetscCall(PetscViewerASCIIPrintf(viewer,"  inverse square root: %s",str));
     252           8 :         PetscCall(PetscViewerASCIIUseTabs(viewer,PETSC_FALSE));
     253           8 :         PetscCall(SlepcSNPrintfScalar(str,sizeof(str),fn->alpha,PETSC_TRUE));
     254           8 :         PetscCall(PetscViewerASCIIPrintf(viewer,"*(%s*x)^(-1/2)\n",str));
     255           8 :         PetscCall(PetscViewerASCIIUseTabs(viewer,PETSC_TRUE));
     256             :       }
     257             :     }
     258           8 :     if (fn->method<nmeth) PetscCall(PetscViewerASCIIPrintf(viewer,"  computing matrix functions with: %s\n",methodname[fn->method]));
     259             :   }
     260           8 :   PetscFunctionReturn(PETSC_SUCCESS);
     261             : }
     262             : 
     263           8 : SLEPC_EXTERN PetscErrorCode FNCreate_Invsqrt(FN fn)
     264             : {
     265           8 :   PetscFunctionBegin;
     266           8 :   fn->ops->evaluatefunction          = FNEvaluateFunction_Invsqrt;
     267           8 :   fn->ops->evaluatederivative        = FNEvaluateDerivative_Invsqrt;
     268           8 :   fn->ops->evaluatefunctionmat[0]    = FNEvaluateFunctionMat_Invsqrt_Schur;
     269           8 :   fn->ops->evaluatefunctionmat[1]    = FNEvaluateFunctionMat_Invsqrt_DBP;
     270           8 :   fn->ops->evaluatefunctionmat[2]    = FNEvaluateFunctionMat_Invsqrt_NS;
     271           8 :   fn->ops->evaluatefunctionmat[3]    = FNEvaluateFunctionMat_Invsqrt_Sadeghi;
     272             : #if defined(PETSC_HAVE_CUDA)
     273             :   fn->ops->evaluatefunctionmatcuda[2] = FNEvaluateFunctionMat_Invsqrt_NS_CUDA;
     274             : #if defined(PETSC_HAVE_MAGMA)
     275             :   fn->ops->evaluatefunctionmatcuda[1] = FNEvaluateFunctionMat_Invsqrt_DBP_CUDAm;
     276             :   fn->ops->evaluatefunctionmatcuda[3] = FNEvaluateFunctionMat_Invsqrt_Sadeghi_CUDAm;
     277             : #endif /* PETSC_HAVE_MAGMA */
     278             : #endif /* PETSC_HAVE_CUDA */
     279           8 :   fn->ops->evaluatefunctionmatvec[0] = FNEvaluateFunctionMatVec_Invsqrt_Schur;
     280           8 :   fn->ops->view                      = FNView_Invsqrt;
     281           8 :   PetscFunctionReturn(PETSC_SUCCESS);
     282             : }

Generated by: LCOV version 1.14