Actual source code: fninvsqrt.c

slepc-3.21.1 2024-04-26
Report Typos and Errors
  1: /*
  2:    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  3:    SLEPc - Scalable Library for Eigenvalue Problem Computations
  4:    Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain

  6:    This file is part of SLEPc.
  7:    SLEPc is distributed under a 2-clause BSD license (see LICENSE).
  8:    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  9: */
 10: /*
 11:    Inverse square root function  x^(-1/2)
 12: */

 14: #include <slepc/private/fnimpl.h>
 15: #include <slepcblaslapack.h>

 17: static PetscErrorCode FNEvaluateFunction_Invsqrt(FN fn,PetscScalar x,PetscScalar *y)
 18: {
 19:   PetscFunctionBegin;
 20:   PetscCheck(x!=0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Function not defined in the requested value");
 21: #if !defined(PETSC_USE_COMPLEX)
 22:   PetscCheck(x>0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Function not defined in the requested value");
 23: #endif
 24:   *y = 1.0/PetscSqrtScalar(x);
 25:   PetscFunctionReturn(PETSC_SUCCESS);
 26: }

 28: static PetscErrorCode FNEvaluateDerivative_Invsqrt(FN fn,PetscScalar x,PetscScalar *y)
 29: {
 30:   PetscFunctionBegin;
 31:   PetscCheck(x!=0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Derivative not defined in the requested value");
 32: #if !defined(PETSC_USE_COMPLEX)
 33:   PetscCheck(x>0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Derivative not defined in the requested value");
 34: #endif
 35:   *y = -1.0/(2.0*PetscPowScalarReal(x,1.5));
 36:   PetscFunctionReturn(PETSC_SUCCESS);
 37: }

 39: static PetscErrorCode FNEvaluateFunctionMat_Invsqrt_Schur(FN fn,Mat A,Mat B)
 40: {
 41:   PetscBLASInt   n=0,ld,*ipiv,info;
 42:   PetscScalar    *Ba,*Wa;
 43:   PetscInt       m;
 44:   Mat            W;

 46:   PetscFunctionBegin;
 47:   PetscCall(FN_AllocateWorkMat(fn,A,&W));
 48:   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
 49:   PetscCall(MatDenseGetArray(B,&Ba));
 50:   PetscCall(MatDenseGetArray(W,&Wa));
 51:   /* compute B = sqrtm(A) */
 52:   PetscCall(MatGetSize(A,&m,NULL));
 53:   PetscCall(PetscBLASIntCast(m,&n));
 54:   ld = n;
 55:   PetscCall(FNSqrtmSchur(fn,n,Ba,n,PETSC_FALSE));
 56:   /* compute B = A\B */
 57:   PetscCall(PetscMalloc1(ld,&ipiv));
 58:   PetscCallBLAS("LAPACKgesv",LAPACKgesv_(&n,&n,Wa,&ld,ipiv,Ba,&ld,&info));
 59:   SlepcCheckLapackInfo("gesv",info);
 60:   PetscCall(PetscLogFlops(2.0*n*n*n/3.0+2.0*n*n*n));
 61:   PetscCall(PetscFree(ipiv));
 62:   PetscCall(MatDenseRestoreArray(W,&Wa));
 63:   PetscCall(MatDenseRestoreArray(B,&Ba));
 64:   PetscCall(FN_FreeWorkMat(fn,&W));
 65:   PetscFunctionReturn(PETSC_SUCCESS);
 66: }

 68: static PetscErrorCode FNEvaluateFunctionMatVec_Invsqrt_Schur(FN fn,Mat A,Vec v)
 69: {
 70:   PetscBLASInt   n=0,ld,*ipiv,info,one=1;
 71:   PetscScalar    *Ba,*Wa;
 72:   PetscInt       m;
 73:   Mat            B,W;

 75:   PetscFunctionBegin;
 76:   PetscCall(FN_AllocateWorkMat(fn,A,&B));
 77:   PetscCall(FN_AllocateWorkMat(fn,A,&W));
 78:   PetscCall(MatDenseGetArray(B,&Ba));
 79:   PetscCall(MatDenseGetArray(W,&Wa));
 80:   /* compute B_1 = sqrtm(A)*e_1 */
 81:   PetscCall(MatGetSize(A,&m,NULL));
 82:   PetscCall(PetscBLASIntCast(m,&n));
 83:   ld = n;
 84:   PetscCall(FNSqrtmSchur(fn,n,Ba,n,PETSC_TRUE));
 85:   /* compute B_1 = A\B_1 */
 86:   PetscCall(PetscMalloc1(ld,&ipiv));
 87:   PetscCallBLAS("LAPACKgesv",LAPACKgesv_(&n,&one,Wa,&ld,ipiv,Ba,&ld,&info));
 88:   SlepcCheckLapackInfo("gesv",info);
 89:   PetscCall(PetscFree(ipiv));
 90:   PetscCall(MatDenseRestoreArray(W,&Wa));
 91:   PetscCall(MatDenseRestoreArray(B,&Ba));
 92:   PetscCall(MatGetColumnVector(B,v,0));
 93:   PetscCall(FN_FreeWorkMat(fn,&W));
 94:   PetscCall(FN_FreeWorkMat(fn,&B));
 95:   PetscFunctionReturn(PETSC_SUCCESS);
 96: }

 98: static PetscErrorCode FNEvaluateFunctionMat_Invsqrt_DBP(FN fn,Mat A,Mat B)
 99: {
100:   PetscBLASInt   n=0;
101:   PetscScalar    *T;
102:   PetscInt       m;

104:   PetscFunctionBegin;
105:   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
106:   PetscCall(MatDenseGetArray(B,&T));
107:   PetscCall(MatGetSize(A,&m,NULL));
108:   PetscCall(PetscBLASIntCast(m,&n));
109:   PetscCall(FNSqrtmDenmanBeavers(fn,n,T,n,PETSC_TRUE));
110:   PetscCall(MatDenseRestoreArray(B,&T));
111:   PetscFunctionReturn(PETSC_SUCCESS);
112: }

114: static PetscErrorCode FNEvaluateFunctionMat_Invsqrt_NS(FN fn,Mat A,Mat B)
115: {
116:   PetscBLASInt   n=0;
117:   PetscScalar    *T;
118:   PetscInt       m;

120:   PetscFunctionBegin;
121:   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
122:   PetscCall(MatDenseGetArray(B,&T));
123:   PetscCall(MatGetSize(A,&m,NULL));
124:   PetscCall(PetscBLASIntCast(m,&n));
125:   PetscCall(FNSqrtmNewtonSchulz(fn,n,T,n,PETSC_TRUE));
126:   PetscCall(MatDenseRestoreArray(B,&T));
127:   PetscFunctionReturn(PETSC_SUCCESS);
128: }

130: static PetscErrorCode FNEvaluateFunctionMat_Invsqrt_Sadeghi(FN fn,Mat A,Mat B)
131: {
132:   PetscBLASInt   n=0,ld,*ipiv,info;
133:   PetscScalar    *Ba,*Wa;
134:   PetscInt       m;
135:   Mat            W;

137:   PetscFunctionBegin;
138:   PetscCall(FN_AllocateWorkMat(fn,A,&W));
139:   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
140:   PetscCall(MatDenseGetArray(B,&Ba));
141:   PetscCall(MatDenseGetArray(W,&Wa));
142:   /* compute B = sqrtm(A) */
143:   PetscCall(MatGetSize(A,&m,NULL));
144:   PetscCall(PetscBLASIntCast(m,&n));
145:   ld = n;
146:   PetscCall(FNSqrtmSadeghi(fn,n,Ba,n));
147:   /* compute B = A\B */
148:   PetscCall(PetscMalloc1(ld,&ipiv));
149:   PetscCallBLAS("LAPACKgesv",LAPACKgesv_(&n,&n,Wa,&ld,ipiv,Ba,&ld,&info));
150:   SlepcCheckLapackInfo("gesv",info);
151:   PetscCall(PetscLogFlops(2.0*n*n*n/3.0+2.0*n*n*n));
152:   PetscCall(PetscFree(ipiv));
153:   PetscCall(MatDenseRestoreArray(W,&Wa));
154:   PetscCall(MatDenseRestoreArray(B,&Ba));
155:   PetscCall(FN_FreeWorkMat(fn,&W));
156:   PetscFunctionReturn(PETSC_SUCCESS);
157: }

159: #if defined(PETSC_HAVE_CUDA)
160: PetscErrorCode FNEvaluateFunctionMat_Invsqrt_NS_CUDA(FN fn,Mat A,Mat B)
161: {
162:   PetscBLASInt   n=0;
163:   PetscScalar    *Ba;
164:   PetscInt       m;

166:   PetscFunctionBegin;
167:   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
168:   PetscCall(MatDenseCUDAGetArray(B,&Ba));
169:   PetscCall(MatGetSize(A,&m,NULL));
170:   PetscCall(PetscBLASIntCast(m,&n));
171:   PetscCall(FNSqrtmNewtonSchulz_CUDA(fn,n,Ba,n,PETSC_TRUE));
172:   PetscCall(MatDenseCUDARestoreArray(B,&Ba));
173:   PetscFunctionReturn(PETSC_SUCCESS);
174: }

176: #if defined(PETSC_HAVE_MAGMA)
177: #include <slepcmagma.h>

179: PetscErrorCode FNEvaluateFunctionMat_Invsqrt_DBP_CUDAm(FN fn,Mat A,Mat B)
180: {
181:   PetscBLASInt   n=0;
182:   PetscScalar    *T;
183:   PetscInt       m;

185:   PetscFunctionBegin;
186:   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
187:   PetscCall(MatDenseCUDAGetArray(B,&T));
188:   PetscCall(MatGetSize(A,&m,NULL));
189:   PetscCall(PetscBLASIntCast(m,&n));
190:   PetscCall(FNSqrtmDenmanBeavers_CUDAm(fn,n,T,n,PETSC_TRUE));
191:   PetscCall(MatDenseCUDARestoreArray(B,&T));
192:   PetscFunctionReturn(PETSC_SUCCESS);
193: }

195: PetscErrorCode FNEvaluateFunctionMat_Invsqrt_Sadeghi_CUDAm(FN fn,Mat A,Mat B)
196: {
197:   PetscBLASInt   n=0,ld,*ipiv;
198:   PetscScalar    *Ba,*Wa;
199:   PetscInt       m;
200:   Mat            W;

202:   PetscFunctionBegin;
203:   PetscCall(FN_AllocateWorkMat(fn,A,&W));
204:   if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
205:   PetscCall(MatDenseCUDAGetArray(B,&Ba));
206:   PetscCall(MatDenseCUDAGetArray(W,&Wa));
207:   /* compute B = sqrtm(A) */
208:   PetscCall(MatGetSize(A,&m,NULL));
209:   PetscCall(PetscBLASIntCast(m,&n));
210:   ld = n;
211:   PetscCall(FNSqrtmSadeghi_CUDAm(fn,n,Ba,n));
212:   /* compute B = A\B */
213:   PetscCall(SlepcMagmaInit());
214:   PetscCall(PetscMalloc1(ld,&ipiv));
215:   PetscCallMAGMA(magma_xgesv_gpu,n,n,Wa,ld,ipiv,Ba,ld);
216:   PetscCall(PetscLogFlops(2.0*n*n*n/3.0+2.0*n*n*n));
217:   PetscCall(PetscFree(ipiv));
218:   PetscCall(MatDenseCUDARestoreArray(W,&Wa));
219:   PetscCall(MatDenseCUDARestoreArray(B,&Ba));
220:   PetscCall(FN_FreeWorkMat(fn,&W));
221:   PetscFunctionReturn(PETSC_SUCCESS);
222: }
223: #endif /* PETSC_HAVE_MAGMA */
224: #endif /* PETSC_HAVE_CUDA */

226: static PetscErrorCode FNView_Invsqrt(FN fn,PetscViewer viewer)
227: {
228:   PetscBool      isascii;
229:   char           str[50];
230:   const char     *methodname[] = {
231:                   "Schur method for inv(A)*sqrtm(A)",
232:                   "Denman-Beavers (product form)",
233:                   "Newton-Schulz iteration",
234:                   "Sadeghi iteration"
235:   };
236:   const int      nmeth=PETSC_STATIC_ARRAY_LENGTH(methodname);

238:   PetscFunctionBegin;
239:   PetscCall(PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&isascii));
240:   if (isascii) {
241:     if (fn->beta==(PetscScalar)1.0) {
242:       if (fn->alpha==(PetscScalar)1.0) PetscCall(PetscViewerASCIIPrintf(viewer,"  inverse square root: x^(-1/2)\n"));
243:       else {
244:         PetscCall(SlepcSNPrintfScalar(str,sizeof(str),fn->alpha,PETSC_TRUE));
245:         PetscCall(PetscViewerASCIIPrintf(viewer,"  inverse square root: (%s*x)^(-1/2)\n",str));
246:       }
247:     } else {
248:       PetscCall(SlepcSNPrintfScalar(str,sizeof(str),fn->beta,PETSC_TRUE));
249:       if (fn->alpha==(PetscScalar)1.0) PetscCall(PetscViewerASCIIPrintf(viewer,"  inverse square root: %s*x^(-1/2)\n",str));
250:       else {
251:         PetscCall(PetscViewerASCIIPrintf(viewer,"  inverse square root: %s",str));
252:         PetscCall(PetscViewerASCIIUseTabs(viewer,PETSC_FALSE));
253:         PetscCall(SlepcSNPrintfScalar(str,sizeof(str),fn->alpha,PETSC_TRUE));
254:         PetscCall(PetscViewerASCIIPrintf(viewer,"*(%s*x)^(-1/2)\n",str));
255:         PetscCall(PetscViewerASCIIUseTabs(viewer,PETSC_TRUE));
256:       }
257:     }
258:     if (fn->method<nmeth) PetscCall(PetscViewerASCIIPrintf(viewer,"  computing matrix functions with: %s\n",methodname[fn->method]));
259:   }
260:   PetscFunctionReturn(PETSC_SUCCESS);
261: }

263: SLEPC_EXTERN PetscErrorCode FNCreate_Invsqrt(FN fn)
264: {
265:   PetscFunctionBegin;
266:   fn->ops->evaluatefunction          = FNEvaluateFunction_Invsqrt;
267:   fn->ops->evaluatederivative        = FNEvaluateDerivative_Invsqrt;
268:   fn->ops->evaluatefunctionmat[0]    = FNEvaluateFunctionMat_Invsqrt_Schur;
269:   fn->ops->evaluatefunctionmat[1]    = FNEvaluateFunctionMat_Invsqrt_DBP;
270:   fn->ops->evaluatefunctionmat[2]    = FNEvaluateFunctionMat_Invsqrt_NS;
271:   fn->ops->evaluatefunctionmat[3]    = FNEvaluateFunctionMat_Invsqrt_Sadeghi;
272: #if defined(PETSC_HAVE_CUDA)
273:   fn->ops->evaluatefunctionmatcuda[2] = FNEvaluateFunctionMat_Invsqrt_NS_CUDA;
274: #if defined(PETSC_HAVE_MAGMA)
275:   fn->ops->evaluatefunctionmatcuda[1] = FNEvaluateFunctionMat_Invsqrt_DBP_CUDAm;
276:   fn->ops->evaluatefunctionmatcuda[3] = FNEvaluateFunctionMat_Invsqrt_Sadeghi_CUDAm;
277: #endif /* PETSC_HAVE_MAGMA */
278: #endif /* PETSC_HAVE_CUDA */
279:   fn->ops->evaluatefunctionmatvec[0] = FNEvaluateFunctionMatVec_Invsqrt_Schur;
280:   fn->ops->view                      = FNView_Invsqrt;
281:   PetscFunctionReturn(PETSC_SUCCESS);
282: }