Line data Source code
1 : /*
2 : - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
3 : SLEPc - Scalable Library for Eigenvalue Problem Computations
4 : Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain
5 :
6 : This file is part of SLEPc.
7 : SLEPc is distributed under a 2-clause BSD license (see LICENSE).
8 : - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
9 : */
10 : /*
11 : BV private kernels that use the BLAS
12 : */
13 :
14 : #include <slepc/private/bvimpl.h>
15 : #include <slepcblaslapack.h>
16 :
17 : #define BLOCKSIZE 64
18 :
19 : /*
20 : C := alpha*A*B + beta*C
21 :
22 : A is mxk (ld=lda), B is kxn (ld=ldb), C is mxn (ld=ldc)
23 : */
24 5593 : PetscErrorCode BVMult_BLAS_Private(BV bv,PetscInt m_,PetscInt n_,PetscInt k_,PetscScalar alpha,const PetscScalar *A,PetscInt lda_,const PetscScalar *B,PetscInt ldb_,PetscScalar beta,PetscScalar *C,PetscInt ldc_)
25 : {
26 5593 : PetscBLASInt m,n,k,lda,ldb,ldc;
27 : #if defined(PETSC_HAVE_FBLASLAPACK) || defined(PETSC_HAVE_F2CBLASLAPACK)
28 : PetscBLASInt l,bs=BLOCKSIZE;
29 : #endif
30 :
31 5593 : PetscFunctionBegin;
32 5593 : PetscCall(PetscBLASIntCast(m_,&m));
33 5593 : PetscCall(PetscBLASIntCast(n_,&n));
34 5593 : PetscCall(PetscBLASIntCast(k_,&k));
35 5593 : PetscCall(PetscBLASIntCast(lda_,&lda));
36 5593 : PetscCall(PetscBLASIntCast(ldb_,&ldb));
37 5593 : PetscCall(PetscBLASIntCast(ldc_,&ldc));
38 : #if defined(PETSC_HAVE_FBLASLAPACK) || defined(PETSC_HAVE_F2CBLASLAPACK)
39 : l = m % bs;
40 : if (l) PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&l,&n,&k,&alpha,(PetscScalar*)A,&lda,(PetscScalar*)B,&ldb,&beta,C,&ldc));
41 : for (;l<m;l+=bs) {
42 : PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&bs,&n,&k,&alpha,(PetscScalar*)A+l,&lda,(PetscScalar*)B,&ldb,&beta,C+l,&ldc));
43 : }
44 : #else
45 5593 : if (m) PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&m,&n,&k,&alpha,(PetscScalar*)A,&lda,(PetscScalar*)B,&ldb,&beta,C,&ldc));
46 : #endif
47 5593 : PetscCall(PetscLogFlops(2.0*m*n*k));
48 5593 : PetscFunctionReturn(PETSC_SUCCESS);
49 : }
50 :
51 : /*
52 : y := alpha*A*x + beta*y
53 :
54 : A is nxk (ld=lda)
55 : */
56 605358 : PetscErrorCode BVMultVec_BLAS_Private(BV bv,PetscInt n_,PetscInt k_,PetscScalar alpha,const PetscScalar *A,PetscInt lda_,const PetscScalar *x,PetscScalar beta,PetscScalar *y)
57 : {
58 605358 : PetscBLASInt n,k,lda,one=1;
59 :
60 605358 : PetscFunctionBegin;
61 605358 : PetscCall(PetscBLASIntCast(n_,&n));
62 605358 : PetscCall(PetscBLASIntCast(k_,&k));
63 605358 : PetscCall(PetscBLASIntCast(lda_,&lda));
64 605358 : if (n) PetscCallBLAS("BLASgemv",BLASgemv_("N",&n,&k,&alpha,A,&lda,x,&one,&beta,y,&one));
65 605358 : PetscCall(PetscLogFlops(2.0*n*k));
66 605358 : PetscFunctionReturn(PETSC_SUCCESS);
67 : }
68 :
69 : /*
70 : A(:,s:e-1) := A*B(:,s:e-1)
71 :
72 : A is mxk (ld=lda), B is kxn (ld=ldb), n=e-s
73 : */
74 26448 : PetscErrorCode BVMultInPlace_BLAS_Private(BV bv,PetscInt m_,PetscInt k_,PetscInt s,PetscInt e,PetscScalar *A,PetscInt lda_,const PetscScalar *B,PetscInt ldb_,PetscBool btrans)
75 : {
76 26448 : PetscScalar *pb,zero=0.0,one=1.0;
77 26448 : PetscBLASInt m,n,k,l,lda,ldb,bs=BLOCKSIZE;
78 26448 : PetscInt j,n_=e-s;
79 26448 : const char *bt;
80 :
81 26448 : PetscFunctionBegin;
82 26448 : PetscCall(PetscBLASIntCast(m_,&m));
83 26448 : PetscCall(PetscBLASIntCast(n_,&n));
84 26448 : PetscCall(PetscBLASIntCast(k_,&k));
85 26448 : PetscCall(PetscBLASIntCast(lda_,&lda));
86 26448 : PetscCall(PetscBLASIntCast(ldb_,&ldb));
87 26448 : PetscCall(BVAllocateWork_Private(bv,BLOCKSIZE*n_));
88 26448 : if (PetscUnlikely(btrans)) {
89 3 : pb = (PetscScalar*)B+s;
90 3 : bt = "C";
91 : } else {
92 26445 : pb = (PetscScalar*)B+s*ldb;
93 26445 : bt = "N";
94 : }
95 26448 : l = m % bs;
96 26448 : if (l) {
97 25349 : PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&l,&n,&k,&one,A,&lda,pb,&ldb,&zero,bv->work,&l));
98 205189 : for (j=0;j<n;j++) PetscCall(PetscArraycpy(A+(s+j)*lda,bv->work+j*l,l));
99 : }
100 153769 : for (;l<m;l+=bs) {
101 127321 : PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&bs,&n,&k,&one,A+l,&lda,pb,&ldb,&zero,bv->work,&bs));
102 1274674 : for (j=0;j<n;j++) PetscCall(PetscArraycpy(A+(s+j)*lda+l,bv->work+j*bs,bs));
103 : }
104 26448 : PetscCall(PetscLogFlops(2.0*m*n*k));
105 26448 : PetscFunctionReturn(PETSC_SUCCESS);
106 : }
107 :
108 : /*
109 : V := V*B
110 :
111 : V is mxn (ld=m), B is nxn (ld=k)
112 : */
113 118 : PetscErrorCode BVMultInPlace_Vecs_Private(BV bv,PetscInt m_,PetscInt n_,PetscInt k_,Vec *V,const PetscScalar *B,PetscBool btrans)
114 : {
115 118 : PetscScalar zero=0.0,one=1.0,*out,*pout;
116 118 : const PetscScalar *pin;
117 118 : PetscBLASInt m = 0,n,k,l,bs=BLOCKSIZE;
118 118 : PetscInt j;
119 118 : const char *bt;
120 :
121 118 : PetscFunctionBegin;
122 118 : PetscCall(PetscBLASIntCast(m_,&m));
123 118 : PetscCall(PetscBLASIntCast(n_,&n));
124 118 : PetscCall(PetscBLASIntCast(k_,&k));
125 118 : PetscCall(BVAllocateWork_Private(bv,2*BLOCKSIZE*n_));
126 118 : out = bv->work+BLOCKSIZE*n_;
127 118 : if (btrans) bt = "C";
128 117 : else bt = "N";
129 118 : l = m % bs;
130 118 : if (l) {
131 575 : for (j=0;j<n;j++) {
132 457 : PetscCall(VecGetArrayRead(V[j],&pin));
133 457 : PetscCall(PetscArraycpy(bv->work+j*l,pin,l));
134 457 : PetscCall(VecRestoreArrayRead(V[j],&pin));
135 : }
136 118 : PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&l,&n,&n,&one,bv->work,&l,(PetscScalar*)B,&k,&zero,out,&l));
137 575 : for (j=0;j<n;j++) {
138 457 : PetscCall(VecGetArray(V[j],&pout));
139 457 : PetscCall(PetscArraycpy(pout,out+j*l,l));
140 457 : PetscCall(VecRestoreArray(V[j],&pout));
141 : }
142 : }
143 544 : for (;l<m;l+=bs) {
144 2028 : for (j=0;j<n;j++) {
145 1602 : PetscCall(VecGetArrayRead(V[j],&pin));
146 1602 : PetscCall(PetscArraycpy(bv->work+j*bs,pin+l,bs));
147 1602 : PetscCall(VecRestoreArrayRead(V[j],&pin));
148 : }
149 426 : PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&bs,&n,&n,&one,bv->work,&bs,(PetscScalar*)B,&k,&zero,out,&bs));
150 2028 : for (j=0;j<n;j++) {
151 1602 : PetscCall(VecGetArray(V[j],&pout));
152 1602 : PetscCall(PetscArraycpy(pout+l,out+j*bs,bs));
153 1602 : PetscCall(VecRestoreArray(V[j],&pout));
154 : }
155 : }
156 118 : PetscCall(PetscLogFlops(2.0*n*n*k));
157 118 : PetscFunctionReturn(PETSC_SUCCESS);
158 : }
159 :
160 : /*
161 : B := alpha*A + beta*B
162 :
163 : A,B are nxk
164 : */
165 1125 : PetscErrorCode BVAXPY_BLAS_Private(BV bv,PetscInt n_,PetscInt k_,PetscScalar alpha,const PetscScalar *A,PetscInt lda_,PetscScalar beta,PetscScalar *B,PetscInt ldb_)
166 : {
167 1125 : PetscBLASInt m,one=1;
168 1125 : PetscInt i,j;
169 :
170 1125 : PetscFunctionBegin;
171 1125 : if (lda_==n_ && ldb_==n_) {
172 1109 : PetscCall(PetscBLASIntCast(n_*k_,&m));
173 1109 : if (beta!=(PetscScalar)1.0) PetscCallBLAS("BLASscal",BLASscal_(&m,&beta,B,&one));
174 1109 : PetscCallBLAS("BLASaxpy",BLASaxpy_(&m,&alpha,A,&one,B,&one));
175 : } else {
176 16 : if (beta!=(PetscScalar)1.0) {
177 0 : for (j=0;j<k_;j++) {
178 0 : for (i=0;i<n_;i++) {
179 0 : B[i+j*ldb_] = alpha*A[i+j*lda_] + beta*B[i+j*ldb_];
180 : }
181 : }
182 : } else {
183 148 : for (j=0;j<k_;j++) {
184 2328 : for (i=0;i<n_;i++) {
185 2196 : B[i+j*ldb_] += alpha*A[i+j*lda_];
186 : }
187 : }
188 : }
189 : }
190 1125 : PetscCall(PetscLogFlops((beta==(PetscScalar)1.0)?2.0*n_*k_:3.0*n_*k_));
191 1125 : PetscFunctionReturn(PETSC_SUCCESS);
192 : }
193 :
194 : /*
195 : C := A'*B
196 :
197 : A' is mxk (ld=lda), B is kxn (ld=ldb), C is mxn (ld=ldc)
198 : */
199 26625 : PetscErrorCode BVDot_BLAS_Private(BV bv,PetscInt m_,PetscInt n_,PetscInt k_,const PetscScalar *A,PetscInt lda_,const PetscScalar *B,PetscInt ldb_,PetscScalar *C,PetscInt ldc_,PetscBool mpi)
200 : {
201 26625 : PetscScalar zero=0.0,one=1.0,*CC;
202 26625 : PetscBLASInt m,n,k,lda,ldb,ldc,j;
203 26625 : PetscMPIInt len;
204 :
205 26625 : PetscFunctionBegin;
206 26625 : PetscCall(PetscBLASIntCast(m_,&m));
207 26625 : PetscCall(PetscBLASIntCast(n_,&n));
208 26625 : PetscCall(PetscBLASIntCast(k_,&k));
209 26625 : PetscCall(PetscBLASIntCast(lda_,&lda));
210 26625 : PetscCall(PetscBLASIntCast(ldb_,&ldb));
211 26625 : PetscCall(PetscBLASIntCast(ldc_,&ldc));
212 26625 : if (mpi) {
213 3611 : if (ldc==m) {
214 2300 : PetscCall(BVAllocateWork_Private(bv,m*n));
215 2300 : if (k) PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&m,&n,&k,&one,(PetscScalar*)A,&lda,(PetscScalar*)B,&ldb,&zero,bv->work,&ldc));
216 0 : else PetscCall(PetscArrayzero(bv->work,m*n));
217 2300 : PetscCall(PetscMPIIntCast(m*n,&len));
218 6900 : PetscCallMPI(MPIU_Allreduce(bv->work,C,len,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)bv)));
219 : } else {
220 1311 : PetscCall(BVAllocateWork_Private(bv,2*m*n));
221 1311 : CC = bv->work+m*n;
222 1311 : if (k) PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&m,&n,&k,&one,(PetscScalar*)A,&lda,(PetscScalar*)B,&ldb,&zero,bv->work,&m));
223 0 : else PetscCall(PetscArrayzero(bv->work,m*n));
224 1311 : PetscCall(PetscMPIIntCast(m*n,&len));
225 3933 : PetscCallMPI(MPIU_Allreduce(bv->work,CC,len,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)bv)));
226 14785 : for (j=0;j<n;j++) PetscCall(PetscArraycpy(C+j*ldc,CC+j*m,m));
227 : }
228 : } else {
229 23014 : if (k) PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&m,&n,&k,&one,(PetscScalar*)A,&lda,(PetscScalar*)B,&ldb,&zero,C,&ldc));
230 : }
231 26625 : PetscCall(PetscLogFlops(2.0*m*n*k));
232 26625 : PetscFunctionReturn(PETSC_SUCCESS);
233 : }
234 :
235 : /*
236 : y := A'*x
237 :
238 : A is nxk (ld=lda)
239 : */
240 464080 : PetscErrorCode BVDotVec_BLAS_Private(BV bv,PetscInt n_,PetscInt k_,const PetscScalar *A,PetscInt lda_,const PetscScalar *x,PetscScalar *y,PetscBool mpi)
241 : {
242 464080 : PetscScalar zero=0.0,done=1.0;
243 464080 : PetscBLASInt n,k,lda,one=1;
244 464080 : PetscMPIInt len;
245 :
246 464080 : PetscFunctionBegin;
247 464080 : PetscCall(PetscBLASIntCast(n_,&n));
248 464080 : PetscCall(PetscBLASIntCast(k_,&k));
249 464080 : PetscCall(PetscBLASIntCast(lda_,&lda));
250 464080 : if (mpi) {
251 23874 : PetscCall(BVAllocateWork_Private(bv,k));
252 23874 : if (n) PetscCallBLAS("BLASgemv",BLASgemv_("C",&n,&k,&done,A,&lda,x,&one,&zero,bv->work,&one));
253 0 : else PetscCall(PetscArrayzero(bv->work,k));
254 23874 : PetscCall(PetscMPIIntCast(k,&len));
255 71622 : PetscCallMPI(MPIU_Allreduce(bv->work,y,len,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)bv)));
256 : } else {
257 440206 : if (n) PetscCallBLAS("BLASgemv",BLASgemv_("C",&n,&k,&done,A,&lda,x,&one,&zero,y,&one));
258 : }
259 464080 : PetscCall(PetscLogFlops(2.0*n*k));
260 464080 : PetscFunctionReturn(PETSC_SUCCESS);
261 : }
262 :
263 : /*
264 : Scale n scalars
265 : */
266 243952 : PetscErrorCode BVScale_BLAS_Private(BV bv,PetscInt n_,PetscScalar *A,PetscScalar alpha)
267 : {
268 243952 : PetscBLASInt n,one=1;
269 :
270 243952 : PetscFunctionBegin;
271 243952 : if (PetscUnlikely(alpha == (PetscScalar)0.0)) PetscCall(PetscArrayzero(A,n_));
272 243691 : else if (alpha!=(PetscScalar)1.0) {
273 243683 : PetscCall(PetscBLASIntCast(n_,&n));
274 243683 : PetscCallBLAS("BLASscal",BLASscal_(&n,&alpha,A,&one));
275 243683 : PetscCall(PetscLogFlops(1.0*n));
276 : }
277 243952 : PetscFunctionReturn(PETSC_SUCCESS);
278 : }
|