LCOV - code coverage report
Current view: top level - sys/classes/fn/impls/sqrt - fnsqrt.c (source / functions) Hit Total Coverage
Test: SLEPc Lines: 156 158 98.7 %
Date: 2024-04-18 01:01:30 Functions: 10 10 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*
       2             :    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
       3             :    SLEPc - Scalable Library for Eigenvalue Problem Computations
       4             :    Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain
       5             : 
       6             :    This file is part of SLEPc.
       7             :    SLEPc is distributed under a 2-clause BSD license (see LICENSE).
       8             :    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
       9             : */
      10             : /*
      11             :    Square root function  sqrt(x)
      12             : */
      13             : 
      14             : #include <slepc/private/fnimpl.h>      /*I "slepcfn.h" I*/
      15             : #include <slepcblaslapack.h>
      16             : 
      17         952 : static PetscErrorCode FNEvaluateFunction_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
      18             : {
      19         952 :   PetscFunctionBegin;
      20             : #if !defined(PETSC_USE_COMPLEX)
      21         952 :   PetscCheck(x>=0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Function not defined in the requested value");
      22             : #endif
      23         952 :   *y = PetscSqrtScalar(x);
      24         952 :   PetscFunctionReturn(PETSC_SUCCESS);
      25             : }
      26             : 
      27          20 : static PetscErrorCode FNEvaluateDerivative_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
      28             : {
      29          20 :   PetscFunctionBegin;
      30          20 :   PetscCheck(x!=0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Derivative not defined in the requested value");
      31             : #if !defined(PETSC_USE_COMPLEX)
      32          20 :   PetscCheck(x>0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Derivative not defined in the requested value");
      33             : #endif
      34          20 :   *y = 1.0/(2.0*PetscSqrtScalar(x));
      35          20 :   PetscFunctionReturn(PETSC_SUCCESS);
      36             : }
      37             : 
      38          19 : static PetscErrorCode FNEvaluateFunctionMat_Sqrt_Schur(FN fn,Mat A,Mat B)
      39             : {
      40          19 :   PetscBLASInt   n=0;
      41          19 :   PetscScalar    *T;
      42          19 :   PetscInt       m;
      43             : 
      44          19 :   PetscFunctionBegin;
      45          19 :   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
      46          19 :   PetscCall(MatDenseGetArray(B,&T));
      47          19 :   PetscCall(MatGetSize(A,&m,NULL));
      48          19 :   PetscCall(PetscBLASIntCast(m,&n));
      49          19 :   PetscCall(FNSqrtmSchur(fn,n,T,n,PETSC_FALSE));
      50          19 :   PetscCall(MatDenseRestoreArray(B,&T));
      51          19 :   PetscFunctionReturn(PETSC_SUCCESS);
      52             : }
      53             : 
      54          14 : static PetscErrorCode FNEvaluateFunctionMatVec_Sqrt_Schur(FN fn,Mat A,Vec v)
      55             : {
      56          14 :   PetscBLASInt   n=0;
      57          14 :   PetscScalar    *T;
      58          14 :   PetscInt       m;
      59          14 :   Mat            B;
      60             : 
      61          14 :   PetscFunctionBegin;
      62          14 :   PetscCall(FN_AllocateWorkMat(fn,A,&B));
      63          14 :   PetscCall(MatDenseGetArray(B,&T));
      64          14 :   PetscCall(MatGetSize(A,&m,NULL));
      65          14 :   PetscCall(PetscBLASIntCast(m,&n));
      66          14 :   PetscCall(FNSqrtmSchur(fn,n,T,n,PETSC_TRUE));
      67          14 :   PetscCall(MatDenseRestoreArray(B,&T));
      68          14 :   PetscCall(MatGetColumnVector(B,v,0));
      69          14 :   PetscCall(FN_FreeWorkMat(fn,&B));
      70          14 :   PetscFunctionReturn(PETSC_SUCCESS);
      71             : }
      72             : 
      73          12 : static PetscErrorCode FNEvaluateFunctionMat_Sqrt_DBP(FN fn,Mat A,Mat B)
      74             : {
      75          12 :   PetscBLASInt   n=0;
      76          12 :   PetscScalar    *T;
      77          12 :   PetscInt       m;
      78             : 
      79          12 :   PetscFunctionBegin;
      80          12 :   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
      81          12 :   PetscCall(MatDenseGetArray(B,&T));
      82          12 :   PetscCall(MatGetSize(A,&m,NULL));
      83          12 :   PetscCall(PetscBLASIntCast(m,&n));
      84          12 :   PetscCall(FNSqrtmDenmanBeavers(fn,n,T,n,PETSC_FALSE));
      85          12 :   PetscCall(MatDenseRestoreArray(B,&T));
      86          12 :   PetscFunctionReturn(PETSC_SUCCESS);
      87             : }
      88             : 
      89          12 : static PetscErrorCode FNEvaluateFunctionMat_Sqrt_NS(FN fn,Mat A,Mat B)
      90             : {
      91          12 :   PetscBLASInt   n=0;
      92          12 :   PetscScalar    *Ba;
      93          12 :   PetscInt       m;
      94             : 
      95          12 :   PetscFunctionBegin;
      96          12 :   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
      97          12 :   PetscCall(MatDenseGetArray(B,&Ba));
      98          12 :   PetscCall(MatGetSize(A,&m,NULL));
      99          12 :   PetscCall(PetscBLASIntCast(m,&n));
     100          12 :   PetscCall(FNSqrtmNewtonSchulz(fn,n,Ba,n,PETSC_FALSE));
     101          12 :   PetscCall(MatDenseRestoreArray(B,&Ba));
     102          12 :   PetscFunctionReturn(PETSC_SUCCESS);
     103             : }
     104             : 
     105             : #define MAXIT 50
     106             : 
     107             : /*
     108             :    Computes the principal square root of the matrix A using the
     109             :    Sadeghi iteration. A is overwritten with sqrtm(A).
     110             :  */
     111          24 : PetscErrorCode FNSqrtmSadeghi(FN fn,PetscBLASInt n,PetscScalar *A,PetscBLASInt ld)
     112             : {
     113          24 :   PetscScalar    *M,*M2,*G,*X=A,*work,work1,sqrtnrm;
     114          24 :   PetscScalar    szero=0.0,sone=1.0,smfive=-5.0,s1d16=1.0/16.0;
     115          24 :   PetscReal      tol,Mres=0.0,nrm,rwork[1],done=1.0;
     116          24 :   PetscInt       i,it;
     117          24 :   PetscBLASInt   N,*piv=NULL,info,lwork=0,query=-1,one=1,zero=0;
     118          24 :   PetscBool      converged=PETSC_FALSE;
     119          24 :   unsigned int   ftz;
     120             : 
     121          24 :   PetscFunctionBegin;
     122          24 :   N = n*n;
     123          24 :   tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
     124          24 :   PetscCall(SlepcSetFlushToZero(&ftz));
     125             : 
     126             :   /* query work size */
     127          24 :   PetscCallBLAS("LAPACKgetri",LAPACKgetri_(&n,A,&ld,piv,&work1,&query,&info));
     128          24 :   PetscCall(PetscBLASIntCast((PetscInt)PetscRealPart(work1),&lwork));
     129             : 
     130          24 :   PetscCall(PetscMalloc5(N,&M,N,&M2,N,&G,lwork,&work,n,&piv));
     131          24 :   PetscCall(PetscArraycpy(M,A,N));
     132             : 
     133             :   /* scale M */
     134          24 :   nrm = LAPACKlange_("fro",&n,&n,M,&n,rwork);
     135          24 :   if (nrm>1.0) {
     136          24 :     sqrtnrm = PetscSqrtReal(nrm);
     137          24 :     PetscCallBLAS("LAPACKlascl",LAPACKlascl_("G",&zero,&zero,&nrm,&done,&N,&one,M,&N,&info));
     138          24 :     SlepcCheckLapackInfo("lascl",info);
     139          24 :     tol *= nrm;
     140             :   }
     141          24 :   PetscCall(PetscInfo(fn,"||A||_F = %g, new tol: %g\n",(double)nrm,(double)tol));
     142             : 
     143             :   /* X = I */
     144          24 :   PetscCall(PetscArrayzero(X,N));
     145        1344 :   for (i=0;i<n;i++) X[i+i*ld] = 1.0;
     146             : 
     147         124 :   for (it=0;it<MAXIT && !converged;it++) {
     148             : 
     149             :     /* G = (5/16)*I + (1/16)*M*(15*I-5*M+M*M) */
     150         100 :     PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M,&ld,M,&ld,&szero,M2,&ld));
     151         100 :     PetscCallBLAS("BLASaxpy",BLASaxpy_(&N,&smfive,M,&one,M2,&one));
     152        6140 :     for (i=0;i<n;i++) M2[i+i*ld] += 15.0;
     153         100 :     PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&s1d16,M,&ld,M2,&ld,&szero,G,&ld));
     154        6140 :     for (i=0;i<n;i++) G[i+i*ld] += 5.0/16.0;
     155             : 
     156             :     /* X = X*G */
     157         100 :     PetscCall(PetscArraycpy(M2,X,N));
     158         100 :     PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M2,&ld,G,&ld,&szero,X,&ld));
     159             : 
     160             :     /* M = M*inv(G*G) */
     161         100 :     PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,G,&ld,&szero,M2,&ld));
     162         100 :     PetscCallBLAS("LAPACKgetrf",LAPACKgetrf_(&n,&n,M2,&ld,piv,&info));
     163         100 :     SlepcCheckLapackInfo("getrf",info);
     164         100 :     PetscCallBLAS("LAPACKgetri",LAPACKgetri_(&n,M2,&ld,piv,work,&lwork,&info));
     165         100 :     SlepcCheckLapackInfo("getri",info);
     166             : 
     167         100 :     PetscCall(PetscArraycpy(G,M,N));
     168         100 :     PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,M2,&ld,&szero,M,&ld));
     169             : 
     170             :     /* check ||I-M|| */
     171         100 :     PetscCall(PetscArraycpy(M2,M,N));
     172        6140 :     for (i=0;i<n;i++) M2[i+i*ld] -= 1.0;
     173         100 :     Mres = LAPACKlange_("fro",&n,&n,M2,&n,rwork);
     174         100 :     PetscCheck(!PetscIsNanReal(Mres),PETSC_COMM_SELF,PETSC_ERR_FP,"The computed norm is not-a-number");
     175         100 :     if (Mres<=tol) converged = PETSC_TRUE;
     176         100 :     PetscCall(PetscInfo(fn,"it: %" PetscInt_FMT " res: %g\n",it,(double)Mres));
     177         100 :     PetscCall(PetscLogFlops(8.0*n*n*n+2.0*n*n+2.0*n*n*n/3.0+4.0*n*n*n/3.0+2.0*n*n*n+2.0*n*n));
     178             :   }
     179             : 
     180          24 :   PetscCheck(Mres<=tol,PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations",MAXIT);
     181             : 
     182             :   /* undo scaling */
     183          24 :   if (nrm>1.0) PetscCallBLAS("BLASscal",BLASscal_(&N,&sqrtnrm,A,&one));
     184             : 
     185          24 :   PetscCall(PetscFree5(M,M2,G,work,piv));
     186          24 :   PetscCall(SlepcResetFlushToZero(&ftz));
     187          24 :   PetscFunctionReturn(PETSC_SUCCESS);
     188             : }
     189             : 
     190             : #if defined(PETSC_HAVE_CUDA)
     191             : #include "../src/sys/classes/fn/impls/cuda/fnutilcuda.h"
     192             : #include <slepccublas.h>
     193             : 
     194             : #if defined(PETSC_HAVE_MAGMA)
     195             : #include <slepcmagma.h>
     196             : 
     197             : /*
     198             :  * Matrix square root by Sadeghi iteration. CUDA version.
     199             :  * Computes the principal square root of the matrix A using the
     200             :  * Sadeghi iteration. A is overwritten with sqrtm(A).
     201             :  */
     202             : PetscErrorCode FNSqrtmSadeghi_CUDAm(FN fn,PetscBLASInt n,PetscScalar *d_A,PetscBLASInt ld)
     203             : {
     204             :   PetscScalar        *d_M,*d_M2,*d_G,*d_work,alpha;
     205             :   const PetscScalar  szero=0.0,sone=1.0,smfive=-5.0,s15=15.0,s1d16=1.0/16.0;
     206             :   PetscReal          tol,Mres=0.0,nrm,sqrtnrm=1.0;
     207             :   PetscInt           it,nb,lwork;
     208             :   PetscBLASInt       *piv,N;
     209             :   const PetscBLASInt one=1;
     210             :   PetscBool          converged=PETSC_FALSE;
     211             :   cublasHandle_t     cublasv2handle;
     212             : 
     213             :   PetscFunctionBegin;
     214             :   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA)); /* For CUDA event timers */
     215             :   PetscCall(PetscCUBLASGetHandle(&cublasv2handle));
     216             :   PetscCall(SlepcMagmaInit());
     217             :   N = n*n;
     218             :   tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
     219             : 
     220             :   PetscCall(PetscMalloc1(n,&piv));
     221             :   PetscCallCUDA(cudaMalloc((void **)&d_M,sizeof(PetscScalar)*N));
     222             :   PetscCallCUDA(cudaMalloc((void **)&d_M2,sizeof(PetscScalar)*N));
     223             :   PetscCallCUDA(cudaMalloc((void **)&d_G,sizeof(PetscScalar)*N));
     224             : 
     225             :   nb = magma_get_xgetri_nb(n);
     226             :   lwork = nb*n;
     227             :   PetscCallCUDA(cudaMalloc((void **)&d_work,sizeof(PetscScalar)*lwork));
     228             :   PetscCall(PetscLogGpuTimeBegin());
     229             : 
     230             :   /* M = A */
     231             :   PetscCallCUDA(cudaMemcpy(d_M,d_A,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
     232             : 
     233             :   /* scale M */
     234             :   PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_M,one,&nrm));
     235             :   if (nrm>1.0) {
     236             :     sqrtnrm = PetscSqrtReal(nrm);
     237             :     alpha = 1.0/nrm;
     238             :     PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_M,one));
     239             :     tol *= nrm;
     240             :   }
     241             :   PetscCall(PetscInfo(fn,"||A||_F = %g, new tol: %g\n",(double)nrm,(double)tol));
     242             : 
     243             :   /* X = I */
     244             :   PetscCallCUDA(cudaMemset(d_A,0,sizeof(PetscScalar)*N));
     245             :   PetscCall(set_diagonal(n,d_A,ld,sone));
     246             : 
     247             :   for (it=0;it<MAXIT && !converged;it++) {
     248             : 
     249             :     /* G = (5/16)*I + (1/16)*M*(15*I-5*M+M*M) */
     250             :     PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_M,ld,d_M,ld,&szero,d_M2,ld));
     251             :     PetscCallCUBLAS(cublasXaxpy(cublasv2handle,N,&smfive,d_M,one,d_M2,one));
     252             :     PetscCall(shift_diagonal(n,d_M2,ld,s15));
     253             :     PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&s1d16,d_M,ld,d_M2,ld,&szero,d_G,ld));
     254             :     PetscCall(shift_diagonal(n,d_G,ld,5.0/16.0));
     255             : 
     256             :     /* X = X*G */
     257             :     PetscCallCUDA(cudaMemcpy(d_M2,d_A,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
     258             :     PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_M2,ld,d_G,ld,&szero,d_A,ld));
     259             : 
     260             :     /* M = M*inv(G*G) */
     261             :     PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_G,ld,d_G,ld,&szero,d_M2,ld));
     262             :     /* magma */
     263             :     PetscCallMAGMA(magma_xgetrf_gpu,n,n,d_M2,ld,piv);
     264             :     PetscCallMAGMA(magma_xgetri_gpu,n,d_M2,ld,piv,d_work,lwork);
     265             :     /* magma */
     266             :     PetscCallCUDA(cudaMemcpy(d_G,d_M,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
     267             :     PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_G,ld,d_M2,ld,&szero,d_M,ld));
     268             : 
     269             :     /* check ||I-M|| */
     270             :     PetscCallCUDA(cudaMemcpy(d_M2,d_M,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
     271             :     PetscCall(shift_diagonal(n,d_M2,ld,-1.0));
     272             :     PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_M2,one,&Mres));
     273             :     PetscCheck(!PetscIsNanReal(Mres),PETSC_COMM_SELF,PETSC_ERR_FP,"The computed norm is not-a-number");
     274             :     if (Mres<=tol) converged = PETSC_TRUE;
     275             :     PetscCall(PetscInfo(fn,"it: %" PetscInt_FMT " res: %g\n",it,(double)Mres));
     276             :     PetscCall(PetscLogGpuFlops(8.0*n*n*n+2.0*n*n+2.0*n*n*n/3.0+4.0*n*n*n/3.0+2.0*n*n*n+2.0*n*n));
     277             :   }
     278             : 
     279             :   PetscCheck(Mres<=tol,PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations", MAXIT);
     280             : 
     281             :   if (nrm>1.0) {
     282             :     alpha = sqrtnrm;
     283             :     PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_A,one));
     284             :   }
     285             :   PetscCall(PetscLogGpuTimeEnd());
     286             : 
     287             :   PetscCallCUDA(cudaFree(d_M));
     288             :   PetscCallCUDA(cudaFree(d_M2));
     289             :   PetscCallCUDA(cudaFree(d_G));
     290             :   PetscCallCUDA(cudaFree(d_work));
     291             :   PetscCall(PetscFree(piv));
     292             :   PetscFunctionReturn(PETSC_SUCCESS);
     293             : }
     294             : #endif /* PETSC_HAVE_MAGMA */
     295             : #endif /* PETSC_HAVE_CUDA */
     296             : 
     297          12 : static PetscErrorCode FNEvaluateFunctionMat_Sqrt_Sadeghi(FN fn,Mat A,Mat B)
     298             : {
     299          12 :   PetscBLASInt   n=0;
     300          12 :   PetscScalar    *Ba;
     301          12 :   PetscInt       m;
     302             : 
     303          12 :   PetscFunctionBegin;
     304          12 :   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
     305          12 :   PetscCall(MatDenseGetArray(B,&Ba));
     306          12 :   PetscCall(MatGetSize(A,&m,NULL));
     307          12 :   PetscCall(PetscBLASIntCast(m,&n));
     308          12 :   PetscCall(FNSqrtmSadeghi(fn,n,Ba,n));
     309          12 :   PetscCall(MatDenseRestoreArray(B,&Ba));
     310          12 :   PetscFunctionReturn(PETSC_SUCCESS);
     311             : }
     312             : 
     313             : #if defined(PETSC_HAVE_CUDA)
     314             : PetscErrorCode FNEvaluateFunctionMat_Sqrt_NS_CUDA(FN fn,Mat A,Mat B)
     315             : {
     316             :   PetscBLASInt   n=0;
     317             :   PetscScalar    *Ba;
     318             :   PetscInt       m;
     319             : 
     320             :   PetscFunctionBegin;
     321             :   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
     322             :   PetscCall(MatDenseCUDAGetArray(B,&Ba));
     323             :   PetscCall(MatGetSize(A,&m,NULL));
     324             :   PetscCall(PetscBLASIntCast(m,&n));
     325             :   PetscCall(FNSqrtmNewtonSchulz_CUDA(fn,n,Ba,n,PETSC_FALSE));
     326             :   PetscCall(MatDenseCUDARestoreArray(B,&Ba));
     327             :   PetscFunctionReturn(PETSC_SUCCESS);
     328             : }
     329             : 
     330             : #if defined(PETSC_HAVE_MAGMA)
     331             : PetscErrorCode FNEvaluateFunctionMat_Sqrt_DBP_CUDAm(FN fn,Mat A,Mat B)
     332             : {
     333             :   PetscBLASInt   n=0;
     334             :   PetscScalar    *T;
     335             :   PetscInt       m;
     336             : 
     337             :   PetscFunctionBegin;
     338             :   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
     339             :   PetscCall(MatDenseCUDAGetArray(B,&T));
     340             :   PetscCall(MatGetSize(A,&m,NULL));
     341             :   PetscCall(PetscBLASIntCast(m,&n));
     342             :   PetscCall(FNSqrtmDenmanBeavers_CUDAm(fn,n,T,n,PETSC_FALSE));
     343             :   PetscCall(MatDenseCUDARestoreArray(B,&T));
     344             :   PetscFunctionReturn(PETSC_SUCCESS);
     345             : }
     346             : 
     347             : PetscErrorCode FNEvaluateFunctionMat_Sqrt_Sadeghi_CUDAm(FN fn,Mat A,Mat B)
     348             : {
     349             :   PetscBLASInt   n=0;
     350             :   PetscScalar    *Ba;
     351             :   PetscInt       m;
     352             : 
     353             :   PetscFunctionBegin;
     354             :   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
     355             :   PetscCall(MatDenseCUDAGetArray(B,&Ba));
     356             :   PetscCall(MatGetSize(A,&m,NULL));
     357             :   PetscCall(PetscBLASIntCast(m,&n));
     358             :   PetscCall(FNSqrtmSadeghi_CUDAm(fn,n,Ba,n));
     359             :   PetscCall(MatDenseCUDARestoreArray(B,&Ba));
     360             :   PetscFunctionReturn(PETSC_SUCCESS);
     361             : }
     362             : #endif /* PETSC_HAVE_MAGMA */
     363             : #endif /* PETSC_HAVE_CUDA */
     364             : 
     365          15 : static PetscErrorCode FNView_Sqrt(FN fn,PetscViewer viewer)
     366             : {
     367          15 :   PetscBool      isascii;
     368          15 :   char           str[50];
     369          15 :   const char     *methodname[] = {
     370             :                   "Schur method for the square root",
     371             :                   "Denman-Beavers (product form)",
     372             :                   "Newton-Schulz iteration",
     373             :                   "Sadeghi iteration"
     374             :   };
     375          15 :   const int      nmeth=PETSC_STATIC_ARRAY_LENGTH(methodname);
     376             : 
     377          15 :   PetscFunctionBegin;
     378          15 :   PetscCall(PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&isascii));
     379          15 :   if (isascii) {
     380          15 :     if (fn->beta==(PetscScalar)1.0) {
     381           1 :       if (fn->alpha==(PetscScalar)1.0) PetscCall(PetscViewerASCIIPrintf(viewer,"  square root: sqrt(x)\n"));
     382             :       else {
     383           0 :         PetscCall(SlepcSNPrintfScalar(str,sizeof(str),fn->alpha,PETSC_TRUE));
     384           0 :         PetscCall(PetscViewerASCIIPrintf(viewer,"  square root: sqrt(%s*x)\n",str));
     385             :       }
     386             :     } else {
     387          14 :       PetscCall(SlepcSNPrintfScalar(str,sizeof(str),fn->beta,PETSC_TRUE));
     388          14 :       if (fn->alpha==(PetscScalar)1.0) PetscCall(PetscViewerASCIIPrintf(viewer,"  square root: %s*sqrt(x)\n",str));
     389             :       else {
     390          14 :         PetscCall(PetscViewerASCIIPrintf(viewer,"  square root: %s",str));
     391          14 :         PetscCall(PetscViewerASCIIUseTabs(viewer,PETSC_FALSE));
     392          14 :         PetscCall(SlepcSNPrintfScalar(str,sizeof(str),fn->alpha,PETSC_TRUE));
     393          14 :         PetscCall(PetscViewerASCIIPrintf(viewer,"*sqrt(%s*x)\n",str));
     394          14 :         PetscCall(PetscViewerASCIIUseTabs(viewer,PETSC_TRUE));
     395             :       }
     396             :     }
     397          15 :     if (fn->method<nmeth) PetscCall(PetscViewerASCIIPrintf(viewer,"  computing matrix functions with: %s\n",methodname[fn->method]));
     398             :   }
     399          15 :   PetscFunctionReturn(PETSC_SUCCESS);
     400             : }
     401             : 
     402          27 : SLEPC_EXTERN PetscErrorCode FNCreate_Sqrt(FN fn)
     403             : {
     404          27 :   PetscFunctionBegin;
     405          27 :   fn->ops->evaluatefunction          = FNEvaluateFunction_Sqrt;
     406          27 :   fn->ops->evaluatederivative        = FNEvaluateDerivative_Sqrt;
     407          27 :   fn->ops->evaluatefunctionmat[0]    = FNEvaluateFunctionMat_Sqrt_Schur;
     408          27 :   fn->ops->evaluatefunctionmat[1]    = FNEvaluateFunctionMat_Sqrt_DBP;
     409          27 :   fn->ops->evaluatefunctionmat[2]    = FNEvaluateFunctionMat_Sqrt_NS;
     410          27 :   fn->ops->evaluatefunctionmat[3]    = FNEvaluateFunctionMat_Sqrt_Sadeghi;
     411             : #if defined(PETSC_HAVE_CUDA)
     412             :   fn->ops->evaluatefunctionmatcuda[2] = FNEvaluateFunctionMat_Sqrt_NS_CUDA;
     413             : #if defined(PETSC_HAVE_MAGMA)
     414             :   fn->ops->evaluatefunctionmatcuda[1] = FNEvaluateFunctionMat_Sqrt_DBP_CUDAm;
     415             :   fn->ops->evaluatefunctionmatcuda[3] = FNEvaluateFunctionMat_Sqrt_Sadeghi_CUDAm;
     416             : #endif /* PETSC_HAVE_MAGMA */
     417             : #endif /* PETSC_HAVE_CUDA */
     418          27 :   fn->ops->evaluatefunctionmatvec[0] = FNEvaluateFunctionMatVec_Sqrt_Schur;
     419          27 :   fn->ops->view                      = FNView_Sqrt;
     420          27 :   PetscFunctionReturn(PETSC_SUCCESS);
     421             : }

Generated by: LCOV version 1.14