Line data Source code
1 : /*
2 : - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
3 : SLEPc - Scalable Library for Eigenvalue Problem Computations
4 : Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain
5 :
6 : This file is part of SLEPc.
7 : SLEPc is distributed under a 2-clause BSD license (see LICENSE).
8 : - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
9 : */
10 : /*
11 : Utility subroutines common to several impls
12 : */
13 :
14 : #include <slepc/private/fnimpl.h> /*I "slepcfn.h" I*/
15 : #include <slepcblaslapack.h>
16 :
17 : /*
18 : Compute the square root of an upper quasi-triangular matrix T,
19 : using Higham's algorithm (LAA 88, 1987). T is overwritten with sqrtm(T).
20 : */
21 126 : static PetscErrorCode SlepcMatDenseSqrt(PetscBLASInt n,PetscScalar *T,PetscBLASInt ld)
22 : {
23 126 : PetscScalar one=1.0,mone=-1.0;
24 126 : PetscReal scal;
25 126 : PetscBLASInt i,j,si,sj,r,ione=1,info;
26 : #if !defined(PETSC_USE_COMPLEX)
27 : PetscReal alpha,theta,mu,mu2;
28 : #endif
29 :
30 126 : PetscFunctionBegin;
31 4955 : for (j=0;j<n;j++) {
32 : #if defined(PETSC_USE_COMPLEX)
33 4829 : sj = 1;
34 4829 : T[j+j*ld] = PetscSqrtScalar(T[j+j*ld]);
35 : #else
36 : sj = (j==n-1 || T[j+1+j*ld] == 0.0)? 1: 2;
37 : if (sj==1) {
38 : PetscCheck(T[j+j*ld]>=0.0,PETSC_COMM_SELF,PETSC_ERR_USER_INPUT,"Matrix has a real negative eigenvalue, no real primary square root exists");
39 : T[j+j*ld] = PetscSqrtReal(T[j+j*ld]);
40 : } else {
41 : /* square root of 2x2 block */
42 : theta = (T[j+j*ld]+T[j+1+(j+1)*ld])/2.0;
43 : mu = (T[j+j*ld]-T[j+1+(j+1)*ld])/2.0;
44 : mu2 = -mu*mu-T[j+1+j*ld]*T[j+(j+1)*ld];
45 : mu = PetscSqrtReal(mu2);
46 : if (theta>0.0) alpha = PetscSqrtReal((theta+PetscSqrtReal(theta*theta+mu2))/2.0);
47 : else alpha = mu/PetscSqrtReal(2.0*(-theta+PetscSqrtReal(theta*theta+mu2)));
48 : T[j+j*ld] /= 2.0*alpha;
49 : T[j+1+(j+1)*ld] /= 2.0*alpha;
50 : T[j+(j+1)*ld] /= 2.0*alpha;
51 : T[j+1+j*ld] /= 2.0*alpha;
52 : T[j+j*ld] += alpha-theta/(2.0*alpha);
53 : T[j+1+(j+1)*ld] += alpha-theta/(2.0*alpha);
54 : }
55 : #endif
56 131369 : for (i=j-1;i>=0;i--) {
57 : #if defined(PETSC_USE_COMPLEX)
58 126540 : si = 1;
59 : #else
60 : si = (i==0 || T[i+(i-1)*ld] == 0.0)? 1: 2;
61 : if (si==2) i--;
62 : #endif
63 : /* solve Sylvester equation of order si x sj */
64 126540 : r = j-i-si;
65 126540 : if (r) PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&si,&sj,&r,&mone,T+i+(i+si)*ld,&ld,T+i+si+j*ld,&ld,&one,T+i+j*ld,&ld));
66 126540 : PetscCallBLAS("LAPACKtrsyl",LAPACKtrsyl_("N","N",&ione,&si,&sj,T+i+i*ld,&ld,T+j+j*ld,&ld,T+i+j*ld,&ld,&scal,&info));
67 126540 : SlepcCheckLapackInfo("trsyl",info);
68 126540 : PetscCheck(scal==1.0,PETSC_COMM_SELF,PETSC_ERR_SUP,"Current implementation cannot handle scale factor %g",(double)scal);
69 : }
70 4829 : if (sj==2) j++;
71 : }
72 126 : PetscFunctionReturn(PETSC_SUCCESS);
73 : }
74 :
75 : #define BLOCKSIZE 64
76 :
77 : /*
78 : Schur method for the square root of an upper quasi-triangular matrix T.
79 : T is overwritten with sqrtm(T).
80 : If firstonly then only the first column of T will contain relevant values.
81 : */
82 73 : PetscErrorCode FNSqrtmSchur(FN fn,PetscBLASInt n,PetscScalar *T,PetscBLASInt ld,PetscBool firstonly)
83 : {
84 73 : PetscBLASInt i,j,k,r,ione=1,sdim,lwork,*s,*p,info,bs=BLOCKSIZE;
85 73 : PetscScalar *wr,*W,*Q,*work,one=1.0,zero=0.0,mone=-1.0;
86 73 : PetscInt m,nblk;
87 73 : PetscReal scal;
88 : #if defined(PETSC_USE_COMPLEX)
89 73 : PetscReal *rwork;
90 : #else
91 : PetscReal *wi;
92 : #endif
93 :
94 73 : PetscFunctionBegin;
95 73 : m = n;
96 73 : nblk = (m+bs-1)/bs;
97 73 : lwork = 5*n;
98 73 : k = firstonly? 1: n;
99 :
100 : /* compute Schur decomposition A*Q = Q*T */
101 : #if !defined(PETSC_USE_COMPLEX)
102 : PetscCall(PetscMalloc7(m,&wr,m,&wi,m*k,&W,m*m,&Q,lwork,&work,nblk,&s,nblk,&p));
103 : PetscCallBLAS("LAPACKgees",LAPACKgees_("V","N",NULL,&n,T,&ld,&sdim,wr,wi,Q,&ld,work,&lwork,NULL,&info));
104 : #else
105 73 : PetscCall(PetscMalloc7(m,&wr,m,&rwork,m*k,&W,m*m,&Q,lwork,&work,nblk,&s,nblk,&p));
106 73 : PetscCallBLAS("LAPACKgees",LAPACKgees_("V","N",NULL,&n,T,&ld,&sdim,wr,Q,&ld,work,&lwork,rwork,NULL,&info));
107 : #endif
108 73 : SlepcCheckLapackInfo("gees",info);
109 :
110 : /* determine block sizes and positions, to avoid cutting 2x2 blocks */
111 73 : j = 0;
112 73 : p[j] = 0;
113 126 : do {
114 126 : s[j] = PetscMin(bs,n-p[j]);
115 : #if !defined(PETSC_USE_COMPLEX)
116 : if (p[j]+s[j]!=n && T[p[j]+s[j]+(p[j]+s[j]-1)*ld]!=0.0) s[j]++;
117 : #endif
118 126 : if (p[j]+s[j]==n) break;
119 53 : j++;
120 53 : p[j] = p[j-1]+s[j-1];
121 53 : } while (1);
122 199 : nblk = j+1;
123 :
124 199 : for (j=0;j<nblk;j++) {
125 : /* evaluate f(T_jj) */
126 126 : PetscCall(SlepcMatDenseSqrt(s[j],T+p[j]+p[j]*ld,ld));
127 179 : for (i=j-1;i>=0;i--) {
128 : /* solve Sylvester equation for block (i,j) */
129 53 : r = p[j]-p[i]-s[i];
130 53 : if (r) PetscCallBLAS("BLASgemm",BLASgemm_("N","N",s+i,s+j,&r,&mone,T+p[i]+(p[i]+s[i])*ld,&ld,T+p[i]+s[i]+p[j]*ld,&ld,&one,T+p[i]+p[j]*ld,&ld));
131 53 : PetscCallBLAS("LAPACKtrsyl",LAPACKtrsyl_("N","N",&ione,s+i,s+j,T+p[i]+p[i]*ld,&ld,T+p[j]+p[j]*ld,&ld,T+p[i]+p[j]*ld,&ld,&scal,&info));
132 53 : SlepcCheckLapackInfo("trsyl",info);
133 53 : PetscCheck(scal==1.0,PETSC_COMM_SELF,PETSC_ERR_SUP,"Current implementation cannot handle scale factor %g",(double)scal);
134 : }
135 : }
136 :
137 : /* backtransform B = Q*T*Q' */
138 73 : PetscCallBLAS("BLASgemm",BLASgemm_("N","C",&n,&k,&n,&one,T,&ld,Q,&ld,&zero,W,&ld));
139 73 : PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&k,&n,&one,Q,&ld,W,&ld,&zero,T,&ld));
140 :
141 : /* flop count: Schur decomposition, triangular square root, and backtransform */
142 73 : PetscCall(PetscLogFlops(25.0*n*n*n+n*n*n/3.0+4.0*n*n*k));
143 :
144 : #if !defined(PETSC_USE_COMPLEX)
145 : PetscCall(PetscFree7(wr,wi,W,Q,work,s,p));
146 : #else
147 73 : PetscCall(PetscFree7(wr,rwork,W,Q,work,s,p));
148 : #endif
149 73 : PetscFunctionReturn(PETSC_SUCCESS);
150 : }
151 :
152 : #define DBMAXIT 25
153 :
154 : /*
155 : Computes the principal square root of the matrix T using the product form
156 : of the Denman-Beavers iteration.
157 : T is overwritten with sqrtm(T) or inv(sqrtm(T)) depending on flag inv.
158 : */
159 24 : PetscErrorCode FNSqrtmDenmanBeavers(FN fn,PetscBLASInt n,PetscScalar *T,PetscBLASInt ld,PetscBool inv)
160 : {
161 24 : PetscScalar *Told,*M=NULL,*invM,*work,work1,prod,alpha;
162 24 : PetscScalar szero=0.0,sone=1.0,smone=-1.0,spfive=0.5,sp25=0.25;
163 24 : PetscReal tol,Mres=0.0,detM,g,reldiff,fnormdiff,fnormT,rwork[1];
164 24 : PetscBLASInt N,i,it,*piv=NULL,info,query=-1,lwork;
165 24 : const PetscBLASInt one=1;
166 24 : PetscBool converged=PETSC_FALSE,scale;
167 24 : unsigned int ftz;
168 :
169 24 : PetscFunctionBegin;
170 24 : N = n*n;
171 24 : tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
172 24 : scale = PetscDefined(USE_REAL_SINGLE)? PETSC_FALSE: PETSC_TRUE;
173 24 : PetscCall(SlepcSetFlushToZero(&ftz));
174 :
175 : /* query work size */
176 24 : PetscCallBLAS("LAPACKgetri",LAPACKgetri_(&n,M,&ld,piv,&work1,&query,&info));
177 24 : PetscCall(PetscBLASIntCast((PetscInt)PetscRealPart(work1),&lwork));
178 24 : PetscCall(PetscMalloc5(lwork,&work,n,&piv,n*n,&Told,n*n,&M,n*n,&invM));
179 24 : PetscCall(PetscArraycpy(M,T,n*n));
180 :
181 24 : if (inv) { /* start recurrence with I instead of A */
182 12 : PetscCall(PetscArrayzero(T,n*n));
183 132 : for (i=0;i<n;i++) T[i+i*ld] += 1.0;
184 : }
185 :
186 144 : for (it=0;it<DBMAXIT && !converged;it++) {
187 :
188 120 : if (scale) { /* g = (abs(det(M)))^(-1/(2*n)) */
189 72 : PetscCall(PetscArraycpy(invM,M,n*n));
190 72 : PetscCallBLAS("LAPACKgetrf",LAPACKgetrf_(&n,&n,invM,&ld,piv,&info));
191 72 : SlepcCheckLapackInfo("getrf",info);
192 72 : prod = invM[0];
193 3960 : for (i=1;i<n;i++) prod *= invM[i+i*ld];
194 72 : detM = PetscAbsScalar(prod);
195 72 : g = (detM>PETSC_MAX_REAL)? 0.5: PetscPowReal(detM,-1.0/(2.0*n));
196 72 : alpha = g;
197 72 : PetscCallBLAS("BLASscal",BLASscal_(&N,&alpha,T,&one));
198 72 : alpha = g*g;
199 72 : PetscCallBLAS("BLASscal",BLASscal_(&N,&alpha,M,&one));
200 72 : PetscCall(PetscLogFlops(2.0*n*n*n/3.0+2.0*n*n));
201 : }
202 :
203 120 : PetscCall(PetscArraycpy(Told,T,n*n));
204 120 : PetscCall(PetscArraycpy(invM,M,n*n));
205 :
206 120 : PetscCallBLAS("LAPACKgetrf",LAPACKgetrf_(&n,&n,invM,&ld,piv,&info));
207 120 : SlepcCheckLapackInfo("getrf",info);
208 120 : PetscCallBLAS("LAPACKgetri",LAPACKgetri_(&n,invM,&ld,piv,work,&lwork,&info));
209 120 : SlepcCheckLapackInfo("getri",info);
210 120 : PetscCall(PetscLogFlops(2.0*n*n*n/3.0+4.0*n*n*n/3.0));
211 :
212 7080 : for (i=0;i<n;i++) invM[i+i*ld] += 1.0;
213 120 : PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&spfive,Told,&ld,invM,&ld,&szero,T,&ld));
214 7080 : for (i=0;i<n;i++) invM[i+i*ld] -= 1.0;
215 :
216 120 : PetscCallBLAS("BLASaxpy",BLASaxpy_(&N,&sone,invM,&one,M,&one));
217 120 : PetscCallBLAS("BLASscal",BLASscal_(&N,&sp25,M,&one));
218 7080 : for (i=0;i<n;i++) M[i+i*ld] -= 0.5;
219 120 : PetscCall(PetscLogFlops(2.0*n*n*n+2.0*n*n));
220 :
221 120 : Mres = LAPACKlange_("F",&n,&n,M,&n,rwork);
222 7200 : for (i=0;i<n;i++) M[i+i*ld] += 1.0;
223 :
224 120 : if (scale) {
225 : /* reldiff = norm(T - Told,'fro')/norm(T,'fro') */
226 72 : PetscCallBLAS("BLASaxpy",BLASaxpy_(&N,&smone,T,&one,Told,&one));
227 72 : fnormdiff = LAPACKlange_("F",&n,&n,Told,&n,rwork);
228 72 : fnormT = LAPACKlange_("F",&n,&n,T,&n,rwork);
229 72 : PetscCall(PetscLogFlops(7.0*n*n));
230 72 : reldiff = fnormdiff/fnormT;
231 72 : PetscCall(PetscInfo(fn,"it: %" PetscBLASInt_FMT " reldiff: %g scale: %g tol*scale: %g\n",it,(double)reldiff,(double)g,(double)(tol*g)));
232 72 : if (reldiff<1e-2) scale = PETSC_FALSE; /* Switch off scaling */
233 : }
234 :
235 120 : if (Mres<=tol) converged = PETSC_TRUE;
236 : }
237 :
238 24 : PetscCheck(Mres<=tol,PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations",DBMAXIT);
239 24 : PetscCall(PetscFree5(work,piv,Told,M,invM));
240 24 : PetscCall(SlepcResetFlushToZero(&ftz));
241 24 : PetscFunctionReturn(PETSC_SUCCESS);
242 : }
243 :
244 : #define NSMAXIT 50
245 :
246 : /*
247 : Computes the principal square root of the matrix A using the Newton-Schulz iteration.
248 : T is overwritten with sqrtm(T) or inv(sqrtm(T)) depending on flag inv.
249 : */
250 24 : PetscErrorCode FNSqrtmNewtonSchulz(FN fn,PetscBLASInt n,PetscScalar *A,PetscBLASInt ld,PetscBool inv)
251 : {
252 24 : PetscScalar *Y=A,*Yold,*Z,*Zold,*M;
253 24 : PetscScalar szero=0.0,sone=1.0,smone=-1.0,spfive=0.5,sthree=3.0;
254 24 : PetscReal sqrtnrm,tol,Yres=0.0,nrm,rwork[1],done=1.0;
255 24 : PetscBLASInt info,i,it,N,one=1,zero=0;
256 24 : PetscBool converged=PETSC_FALSE;
257 24 : unsigned int ftz;
258 :
259 24 : PetscFunctionBegin;
260 24 : N = n*n;
261 24 : tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
262 24 : PetscCall(SlepcSetFlushToZero(&ftz));
263 :
264 24 : PetscCall(PetscMalloc4(N,&Yold,N,&Z,N,&Zold,N,&M));
265 :
266 : /* scale */
267 24 : PetscCall(PetscArraycpy(Z,A,N));
268 1344 : for (i=0;i<n;i++) Z[i+i*ld] -= 1.0;
269 24 : nrm = LAPACKlange_("fro",&n,&n,Z,&n,rwork);
270 24 : sqrtnrm = PetscSqrtReal(nrm);
271 24 : PetscCallBLAS("LAPACKlascl",LAPACKlascl_("G",&zero,&zero,&nrm,&done,&N,&one,A,&N,&info));
272 24 : SlepcCheckLapackInfo("lascl",info);
273 24 : tol *= nrm;
274 24 : PetscCall(PetscInfo(fn,"||I-A||_F = %g, new tol: %g\n",(double)nrm,(double)tol));
275 24 : PetscCall(PetscLogFlops(2.0*n*n));
276 :
277 : /* Z = I */
278 24 : PetscCall(PetscArrayzero(Z,N));
279 1344 : for (i=0;i<n;i++) Z[i+i*ld] = 1.0;
280 :
281 280 : for (it=0;it<NSMAXIT && !converged;it++) {
282 : /* Yold = Y, Zold = Z */
283 256 : PetscCall(PetscArraycpy(Yold,Y,N));
284 256 : PetscCall(PetscArraycpy(Zold,Z,N));
285 :
286 : /* M = (3*I-Zold*Yold) */
287 256 : PetscCall(PetscArrayzero(M,N));
288 16856 : for (i=0;i<n;i++) M[i+i*ld] = sthree;
289 256 : PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&smone,Zold,&ld,Yold,&ld,&sone,M,&ld));
290 :
291 : /* Y = (1/2)*Yold*M, Z = (1/2)*M*Zold */
292 256 : PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&spfive,Yold,&ld,M,&ld,&szero,Y,&ld));
293 256 : PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&spfive,M,&ld,Zold,&ld,&szero,Z,&ld));
294 :
295 : /* reldiff = norm(Y-Yold,'fro')/norm(Y,'fro') */
296 256 : PetscCallBLAS("BLASaxpy",BLASaxpy_(&N,&smone,Y,&one,Yold,&one));
297 256 : Yres = LAPACKlange_("fro",&n,&n,Yold,&n,rwork);
298 256 : PetscCheck(!PetscIsNanReal(Yres),PETSC_COMM_SELF,PETSC_ERR_FP,"The computed norm is not-a-number");
299 256 : if (Yres<=tol) converged = PETSC_TRUE;
300 256 : PetscCall(PetscInfo(fn,"it: %" PetscBLASInt_FMT " res: %g\n",it,(double)Yres));
301 :
302 256 : PetscCall(PetscLogFlops(6.0*n*n*n+2.0*n*n));
303 : }
304 :
305 24 : PetscCheck(Yres<=tol,PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations",NSMAXIT);
306 :
307 : /* undo scaling */
308 24 : if (inv) {
309 12 : PetscCall(PetscArraycpy(A,Z,N));
310 12 : PetscCallBLAS("LAPACKlascl",LAPACKlascl_("G",&zero,&zero,&sqrtnrm,&done,&N,&one,A,&N,&info));
311 12 : } else PetscCallBLAS("LAPACKlascl",LAPACKlascl_("G",&zero,&zero,&done,&sqrtnrm,&N,&one,A,&N,&info));
312 24 : SlepcCheckLapackInfo("lascl",info);
313 :
314 24 : PetscCall(PetscFree4(Yold,Z,Zold,M));
315 24 : PetscCall(SlepcResetFlushToZero(&ftz));
316 24 : PetscFunctionReturn(PETSC_SUCCESS);
317 : }
318 :
319 : #if defined(PETSC_HAVE_CUDA)
320 : #include "../src/sys/classes/fn/impls/cuda/fnutilcuda.h"
321 : #include <slepccupmblas.h>
322 :
323 : /*
324 : * Matrix square root by Newton-Schulz iteration. CUDA version.
325 : * Computes the principal square root of the matrix A using the
326 : * Newton-Schulz iteration. A is overwritten with sqrtm(A).
327 : */
328 : PetscErrorCode FNSqrtmNewtonSchulz_CUDA(FN fn,PetscBLASInt n,PetscScalar *d_A,PetscBLASInt ld,PetscBool inv)
329 : {
330 : PetscScalar *d_Yold,*d_Z,*d_Zold,*d_M,alpha;
331 : PetscReal nrm,sqrtnrm,tol,Yres=0.0;
332 : const PetscScalar szero=0.0,sone=1.0,smone=-1.0,spfive=0.5,sthree=3.0;
333 : PetscInt it;
334 : PetscBLASInt N;
335 : const PetscBLASInt one=1;
336 : PetscBool converged=PETSC_FALSE;
337 : cublasHandle_t cublasv2handle;
338 :
339 : PetscFunctionBegin;
340 : PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA)); /* For CUDA event timers */
341 : PetscCall(PetscCUBLASGetHandle(&cublasv2handle));
342 : N = n*n;
343 : tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
344 :
345 : PetscCallCUDA(cudaMalloc((void **)&d_Yold,sizeof(PetscScalar)*N));
346 : PetscCallCUDA(cudaMalloc((void **)&d_Z,sizeof(PetscScalar)*N));
347 : PetscCallCUDA(cudaMalloc((void **)&d_Zold,sizeof(PetscScalar)*N));
348 : PetscCallCUDA(cudaMalloc((void **)&d_M,sizeof(PetscScalar)*N));
349 :
350 : PetscCall(PetscLogGpuTimeBegin());
351 :
352 : /* Z = I; */
353 : PetscCallCUDA(cudaMemset(d_Z,0,sizeof(PetscScalar)*N));
354 : PetscCall(set_diagonal(n,d_Z,ld,sone));
355 :
356 : /* scale */
357 : PetscCallCUBLAS(cublasXaxpy(cublasv2handle,N,&smone,d_A,one,d_Z,one));
358 : PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_Z,one,&nrm));
359 : sqrtnrm = PetscSqrtReal(nrm);
360 : alpha = 1.0/nrm;
361 : PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_A,one));
362 : tol *= nrm;
363 : PetscCall(PetscInfo(fn,"||I-A||_F = %g, new tol: %g\n",(double)nrm,(double)tol));
364 : PetscCall(PetscLogGpuFlops(2.0*n*n));
365 :
366 : /* Z = I; */
367 : PetscCallCUDA(cudaMemset(d_Z,0,sizeof(PetscScalar)*N));
368 : PetscCall(set_diagonal(n,d_Z,ld,sone));
369 :
370 : for (it=0;it<NSMAXIT && !converged;it++) {
371 : /* Yold = Y, Zold = Z */
372 : PetscCallCUDA(cudaMemcpy(d_Yold,d_A,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
373 : PetscCallCUDA(cudaMemcpy(d_Zold,d_Z,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
374 :
375 : /* M = (3*I - Zold*Yold) */
376 : PetscCallCUDA(cudaMemset(d_M,0,sizeof(PetscScalar)*N));
377 : PetscCall(set_diagonal(n,d_M,ld,sthree));
378 : PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&smone,d_Zold,ld,d_Yold,ld,&sone,d_M,ld));
379 :
380 : /* Y = (1/2) * Yold * M, Z = (1/2) * M * Zold */
381 : PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&spfive,d_Yold,ld,d_M,ld,&szero,d_A,ld));
382 : PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&spfive,d_M,ld,d_Zold,ld,&szero,d_Z,ld));
383 :
384 : /* reldiff = norm(Y-Yold,'fro')/norm(Y,'fro') */
385 : PetscCallCUBLAS(cublasXaxpy(cublasv2handle,N,&smone,d_A,one,d_Yold,one));
386 : PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_Yold,one,&Yres));
387 : PetscCheck(!PetscIsNanReal(Yres),PETSC_COMM_SELF,PETSC_ERR_FP,"The computed norm is not-a-number");
388 : if (Yres<=tol) converged = PETSC_TRUE;
389 : PetscCall(PetscInfo(fn,"it: %" PetscInt_FMT " res: %g\n",it,(double)Yres));
390 :
391 : PetscCall(PetscLogGpuFlops(6.0*n*n*n+2.0*n*n));
392 : }
393 :
394 : PetscCheck(Yres<=tol,PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations", NSMAXIT);
395 :
396 : /* undo scaling */
397 : if (inv) {
398 : alpha = 1.0/sqrtnrm;
399 : PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_Z,one));
400 : PetscCallCUDA(cudaMemcpy(d_A,d_Z,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
401 : } else {
402 : alpha = sqrtnrm;
403 : PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_A,one));
404 : }
405 :
406 : PetscCall(PetscLogGpuTimeEnd());
407 : PetscCallCUDA(cudaFree(d_Yold));
408 : PetscCallCUDA(cudaFree(d_Z));
409 : PetscCallCUDA(cudaFree(d_Zold));
410 : PetscCallCUDA(cudaFree(d_M));
411 : PetscFunctionReturn(PETSC_SUCCESS);
412 : }
413 :
414 : #if defined(PETSC_HAVE_MAGMA)
415 : #include <slepcmagma.h>
416 :
417 : /*
418 : * Matrix square root by product form of Denman-Beavers iteration. CUDA version.
419 : * Computes the principal square root of the matrix T using the product form
420 : * of the Denman-Beavers iteration. T is overwritten with sqrtm(T).
421 : */
422 : PetscErrorCode FNSqrtmDenmanBeavers_CUDAm(FN fn,PetscBLASInt n,PetscScalar *d_T,PetscBLASInt ld,PetscBool inv)
423 : {
424 : PetscScalar *d_Told,*d_M,*d_invM,*d_work,prod,szero=0.0,sone=1.0,smone=-1.0,spfive=0.5,sneg_pfive=-0.5,sp25=0.25,alpha;
425 : PetscReal tol,Mres=0.0,detM,g,reldiff,fnormdiff,fnormT;
426 : PetscInt it,lwork,nb;
427 : PetscBLASInt N,one=1,*piv=NULL;
428 : PetscBool converged=PETSC_FALSE,scale;
429 : cublasHandle_t cublasv2handle;
430 :
431 : PetscFunctionBegin;
432 : PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA)); /* For CUDA event timers */
433 : PetscCall(PetscCUBLASGetHandle(&cublasv2handle));
434 : PetscCall(SlepcMagmaInit());
435 : N = n*n;
436 : scale = PetscDefined(USE_REAL_SINGLE)? PETSC_FALSE: PETSC_TRUE;
437 : tol = PetscSqrtReal((PetscReal)n)*PETSC_MACHINE_EPSILON/2;
438 :
439 : /* query work size */
440 : nb = magma_get_xgetri_nb(n);
441 : lwork = nb*n;
442 : PetscCall(PetscMalloc1(n,&piv));
443 : PetscCallCUDA(cudaMalloc((void **)&d_work,sizeof(PetscScalar)*lwork));
444 : PetscCallCUDA(cudaMalloc((void **)&d_Told,sizeof(PetscScalar)*N));
445 : PetscCallCUDA(cudaMalloc((void **)&d_M,sizeof(PetscScalar)*N));
446 : PetscCallCUDA(cudaMalloc((void **)&d_invM,sizeof(PetscScalar)*N));
447 :
448 : PetscCall(PetscLogGpuTimeBegin());
449 : PetscCallCUDA(cudaMemcpy(d_M,d_T,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
450 : if (inv) { /* start recurrence with I instead of A */
451 : PetscCallCUDA(cudaMemset(d_T,0,sizeof(PetscScalar)*N));
452 : PetscCall(set_diagonal(n,d_T,ld,1.0));
453 : }
454 :
455 : for (it=0;it<DBMAXIT && !converged;it++) {
456 :
457 : if (scale) { /* g = (abs(det(M)))^(-1/(2*n)); */
458 : PetscCallCUDA(cudaMemcpy(d_invM,d_M,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
459 : PetscCallMAGMA(magma_xgetrf_gpu,n,n,d_invM,ld,piv);
460 : PetscCall(mult_diagonal(n,d_invM,ld,&prod));
461 : detM = PetscAbsScalar(prod);
462 : g = (detM>PETSC_MAX_REAL)? 0.5: PetscPowReal(detM,-1.0/(2.0*n));
463 : alpha = g;
464 : PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_T,one));
465 : alpha = g*g;
466 : PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&alpha,d_M,one));
467 : PetscCall(PetscLogGpuFlops(2.0*n*n*n/3.0+2.0*n*n));
468 : }
469 :
470 : PetscCallCUDA(cudaMemcpy(d_Told,d_T,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
471 : PetscCallCUDA(cudaMemcpy(d_invM,d_M,sizeof(PetscScalar)*N,cudaMemcpyDeviceToDevice));
472 :
473 : PetscCallMAGMA(magma_xgetrf_gpu,n,n,d_invM,ld,piv);
474 : PetscCallMAGMA(magma_xgetri_gpu,n,d_invM,ld,piv,d_work,lwork);
475 : PetscCall(PetscLogGpuFlops(2.0*n*n*n/3.0+4.0*n*n*n/3.0));
476 :
477 : PetscCall(shift_diagonal(n,d_invM,ld,sone));
478 : PetscCallCUBLAS(cublasXgemm(cublasv2handle,CUBLAS_OP_N,CUBLAS_OP_N,n,n,n,&spfive,d_Told,ld,d_invM,ld,&szero,d_T,ld));
479 : PetscCall(shift_diagonal(n,d_invM,ld,smone));
480 :
481 : PetscCallCUBLAS(cublasXaxpy(cublasv2handle,N,&sone,d_invM,one,d_M,one));
482 : PetscCallCUBLAS(cublasXscal(cublasv2handle,N,&sp25,d_M,one));
483 : PetscCall(shift_diagonal(n,d_M,ld,sneg_pfive));
484 : PetscCall(PetscLogGpuFlops(2.0*n*n*n+2.0*n*n));
485 :
486 : PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_M,one,&Mres));
487 : PetscCall(shift_diagonal(n,d_M,ld,sone));
488 :
489 : if (scale) {
490 : /* reldiff = norm(T - Told,'fro')/norm(T,'fro'); */
491 : PetscCallCUBLAS(cublasXaxpy(cublasv2handle,N,&smone,d_T,one,d_Told,one));
492 : PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_Told,one,&fnormdiff));
493 : PetscCallCUBLAS(cublasXnrm2(cublasv2handle,N,d_T,one,&fnormT));
494 : PetscCall(PetscLogGpuFlops(7.0*n*n));
495 : reldiff = fnormdiff/fnormT;
496 : PetscCall(PetscInfo(fn,"it: %" PetscInt_FMT " reldiff: %g scale: %g tol*scale: %g\n",it,(double)reldiff,(double)g,(double)tol*g));
497 : if (reldiff<1e-2) scale = PETSC_FALSE; /* Switch to no scaling. */
498 : }
499 :
500 : PetscCall(PetscInfo(fn,"it: %" PetscInt_FMT " Mres: %g\n",it,(double)Mres));
501 : if (Mres<=tol) converged = PETSC_TRUE;
502 : }
503 :
504 : PetscCheck(Mres<=tol,PETSC_COMM_SELF,PETSC_ERR_LIB,"SQRTM not converged after %d iterations", DBMAXIT);
505 : PetscCall(PetscLogGpuTimeEnd());
506 : PetscCall(PetscFree(piv));
507 : PetscCallCUDA(cudaFree(d_work));
508 : PetscCallCUDA(cudaFree(d_Told));
509 : PetscCallCUDA(cudaFree(d_M));
510 : PetscCallCUDA(cudaFree(d_invM));
511 : PetscFunctionReturn(PETSC_SUCCESS);
512 : }
513 : #endif /* PETSC_HAVE_MAGMA */
514 :
515 : #endif /* PETSC_HAVE_CUDA */
516 :
517 : #define ITMAX 5
518 :
519 : /*
520 : Estimate norm(A^m,1) by block 1-norm power method (required workspace is 11*n)
521 : */
522 28 : static PetscErrorCode SlepcNormEst1(PetscBLASInt n,PetscScalar *A,PetscInt m,PetscScalar *work,PetscRandom rand,PetscReal *nrm)
523 : {
524 28 : PetscScalar *X,*Y,*Z,*S,*S_old,*aux,val,sone=1.0,szero=0.0;
525 28 : PetscReal est=0.0,est_old,vals[2]={0.0,0.0},*zvals,maxzval[2],raux;
526 28 : PetscBLASInt i,j,t=2,it=0,ind[2],est_j=0,m1;
527 :
528 28 : PetscFunctionBegin;
529 28 : X = work;
530 28 : Y = work + 2*n;
531 28 : Z = work + 4*n;
532 28 : S = work + 6*n;
533 28 : S_old = work + 8*n;
534 28 : zvals = (PetscReal*)(work + 10*n);
535 :
536 3398 : for (i=0;i<n;i++) { /* X has columns of unit 1-norm */
537 3370 : X[i] = 1.0/n;
538 3370 : PetscCall(PetscRandomGetValue(rand,&val));
539 3370 : if (PetscRealPart(val) < 0.5) X[i+n] = -1.0/n;
540 1713 : else X[i+n] = 1.0/n;
541 : }
542 6768 : for (i=0;i<t*n;i++) S[i] = 0.0;
543 28 : ind[0] = 0; ind[1] = 0;
544 28 : est_old = 0;
545 56 : while (1) {
546 56 : it++;
547 252 : for (j=0;j<m;j++) { /* Y = A^m*X */
548 196 : PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&t,&n,&sone,A,&n,X,&n,&szero,Y,&n));
549 196 : if (j<m-1) SlepcSwap(X,Y,aux);
550 : }
551 168 : for (j=0;j<t;j++) { /* vals[j] = norm(Y(:,j),1) */
552 112 : vals[j] = 0.0;
553 13592 : for (i=0;i<n;i++) vals[j] += PetscAbsScalar(Y[i+j*n]);
554 : }
555 56 : if (vals[0]<vals[1]) {
556 36 : SlepcSwap(vals[0],vals[1],raux);
557 36 : m1 = 1;
558 : } else m1 = 0;
559 56 : est = vals[0];
560 56 : if (est>est_old || it==2) est_j = ind[m1];
561 56 : if (it>=2 && est<=est_old) {
562 : est = est_old;
563 : break;
564 : }
565 56 : est_old = est;
566 56 : if (it>ITMAX) break;
567 13536 : SlepcSwap(S,S_old,aux);
568 13536 : for (i=0;i<t*n;i++) { /* S = sign(Y) */
569 23245 : S[i] = (PetscRealPart(Y[i]) < 0.0)? -1.0: 1.0;
570 : }
571 252 : for (j=0;j<m;j++) { /* Z = (A^T)^m*S */
572 196 : PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&n,&t,&n,&sone,A,&n,S,&n,&szero,Z,&n));
573 196 : if (j<m-1) SlepcSwap(S,Z,aux);
574 : }
575 56 : maxzval[0] = -1; maxzval[1] = -1;
576 56 : ind[0] = 0; ind[1] = 0;
577 6796 : for (i=0;i<n;i++) { /* zvals[i] = norm(Z(i,:),inf) */
578 6740 : zvals[i] = PetscMax(PetscAbsScalar(Z[i+0*n]),PetscAbsScalar(Z[i+1*n]));
579 6740 : if (zvals[i]>maxzval[0]) {
580 642 : maxzval[0] = zvals[i];
581 642 : ind[0] = i;
582 6098 : } else if (zvals[i]>maxzval[1]) {
583 264 : maxzval[1] = zvals[i];
584 264 : ind[1] = i;
585 : }
586 : }
587 56 : if (it>=2 && maxzval[0]==zvals[est_j]) break;
588 6768 : for (i=0;i<t*n;i++) X[i] = 0.0;
589 84 : for (j=0;j<t;j++) X[ind[j]+j*n] = 1.0;
590 : }
591 28 : *nrm = est;
592 : /* Flop count is roughly (it * 2*m * t*gemv) = 4*its*m*t*n*n */
593 28 : PetscCall(PetscLogFlops(4.0*it*m*t*n*n));
594 28 : PetscFunctionReturn(PETSC_SUCCESS);
595 : }
596 :
597 : #define SMALLN 100
598 :
599 : /*
600 : Estimate norm(A^m,1) (required workspace is 2*n*n)
601 : */
602 541 : PetscErrorCode SlepcNormAm(PetscBLASInt n,PetscScalar *A,PetscInt m,PetscScalar *work,PetscRandom rand,PetscReal *nrm)
603 : {
604 541 : PetscScalar *v=work,*w=work+n*n,*aux,sone=1.0,szero=0.0;
605 541 : PetscReal rwork[1],tmp;
606 541 : PetscBLASInt i,j,one=1;
607 541 : PetscBool isrealpos=PETSC_TRUE;
608 :
609 541 : PetscFunctionBegin;
610 541 : if (n<SMALLN) { /* compute matrix power explicitly */
611 495 : if (m==1) {
612 0 : *nrm = LAPACKlange_("O",&n,&n,A,&n,rwork);
613 0 : PetscCall(PetscLogFlops(1.0*n*n));
614 : } else { /* m>=2 */
615 495 : PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,A,&n,A,&n,&szero,v,&n));
616 7498 : for (j=0;j<m-2;j++) {
617 7003 : PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&n,&n,&n,&sone,A,&n,v,&n,&szero,w,&n));
618 7003 : SlepcSwap(v,w,aux);
619 : }
620 495 : *nrm = LAPACKlange_("O",&n,&n,v,&n,rwork);
621 495 : PetscCall(PetscLogFlops(2.0*n*n*n*(m-1)+1.0*n*n));
622 : }
623 : } else {
624 5501 : for (i=0;i<n;i++)
625 442126 : for (j=0;j<n;j++)
626 : #if defined(PETSC_USE_COMPLEX)
627 440029 : if (PetscRealPart(A[i+j*n])<0.0 || PetscImaginaryPart(A[i+j*n])!=0.0) { isrealpos = PETSC_FALSE; break; }
628 : #else
629 : if (A[i+j*n]<0.0) { isrealpos = PETSC_FALSE; break; }
630 : #endif
631 46 : if (isrealpos) { /* for positive matrices only */
632 2103 : for (i=0;i<n;i++) v[i] = 1.0;
633 440 : for (j=0;j<m;j++) { /* w = A'*v */
634 422 : PetscCallBLAS("BLASgemv",BLASgemv_("C",&n,&n,&sone,A,&n,v,&one,&szero,w,&one));
635 422 : SlepcSwap(v,w,aux);
636 : }
637 18 : PetscCall(PetscLogFlops(2.0*n*n*m));
638 18 : *nrm = 0.0;
639 2103 : for (i=0;i<n;i++) if ((tmp = PetscAbsScalar(v[i])) > *nrm) *nrm = tmp; /* norm(v,inf) */
640 28 : } else PetscCall(SlepcNormEst1(n,A,m,work,rand,nrm));
641 : }
642 541 : PetscFunctionReturn(PETSC_SUCCESS);
643 : }
|