Actual source code: fnsqrt.c
slepc-3.22.2 2024-12-02
1: /*
2: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
3: SLEPc - Scalable Library for Eigenvalue Problem Computations
4: Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain
6: This file is part of SLEPc.
7: SLEPc is distributed under a 2-clause BSD license (see LICENSE).
8: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
9: */
10: /*
11: Square root function sqrt(x)
12: */
14: #include <slepc/private/fnimpl.h>
15: #include <slepcblaslapack.h>
17: static PetscErrorCode FNEvaluateFunction_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
18: {
19: PetscFunctionBegin;
20: #if !defined(PETSC_USE_COMPLEX)
21: PetscCheck(x>=0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Function not defined in the requested value");
22: #endif
23: *y = PetscSqrtScalar(x);
24: PetscFunctionReturn(PETSC_SUCCESS);
25: }
27: static PetscErrorCode FNEvaluateDerivative_Sqrt(FN fn,PetscScalar x,PetscScalar *y)
28: {
29: PetscFunctionBegin;
30: PetscCheck(x!=0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Derivative not defined in the requested value");
31: #if !defined(PETSC_USE_COMPLEX)
32: PetscCheck(x>0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Derivative not defined in the requested value");
33: #endif
34: *y = 1.0/(2.0*PetscSqrtScalar(x));
35: PetscFunctionReturn(PETSC_SUCCESS);
36: }
38: static PetscErrorCode FNEvaluateFunctionMat_Sqrt_Schur(FN fn,Mat A,Mat B)
39: {
40: PetscBLASInt n=0;
41: PetscScalar *T;
42: PetscInt m;
44: PetscFunctionBegin;
45: if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
46: PetscCall(MatDenseGetArray(B,&T));
47: PetscCall(MatGetSize(A,&m,NULL));
48: PetscCall(PetscBLASIntCast(m,&n));
49: PetscCall(FNSqrtmSchur(fn,n,T,n,PETSC_FALSE));
50: PetscCall(MatDenseRestoreArray(B,&T));
51: PetscFunctionReturn(PETSC_SUCCESS);
52: }
54: static PetscErrorCode FNEvaluateFunctionMatVec_Sqrt_Schur(FN fn,Mat A,Vec v)
55: {
56: PetscBLASInt n=0;
57: PetscScalar *T;
58: PetscInt m;
59: Mat B;
61: PetscFunctionBegin;
62: PetscCall(FN_AllocateWorkMat(fn,A,&B));
63: PetscCall(MatDenseGetArray(B,&T));
64: PetscCall(MatGetSize(A,&m,NULL));
65: PetscCall(PetscBLASIntCast(m,&n));
66: PetscCall(FNSqrtmSchur(fn,n,T,n,PETSC_TRUE));
67: PetscCall(MatDenseRestoreArray(B,&T));
68: PetscCall(MatGetColumnVector(B,v,0));
69: PetscCall(FN_FreeWorkMat(fn,&B));
70: PetscFunctionReturn(PETSC_SUCCESS);
71: }
73: static PetscErrorCode FNEvaluateFunctionMat_Sqrt_DBP(FN fn,Mat A,Mat B)
74: {
75: PetscBLASInt n=0;
76: PetscScalar *T;
77: PetscInt m;
79: PetscFunctionBegin;
80: if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
81: PetscCall(MatDenseGetArray(B,&T));
82: PetscCall(MatGetSize(A,&m,NULL));
83: PetscCall(PetscBLASIntCast(m,&n));
84: PetscCall(FNSqrtmDenmanBeavers(fn,n,T,n,PETSC_FALSE));
85: PetscCall(MatDenseRestoreArray(B,&T));
86: PetscFunctionReturn(PETSC_SUCCESS);
87: }
89: static PetscErrorCode FNEvaluateFunctionMat_Sqrt_NS(FN fn,Mat A,Mat B)
90: {
91: PetscBLASInt n=0;
92: PetscScalar *Ba;
93: PetscInt m;
95: PetscFunctionBegin;
96: if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
97: PetscCall(MatDenseGetArray(B,&Ba));
98: PetscCall(MatGetSize(A,&m,NULL));
99: PetscCall(PetscBLASIntCast(m,&n));
100: PetscCall(FNSqrtmNewtonSchulz(fn,n,Ba,n,PETSC_FALSE));
101: PetscCall(MatDenseRestoreArray(B,&Ba));
102: PetscFunctionReturn(PETSC_SUCCESS);
103: }
105: #define MAXIT 50
107: /*
108: Computes the principal square root of the matrix A using the
109: Sadeghi iteration. A is overwritten with sqrtm(A).
110: */
111: PetscErrorCode FNSqrtmSadeghi(FN fn,PetscBLASInt n,PetscScalar *A,PetscBLASInt ld)
112: {
113: PetscScalar *M,*M2,*G,*X=A,*work,work1,sqrtnrm;
114: PetscScalar szero=0.0,sone=1.0,smfive=-5.0,s1d16=1.0/16.0;
115: PetscReal tol,Mres=0.0,nrm,rwork[1],done=1.0;
116: PetscInt i,it;
117: PetscBLASInt N,*piv=NULL,info,lwork=0,query=-1,one=1,zero=0;
118: PetscBool converged=PETSC_FALSE;
119: unsigned int ftz;
121: PetscFunctionBegin;
122: N = n*n;
123: tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
124: PetscCall(SlepcSetFlushToZero(&ftz));
126: /* query work size */
127: PetscCallBLAS("LAPACKgetri",LAPACKgetri_(&n,A,&ld,piv,&work1,&query,&info));
128: PetscCall(PetscBLASIntCast((PetscInt)PetscRealPart(work1),&lwork));
130: PetscCall(PetscMalloc5(N,&M,N,&M2,N,&G,lwork,&work,n,&piv));
131: PetscCall(PetscArraycpy(M,A,N));
133: /* scale M */
134: nrm = LAPACKlange_("fro",&n,&n,M,&n,rwork);
135: if (nrm>1.0) {
136: sqrtnrm = PetscSqrtReal(nrm);
137: PetscCallBLAS("LAPACKlascl",LAPACKlascl_("G",&zero,&zero,&nrm,&done,&N,&one,M,&N,&info));
138: SlepcCheckLapackInfo("lascl",info);
139: tol *= nrm;
140: }
141: PetscCall(PetscInfo(fn,"||A||_F = %g, new tol: %g\n",(double)nrm,(double)tol));
143: /* X = I */
144: PetscCall(PetscArrayzero(X,N));
145: for (i=0;i<n;i++) X[i+i*ld] = 1.0;
147: for (it=0;it<MAXIT && !converged;it++) {
149: /* G = (5/16)*I + (1/16)*M*(15*I-5*M+M*M) */
150: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M,&ld,M,&ld,&szero,M2,&ld));
151: PetscCallBLAS("BLASaxpy",BLASaxpy_(&N,&smfive,M,&one,M2,&one));
152: for (i=0;i<n;i++) M2[i+i*ld] += 15.0;
153: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&s1d16,M,&ld,M2,&ld,&szero,G,&ld));
154: for (i=0;i<n;i++) G[i+i*ld] += 5.0/16.0;
156: /* X = X*G */
157: PetscCall(PetscArraycpy(M2,X,N));
158: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,M2,&ld,G,&ld,&szero,X,&ld));
160: /* M = M*inv(G*G) */
161: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,G,&ld,&szero,M2,&ld));
162: PetscCallBLAS("LAPACKgetrf",LAPACKgetrf_(&n,&n,M2,&ld,piv,&info));
163: SlepcCheckLapackInfo("getrf",info);
164: PetscCallBLAS("LAPACKgetri",LAPACKgetri_(&n,M2,&ld,piv,work,&lwork,&info));
165: SlepcCheckLapackInfo("getri",info);
167: PetscCall(PetscArraycpy(G,M,N));
168: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,G,&ld,M2,&ld,&szero,M,&ld));
170: /* check ||I-M|| */
171: PetscCall(PetscArraycpy(M2,M,N));
172: for (i=0;i<n;i++) M2[i+i*ld] -= 1.0;
173: Mres = LAPACKlange_("fro",&n,&n,M2,&n,rwork);
174: PetscCheck(!PetscIsNanReal(Mres),PETSC_COMM_SELF,PETSC_ERR_FP,"The computed norm is not-a-number");
175: if (Mres<=tol) converged = PETSC_TRUE;
176: PetscCall(PetscInfo(fn,"it: %" PetscInt_FMT " res: %g\n",it,(double)Mres));
177: PetscCall(PetscLogFlops(8.0*n*n*n+2.0*n*n+2.0*n*n*n/3.0+4.0*n*n*n/3.0+2.0*n*n*n+2.0*n*n));
178: }
180: PetscCheck(Mres<=tol,PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations",MAXIT);
182: /* undo scaling */
183: if (nrm>1.0) PetscCallBLAS("BLASscal",BLASscal_(&N,&sqrtnrm,A,&one));
185: PetscCall(PetscFree5(M,M2,G,work,piv));
186: PetscCall(SlepcResetFlushToZero(&ftz));
187: PetscFunctionReturn(PETSC_SUCCESS);
188: }
190: #if defined(PETSC_HAVE_CUDA)
191: #include "../src/sys/classes/fn/impls/cuda/fnutilcuda.h"
192: #include <slepccupmblas.h>
194: #if defined(PETSC_HAVE_MAGMA)
195: #include <slepcmagma.h>
197: /*
198: * Matrix square root by Sadeghi iteration. CUDA version.
199: * Computes the principal square root of the matrix A using the
200: * Sadeghi iteration. A is overwritten with sqrtm(A).
201: */
202: PetscErrorCode FNSqrtmSadeghi_CUDAm(FN fn,PetscBLASInt n,PetscScalar *d_A,PetscBLASInt ld)
203: {
204: PetscScalar *d_M,*d_M2,*d_G,*d_work,alpha;
205: const PetscScalar szero=0.0,sone=1.0,smfive=-5.0,s15=15.0,s1d16=1.0/16.0;
206: PetscReal tol,Mres=0.0,nrm,sqrtnrm=1.0;
207: PetscInt it,nb,lwork;
208: PetscBLASInt *piv,N;
209: const PetscBLASInt one=1;
210: PetscBool converged=PETSC_FALSE;
211: cublasHandle_t cublasv2handle;
213: PetscFunctionBegin;
214: PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA)); /* For CUDA event timers */
215: PetscCall(PetscCUBLASGetHandle(&cublasv2handle));
216: PetscCall(SlepcMagmaInit());
217: N = n*n;
218: tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
220: PetscCall(PetscMalloc1(n,&piv));
221: PetscCallCUDA(cudaMalloc((void **)&d_M,sizeof(PetscScalar)*N));
222: PetscCallCUDA(cudaMalloc((void **)&d_M2,sizeof(PetscScalar)*N));
223: PetscCallCUDA(cudaMalloc((void **)&d_G,sizeof(PetscScalar)*N));
225: nb = magma_get_xgetri_nb(n);
226: lwork = nb*n;
227: PetscCallCUDA(cudaMalloc((void **)&d_work,sizeof(PetscScalar)*lwork));
228: PetscCall(PetscLogGpuTimeBegin());
230: /* M = A */
231: PetscCallCUDA(cudaMemcpy(d_M,d_A,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
233: /* scale M */
234: PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_M,one,&nrm));
235: if (nrm>1.0) {
236: sqrtnrm = PetscSqrtReal(nrm);
237: alpha = 1.0/nrm;
238: PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_M,one));
239: tol *= nrm;
240: }
241: PetscCall(PetscInfo(fn,"||A||_F = %g, new tol: %g\n",(double)nrm,(double)tol));
243: /* X = I */
244: PetscCallCUDA(cudaMemset(d_A,0,sizeof(PetscScalar)*N));
245: PetscCall(set_diagonal(n,d_A,ld,sone));
247: for (it=0;it<MAXIT && !converged;it++) {
249: /* G = (5/16)*I + (1/16)*M*(15*I-5*M+M*M) */
250: PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_M,ld,d_M,ld,&szero,d_M2,ld));
251: PetscCallCUBLAS(cublasXaxpy(cublasv2handle,N,&smfive,d_M,one,d_M2,one));
252: PetscCall(shift_diagonal(n,d_M2,ld,s15));
253: PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&s1d16,d_M,ld,d_M2,ld,&szero,d_G,ld));
254: PetscCall(shift_diagonal(n,d_G,ld,5.0/16.0));
256: /* X = X*G */
257: PetscCallCUDA(cudaMemcpy(d_M2,d_A,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
258: PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_M2,ld,d_G,ld,&szero,d_A,ld));
260: /* M = M*inv(G*G) */
261: PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_G,ld,d_G,ld,&szero,d_M2,ld));
262: /* magma */
263: PetscCallMAGMA(magma_xgetrf_gpu,n,n,d_M2,ld,piv);
264: PetscCallMAGMA(magma_xgetri_gpu,n,d_M2,ld,piv,d_work,lwork);
265: /* magma */
266: PetscCallCUDA(cudaMemcpy(d_G,d_M,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
267: PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&sone,d_G,ld,d_M2,ld,&szero,d_M,ld));
269: /* check ||I-M|| */
270: PetscCallCUDA(cudaMemcpy(d_M2,d_M,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
271: PetscCall(shift_diagonal(n,d_M2,ld,-1.0));
272: PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_M2,one,&Mres));
273: PetscCheck(!PetscIsNanReal(Mres),PETSC_COMM_SELF,PETSC_ERR_FP,"The computed norm is not-a-number");
274: if (Mres<=tol) converged = PETSC_TRUE;
275: PetscCall(PetscInfo(fn,"it: %" PetscInt_FMT " res: %g\n",it,(double)Mres));
276: PetscCall(PetscLogGpuFlops(8.0*n*n*n+2.0*n*n+2.0*n*n*n/3.0+4.0*n*n*n/3.0+2.0*n*n*n+2.0*n*n));
277: }
279: PetscCheck(Mres<=tol,PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations", MAXIT);
281: if (nrm>1.0) {
282: alpha = sqrtnrm;
283: PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_A,one));
284: }
285: PetscCall(PetscLogGpuTimeEnd());
287: PetscCallCUDA(cudaFree(d_M));
288: PetscCallCUDA(cudaFree(d_M2));
289: PetscCallCUDA(cudaFree(d_G));
290: PetscCallCUDA(cudaFree(d_work));
291: PetscCall(PetscFree(piv));
292: PetscFunctionReturn(PETSC_SUCCESS);
293: }
294: #endif /* PETSC_HAVE_MAGMA */
295: #endif /* PETSC_HAVE_CUDA */
297: static PetscErrorCode FNEvaluateFunctionMat_Sqrt_Sadeghi(FN fn,Mat A,Mat B)
298: {
299: PetscBLASInt n=0;
300: PetscScalar *Ba;
301: PetscInt m;
303: PetscFunctionBegin;
304: if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
305: PetscCall(MatDenseGetArray(B,&Ba));
306: PetscCall(MatGetSize(A,&m,NULL));
307: PetscCall(PetscBLASIntCast(m,&n));
308: PetscCall(FNSqrtmSadeghi(fn,n,Ba,n));
309: PetscCall(MatDenseRestoreArray(B,&Ba));
310: PetscFunctionReturn(PETSC_SUCCESS);
311: }
313: #if defined(PETSC_HAVE_CUDA)
314: PetscErrorCode FNEvaluateFunctionMat_Sqrt_NS_CUDA(FN fn,Mat A,Mat B)
315: {
316: PetscBLASInt n=0;
317: PetscScalar *Ba;
318: PetscInt m;
320: PetscFunctionBegin;
321: if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
322: PetscCall(MatDenseCUDAGetArray(B,&Ba));
323: PetscCall(MatGetSize(A,&m,NULL));
324: PetscCall(PetscBLASIntCast(m,&n));
325: PetscCall(FNSqrtmNewtonSchulz_CUDA(fn,n,Ba,n,PETSC_FALSE));
326: PetscCall(MatDenseCUDARestoreArray(B,&Ba));
327: PetscFunctionReturn(PETSC_SUCCESS);
328: }
330: #if defined(PETSC_HAVE_MAGMA)
331: PetscErrorCode FNEvaluateFunctionMat_Sqrt_DBP_CUDAm(FN fn,Mat A,Mat B)
332: {
333: PetscBLASInt n=0;
334: PetscScalar *T;
335: PetscInt m;
337: PetscFunctionBegin;
338: if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
339: PetscCall(MatDenseCUDAGetArray(B,&T));
340: PetscCall(MatGetSize(A,&m,NULL));
341: PetscCall(PetscBLASIntCast(m,&n));
342: PetscCall(FNSqrtmDenmanBeavers_CUDAm(fn,n,T,n,PETSC_FALSE));
343: PetscCall(MatDenseCUDARestoreArray(B,&T));
344: PetscFunctionReturn(PETSC_SUCCESS);
345: }
347: PetscErrorCode FNEvaluateFunctionMat_Sqrt_Sadeghi_CUDAm(FN fn,Mat A,Mat B)
348: {
349: PetscBLASInt n=0;
350: PetscScalar *Ba;
351: PetscInt m;
353: PetscFunctionBegin;
354: if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
355: PetscCall(MatDenseCUDAGetArray(B,&Ba));
356: PetscCall(MatGetSize(A,&m,NULL));
357: PetscCall(PetscBLASIntCast(m,&n));
358: PetscCall(FNSqrtmSadeghi_CUDAm(fn,n,Ba,n));
359: PetscCall(MatDenseCUDARestoreArray(B,&Ba));
360: PetscFunctionReturn(PETSC_SUCCESS);
361: }
362: #endif /* PETSC_HAVE_MAGMA */
363: #endif /* PETSC_HAVE_CUDA */
365: static PetscErrorCode FNView_Sqrt(FN fn,PetscViewer viewer)
366: {
367: PetscBool isascii;
368: char str[50];
369: const char *methodname[] = {
370: "Schur method for the square root",
371: "Denman-Beavers (product form)",
372: "Newton-Schulz iteration",
373: "Sadeghi iteration"
374: };
375: const int nmeth=PETSC_STATIC_ARRAY_LENGTH(methodname);
377: PetscFunctionBegin;
378: PetscCall(PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&isascii));
379: if (isascii) {
380: if (fn->beta==(PetscScalar)1.0) {
381: if (fn->alpha==(PetscScalar)1.0) PetscCall(PetscViewerASCIIPrintf(viewer," square root: sqrt(x)\n"));
382: else {
383: PetscCall(SlepcSNPrintfScalar(str,sizeof(str),fn->alpha,PETSC_TRUE));
384: PetscCall(PetscViewerASCIIPrintf(viewer," square root: sqrt(%s*x)\n",str));
385: }
386: } else {
387: PetscCall(SlepcSNPrintfScalar(str,sizeof(str),fn->beta,PETSC_TRUE));
388: if (fn->alpha==(PetscScalar)1.0) PetscCall(PetscViewerASCIIPrintf(viewer," square root: %s*sqrt(x)\n",str));
389: else {
390: PetscCall(PetscViewerASCIIPrintf(viewer," square root: %s",str));
391: PetscCall(PetscViewerASCIIUseTabs(viewer,PETSC_FALSE));
392: PetscCall(SlepcSNPrintfScalar(str,sizeof(str),fn->alpha,PETSC_TRUE));
393: PetscCall(PetscViewerASCIIPrintf(viewer,"*sqrt(%s*x)\n",str));
394: PetscCall(PetscViewerASCIIUseTabs(viewer,PETSC_TRUE));
395: }
396: }
397: if (fn->method<nmeth) PetscCall(PetscViewerASCIIPrintf(viewer," computing matrix functions with: %s\n",methodname[fn->method]));
398: }
399: PetscFunctionReturn(PETSC_SUCCESS);
400: }
402: SLEPC_EXTERN PetscErrorCode FNCreate_Sqrt(FN fn)
403: {
404: PetscFunctionBegin;
405: fn->ops->evaluatefunction = FNEvaluateFunction_Sqrt;
406: fn->ops->evaluatederivative = FNEvaluateDerivative_Sqrt;
407: fn->ops->evaluatefunctionmat[0] = FNEvaluateFunctionMat_Sqrt_Schur;
408: fn->ops->evaluatefunctionmat[1] = FNEvaluateFunctionMat_Sqrt_DBP;
409: fn->ops->evaluatefunctionmat[2] = FNEvaluateFunctionMat_Sqrt_NS;
410: fn->ops->evaluatefunctionmat[3] = FNEvaluateFunctionMat_Sqrt_Sadeghi;
411: #if defined(PETSC_HAVE_CUDA)
412: fn->ops->evaluatefunctionmatcuda[2] = FNEvaluateFunctionMat_Sqrt_NS_CUDA;
413: #if defined(PETSC_HAVE_MAGMA)
414: fn->ops->evaluatefunctionmatcuda[1] = FNEvaluateFunctionMat_Sqrt_DBP_CUDAm;
415: fn->ops->evaluatefunctionmatcuda[3] = FNEvaluateFunctionMat_Sqrt_Sadeghi_CUDAm;
416: #endif /* PETSC_HAVE_MAGMA */
417: #endif /* PETSC_HAVE_CUDA */
418: fn->ops->evaluatefunctionmatvec[0] = FNEvaluateFunctionMatVec_Sqrt_Schur;
419: fn->ops->view = FNView_Sqrt;
420: PetscFunctionReturn(PETSC_SUCCESS);
421: }