Actual source code: fnutil.c
slepc-3.22.2 2024-12-02
1: /*
2: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
3: SLEPc - Scalable Library for Eigenvalue Problem Computations
4: Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain
6: This file is part of SLEPc.
7: SLEPc is distributed under a 2-clause BSD license (see LICENSE).
8: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
9: */
10: /*
11: Utility subroutines common to several impls
12: */
14: #include <slepc/private/fnimpl.h>
15: #include <slepcblaslapack.h>
17: /*
18: Compute the square root of an upper quasi-triangular matrix T,
19: using Higham's algorithm (LAA 88, 1987). T is overwritten with sqrtm(T).
20: */
21: static PetscErrorCode SlepcMatDenseSqrt(PetscBLASInt n,PetscScalar *T,PetscBLASInt ld)
22: {
23: PetscScalar one=1.0,mone=-1.0;
24: PetscReal scal;
25: PetscBLASInt i,j,si,sj,r,ione=1,info;
26: #if !defined(PETSC_USE_COMPLEX)
27: PetscReal alpha,theta,mu,mu2;
28: #endif
30: PetscFunctionBegin;
31: for (j=0;j<n;j++) {
32: #if defined(PETSC_USE_COMPLEX)
33: sj = 1;
34: T[j+j*ld] = PetscSqrtScalar(T[j+j*ld]);
35: #else
36: sj = (j==n-1 || T[j+1+j*ld] == 0.0)? 1: 2;
37: if (sj==1) {
38: PetscCheck(T[j+j*ld]>=0.0,PETSC_COMM_SELF,PETSC_ERR_USER_INPUT,"Matrix has a real negative eigenvalue, no real primary square root exists");
39: T[j+j*ld] = PetscSqrtReal(T[j+j*ld]);
40: } else {
41: /* square root of 2x2 block */
42: theta = (T[j+j*ld]+T[j+1+(j+1)*ld])/2.0;
43: mu = (T[j+j*ld]-T[j+1+(j+1)*ld])/2.0;
44: mu2 = -mu*mu-T[j+1+j*ld]*T[j+(j+1)*ld];
45: mu = PetscSqrtReal(mu2);
46: if (theta>0.0) alpha = PetscSqrtReal((theta+PetscSqrtReal(theta*theta+mu2))/2.0);
47: else alpha = mu/PetscSqrtReal(2.0*(-theta+PetscSqrtReal(theta*theta+mu2)));
48: T[j+j*ld] /= 2.0*alpha;
49: T[j+1+(j+1)*ld] /= 2.0*alpha;
50: T[j+(j+1)*ld] /= 2.0*alpha;
51: T[j+1+j*ld] /= 2.0*alpha;
52: T[j+j*ld] += alpha-theta/(2.0*alpha);
53: T[j+1+(j+1)*ld] += alpha-theta/(2.0*alpha);
54: }
55: #endif
56: for (i=j-1;i>=0;i--) {
57: #if defined(PETSC_USE_COMPLEX)
58: si = 1;
59: #else
60: si = (i==0 || T[i+(i-1)*ld] == 0.0)? 1: 2;
61: if (si==2) i--;
62: #endif
63: /* solve Sylvester equation of order si x sj */
64: r = j-i-si;
65: if (r) PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&si,&sj,&r,&mone,T+i+(i+si)*ld,&ld,T+i+si+j*ld,&ld,&one,T+i+j*ld,&ld));
66: PetscCallBLAS("LAPACKtrsyl",LAPACKtrsyl_("N","N",&ione,&si,&sj,T+i+i*ld,&ld,T+j+j*ld,&ld,T+i+j*ld,&ld,&scal,&info));
67: SlepcCheckLapackInfo("trsyl",info);
68: PetscCheck(scal==1.0,PETSC_COMM_SELF,PETSC_ERR_SUP,"Current implementation cannot handle scale factor %g",(double)scal);
69: }
70: if (sj==2) j++;
71: }
72: PetscFunctionReturn(PETSC_SUCCESS);
73: }
75: #define BLOCKSIZE 64
77: /*
78: Schur method for the square root of an upper quasi-triangular matrix T.
79: T is overwritten with sqrtm(T).
80: If firstonly then only the first column of T will contain relevant values.
81: */
82: PetscErrorCode FNSqrtmSchur(FN fn,PetscBLASInt n,PetscScalar *T,PetscBLASInt ld,PetscBool firstonly)
83: {
84: PetscBLASInt i,j,k,r,ione=1,sdim,lwork,*s,*p,info,bs=BLOCKSIZE;
85: PetscScalar *wr,*W,*Q,*work,one=1.0,zero=0.0,mone=-1.0;
86: PetscInt m,nblk;
87: PetscReal scal;
88: #if defined(PETSC_USE_COMPLEX)
89: PetscReal *rwork;
90: #else
91: PetscReal *wi;
92: #endif
94: PetscFunctionBegin;
95: m = n;
96: nblk = (m+bs-1)/bs;
97: lwork = 5*n;
98: k = firstonly? 1: n;
100: /* compute Schur decomposition A*Q = Q*T */
101: #if !defined(PETSC_USE_COMPLEX)
102: PetscCall(PetscMalloc7(m,&wr,m,&wi,m*k,&W,m*m,&Q,lwork,&work,nblk,&s,nblk,&p));
103: PetscCallBLAS("LAPACKgees",LAPACKgees_("V","N",NULL,&n,T,&ld,&sdim,wr,wi,Q,&ld,work,&lwork,NULL,&info));
104: #else
105: PetscCall(PetscMalloc7(m,&wr,m,&rwork,m*k,&W,m*m,&Q,lwork,&work,nblk,&s,nblk,&p));
106: PetscCallBLAS("LAPACKgees",LAPACKgees_("V","N",NULL,&n,T,&ld,&sdim,wr,Q,&ld,work,&lwork,rwork,NULL,&info));
107: #endif
108: SlepcCheckLapackInfo("gees",info);
110: /* determine block sizes and positions, to avoid cutting 2x2 blocks */
111: j = 0;
112: p[j] = 0;
113: do {
114: s[j] = PetscMin(bs,n-p[j]);
115: #if !defined(PETSC_USE_COMPLEX)
116: if (p[j]+s[j]!=n && T[p[j]+s[j]+(p[j]+s[j]-1)*ld]!=0.0) s[j]++;
117: #endif
118: if (p[j]+s[j]==n) break;
119: j++;
120: p[j] = p[j-1]+s[j-1];
121: } while (1);
122: nblk = j+1;
124: for (j=0;j<nblk;j++) {
125: /* evaluate f(T_jj) */
126: PetscCall(SlepcMatDenseSqrt(s[j],T+p[j]+p[j]*ld,ld));
127: for (i=j-1;i>=0;i--) {
128: /* solve Sylvester equation for block (i,j) */
129: r = p[j]-p[i]-s[i];
130: if (r) PetscCallBLAS("BLASgemm",BLASgemm_("N","N",s+i,s+j,&r,&mone,T+p[i]+(p[i]+s[i])*ld,&ld,T+p[i]+s[i]+p[j]*ld,&ld,&one,T+p[i]+p[j]*ld,&ld));
131: PetscCallBLAS("LAPACKtrsyl",LAPACKtrsyl_("N","N",&ione,s+i,s+j,T+p[i]+p[i]*ld,&ld,T+p[j]+p[j]*ld,&ld,T+p[i]+p[j]*ld,&ld,&scal,&info));
132: SlepcCheckLapackInfo("trsyl",info);
133: PetscCheck(scal==1.0,PETSC_COMM_SELF,PETSC_ERR_SUP,"Current implementation cannot handle scale factor %g",(double)scal);
134: }
135: }
137: /* backtransform B = Q*T*Q' */
138: PetscCallBLAS("BLASgemm",BLASgemm_("N","C",&n,&k,&n,&one,T,&ld,Q,&ld,&zero,W,&ld));
139: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&k,&n,&one,Q,&ld,W,&ld,&zero,T,&ld));
141: /* flop count: Schur decomposition, triangular square root, and backtransform */
142: PetscCall(PetscLogFlops(25.0*n*n*n+n*n*n/3.0+4.0*n*n*k));
144: #if !defined(PETSC_USE_COMPLEX)
145: PetscCall(PetscFree7(wr,wi,W,Q,work,s,p));
146: #else
147: PetscCall(PetscFree7(wr,rwork,W,Q,work,s,p));
148: #endif
149: PetscFunctionReturn(PETSC_SUCCESS);
150: }
152: #define DBMAXIT 25
154: /*
155: Computes the principal square root of the matrix T using the product form
156: of the Denman-Beavers iteration.
157: T is overwritten with sqrtm(T) or inv(sqrtm(T)) depending on flag inv.
158: */
159: PetscErrorCode FNSqrtmDenmanBeavers(FN fn,PetscBLASInt n,PetscScalar *T,PetscBLASInt ld,PetscBool inv)
160: {
161: PetscScalar *Told,*M=NULL,*invM,*work,work1,prod,alpha;
162: PetscScalar szero=0.0,sone=1.0,smone=-1.0,spfive=0.5,sp25=0.25;
163: PetscReal tol,Mres=0.0,detM,g,reldiff,fnormdiff,fnormT,rwork[1];
164: PetscBLASInt N,i,it,*piv=NULL,info,query=-1,lwork;
165: const PetscBLASInt one=1;
166: PetscBool converged=PETSC_FALSE,scale;
167: unsigned int ftz;
169: PetscFunctionBegin;
170: N = n*n;
171: tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
172: scale = PetscDefined(USE_REAL_SINGLE)? PETSC_FALSE: PETSC_TRUE;
173: PetscCall(SlepcSetFlushToZero(&ftz));
175: /* query work size */
176: PetscCallBLAS("LAPACKgetri",LAPACKgetri_(&n,M,&ld,piv,&work1,&query,&info));
177: PetscCall(PetscBLASIntCast((PetscInt)PetscRealPart(work1),&lwork));
178: PetscCall(PetscMalloc5(lwork,&work,n,&piv,n*n,&Told,n*n,&M,n*n,&invM));
179: PetscCall(PetscArraycpy(M,T,n*n));
181: if (inv) { /* start recurrence with I instead of A */
182: PetscCall(PetscArrayzero(T,n*n));
183: for (i=0;i<n;i++) T[i+i*ld] += 1.0;
184: }
186: for (it=0;it<DBMAXIT && !converged;it++) {
188: if (scale) { /* g = (abs(det(M)))^(-1/(2*n)) */
189: PetscCall(PetscArraycpy(invM,M,n*n));
190: PetscCallBLAS("LAPACKgetrf",LAPACKgetrf_(&n,&n,invM,&ld,piv,&info));
191: SlepcCheckLapackInfo("getrf",info);
192: prod = invM[0];
193: for (i=1;i<n;i++) prod *= invM[i+i*ld];
194: detM = PetscAbsScalar(prod);
195: g = (detM>PETSC_MAX_REAL)? 0.5: PetscPowReal(detM,-1.0/(2.0*n));
196: alpha = g;
197: PetscCallBLAS("BLASscal",BLASscal_(&N,&alpha,T,&one));
198: alpha = g*g;
199: PetscCallBLAS("BLASscal",BLASscal_(&N,&alpha,M,&one));
200: PetscCall(PetscLogFlops(2.0*n*n*n/3.0+2.0*n*n));
201: }
203: PetscCall(PetscArraycpy(Told,T,n*n));
204: PetscCall(PetscArraycpy(invM,M,n*n));
206: PetscCallBLAS("LAPACKgetrf",LAPACKgetrf_(&n,&n,invM,&ld,piv,&info));
207: SlepcCheckLapackInfo("getrf",info);
208: PetscCallBLAS("LAPACKgetri",LAPACKgetri_(&n,invM,&ld,piv,work,&lwork,&info));
209: SlepcCheckLapackInfo("getri",info);
210: PetscCall(PetscLogFlops(2.0*n*n*n/3.0+4.0*n*n*n/3.0));
212: for (i=0;i<n;i++) invM[i+i*ld] += 1.0;
213: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&spfive,Told,&ld,invM,&ld,&szero,T,&ld));
214: for (i=0;i<n;i++) invM[i+i*ld] -= 1.0;
216: PetscCallBLAS("BLASaxpy",BLASaxpy_(&N,&sone,invM,&one,M,&one));
217: PetscCallBLAS("BLASscal",BLASscal_(&N,&sp25,M,&one));
218: for (i=0;i<n;i++) M[i+i*ld] -= 0.5;
219: PetscCall(PetscLogFlops(2.0*n*n*n+2.0*n*n));
221: Mres = LAPACKlange_("F",&n,&n,M,&n,rwork);
222: for (i=0;i<n;i++) M[i+i*ld] += 1.0;
224: if (scale) {
225: /* reldiff = norm(T - Told,'fro')/norm(T,'fro') */
226: PetscCallBLAS("BLASaxpy",BLASaxpy_(&N,&smone,T,&one,Told,&one));
227: fnormdiff = LAPACKlange_("F",&n,&n,Told,&n,rwork);
228: fnormT = LAPACKlange_("F",&n,&n,T,&n,rwork);
229: PetscCall(PetscLogFlops(7.0*n*n));
230: reldiff = fnormdiff/fnormT;
231: PetscCall(PetscInfo(fn,"it: %" PetscBLASInt_FMT " reldiff: %g scale: %g tol*scale: %g\n",it,(double)reldiff,(double)g,(double)(tol*g)));
232: if (reldiff<1e-2) scale = PETSC_FALSE; /* Switch off scaling */
233: }
235: if (Mres<=tol) converged = PETSC_TRUE;
236: }
238: PetscCheck(Mres<=tol,PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations",DBMAXIT);
239: PetscCall(PetscFree5(work,piv,Told,M,invM));
240: PetscCall(SlepcResetFlushToZero(&ftz));
241: PetscFunctionReturn(PETSC_SUCCESS);
242: }
244: #define NSMAXIT 50
246: /*
247: Computes the principal square root of the matrix A using the Newton-Schulz iteration.
248: T is overwritten with sqrtm(T) or inv(sqrtm(T)) depending on flag inv.
249: */
250: PetscErrorCode FNSqrtmNewtonSchulz(FN fn,PetscBLASInt n,PetscScalar *A,PetscBLASInt ld,PetscBool inv)
251: {
252: PetscScalar *Y=A,*Yold,*Z,*Zold,*M;
253: PetscScalar szero=0.0,sone=1.0,smone=-1.0,spfive=0.5,sthree=3.0;
254: PetscReal sqrtnrm,tol,Yres=0.0,nrm,rwork[1],done=1.0;
255: PetscBLASInt info,i,it,N,one=1,zero=0;
256: PetscBool converged=PETSC_FALSE;
257: unsigned int ftz;
259: PetscFunctionBegin;
260: N = n*n;
261: tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
262: PetscCall(SlepcSetFlushToZero(&ftz));
264: PetscCall(PetscMalloc4(N,&Yold,N,&Z,N,&Zold,N,&M));
266: /* scale */
267: PetscCall(PetscArraycpy(Z,A,N));
268: for (i=0;i<n;i++) Z[i+i*ld] -= 1.0;
269: nrm = LAPACKlange_("fro",&n,&n,Z,&n,rwork);
270: sqrtnrm = PetscSqrtReal(nrm);
271: PetscCallBLAS("LAPACKlascl",LAPACKlascl_("G",&zero,&zero,&nrm,&done,&N,&one,A,&N,&info));
272: SlepcCheckLapackInfo("lascl",info);
273: tol *= nrm;
274: PetscCall(PetscInfo(fn,"||I-A||_F = %g, new tol: %g\n",(double)nrm,(double)tol));
275: PetscCall(PetscLogFlops(2.0*n*n));
277: /* Z = I */
278: PetscCall(PetscArrayzero(Z,N));
279: for (i=0;i<n;i++) Z[i+i*ld] = 1.0;
281: for (it=0;it<NSMAXIT && !converged;it++) {
282: /* Yold = Y, Zold = Z */
283: PetscCall(PetscArraycpy(Yold,Y,N));
284: PetscCall(PetscArraycpy(Zold,Z,N));
286: /* M = (3*I-Zold*Yold) */
287: PetscCall(PetscArrayzero(M,N));
288: for (i=0;i<n;i++) M[i+i*ld] = sthree;
289: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&smone,Zold,&ld,Yold,&ld,&sone,M,&ld));
291: /* Y = (1/2)*Yold*M, Z = (1/2)*M*Zold */
292: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&spfive,Yold,&ld,M,&ld,&szero,Y,&ld));
293: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&spfive,M,&ld,Zold,&ld,&szero,Z,&ld));
295: /* reldiff = norm(Y-Yold,'fro')/norm(Y,'fro') */
296: PetscCallBLAS("BLASaxpy",BLASaxpy_(&N,&smone,Y,&one,Yold,&one));
297: Yres = LAPACKlange_("fro",&n,&n,Yold,&n,rwork);
298: PetscCheck(!PetscIsNanReal(Yres),PETSC_COMM_SELF,PETSC_ERR_FP,"The computed norm is not-a-number");
299: if (Yres<=tol) converged = PETSC_TRUE;
300: PetscCall(PetscInfo(fn,"it: %" PetscBLASInt_FMT " res: %g\n",it,(double)Yres));
302: PetscCall(PetscLogFlops(6.0*n*n*n+2.0*n*n));
303: }
305: PetscCheck(Yres<=tol,PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations",NSMAXIT);
307: /* undo scaling */
308: if (inv) {
309: PetscCall(PetscArraycpy(A,Z,N));
310: PetscCallBLAS("LAPACKlascl",LAPACKlascl_("G",&zero,&zero,&sqrtnrm,&done,&N,&one,A,&N,&info));
311: } else PetscCallBLAS("LAPACKlascl",LAPACKlascl_("G",&zero,&zero,&done,&sqrtnrm,&N,&one,A,&N,&info));
312: SlepcCheckLapackInfo("lascl",info);
314: PetscCall(PetscFree4(Yold,Z,Zold,M));
315: PetscCall(SlepcResetFlushToZero(&ftz));
316: PetscFunctionReturn(PETSC_SUCCESS);
317: }
319: #if defined(PETSC_HAVE_CUDA)
320: #include "../src/sys/classes/fn/impls/cuda/fnutilcuda.h"
321: #include <slepccupmblas.h>
323: /*
324: * Matrix square root by Newton-Schulz iteration. CUDA version.
325: * Computes the principal square root of the matrix A using the
326: * Newton-Schulz iteration. A is overwritten with sqrtm(A).
327: */
328: PetscErrorCode FNSqrtmNewtonSchulz_CUDA(FN fn,PetscBLASInt n,PetscScalar *d_A,PetscBLASInt ld,PetscBool inv)
329: {
330: PetscScalar *d_Yold,*d_Z,*d_Zold,*d_M,alpha;
331: PetscReal nrm,sqrtnrm,tol,Yres=0.0;
332: const PetscScalar szero=0.0,sone=1.0,smone=-1.0,spfive=0.5,sthree=3.0;
333: PetscInt it;
334: PetscBLASInt N;
335: const PetscBLASInt one=1;
336: PetscBool converged=PETSC_FALSE;
337: cublasHandle_t cublasv2handle;
339: PetscFunctionBegin;
340: PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA)); /* For CUDA event timers */
341: PetscCall(PetscCUBLASGetHandle(&cublasv2handle));
342: N = n*n;
343: tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
345: PetscCallCUDA(cudaMalloc((void **)&d_Yold,sizeof(PetscScalar)*N));
346: PetscCallCUDA(cudaMalloc((void **)&d_Z,sizeof(PetscScalar)*N));
347: PetscCallCUDA(cudaMalloc((void **)&d_Zold,sizeof(PetscScalar)*N));
348: PetscCallCUDA(cudaMalloc((void **)&d_M,sizeof(PetscScalar)*N));
350: PetscCall(PetscLogGpuTimeBegin());
352: /* Z = I; */
353: PetscCallCUDA(cudaMemset(d_Z,0,sizeof(PetscScalar)*N));
354: PetscCall(set_diagonal(n,d_Z,ld,sone));
356: /* scale */
357: PetscCallCUBLAS(cublasXaxpy(cublasv2handle,N,&smone,d_A,one,d_Z,one));
358: PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_Z,one,&nrm));
359: sqrtnrm = PetscSqrtReal(nrm);
360: alpha = 1.0/nrm;
361: PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_A,one));
362: tol *= nrm;
363: PetscCall(PetscInfo(fn,"||I-A||_F = %g, new tol: %g\n",(double)nrm,(double)tol));
364: PetscCall(PetscLogGpuFlops(2.0*n*n));
366: /* Z = I; */
367: PetscCallCUDA(cudaMemset(d_Z,0,sizeof(PetscScalar)*N));
368: PetscCall(set_diagonal(n,d_Z,ld,sone));
370: for (it=0;it<NSMAXIT && !converged;it++) {
371: /* Yold = Y, Zold = Z */
372: PetscCallCUDA(cudaMemcpy(d_Yold,d_A,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
373: PetscCallCUDA(cudaMemcpy(d_Zold,d_Z,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
375: /* M = (3*I - Zold*Yold) */
376: PetscCallCUDA(cudaMemset(d_M,0,sizeof(PetscScalar)*N));
377: PetscCall(set_diagonal(n,d_M,ld,sthree));
378: PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&smone,d_Zold,ld,d_Yold,ld,&sone,d_M,ld));
380: /* Y = (1/2) * Yold * M, Z = (1/2) * M * Zold */
381: PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&spfive,d_Yold,ld,d_M,ld,&szero,d_A,ld));
382: PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&spfive,d_M,ld,d_Zold,ld,&szero,d_Z,ld));
384: /* reldiff = norm(Y-Yold,'fro')/norm(Y,'fro') */
385: PetscCallCUBLAS(cublasXaxpy(cublasv2handle,N,&smone,d_A,one,d_Yold,one));
386: PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_Yold,one,&Yres));
387: PetscCheck(!PetscIsNanReal(Yres),PETSC_COMM_SELF,PETSC_ERR_FP,"The computed norm is not-a-number");
388: if (Yres<=tol) converged = PETSC_TRUE;
389: PetscCall(PetscInfo(fn,"it: %" PetscInt_FMT " res: %g\n",it,(double)Yres));
391: PetscCall(PetscLogGpuFlops(6.0*n*n*n+2.0*n*n));
392: }
394: PetscCheck(Yres<=tol,PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations", NSMAXIT);
396: /* undo scaling */
397: if (inv) {
398: alpha = 1.0/sqrtnrm;
399: PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_Z,one));
400: PetscCallCUDA(cudaMemcpy(d_A,d_Z,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
401: } else {
402: alpha = sqrtnrm;
403: PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_A,one));
404: }
406: PetscCall(PetscLogGpuTimeEnd());
407: PetscCallCUDA(cudaFree(d_Yold));
408: PetscCallCUDA(cudaFree(d_Z));
409: PetscCallCUDA(cudaFree(d_Zold));
410: PetscCallCUDA(cudaFree(d_M));
411: PetscFunctionReturn(PETSC_SUCCESS);
412: }
414: #if defined(PETSC_HAVE_MAGMA)
415: #include <slepcmagma.h>
417: /*
418: * Matrix square root by product form of Denman-Beavers iteration. CUDA version.
419: * Computes the principal square root of the matrix T using the product form
420: * of the Denman-Beavers iteration. T is overwritten with sqrtm(T).
421: */
422: PetscErrorCode FNSqrtmDenmanBeavers_CUDAm(FN fn,PetscBLASInt n,PetscScalar *d_T,PetscBLASInt ld,PetscBool inv)
423: {
424: PetscScalar *d_Told,*d_M,*d_invM,*d_work,prod,szero=0.0,sone=1.0,smone=-1.0,spfive=0.5,sneg_pfive=-0.5,sp25=0.25,alpha;
425: PetscReal tol,Mres=0.0,detM,g,reldiff,fnormdiff,fnormT;
426: PetscInt it,lwork,nb;
427: PetscBLASInt N,one=1,*piv=NULL;
428: PetscBool converged=PETSC_FALSE,scale;
429: cublasHandle_t cublasv2handle;
431: PetscFunctionBegin;
432: PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA)); /* For CUDA event timers */
433: PetscCall(PetscCUBLASGetHandle(&cublasv2handle));
434: PetscCall(SlepcMagmaInit());
435: N = n*n;
436: scale = PetscDefined(USE_REAL_SINGLE)? PETSC_FALSE: PETSC_TRUE;
437: tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
439: /* query work size */
440: nb = magma_get_xgetri_nb(n);
441: lwork = nb*n;
442: PetscCall(PetscMalloc1(n,&piv));
443: PetscCallCUDA(cudaMalloc((void **)&d_work,sizeof(PetscScalar)*lwork));
444: PetscCallCUDA(cudaMalloc((void **)&d_Told,sizeof(PetscScalar)*N));
445: PetscCallCUDA(cudaMalloc((void **)&d_M,sizeof(PetscScalar)*N));
446: PetscCallCUDA(cudaMalloc((void **)&d_invM,sizeof(PetscScalar)*N));
448: PetscCall(PetscLogGpuTimeBegin());
449: PetscCallCUDA(cudaMemcpy(d_M,d_T,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
450: if (inv) { /* start recurrence with I instead of A */
451: PetscCallCUDA(cudaMemset(d_T,0,sizeof(PetscScalar)*N));
452: PetscCall(set_diagonal(n,d_T,ld,1.0));
453: }
455: for (it=0;it<DBMAXIT && !converged;it++) {
457: if (scale) { /* g = (abs(det(M)))^(-1/(2*n)); */
458: PetscCallCUDA(cudaMemcpy(d_invM,d_M,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
459: PetscCallMAGMA(magma_xgetrf_gpu,n,n,d_invM,ld,piv);
460: PetscCall(mult_diagonal(n,d_invM,ld,&prod));
461: detM = PetscAbsScalar(prod);
462: g = (detM>PETSC_MAX_REAL)? 0.5: PetscPowReal(detM,-1.0/(2.0*n));
463: alpha = g;
464: PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_T,one));
465: alpha = g*g;
466: PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_M,one));
467: PetscCall(PetscLogGpuFlops(2.0*n*n*n/3.0+2.0*n*n));
468: }
470: PetscCallCUDA(cudaMemcpy(d_Told,d_T,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
471: PetscCallCUDA(cudaMemcpy(d_invM,d_M,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
473: PetscCallMAGMA(magma_xgetrf_gpu,n,n,d_invM,ld,piv);
474: PetscCallMAGMA(magma_xgetri_gpu,n,d_invM,ld,piv,d_work,lwork);
475: PetscCall(PetscLogGpuFlops(2.0*n*n*n/3.0+4.0*n*n*n/3.0));
477: PetscCall(shift_diagonal(n,d_invM,ld,sone));
478: PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&spfive,d_Told,ld,d_invM,ld,&szero,d_T,ld));
479: PetscCall(shift_diagonal(n,d_invM,ld,smone));
481: PetscCallCUBLAS(cublasXaxpy(cublasv2handle,N,&sone,d_invM,one,d_M,one));
482: PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&sp25,d_M,one));
483: PetscCall(shift_diagonal(n,d_M,ld,sneg_pfive));
484: PetscCall(PetscLogGpuFlops(2.0*n*n*n+2.0*n*n));
486: PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_M,one,&Mres));
487: PetscCall(shift_diagonal(n,d_M,ld,sone));
489: if (scale) {
490: /* reldiff = norm(T - Told,'fro')/norm(T,'fro'); */
491: PetscCallCUBLAS(cublasXaxpy(cublasv2handle,N,&smone,d_T,one,d_Told,one));
492: PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_Told,one,&fnormdiff));
493: PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_T,one,&fnormT));
494: PetscCall(PetscLogGpuFlops(7.0*n*n));
495: reldiff = fnormdiff/fnormT;
496: PetscCall(PetscInfo(fn,"it: %" PetscInt_FMT " reldiff: %g scale: %g tol*scale: %g\n",it,(double)reldiff,(double)g,(double)tol*g));
497: if (reldiff<1e-2) scale = PETSC_FALSE; /* Switch to no scaling. */
498: }
500: PetscCall(PetscInfo(fn,"it: %" PetscInt_FMT " Mres: %g\n",it,(double)Mres));
501: if (Mres<=tol) converged = PETSC_TRUE;
502: }
504: PetscCheck(Mres<=tol,PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations", DBMAXIT);
505: PetscCall(PetscLogGpuTimeEnd());
506: PetscCall(PetscFree(piv));
507: PetscCallCUDA(cudaFree(d_work));
508: PetscCallCUDA(cudaFree(d_Told));
509: PetscCallCUDA(cudaFree(d_M));
510: PetscCallCUDA(cudaFree(d_invM));
511: PetscFunctionReturn(PETSC_SUCCESS);
512: }
513: #endif /* PETSC_HAVE_MAGMA */
515: #endif /* PETSC_HAVE_CUDA */
517: #define ITMAX 5
519: /*
520: Estimate norm(A^m,1) by block 1-norm power method (required workspace is 11*n)
521: */
522: static PetscErrorCode SlepcNormEst1(PetscBLASInt n,PetscScalar *A,PetscInt m,PetscScalar *work,PetscRandom rand,PetscReal *nrm)
523: {
524: PetscScalar *X,*Y,*Z,*S,*S_old,*aux,val,sone=1.0,szero=0.0;
525: PetscReal est=0.0,est_old,vals[2]={0.0,0.0},*zvals,maxzval[2],raux;
526: PetscBLASInt i,j,t=2,it=0,ind[2],est_j=0,m1;
528: PetscFunctionBegin;
529: X = work;
530: Y = work + 2*n;
531: Z = work + 4*n;
532: S = work + 6*n;
533: S_old = work + 8*n;
534: zvals = (PetscReal*)(work + 10*n);
536: for (i=0;i<n;i++) { /* X has columns of unit 1-norm */
537: X[i] = 1.0/n;
538: PetscCall(PetscRandomGetValue(rand,&val));
539: if (PetscRealPart(val) < 0.5) X[i+n] = -1.0/n;
540: else X[i+n] = 1.0/n;
541: }
542: for (i=0;i<t*n;i++) S[i] = 0.0;
543: ind[0] = 0; ind[1] = 0;
544: est_old = 0;
545: while (1) {
546: it++;
547: for (j=0;j<m;j++) { /* Y = A^m*X */
548: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&t,&n,&sone,A,&n,X,&n,&szero,Y,&n));
549: if (j<m-1) SlepcSwap(X,Y,aux);
550: }
551: for (j=0;j<t;j++) { /* vals[j] = norm(Y(:,j),1) */
552: vals[j] = 0.0;
553: for (i=0;i<n;i++) vals[j] += PetscAbsScalar(Y[i+j*n]);
554: }
555: if (vals[0]<vals[1]) {
556: SlepcSwap(vals[0],vals[1],raux);
557: m1 = 1;
558: } else m1 = 0;
559: est = vals[0];
560: if (est>est_old || it==2) est_j = ind[m1];
561: if (it>=2 && est<=est_old) {
562: est = est_old;
563: break;
564: }
565: est_old = est;
566: if (it>ITMAX) break;
567: SlepcSwap(S,S_old,aux);
568: for (i=0;i<t*n;i++) { /* S = sign(Y) */
569: S[i] = (PetscRealPart(Y[i]) < 0.0)? -1.0: 1.0;
570: }
571: for (j=0;j<m;j++) { /* Z = (A^T)^m*S */
572: PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&n,&t,&n,&sone,A,&n,S,&n,&szero,Z,&n));
573: if (j<m-1) SlepcSwap(S,Z,aux);
574: }
575: maxzval[0] = -1; maxzval[1] = -1;
576: ind[0] = 0; ind[1] = 0;
577: for (i=0;i<n;i++) { /* zvals[i] = norm(Z(i,:),inf) */
578: zvals[i] = PetscMax(PetscAbsScalar(Z[i+0*n]),PetscAbsScalar(Z[i+1*n]));
579: if (zvals[i]>maxzval[0]) {
580: maxzval[0] = zvals[i];
581: ind[0] = i;
582: } else if (zvals[i]>maxzval[1]) {
583: maxzval[1] = zvals[i];
584: ind[1] = i;
585: }
586: }
587: if (it>=2 && maxzval[0]==zvals[est_j]) break;
588: for (i=0;i<t*n;i++) X[i] = 0.0;
589: for (j=0;j<t;j++) X[ind[j]+j*n] = 1.0;
590: }
591: *nrm = est;
592: /* Flop count is roughly (it * 2*m * t*gemv) = 4*its*m*t*n*n */
593: PetscCall(PetscLogFlops(4.0*it*m*t*n*n));
594: PetscFunctionReturn(PETSC_SUCCESS);
595: }
597: #define SMALLN 100
599: /*
600: Estimate norm(A^m,1) (required workspace is 2*n*n)
601: */
602: PetscErrorCode SlepcNormAm(PetscBLASInt n,PetscScalar *A,PetscInt m,PetscScalar *work,PetscRandom rand,PetscReal *nrm)
603: {
604: PetscScalar *v=work,*w=work+n*n,*aux,sone=1.0,szero=0.0;
605: PetscReal rwork[1],tmp;
606: PetscBLASInt i,j,one=1;
607: PetscBool isrealpos=PETSC_TRUE;
609: PetscFunctionBegin;
610: if (n<SMALLN) { /* compute matrix power explicitly */
611: if (m==1) {
612: *nrm = LAPACKlange_("O",&n,&n,A,&n,rwork);
613: PetscCall(PetscLogFlops(1.0*n*n));
614: } else { /* m>=2 */
615: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,A,&n,A,&n,&szero,v,&n));
616: for (j=0;j<m-2;j++) {
617: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,A,&n,v,&n,&szero,w,&n));
618: SlepcSwap(v,w,aux);
619: }
620: *nrm = LAPACKlange_("O",&n,&n,v,&n,rwork);
621: PetscCall(PetscLogFlops(2.0*n*n*n*(m-1)+1.0*n*n));
622: }
623: } else {
624: for (i=0;i<n;i++)
625: for (j=0;j<n;j++)
626: #if defined(PETSC_USE_COMPLEX)
627: if (PetscRealPart(A[i+j*n])<0.0 || PetscImaginaryPart(A[i+j*n])!=0.0) { isrealpos = PETSC_FALSE; break; }
628: #else
629: if (A[i+j*n]<0.0) { isrealpos = PETSC_FALSE; break; }
630: #endif
631: if (isrealpos) { /* for positive matrices only */
632: for (i=0;i<n;i++) v[i] = 1.0;
633: for (j=0;j<m;j++) { /* w = A'*v */
634: PetscCallBLAS("BLASgemv",BLASgemv_("C",&n,&n,&sone,A,&n,v,&one,&szero,w,&one));
635: SlepcSwap(v,w,aux);
636: }
637: PetscCall(PetscLogFlops(2.0*n*n*m));
638: *nrm = 0.0;
639: for (i=0;i<n;i++) if ((tmp = PetscAbsScalar(v[i])) > *nrm) *nrm = tmp; /* norm(v,inf) */
640: } else PetscCall(SlepcNormEst1(n,A,m,work,rand,nrm));
641: }
642: PetscFunctionReturn(PETSC_SUCCESS);
643: }