LCOV - code coverage report
Current view: top level - sys/classes/fn/impls - fnutil.c (source / functions) Hit Total Coverage
Test: SLEPc Lines: 262 265 98.9 %
Date: 2024-12-18 00:42:09 Functions: 6 6 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*
       2             :    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
       3             :    SLEPc - Scalable Library for Eigenvalue Problem Computations
       4             :    Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain
       5             : 
       6             :    This file is part of SLEPc.
       7             :    SLEPc is distributed under a 2-clause BSD license (see LICENSE).
       8             :    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
       9             : */
      10             : /*
      11             :    Utility subroutines common to several impls
      12             : */
      13             : 
      14             : #include <slepc/private/fnimpl.h>      /*I "slepcfn.h" I*/
      15             : #include <slepcblaslapack.h>
      16             : 
      17             : /*
      18             :    Compute the square root of an upper quasi-triangular matrix T,
      19             :    using Higham's algorithm (LAA 88, 1987). T is overwritten with sqrtm(T).
      20             :  */
      21          98 : static PetscErrorCode SlepcMatDenseSqrt(PetscBLASInt n,PetscScalar *T,PetscBLASInt ld)
      22             : {
      23          98 :   PetscScalar  one=1.0,mone=-1.0;
      24          98 :   PetscReal    scal;
      25          98 :   PetscBLASInt i,j,si,sj,r,ione=1,info;
      26             : #if !defined(PETSC_USE_COMPLEX)
      27          98 :   PetscReal    alpha,theta,mu,mu2;
      28             : #endif
      29             : 
      30          98 :   PetscFunctionBegin;
      31        3323 :   for (j=0;j<n;j++) {
      32             : #if defined(PETSC_USE_COMPLEX)
      33             :     sj = 1;
      34             :     T[j+j*ld] = PetscSqrtScalar(T[j+j*ld]);
      35             : #else
      36        3225 :     sj = (j==n-1 || T[j+1+j*ld] == 0.0)? 1: 2;
      37        3225 :     if (sj==1) {
      38        2721 :       PetscCheck(T[j+j*ld]>=0.0,PETSC_COMM_SELF,PETSC_ERR_USER_INPUT,"Matrix has a real negative eigenvalue, no real primary square root exists");
      39        2721 :       T[j+j*ld] = PetscSqrtReal(T[j+j*ld]);
      40             :     } else {
      41             :       /* square root of 2x2 block */
      42         504 :       theta = (T[j+j*ld]+T[j+1+(j+1)*ld])/2.0;
      43         504 :       mu    = (T[j+j*ld]-T[j+1+(j+1)*ld])/2.0;
      44         504 :       mu2   = -mu*mu-T[j+1+j*ld]*T[j+(j+1)*ld];
      45         504 :       mu    = PetscSqrtReal(mu2);
      46         504 :       if (theta>0.0) alpha = PetscSqrtReal((theta+PetscSqrtReal(theta*theta+mu2))/2.0);
      47           0 :       else alpha = mu/PetscSqrtReal(2.0*(-theta+PetscSqrtReal(theta*theta+mu2)));
      48         504 :       T[j+j*ld]       /= 2.0*alpha;
      49         504 :       T[j+1+(j+1)*ld] /= 2.0*alpha;
      50         504 :       T[j+(j+1)*ld]   /= 2.0*alpha;
      51         504 :       T[j+1+j*ld]     /= 2.0*alpha;
      52         504 :       T[j+j*ld]       += alpha-theta/(2.0*alpha);
      53         504 :       T[j+1+(j+1)*ld] += alpha-theta/(2.0*alpha);
      54             :     }
      55             : #endif
      56       78753 :     for (i=j-1;i>=0;i--) {
      57             : #if defined(PETSC_USE_COMPLEX)
      58             :       si = 1;
      59             : #else
      60       75528 :       si = (i==0 || T[i+(i-1)*ld] == 0.0)? 1: 2;
      61       75528 :       if (si==2) i--;
      62             : #endif
      63             :       /* solve Sylvester equation of order si x sj */
      64       75528 :       r = j-i-si;
      65       75528 :       if (r) PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&si,&sj,&r,&mone,T+i+(i+si)*ld,&ld,T+i+si+j*ld,&ld,&one,T+i+j*ld,&ld));
      66       75528 :       PetscCallBLAS("LAPACKtrsyl",LAPACKtrsyl_("N","N",&ione,&si,&sj,T+i+i*ld,&ld,T+j+j*ld,&ld,T+i+j*ld,&ld,&scal,&info));
      67       75528 :       SlepcCheckLapackInfo("trsyl",info);
      68       75528 :       PetscCheck(scal==1.0,PETSC_COMM_SELF,PETSC_ERR_SUP,"Current implementation cannot handle scale factor %g",(double)scal);
      69             :     }
      70        3225 :     if (sj==2) j++;
      71             :   }
      72          98 :   PetscFunctionReturn(PETSC_SUCCESS);
      73             : }
      74             : 
      75             : #define BLOCKSIZE 64
      76             : 
      77             : /*
      78             :    Schur method for the square root of an upper quasi-triangular matrix T.
      79             :    T is overwritten with sqrtm(T).
      80             :    If firstonly then only the first column of T will contain relevant values.
      81             :  */
      82          59 : PetscErrorCode FNSqrtmSchur(FN fn,PetscBLASInt n,PetscScalar *T,PetscBLASInt ld,PetscBool firstonly)
      83             : {
      84          59 :   PetscBLASInt   i,j,k,r,ione=1,sdim,lwork,*s,*p,info,bs=BLOCKSIZE;
      85          59 :   PetscScalar    *wr,*W,*Q,*work,one=1.0,zero=0.0,mone=-1.0;
      86          59 :   PetscInt       m,nblk;
      87          59 :   PetscReal      scal;
      88             : #if defined(PETSC_USE_COMPLEX)
      89             :   PetscReal      *rwork;
      90             : #else
      91          59 :   PetscReal      *wi;
      92             : #endif
      93             : 
      94          59 :   PetscFunctionBegin;
      95          59 :   m     = n;
      96          59 :   nblk  = (m+bs-1)/bs;
      97          59 :   lwork = 5*n;
      98          59 :   k     = firstonly? 1: n;
      99             : 
     100             :   /* compute Schur decomposition A*Q = Q*T */
     101             : #if !defined(PETSC_USE_COMPLEX)
     102          59 :   PetscCall(PetscMalloc7(m,&wr,m,&wi,m*k,&W,m*m,&Q,lwork,&work,nblk,&s,nblk,&p));
     103          59 :   PetscCallBLAS("LAPACKgees",LAPACKgees_("V","N",NULL,&n,T,&ld,&sdim,wr,wi,Q,&ld,work,&lwork,NULL,&info));
     104             : #else
     105             :   PetscCall(PetscMalloc7(m,&wr,m,&rwork,m*k,&W,m*m,&Q,lwork,&work,nblk,&s,nblk,&p));
     106             :   PetscCallBLAS("LAPACKgees",LAPACKgees_("V","N",NULL,&n,T,&ld,&sdim,wr,Q,&ld,work,&lwork,rwork,NULL,&info));
     107             : #endif
     108          59 :   SlepcCheckLapackInfo("gees",info);
     109             : 
     110             :   /* determine block sizes and positions, to avoid cutting 2x2 blocks */
     111          59 :   j = 0;
     112          59 :   p[j] = 0;
     113          98 :   do {
     114          98 :     s[j] = PetscMin(bs,n-p[j]);
     115             : #if !defined(PETSC_USE_COMPLEX)
     116          98 :     if (p[j]+s[j]!=n && T[p[j]+s[j]+(p[j]+s[j]-1)*ld]!=0.0) s[j]++;
     117             : #endif
     118          98 :     if (p[j]+s[j]==n) break;
     119          39 :     j++;
     120          39 :     p[j] = p[j-1]+s[j-1];
     121          39 :   } while (1);
     122         157 :   nblk = j+1;
     123             : 
     124         157 :   for (j=0;j<nblk;j++) {
     125             :     /* evaluate f(T_jj) */
     126          98 :     PetscCall(SlepcMatDenseSqrt(s[j],T+p[j]+p[j]*ld,ld));
     127         137 :     for (i=j-1;i>=0;i--) {
     128             :       /* solve Sylvester equation for block (i,j) */
     129          39 :       r = p[j]-p[i]-s[i];
     130          39 :       if (r) PetscCallBLAS("BLASgemm",BLASgemm_("N","N",s+i,s+j,&r,&mone,T+p[i]+(p[i]+s[i])*ld,&ld,T+p[i]+s[i]+p[j]*ld,&ld,&one,T+p[i]+p[j]*ld,&ld));
     131          39 :       PetscCallBLAS("LAPACKtrsyl",LAPACKtrsyl_("N","N",&ione,s+i,s+j,T+p[i]+p[i]*ld,&ld,T+p[j]+p[j]*ld,&ld,T+p[i]+p[j]*ld,&ld,&scal,&info));
     132          39 :       SlepcCheckLapackInfo("trsyl",info);
     133          39 :       PetscCheck(scal==1.0,PETSC_COMM_SELF,PETSC_ERR_SUP,"Current implementation cannot handle scale factor %g",(double)scal);
     134             :     }
     135             :   }
     136             : 
     137             :   /* backtransform B = Q*T*Q' */
     138          59 :   PetscCallBLAS("BLASgemm",BLASgemm_("N","C",&n,&k,&n,&one,T,&ld,Q,&ld,&zero,W,&ld));
     139          59 :   PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&k,&n,&one,Q,&ld,W,&ld,&zero,T,&ld));
     140             : 
     141             :   /* flop count: Schur decomposition, triangular square root, and backtransform */
     142          59 :   PetscCall(PetscLogFlops(25.0*n*n*n+n*n*n/3.0+4.0*n*n*k));
     143             : 
     144             : #if !defined(PETSC_USE_COMPLEX)
     145          59 :   PetscCall(PetscFree7(wr,wi,W,Q,work,s,p));
     146             : #else
     147             :   PetscCall(PetscFree7(wr,rwork,W,Q,work,s,p));
     148             : #endif
     149          59 :   PetscFunctionReturn(PETSC_SUCCESS);
     150             : }
     151             : 
     152             : #define DBMAXIT 25
     153             : 
     154             : /*
     155             :    Computes the principal square root of the matrix T using the product form
     156             :    of the Denman-Beavers iteration.
     157             :    T is overwritten with sqrtm(T) or inv(sqrtm(T)) depending on flag inv.
     158             :  */
     159          24 : PetscErrorCode FNSqrtmDenmanBeavers(FN fn,PetscBLASInt n,PetscScalar *T,PetscBLASInt ld,PetscBool inv)
     160             : {
     161          24 :   PetscScalar        *Told,*M=NULL,*invM,*work,work1,prod,alpha;
     162          24 :   PetscScalar        szero=0.0,sone=1.0,smone=-1.0,spfive=0.5,sp25=0.25;
     163          24 :   PetscReal          tol,Mres=0.0,detM,g,reldiff,fnormdiff,fnormT,rwork[1];
     164          24 :   PetscBLASInt       N,i,it,*piv=NULL,info,query=-1,lwork;
     165          24 :   const PetscBLASInt one=1;
     166          24 :   PetscBool          converged=PETSC_FALSE,scale;
     167          24 :   unsigned int       ftz;
     168             : 
     169          24 :   PetscFunctionBegin;
     170          24 :   N = n*n;
     171          24 :   tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
     172          24 :   scale = PetscDefined(USE_REAL_SINGLE)? PETSC_FALSE: PETSC_TRUE;
     173          24 :   PetscCall(SlepcSetFlushToZero(&ftz));
     174             : 
     175             :   /* query work size */
     176          24 :   PetscCallBLAS("LAPACKgetri",LAPACKgetri_(&n,M,&ld,piv,&work1,&query,&info));
     177          24 :   PetscCall(PetscBLASIntCast((PetscInt)PetscRealPart(work1),&lwork));
     178          24 :   PetscCall(PetscMalloc5(lwork,&work,n,&piv,n*n,&Told,n*n,&M,n*n,&invM));
     179          24 :   PetscCall(PetscArraycpy(M,T,n*n));
     180             : 
     181          24 :   if (inv) {  /* start recurrence with I instead of A */
     182          12 :     PetscCall(PetscArrayzero(T,n*n));
     183         132 :     for (i=0;i<n;i++) T[i+i*ld] += 1.0;
     184             :   }
     185             : 
     186         144 :   for (it=0;it<DBMAXIT && !converged;it++) {
     187             : 
     188         120 :     if (scale) {  /* g = (abs(det(M)))^(-1/(2*n)) */
     189          72 :       PetscCall(PetscArraycpy(invM,M,n*n));
     190          72 :       PetscCallBLAS("LAPACKgetrf",LAPACKgetrf_(&n,&n,invM,&ld,piv,&info));
     191          72 :       SlepcCheckLapackInfo("getrf",info);
     192          72 :       prod = invM[0];
     193        3960 :       for (i=1;i<n;i++) prod *= invM[i+i*ld];
     194          72 :       detM = PetscAbsScalar(prod);
     195          72 :       g = (detM>PETSC_MAX_REAL)? 0.5: PetscPowReal(detM,-1.0/(2.0*n));
     196          72 :       alpha = g;
     197          72 :       PetscCallBLAS("BLASscal",BLASscal_(&N,&alpha,T,&one));
     198          72 :       alpha = g*g;
     199          72 :       PetscCallBLAS("BLASscal",BLASscal_(&N,&alpha,M,&one));
     200          72 :       PetscCall(PetscLogFlops(2.0*n*n*n/3.0+2.0*n*n));
     201             :     }
     202             : 
     203         120 :     PetscCall(PetscArraycpy(Told,T,n*n));
     204         120 :     PetscCall(PetscArraycpy(invM,M,n*n));
     205             : 
     206         120 :     PetscCallBLAS("LAPACKgetrf",LAPACKgetrf_(&n,&n,invM,&ld,piv,&info));
     207         120 :     SlepcCheckLapackInfo("getrf",info);
     208         120 :     PetscCallBLAS("LAPACKgetri",LAPACKgetri_(&n,invM,&ld,piv,work,&lwork,&info));
     209         120 :     SlepcCheckLapackInfo("getri",info);
     210         120 :     PetscCall(PetscLogFlops(2.0*n*n*n/3.0+4.0*n*n*n/3.0));
     211             : 
     212        7080 :     for (i=0;i<n;i++) invM[i+i*ld] += 1.0;
     213         120 :     PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&spfive,Told,&ld,invM,&ld,&szero,T,&ld));
     214        7080 :     for (i=0;i<n;i++) invM[i+i*ld] -= 1.0;
     215             : 
     216         120 :     PetscCallBLAS("BLASaxpy",BLASaxpy_(&N,&sone,invM,&one,M,&one));
     217         120 :     PetscCallBLAS("BLASscal",BLASscal_(&N,&sp25,M,&one));
     218        7080 :     for (i=0;i<n;i++) M[i+i*ld] -= 0.5;
     219         120 :     PetscCall(PetscLogFlops(2.0*n*n*n+2.0*n*n));
     220             : 
     221         120 :     Mres = LAPACKlange_("F",&n,&n,M,&n,rwork);
     222        7200 :     for (i=0;i<n;i++) M[i+i*ld] += 1.0;
     223             : 
     224         120 :     if (scale) {
     225             :       /* reldiff = norm(T - Told,'fro')/norm(T,'fro') */
     226          72 :       PetscCallBLAS("BLASaxpy",BLASaxpy_(&N,&smone,T,&one,Told,&one));
     227          72 :       fnormdiff = LAPACKlange_("F",&n,&n,Told,&n,rwork);
     228          72 :       fnormT = LAPACKlange_("F",&n,&n,T,&n,rwork);
     229          72 :       PetscCall(PetscLogFlops(7.0*n*n));
     230          72 :       reldiff = fnormdiff/fnormT;
     231          72 :       PetscCall(PetscInfo(fn,"it: %" PetscBLASInt_FMT " reldiff: %g scale: %g tol*scale: %g\n",it,(double)reldiff,(double)g,(double)(tol*g)));
     232          72 :       if (reldiff<1e-2) scale = PETSC_FALSE;  /* Switch off scaling */
     233             :     }
     234             : 
     235         120 :     if (Mres<=tol) converged = PETSC_TRUE;
     236             :   }
     237             : 
     238          24 :   PetscCheck(Mres<=tol,PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations",DBMAXIT);
     239          24 :   PetscCall(PetscFree5(work,piv,Told,M,invM));
     240          24 :   PetscCall(SlepcResetFlushToZero(&ftz));
     241          24 :   PetscFunctionReturn(PETSC_SUCCESS);
     242             : }
     243             : 
     244             : #define NSMAXIT 50
     245             : 
     246             : /*
     247             :    Computes the principal square root of the matrix A using the Newton-Schulz iteration.
     248             :    T is overwritten with sqrtm(T) or inv(sqrtm(T)) depending on flag inv.
     249             :  */
     250          24 : PetscErrorCode FNSqrtmNewtonSchulz(FN fn,PetscBLASInt n,PetscScalar *A,PetscBLASInt ld,PetscBool inv)
     251             : {
     252          24 :   PetscScalar    *Y=A,*Yold,*Z,*Zold,*M;
     253          24 :   PetscScalar    szero=0.0,sone=1.0,smone=-1.0,spfive=0.5,sthree=3.0;
     254          24 :   PetscReal      sqrtnrm,tol,Yres=0.0,nrm,rwork[1],done=1.0;
     255          24 :   PetscBLASInt   info,i,it,N,one=1,zero=0;
     256          24 :   PetscBool      converged=PETSC_FALSE;
     257          24 :   unsigned int   ftz;
     258             : 
     259          24 :   PetscFunctionBegin;
     260          24 :   N = n*n;
     261          24 :   tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
     262          24 :   PetscCall(SlepcSetFlushToZero(&ftz));
     263             : 
     264          24 :   PetscCall(PetscMalloc4(N,&Yold,N,&Z,N,&Zold,N,&M));
     265             : 
     266             :   /* scale */
     267          24 :   PetscCall(PetscArraycpy(Z,A,N));
     268        1344 :   for (i=0;i<n;i++) Z[i+i*ld] -= 1.0;
     269          24 :   nrm = LAPACKlange_("fro",&n,&n,Z,&n,rwork);
     270          24 :   sqrtnrm = PetscSqrtReal(nrm);
     271          24 :   PetscCallBLAS("LAPACKlascl",LAPACKlascl_("G",&zero,&zero,&nrm,&done,&N,&one,A,&N,&info));
     272          24 :   SlepcCheckLapackInfo("lascl",info);
     273          24 :   tol *= nrm;
     274          24 :   PetscCall(PetscInfo(fn,"||I-A||_F = %g, new tol: %g\n",(double)nrm,(double)tol));
     275          24 :   PetscCall(PetscLogFlops(2.0*n*n));
     276             : 
     277             :   /* Z = I */
     278          24 :   PetscCall(PetscArrayzero(Z,N));
     279        1344 :   for (i=0;i<n;i++) Z[i+i*ld] = 1.0;
     280             : 
     281         280 :   for (it=0;it<NSMAXIT && !converged;it++) {
     282             :     /* Yold = Y, Zold = Z */
     283         256 :     PetscCall(PetscArraycpy(Yold,Y,N));
     284         256 :     PetscCall(PetscArraycpy(Zold,Z,N));
     285             : 
     286             :     /* M = (3*I-Zold*Yold) */
     287         256 :     PetscCall(PetscArrayzero(M,N));
     288       16856 :     for (i=0;i<n;i++) M[i+i*ld] = sthree;
     289         256 :     PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&smone,Zold,&ld,Yold,&ld,&sone,M,&ld));
     290             : 
     291             :     /* Y = (1/2)*Yold*M, Z = (1/2)*M*Zold */
     292         256 :     PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&spfive,Yold,&ld,M,&ld,&szero,Y,&ld));
     293         256 :     PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&spfive,M,&ld,Zold,&ld,&szero,Z,&ld));
     294             : 
     295             :     /* reldiff = norm(Y-Yold,'fro')/norm(Y,'fro') */
     296         256 :     PetscCallBLAS("BLASaxpy",BLASaxpy_(&N,&smone,Y,&one,Yold,&one));
     297         256 :     Yres = LAPACKlange_("fro",&n,&n,Yold,&n,rwork);
     298         256 :     PetscCheck(!PetscIsNanReal(Yres),PETSC_COMM_SELF,PETSC_ERR_FP,"The computed norm is not-a-number");
     299         256 :     if (Yres<=tol) converged = PETSC_TRUE;
     300         256 :     PetscCall(PetscInfo(fn,"it: %" PetscBLASInt_FMT " res: %g\n",it,(double)Yres));
     301             : 
     302         256 :     PetscCall(PetscLogFlops(6.0*n*n*n+2.0*n*n));
     303             :   }
     304             : 
     305          24 :   PetscCheck(Yres<=tol,PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations",NSMAXIT);
     306             : 
     307             :   /* undo scaling */
     308          24 :   if (inv) {
     309          12 :     PetscCall(PetscArraycpy(A,Z,N));
     310          12 :     PetscCallBLAS("LAPACKlascl",LAPACKlascl_("G",&zero,&zero,&sqrtnrm,&done,&N,&one,A,&N,&info));
     311          12 :   } else PetscCallBLAS("LAPACKlascl",LAPACKlascl_("G",&zero,&zero,&done,&sqrtnrm,&N,&one,A,&N,&info));
     312          24 :   SlepcCheckLapackInfo("lascl",info);
     313             : 
     314          24 :   PetscCall(PetscFree4(Yold,Z,Zold,M));
     315          24 :   PetscCall(SlepcResetFlushToZero(&ftz));
     316          24 :   PetscFunctionReturn(PETSC_SUCCESS);
     317             : }
     318             : 
     319             : #if defined(PETSC_HAVE_CUDA)
     320             : #include "../src/sys/classes/fn/impls/cuda/fnutilcuda.h"
     321             : #include <slepccupmblas.h>
     322             : 
     323             : /*
     324             :  * Matrix square root by Newton-Schulz iteration. CUDA version.
     325             :  * Computes the principal square root of the matrix A using the
     326             :  * Newton-Schulz iteration. A is overwritten with sqrtm(A).
     327             :  */
     328             : PetscErrorCode FNSqrtmNewtonSchulz_CUDA(FN fn,PetscBLASInt n,PetscScalar *d_A,PetscBLASInt ld,PetscBool inv)
     329             : {
     330             :   PetscScalar        *d_Yold,*d_Z,*d_Zold,*d_M,alpha;
     331             :   PetscReal          nrm,sqrtnrm,tol,Yres=0.0;
     332             :   const PetscScalar  szero=0.0,sone=1.0,smone=-1.0,spfive=0.5,sthree=3.0;
     333             :   PetscInt           it;
     334             :   PetscBLASInt       N;
     335             :   const PetscBLASInt one=1;
     336             :   PetscBool          converged=PETSC_FALSE;
     337             :   cublasHandle_t     cublasv2handle;
     338             : 
     339             :   PetscFunctionBegin;
     340             :   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA)); /* For CUDA event timers */
     341             :   PetscCall(PetscCUBLASGetHandle(&cublasv2handle));
     342             :   N = n*n;
     343             :   tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
     344             : 
     345             :   PetscCallCUDA(cudaMalloc((void **)&d_Yold,sizeof(PetscScalar)*N));
     346             :   PetscCallCUDA(cudaMalloc((void **)&d_Z,sizeof(PetscScalar)*N));
     347             :   PetscCallCUDA(cudaMalloc((void **)&d_Zold,sizeof(PetscScalar)*N));
     348             :   PetscCallCUDA(cudaMalloc((void **)&d_M,sizeof(PetscScalar)*N));
     349             : 
     350             :   PetscCall(PetscLogGpuTimeBegin());
     351             : 
     352             :   /* Z = I; */
     353             :   PetscCallCUDA(cudaMemset(d_Z,0,sizeof(PetscScalar)*N));
     354             :   PetscCall(set_diagonal(n,d_Z,ld,sone));
     355             : 
     356             :   /* scale */
     357             :   PetscCallCUBLAS(cublasXaxpy(cublasv2handle,N,&smone,d_A,one,d_Z,one));
     358             :   PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_Z,one,&nrm));
     359             :   sqrtnrm = PetscSqrtReal(nrm);
     360             :   alpha = 1.0/nrm;
     361             :   PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_A,one));
     362             :   tol *= nrm;
     363             :   PetscCall(PetscInfo(fn,"||I-A||_F = %g, new tol: %g\n",(double)nrm,(double)tol));
     364             :   PetscCall(PetscLogGpuFlops(2.0*n*n));
     365             : 
     366             :   /* Z = I; */
     367             :   PetscCallCUDA(cudaMemset(d_Z,0,sizeof(PetscScalar)*N));
     368             :   PetscCall(set_diagonal(n,d_Z,ld,sone));
     369             : 
     370             :   for (it=0;it<NSMAXIT && !converged;it++) {
     371             :     /* Yold = Y, Zold = Z */
     372             :     PetscCallCUDA(cudaMemcpy(d_Yold,d_A,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
     373             :     PetscCallCUDA(cudaMemcpy(d_Zold,d_Z,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
     374             : 
     375             :     /* M = (3*I - Zold*Yold) */
     376             :     PetscCallCUDA(cudaMemset(d_M,0,sizeof(PetscScalar)*N));
     377             :     PetscCall(set_diagonal(n,d_M,ld,sthree));
     378             :     PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&smone,d_Zold,ld,d_Yold,ld,&sone,d_M,ld));
     379             : 
     380             :     /* Y = (1/2) * Yold * M, Z = (1/2) * M * Zold */
     381             :     PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&spfive,d_Yold,ld,d_M,ld,&szero,d_A,ld));
     382             :     PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&spfive,d_M,ld,d_Zold,ld,&szero,d_Z,ld));
     383             : 
     384             :     /* reldiff = norm(Y-Yold,'fro')/norm(Y,'fro') */
     385             :     PetscCallCUBLAS(cublasXaxpy(cublasv2handle,N,&smone,d_A,one,d_Yold,one));
     386             :     PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_Yold,one,&Yres));
     387             :     PetscCheck(!PetscIsNanReal(Yres),PETSC_COMM_SELF,PETSC_ERR_FP,"The computed norm is not-a-number");
     388             :     if (Yres<=tol) converged = PETSC_TRUE;
     389             :     PetscCall(PetscInfo(fn,"it: %" PetscInt_FMT " res: %g\n",it,(double)Yres));
     390             : 
     391             :     PetscCall(PetscLogGpuFlops(6.0*n*n*n+2.0*n*n));
     392             :   }
     393             : 
     394             :   PetscCheck(Yres<=tol,PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations", NSMAXIT);
     395             : 
     396             :   /* undo scaling */
     397             :   if (inv) {
     398             :     alpha = 1.0/sqrtnrm;
     399             :     PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_Z,one));
     400             :     PetscCallCUDA(cudaMemcpy(d_A,d_Z,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
     401             :   } else {
     402             :     alpha = sqrtnrm;
     403             :     PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_A,one));
     404             :   }
     405             : 
     406             :   PetscCall(PetscLogGpuTimeEnd());
     407             :   PetscCallCUDA(cudaFree(d_Yold));
     408             :   PetscCallCUDA(cudaFree(d_Z));
     409             :   PetscCallCUDA(cudaFree(d_Zold));
     410             :   PetscCallCUDA(cudaFree(d_M));
     411             :   PetscFunctionReturn(PETSC_SUCCESS);
     412             : }
     413             : 
     414             : #if defined(PETSC_HAVE_MAGMA)
     415             : #include <slepcmagma.h>
     416             : 
     417             : /*
     418             :  * Matrix square root by product form of Denman-Beavers iteration. CUDA version.
     419             :  * Computes the principal square root of the matrix T using the product form
     420             :  * of the Denman-Beavers iteration. T is overwritten with sqrtm(T).
     421             :  */
     422             : PetscErrorCode FNSqrtmDenmanBeavers_CUDAm(FN fn,PetscBLASInt n,PetscScalar *d_T,PetscBLASInt ld,PetscBool inv)
     423             : {
     424             :   PetscScalar    *d_Told,*d_M,*d_invM,*d_work,prod,szero=0.0,sone=1.0,smone=-1.0,spfive=0.5,sneg_pfive=-0.5,sp25=0.25,alpha;
     425             :   PetscReal      tol,Mres=0.0,detM,g,reldiff,fnormdiff,fnormT;
     426             :   PetscInt       it,lwork,nb;
     427             :   PetscBLASInt   N,one=1,*piv=NULL;
     428             :   PetscBool      converged=PETSC_FALSE,scale;
     429             :   cublasHandle_t cublasv2handle;
     430             : 
     431             :   PetscFunctionBegin;
     432             :   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA)); /* For CUDA event timers */
     433             :   PetscCall(PetscCUBLASGetHandle(&cublasv2handle));
     434             :   PetscCall(SlepcMagmaInit());
     435             :   N = n*n;
     436             :   scale = PetscDefined(USE_REAL_SINGLE)? PETSC_FALSE: PETSC_TRUE;
     437             :   tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
     438             : 
     439             :   /* query work size */
     440             :   nb = magma_get_xgetri_nb(n);
     441             :   lwork = nb*n;
     442             :   PetscCall(PetscMalloc1(n,&piv));
     443             :   PetscCallCUDA(cudaMalloc((void **)&d_work,sizeof(PetscScalar)*lwork));
     444             :   PetscCallCUDA(cudaMalloc((void **)&d_Told,sizeof(PetscScalar)*N));
     445             :   PetscCallCUDA(cudaMalloc((void **)&d_M,sizeof(PetscScalar)*N));
     446             :   PetscCallCUDA(cudaMalloc((void **)&d_invM,sizeof(PetscScalar)*N));
     447             : 
     448             :   PetscCall(PetscLogGpuTimeBegin());
     449             :   PetscCallCUDA(cudaMemcpy(d_M,d_T,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
     450             :   if (inv) {  /* start recurrence with I instead of A */
     451             :     PetscCallCUDA(cudaMemset(d_T,0,sizeof(PetscScalar)*N));
     452             :     PetscCall(set_diagonal(n,d_T,ld,1.0));
     453             :   }
     454             : 
     455             :   for (it=0;it<DBMAXIT && !converged;it++) {
     456             : 
     457             :     if (scale) { /* g = (abs(det(M)))^(-1/(2*n)); */
     458             :       PetscCallCUDA(cudaMemcpy(d_invM,d_M,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
     459             :       PetscCallMAGMA(magma_xgetrf_gpu,n,n,d_invM,ld,piv);
     460             :       PetscCall(mult_diagonal(n,d_invM,ld,&prod));
     461             :       detM = PetscAbsScalar(prod);
     462             :       g = (detM>PETSC_MAX_REAL)? 0.5: PetscPowReal(detM,-1.0/(2.0*n));
     463             :       alpha = g;
     464             :       PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_T,one));
     465             :       alpha = g*g;
     466             :       PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_M,one));
     467             :       PetscCall(PetscLogGpuFlops(2.0*n*n*n/3.0+2.0*n*n));
     468             :     }
     469             : 
     470             :     PetscCallCUDA(cudaMemcpy(d_Told,d_T,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
     471             :     PetscCallCUDA(cudaMemcpy(d_invM,d_M,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
     472             : 
     473             :     PetscCallMAGMA(magma_xgetrf_gpu,n,n,d_invM,ld,piv);
     474             :     PetscCallMAGMA(magma_xgetri_gpu,n,d_invM,ld,piv,d_work,lwork);
     475             :     PetscCall(PetscLogGpuFlops(2.0*n*n*n/3.0+4.0*n*n*n/3.0));
     476             : 
     477             :     PetscCall(shift_diagonal(n,d_invM,ld,sone));
     478             :     PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&spfive,d_Told,ld,d_invM,ld,&szero,d_T,ld));
     479             :     PetscCall(shift_diagonal(n,d_invM,ld,smone));
     480             : 
     481             :     PetscCallCUBLAS(cublasXaxpy(cublasv2handle,N,&sone,d_invM,one,d_M,one));
     482             :     PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&sp25,d_M,one));
     483             :     PetscCall(shift_diagonal(n,d_M,ld,sneg_pfive));
     484             :     PetscCall(PetscLogGpuFlops(2.0*n*n*n+2.0*n*n));
     485             : 
     486             :     PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_M,one,&Mres));
     487             :     PetscCall(shift_diagonal(n,d_M,ld,sone));
     488             : 
     489             :     if (scale) {
     490             :       /* reldiff = norm(T - Told,'fro')/norm(T,'fro'); */
     491             :       PetscCallCUBLAS(cublasXaxpy(cublasv2handle,N,&smone,d_T,one,d_Told,one));
     492             :       PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_Told,one,&fnormdiff));
     493             :       PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_T,one,&fnormT));
     494             :       PetscCall(PetscLogGpuFlops(7.0*n*n));
     495             :       reldiff = fnormdiff/fnormT;
     496             :       PetscCall(PetscInfo(fn,"it: %" PetscInt_FMT " reldiff: %g scale: %g tol*scale: %g\n",it,(double)reldiff,(double)g,(double)tol*g));
     497             :       if (reldiff<1e-2) scale = PETSC_FALSE; /* Switch to no scaling. */
     498             :     }
     499             : 
     500             :     PetscCall(PetscInfo(fn,"it: %" PetscInt_FMT " Mres: %g\n",it,(double)Mres));
     501             :     if (Mres<=tol) converged = PETSC_TRUE;
     502             :   }
     503             : 
     504             :   PetscCheck(Mres<=tol,PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations", DBMAXIT);
     505             :   PetscCall(PetscLogGpuTimeEnd());
     506             :   PetscCall(PetscFree(piv));
     507             :   PetscCallCUDA(cudaFree(d_work));
     508             :   PetscCallCUDA(cudaFree(d_Told));
     509             :   PetscCallCUDA(cudaFree(d_M));
     510             :   PetscCallCUDA(cudaFree(d_invM));
     511             :   PetscFunctionReturn(PETSC_SUCCESS);
     512             : }
     513             : #endif /* PETSC_HAVE_MAGMA */
     514             : 
     515             : #endif /* PETSC_HAVE_CUDA */
     516             : 
     517             : #define ITMAX 5
     518             : 
     519             : /*
     520             :    Estimate norm(A^m,1) by block 1-norm power method (required workspace is 11*n)
     521             : */
     522          28 : static PetscErrorCode SlepcNormEst1(PetscBLASInt n,PetscScalar *A,PetscInt m,PetscScalar *work,PetscRandom rand,PetscReal *nrm)
     523             : {
     524          28 :   PetscScalar    *X,*Y,*Z,*S,*S_old,*aux,val,sone=1.0,szero=0.0;
     525          28 :   PetscReal      est=0.0,est_old,vals[2]={0.0,0.0},*zvals,maxzval[2],raux;
     526          28 :   PetscBLASInt   i,j,t=2,it=0,ind[2],est_j=0,m1;
     527             : 
     528          28 :   PetscFunctionBegin;
     529          28 :   X = work;
     530          28 :   Y = work + 2*n;
     531          28 :   Z = work + 4*n;
     532          28 :   S = work + 6*n;
     533          28 :   S_old = work + 8*n;
     534          28 :   zvals = (PetscReal*)(work + 10*n);
     535             : 
     536        3398 :   for (i=0;i<n;i++) {  /* X has columns of unit 1-norm */
     537        3370 :     X[i] = 1.0/n;
     538        3370 :     PetscCall(PetscRandomGetValue(rand,&val));
     539        3370 :     if (PetscRealPart(val) < 0.5) X[i+n] = -1.0/n;
     540        1606 :     else X[i+n] = 1.0/n;
     541             :   }
     542        6768 :   for (i=0;i<t*n;i++) S[i] = 0.0;
     543          28 :   ind[0] = 0; ind[1] = 0;
     544          28 :   est_old = 0;
     545          66 :   while (1) {
     546          66 :     it++;
     547         282 :     for (j=0;j<m;j++) {  /* Y = A^m*X */
     548         216 :       PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&t,&n,&sone,A,&n,X,&n,&szero,Y,&n));
     549         216 :       if (j<m-1) SlepcSwap(X,Y,aux);
     550             :     }
     551         198 :     for (j=0;j<t;j++) {  /* vals[j] = norm(Y(:,j),1) */
     552         132 :       vals[j] = 0.0;
     553       16016 :       for (i=0;i<n;i++) vals[j] += PetscAbsScalar(Y[i+j*n]);
     554             :     }
     555          66 :     if (vals[0]<vals[1]) {
     556          32 :       SlepcSwap(vals[0],vals[1],raux);
     557          32 :       m1 = 1;
     558             :     } else m1 = 0;
     559          66 :     est = vals[0];
     560          66 :     if (est>est_old || it==2) est_j = ind[m1];
     561          66 :     if (it>=2 && est<=est_old) {
     562             :       est = est_old;
     563             :       break;
     564             :     }
     565          66 :     est_old = est;
     566          66 :     if (it>ITMAX) break;
     567       15950 :     SlepcSwap(S,S_old,aux);
     568       15950 :     for (i=0;i<t*n;i++) {  /* S = sign(Y) */
     569       27933 :       S[i] = (PetscRealPart(Y[i]) < 0.0)? -1.0: 1.0;
     570             :     }
     571         282 :     for (j=0;j<m;j++) {  /* Z = (A^T)^m*S */
     572         216 :       PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&n,&t,&n,&sone,A,&n,S,&n,&szero,Z,&n));
     573         216 :       if (j<m-1) SlepcSwap(S,Z,aux);
     574             :     }
     575          66 :     maxzval[0] = -1; maxzval[1] = -1;
     576          66 :     ind[0] = 0; ind[1] = 0;
     577        8008 :     for (i=0;i<n;i++) {  /* zvals[i] = norm(Z(i,:),inf) */
     578        7942 :       zvals[i] = PetscMax(PetscAbsScalar(Z[i+0*n]),PetscAbsScalar(Z[i+1*n]));
     579        7942 :       if (zvals[i]>maxzval[0]) {
     580         815 :         maxzval[0] = zvals[i];
     581         815 :         ind[0] = i;
     582        7127 :       } else if (zvals[i]>maxzval[1]) {
     583         393 :         maxzval[1] = zvals[i];
     584         393 :         ind[1] = i;
     585             :       }
     586             :     }
     587          66 :     if (it>=2 && maxzval[0]==zvals[est_j]) break;
     588        9182 :     for (i=0;i<t*n;i++) X[i] = 0.0;
     589         114 :     for (j=0;j<t;j++) X[ind[j]+j*n] = 1.0;
     590             :   }
     591          28 :   *nrm = est;
     592             :   /* Flop count is roughly (it * 2*m * t*gemv) = 4*its*m*t*n*n */
     593          28 :   PetscCall(PetscLogFlops(4.0*it*m*t*n*n));
     594          28 :   PetscFunctionReturn(PETSC_SUCCESS);
     595             : }
     596             : 
     597             : #define SMALLN 100
     598             : 
     599             : /*
     600             :    Estimate norm(A^m,1) (required workspace is 2*n*n)
     601             : */
     602         491 : PetscErrorCode SlepcNormAm(PetscBLASInt n,PetscScalar *A,PetscInt m,PetscScalar *work,PetscRandom rand,PetscReal *nrm)
     603             : {
     604         491 :   PetscScalar    *v=work,*w=work+n*n,*aux,sone=1.0,szero=0.0;
     605         491 :   PetscReal      rwork[1],tmp;
     606         491 :   PetscBLASInt   i,j,one=1;
     607         491 :   PetscBool      isrealpos=PETSC_TRUE;
     608             : 
     609         491 :   PetscFunctionBegin;
     610         491 :   if (n<SMALLN) {   /* compute matrix power explicitly */
     611         445 :     if (m==1) {
     612           0 :       *nrm = LAPACKlange_("O",&n,&n,A,&n,rwork);
     613           0 :       PetscCall(PetscLogFlops(1.0*n*n));
     614             :     } else {  /* m>=2 */
     615         445 :       PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,A,&n,A,&n,&szero,v,&n));
     616        7204 :       for (j=0;j<m-2;j++) {
     617        6759 :         PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,A,&n,v,&n,&szero,w,&n));
     618        6759 :         SlepcSwap(v,w,aux);
     619             :       }
     620         445 :       *nrm = LAPACKlange_("O",&n,&n,v,&n,rwork);
     621         445 :       PetscCall(PetscLogFlops(2.0*n*n*n*(m-1)+1.0*n*n));
     622             :     }
     623             :   } else {
     624        5501 :     for (i=0;i<n;i++)
     625      442126 :       for (j=0;j<n;j++)
     626             : #if defined(PETSC_USE_COMPLEX)
     627             :         if (PetscRealPart(A[i+j*n])<0.0 || PetscImaginaryPart(A[i+j*n])!=0.0) { isrealpos = PETSC_FALSE; break; }
     628             : #else
     629      440029 :         if (A[i+j*n]<0.0) { isrealpos = PETSC_FALSE; break; }
     630             : #endif
     631          46 :     if (isrealpos) {   /* for positive matrices only */
     632        2103 :       for (i=0;i<n;i++) v[i] = 1.0;
     633         436 :       for (j=0;j<m;j++) {  /* w = A'*v */
     634         418 :         PetscCallBLAS("BLASgemv",BLASgemv_("C",&n,&n,&sone,A,&n,v,&one,&szero,w,&one));
     635         418 :         SlepcSwap(v,w,aux);
     636             :       }
     637          18 :       PetscCall(PetscLogFlops(2.0*n*n*m));
     638          18 :       *nrm = 0.0;
     639        2103 :       for (i=0;i<n;i++) if ((tmp = PetscAbsScalar(v[i])) > *nrm) *nrm = tmp;   /* norm(v,inf) */
     640          28 :     } else PetscCall(SlepcNormEst1(n,A,m,work,rand,nrm));
     641             :   }
     642         491 :   PetscFunctionReturn(PETSC_SUCCESS);
     643             : }

Generated by: LCOV version 1.14