Line data Source code
1 : /*
2 : - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
3 : SLEPc - Scalable Library for Eigenvalue Problem Computations
4 : Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain
5 :
6 : This file is part of SLEPc.
7 : SLEPc is distributed under a 2-clause BSD license (see LICENSE).
8 : - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
9 : */
10 : /*
11 : Inverse square root function x^(-1/2)
12 : */
13 :
14 : #include <slepc/private/fnimpl.h> /*I "slepcfn.h" I*/
15 : #include <slepcblaslapack.h>
16 :
17 48 : static PetscErrorCode FNEvaluateFunction_Invsqrt(FN fn,PetscScalar x,PetscScalar *y)
18 : {
19 48 : PetscFunctionBegin;
20 48 : PetscCheck(x!=0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Function not defined in the requested value");
21 : #if !defined(PETSC_USE_COMPLEX)
22 48 : PetscCheck(x>0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Function not defined in the requested value");
23 : #endif
24 48 : *y = 1.0/PetscSqrtScalar(x);
25 48 : PetscFunctionReturn(PETSC_SUCCESS);
26 : }
27 :
28 8 : static PetscErrorCode FNEvaluateDerivative_Invsqrt(FN fn,PetscScalar x,PetscScalar *y)
29 : {
30 8 : PetscFunctionBegin;
31 8 : PetscCheck(x!=0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Derivative not defined in the requested value");
32 : #if !defined(PETSC_USE_COMPLEX)
33 8 : PetscCheck(x>0.0,PETSC_COMM_SELF,PETSC_ERR_ARG_OUTOFRANGE,"Derivative not defined in the requested value");
34 : #endif
35 8 : *y = -1.0/(2.0*PetscPowScalarReal(x,1.5));
36 8 : PetscFunctionReturn(PETSC_SUCCESS);
37 : }
38 :
39 4 : static PetscErrorCode FNEvaluateFunctionMat_Invsqrt_Schur(FN fn,Mat A,Mat B)
40 : {
41 4 : PetscBLASInt n=0,ld,*ipiv,info;
42 4 : PetscScalar *Ba,*Wa;
43 4 : PetscInt m;
44 4 : Mat W;
45 :
46 4 : PetscFunctionBegin;
47 4 : PetscCall(FN_AllocateWorkMat(fn,A,&W));
48 4 : if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
49 4 : PetscCall(MatDenseGetArray(B,&Ba));
50 4 : PetscCall(MatDenseGetArray(W,&Wa));
51 : /* compute B = sqrtm(A) */
52 4 : PetscCall(MatGetSize(A,&m,NULL));
53 4 : PetscCall(PetscBLASIntCast(m,&n));
54 4 : ld = n;
55 4 : PetscCall(FNSqrtmSchur(fn,n,Ba,n,PETSC_FALSE));
56 : /* compute B = A\B */
57 4 : PetscCall(PetscMalloc1(ld,&ipiv));
58 4 : PetscCallBLAS("LAPACKgesv",LAPACKgesv_(&n,&n,Wa,&ld,ipiv,Ba,&ld,&info));
59 4 : SlepcCheckLapackInfo("gesv",info);
60 4 : PetscCall(PetscLogFlops(2.0*n*n*n/3.0+2.0*n*n*n));
61 4 : PetscCall(PetscFree(ipiv));
62 4 : PetscCall(MatDenseRestoreArray(W,&Wa));
63 4 : PetscCall(MatDenseRestoreArray(B,&Ba));
64 4 : PetscCall(FN_FreeWorkMat(fn,&W));
65 4 : PetscFunctionReturn(PETSC_SUCCESS);
66 : }
67 :
68 4 : static PetscErrorCode FNEvaluateFunctionMatVec_Invsqrt_Schur(FN fn,Mat A,Vec v)
69 : {
70 4 : PetscBLASInt n=0,ld,*ipiv,info,one=1;
71 4 : PetscScalar *Ba,*Wa;
72 4 : PetscInt m;
73 4 : Mat B,W;
74 :
75 4 : PetscFunctionBegin;
76 4 : PetscCall(FN_AllocateWorkMat(fn,A,&B));
77 4 : PetscCall(FN_AllocateWorkMat(fn,A,&W));
78 4 : PetscCall(MatDenseGetArray(B,&Ba));
79 4 : PetscCall(MatDenseGetArray(W,&Wa));
80 : /* compute B_1 = sqrtm(A)*e_1 */
81 4 : PetscCall(MatGetSize(A,&m,NULL));
82 4 : PetscCall(PetscBLASIntCast(m,&n));
83 4 : ld = n;
84 4 : PetscCall(FNSqrtmSchur(fn,n,Ba,n,PETSC_TRUE));
85 : /* compute B_1 = A\B_1 */
86 4 : PetscCall(PetscMalloc1(ld,&ipiv));
87 4 : PetscCallBLAS("LAPACKgesv",LAPACKgesv_(&n,&one,Wa,&ld,ipiv,Ba,&ld,&info));
88 4 : SlepcCheckLapackInfo("gesv",info);
89 4 : PetscCall(PetscFree(ipiv));
90 4 : PetscCall(MatDenseRestoreArray(W,&Wa));
91 4 : PetscCall(MatDenseRestoreArray(B,&Ba));
92 4 : PetscCall(MatGetColumnVector(B,v,0));
93 4 : PetscCall(FN_FreeWorkMat(fn,&W));
94 4 : PetscCall(FN_FreeWorkMat(fn,&B));
95 4 : PetscFunctionReturn(PETSC_SUCCESS);
96 : }
97 :
98 12 : static PetscErrorCode FNEvaluateFunctionMat_Invsqrt_DBP(FN fn,Mat A,Mat B)
99 : {
100 12 : PetscBLASInt n=0;
101 12 : PetscScalar *T;
102 12 : PetscInt m;
103 :
104 12 : PetscFunctionBegin;
105 12 : if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
106 12 : PetscCall(MatDenseGetArray(B,&T));
107 12 : PetscCall(MatGetSize(A,&m,NULL));
108 12 : PetscCall(PetscBLASIntCast(m,&n));
109 12 : PetscCall(FNSqrtmDenmanBeavers(fn,n,T,n,PETSC_TRUE));
110 12 : PetscCall(MatDenseRestoreArray(B,&T));
111 12 : PetscFunctionReturn(PETSC_SUCCESS);
112 : }
113 :
114 12 : static PetscErrorCode FNEvaluateFunctionMat_Invsqrt_NS(FN fn,Mat A,Mat B)
115 : {
116 12 : PetscBLASInt n=0;
117 12 : PetscScalar *T;
118 12 : PetscInt m;
119 :
120 12 : PetscFunctionBegin;
121 12 : if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
122 12 : PetscCall(MatDenseGetArray(B,&T));
123 12 : PetscCall(MatGetSize(A,&m,NULL));
124 12 : PetscCall(PetscBLASIntCast(m,&n));
125 12 : PetscCall(FNSqrtmNewtonSchulz(fn,n,T,n,PETSC_TRUE));
126 12 : PetscCall(MatDenseRestoreArray(B,&T));
127 12 : PetscFunctionReturn(PETSC_SUCCESS);
128 : }
129 :
130 12 : static PetscErrorCode FNEvaluateFunctionMat_Invsqrt_Sadeghi(FN fn,Mat A,Mat B)
131 : {
132 12 : PetscBLASInt n=0,ld,*ipiv,info;
133 12 : PetscScalar *Ba,*Wa;
134 12 : PetscInt m;
135 12 : Mat W;
136 :
137 12 : PetscFunctionBegin;
138 12 : PetscCall(FN_AllocateWorkMat(fn,A,&W));
139 12 : if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
140 12 : PetscCall(MatDenseGetArray(B,&Ba));
141 12 : PetscCall(MatDenseGetArray(W,&Wa));
142 : /* compute B = sqrtm(A) */
143 12 : PetscCall(MatGetSize(A,&m,NULL));
144 12 : PetscCall(PetscBLASIntCast(m,&n));
145 12 : ld = n;
146 12 : PetscCall(FNSqrtmSadeghi(fn,n,Ba,n));
147 : /* compute B = A\B */
148 12 : PetscCall(PetscMalloc1(ld,&ipiv));
149 12 : PetscCallBLAS("LAPACKgesv",LAPACKgesv_(&n,&n,Wa,&ld,ipiv,Ba,&ld,&info));
150 12 : SlepcCheckLapackInfo("gesv",info);
151 12 : PetscCall(PetscLogFlops(2.0*n*n*n/3.0+2.0*n*n*n));
152 12 : PetscCall(PetscFree(ipiv));
153 12 : PetscCall(MatDenseRestoreArray(W,&Wa));
154 12 : PetscCall(MatDenseRestoreArray(B,&Ba));
155 12 : PetscCall(FN_FreeWorkMat(fn,&W));
156 12 : PetscFunctionReturn(PETSC_SUCCESS);
157 : }
158 :
159 : #if defined(PETSC_HAVE_CUDA)
160 : PetscErrorCode FNEvaluateFunctionMat_Invsqrt_NS_CUDA(FN fn,Mat A,Mat B)
161 : {
162 : PetscBLASInt n=0;
163 : PetscScalar *Ba;
164 : PetscInt m;
165 :
166 : PetscFunctionBegin;
167 : if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
168 : PetscCall(MatDenseCUDAGetArray(B,&Ba));
169 : PetscCall(MatGetSize(A,&m,NULL));
170 : PetscCall(PetscBLASIntCast(m,&n));
171 : PetscCall(FNSqrtmNewtonSchulz_CUDA(fn,n,Ba,n,PETSC_TRUE));
172 : PetscCall(MatDenseCUDARestoreArray(B,&Ba));
173 : PetscFunctionReturn(PETSC_SUCCESS);
174 : }
175 :
176 : #if defined(PETSC_HAVE_MAGMA)
177 : #include <slepcmagma.h>
178 :
179 : PetscErrorCode FNEvaluateFunctionMat_Invsqrt_DBP_CUDAm(FN fn,Mat A,Mat B)
180 : {
181 : PetscBLASInt n=0;
182 : PetscScalar *T;
183 : PetscInt m;
184 :
185 : PetscFunctionBegin;
186 : if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
187 : PetscCall(MatDenseCUDAGetArray(B,&T));
188 : PetscCall(MatGetSize(A,&m,NULL));
189 : PetscCall(PetscBLASIntCast(m,&n));
190 : PetscCall(FNSqrtmDenmanBeavers_CUDAm(fn,n,T,n,PETSC_TRUE));
191 : PetscCall(MatDenseCUDARestoreArray(B,&T));
192 : PetscFunctionReturn(PETSC_SUCCESS);
193 : }
194 :
195 : PetscErrorCode FNEvaluateFunctionMat_Invsqrt_Sadeghi_CUDAm(FN fn,Mat A,Mat B)
196 : {
197 : PetscBLASInt n=0,ld,*ipiv;
198 : PetscScalar *Ba,*Wa;
199 : PetscInt m;
200 : Mat W;
201 :
202 : PetscFunctionBegin;
203 : PetscCall(FN_AllocateWorkMat(fn,A,&W));
204 : if (A!=B) PetscCall(MatCopy(A,B,SAME_NONZERO_PATTERN));
205 : PetscCall(MatDenseCUDAGetArray(B,&Ba));
206 : PetscCall(MatDenseCUDAGetArray(W,&Wa));
207 : /* compute B = sqrtm(A) */
208 : PetscCall(MatGetSize(A,&m,NULL));
209 : PetscCall(PetscBLASIntCast(m,&n));
210 : ld = n;
211 : PetscCall(FNSqrtmSadeghi_CUDAm(fn,n,Ba,n));
212 : /* compute B = A\B */
213 : PetscCall(SlepcMagmaInit());
214 : PetscCall(PetscMalloc1(ld,&ipiv));
215 : PetscCallMAGMA(magma_xgesv_gpu,n,n,Wa,ld,ipiv,Ba,ld);
216 : PetscCall(PetscLogFlops(2.0*n*n*n/3.0+2.0*n*n*n));
217 : PetscCall(PetscFree(ipiv));
218 : PetscCall(MatDenseCUDARestoreArray(W,&Wa));
219 : PetscCall(MatDenseCUDARestoreArray(B,&Ba));
220 : PetscCall(FN_FreeWorkMat(fn,&W));
221 : PetscFunctionReturn(PETSC_SUCCESS);
222 : }
223 : #endif /* PETSC_HAVE_MAGMA */
224 : #endif /* PETSC_HAVE_CUDA */
225 :
226 8 : static PetscErrorCode FNView_Invsqrt(FN fn,PetscViewer viewer)
227 : {
228 8 : PetscBool isascii;
229 8 : char str[50];
230 8 : const char *methodname[] = {
231 : "Schur method for inv(A)*sqrtm(A)",
232 : "Denman-Beavers (product form)",
233 : "Newton-Schulz iteration",
234 : "Sadeghi iteration"
235 : };
236 8 : const int nmeth=PETSC_STATIC_ARRAY_LENGTH(methodname);
237 :
238 8 : PetscFunctionBegin;
239 8 : PetscCall(PetscObjectTypeCompare((PetscObject)viewer,PETSCVIEWERASCII,&isascii));
240 8 : if (isascii) {
241 8 : if (fn->beta==(PetscScalar)1.0) {
242 0 : if (fn->alpha==(PetscScalar)1.0) PetscCall(PetscViewerASCIIPrintf(viewer," inverse square root: x^(-1/2)\n"));
243 : else {
244 0 : PetscCall(SlepcSNPrintfScalar(str,sizeof(str),fn->alpha,PETSC_TRUE));
245 0 : PetscCall(PetscViewerASCIIPrintf(viewer," inverse square root: (%s*x)^(-1/2)\n",str));
246 : }
247 : } else {
248 8 : PetscCall(SlepcSNPrintfScalar(str,sizeof(str),fn->beta,PETSC_TRUE));
249 8 : if (fn->alpha==(PetscScalar)1.0) PetscCall(PetscViewerASCIIPrintf(viewer," inverse square root: %s*x^(-1/2)\n",str));
250 : else {
251 8 : PetscCall(PetscViewerASCIIPrintf(viewer," inverse square root: %s",str));
252 8 : PetscCall(PetscViewerASCIIUseTabs(viewer,PETSC_FALSE));
253 8 : PetscCall(SlepcSNPrintfScalar(str,sizeof(str),fn->alpha,PETSC_TRUE));
254 8 : PetscCall(PetscViewerASCIIPrintf(viewer,"*(%s*x)^(-1/2)\n",str));
255 8 : PetscCall(PetscViewerASCIIUseTabs(viewer,PETSC_TRUE));
256 : }
257 : }
258 8 : if (fn->method<nmeth) PetscCall(PetscViewerASCIIPrintf(viewer," computing matrix functions with: %s\n",methodname[fn->method]));
259 : }
260 8 : PetscFunctionReturn(PETSC_SUCCESS);
261 : }
262 :
263 8 : SLEPC_EXTERN PetscErrorCode FNCreate_Invsqrt(FN fn)
264 : {
265 8 : PetscFunctionBegin;
266 8 : fn->ops->evaluatefunction = FNEvaluateFunction_Invsqrt;
267 8 : fn->ops->evaluatederivative = FNEvaluateDerivative_Invsqrt;
268 8 : fn->ops->evaluatefunctionmat[0] = FNEvaluateFunctionMat_Invsqrt_Schur;
269 8 : fn->ops->evaluatefunctionmat[1] = FNEvaluateFunctionMat_Invsqrt_DBP;
270 8 : fn->ops->evaluatefunctionmat[2] = FNEvaluateFunctionMat_Invsqrt_NS;
271 8 : fn->ops->evaluatefunctionmat[3] = FNEvaluateFunctionMat_Invsqrt_Sadeghi;
272 : #if defined(PETSC_HAVE_CUDA)
273 : fn->ops->evaluatefunctionmatcuda[2] = FNEvaluateFunctionMat_Invsqrt_NS_CUDA;
274 : #if defined(PETSC_HAVE_MAGMA)
275 : fn->ops->evaluatefunctionmatcuda[1] = FNEvaluateFunctionMat_Invsqrt_DBP_CUDAm;
276 : fn->ops->evaluatefunctionmatcuda[3] = FNEvaluateFunctionMat_Invsqrt_Sadeghi_CUDAm;
277 : #endif /* PETSC_HAVE_MAGMA */
278 : #endif /* PETSC_HAVE_CUDA */
279 8 : fn->ops->evaluatefunctionmatvec[0] = FNEvaluateFunctionMatVec_Invsqrt_Schur;
280 8 : fn->ops->view = FNView_Invsqrt;
281 8 : PetscFunctionReturn(PETSC_SUCCESS);
282 : }
|