Actual source code: bvblas.c

slepc-3.21.1 2024-04-26
Report Typos and Errors
  1: /*
  2:    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  3:    SLEPc - Scalable Library for Eigenvalue Problem Computations
  4:    Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain

  6:    This file is part of SLEPc.
  7:    SLEPc is distributed under a 2-clause BSD license (see LICENSE).
  8:    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  9: */
 10: /*
 11:    BV private kernels that use the BLAS
 12: */

 14: #include <slepc/private/bvimpl.h>
 15: #include <slepcblaslapack.h>

 17: #define BLOCKSIZE 64

 19: /*
 20:     C := alpha*A*B + beta*C

 22:     A is mxk (ld=lda), B is kxn (ld=ldb), C is mxn (ld=ldc)
 23: */
 24: PetscErrorCode BVMult_BLAS_Private(BV bv,PetscInt m_,PetscInt n_,PetscInt k_,PetscScalar alpha,const PetscScalar *A,PetscInt lda_,const PetscScalar *B,PetscInt ldb_,PetscScalar beta,PetscScalar *C,PetscInt ldc_)
 25: {
 26:   PetscBLASInt   m,n,k,lda,ldb,ldc;
 27: #if defined(PETSC_HAVE_FBLASLAPACK) || defined(PETSC_HAVE_F2CBLASLAPACK)
 28:   PetscBLASInt   l,bs=BLOCKSIZE;
 29: #endif

 31:   PetscFunctionBegin;
 32:   PetscCall(PetscBLASIntCast(m_,&m));
 33:   PetscCall(PetscBLASIntCast(n_,&n));
 34:   PetscCall(PetscBLASIntCast(k_,&k));
 35:   PetscCall(PetscBLASIntCast(lda_,&lda));
 36:   PetscCall(PetscBLASIntCast(ldb_,&ldb));
 37:   PetscCall(PetscBLASIntCast(ldc_,&ldc));
 38: #if defined(PETSC_HAVE_FBLASLAPACK) || defined(PETSC_HAVE_F2CBLASLAPACK)
 39:   l = m % bs;
 40:   if (l) PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&l,&n,&k,&alpha,(PetscScalar*)A,&lda,(PetscScalar*)B,&ldb,&beta,C,&ldc));
 41:   for (;l<m;l+=bs) {
 42:     PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&bs,&n,&k,&alpha,(PetscScalar*)A+l,&lda,(PetscScalar*)B,&ldb,&beta,C+l,&ldc));
 43:   }
 44: #else
 45:   if (m) PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&m,&n,&k,&alpha,(PetscScalar*)A,&lda,(PetscScalar*)B,&ldb,&beta,C,&ldc));
 46: #endif
 47:   PetscCall(PetscLogFlops(2.0*m*n*k));
 48:   PetscFunctionReturn(PETSC_SUCCESS);
 49: }

 51: /*
 52:     y := alpha*A*x + beta*y

 54:     A is nxk (ld=lda)
 55: */
 56: PetscErrorCode BVMultVec_BLAS_Private(BV bv,PetscInt n_,PetscInt k_,PetscScalar alpha,const PetscScalar *A,PetscInt lda_,const PetscScalar *x,PetscScalar beta,PetscScalar *y)
 57: {
 58:   PetscBLASInt   n,k,lda,one=1;

 60:   PetscFunctionBegin;
 61:   PetscCall(PetscBLASIntCast(n_,&n));
 62:   PetscCall(PetscBLASIntCast(k_,&k));
 63:   PetscCall(PetscBLASIntCast(lda_,&lda));
 64:   if (n) PetscCallBLAS("BLASgemv",BLASgemv_("N",&n,&k,&alpha,A,&lda,x,&one,&beta,y,&one));
 65:   PetscCall(PetscLogFlops(2.0*n*k));
 66:   PetscFunctionReturn(PETSC_SUCCESS);
 67: }

 69: /*
 70:     A(:,s:e-1) := A*B(:,s:e-1)

 72:     A is mxk (ld=lda), B is kxn (ld=ldb), n=e-s
 73: */
 74: PetscErrorCode BVMultInPlace_BLAS_Private(BV bv,PetscInt m_,PetscInt k_,PetscInt s,PetscInt e,PetscScalar *A,PetscInt lda_,const PetscScalar *B,PetscInt ldb_,PetscBool btrans)
 75: {
 76:   PetscScalar    *pb,zero=0.0,one=1.0;
 77:   PetscBLASInt   m,n,k,l,lda,ldb,bs=BLOCKSIZE;
 78:   PetscInt       j,n_=e-s;
 79:   const char     *bt;

 81:   PetscFunctionBegin;
 82:   PetscCall(PetscBLASIntCast(m_,&m));
 83:   PetscCall(PetscBLASIntCast(n_,&n));
 84:   PetscCall(PetscBLASIntCast(k_,&k));
 85:   PetscCall(PetscBLASIntCast(lda_,&lda));
 86:   PetscCall(PetscBLASIntCast(ldb_,&ldb));
 87:   PetscCall(BVAllocateWork_Private(bv,BLOCKSIZE*n_));
 88:   if (PetscUnlikely(btrans)) {
 89:     pb = (PetscScalar*)B+s;
 90:     bt = "C";
 91:   } else {
 92:     pb = (PetscScalar*)B+s*ldb;
 93:     bt = "N";
 94:   }
 95:   l = m % bs;
 96:   if (l) {
 97:     PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&l,&n,&k,&one,A,&lda,pb,&ldb,&zero,bv->work,&l));
 98:     for (j=0;j<n;j++) PetscCall(PetscArraycpy(A+(s+j)*lda,bv->work+j*l,l));
 99:   }
100:   for (;l<m;l+=bs) {
101:     PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&bs,&n,&k,&one,A+l,&lda,pb,&ldb,&zero,bv->work,&bs));
102:     for (j=0;j<n;j++) PetscCall(PetscArraycpy(A+(s+j)*lda+l,bv->work+j*bs,bs));
103:   }
104:   PetscCall(PetscLogFlops(2.0*m*n*k));
105:   PetscFunctionReturn(PETSC_SUCCESS);
106: }

108: /*
109:     V := V*B

111:     V is mxn (ld=m), B is nxn (ld=k)
112: */
113: PetscErrorCode BVMultInPlace_Vecs_Private(BV bv,PetscInt m_,PetscInt n_,PetscInt k_,Vec *V,const PetscScalar *B,PetscBool btrans)
114: {
115:   PetscScalar       zero=0.0,one=1.0,*out,*pout;
116:   const PetscScalar *pin;
117:   PetscBLASInt      m = 0,n,k,l,bs=BLOCKSIZE;
118:   PetscInt          j;
119:   const char        *bt;

121:   PetscFunctionBegin;
122:   PetscCall(PetscBLASIntCast(m_,&m));
123:   PetscCall(PetscBLASIntCast(n_,&n));
124:   PetscCall(PetscBLASIntCast(k_,&k));
125:   PetscCall(BVAllocateWork_Private(bv,2*BLOCKSIZE*n_));
126:   out = bv->work+BLOCKSIZE*n_;
127:   if (btrans) bt = "C";
128:   else bt = "N";
129:   l = m % bs;
130:   if (l) {
131:     for (j=0;j<n;j++) {
132:       PetscCall(VecGetArrayRead(V[j],&pin));
133:       PetscCall(PetscArraycpy(bv->work+j*l,pin,l));
134:       PetscCall(VecRestoreArrayRead(V[j],&pin));
135:     }
136:     PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&l,&n,&n,&one,bv->work,&l,(PetscScalar*)B,&k,&zero,out,&l));
137:     for (j=0;j<n;j++) {
138:       PetscCall(VecGetArray(V[j],&pout));
139:       PetscCall(PetscArraycpy(pout,out+j*l,l));
140:       PetscCall(VecRestoreArray(V[j],&pout));
141:     }
142:   }
143:   for (;l<m;l+=bs) {
144:     for (j=0;j<n;j++) {
145:       PetscCall(VecGetArrayRead(V[j],&pin));
146:       PetscCall(PetscArraycpy(bv->work+j*bs,pin+l,bs));
147:       PetscCall(VecRestoreArrayRead(V[j],&pin));
148:     }
149:     PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&bs,&n,&n,&one,bv->work,&bs,(PetscScalar*)B,&k,&zero,out,&bs));
150:     for (j=0;j<n;j++) {
151:       PetscCall(VecGetArray(V[j],&pout));
152:       PetscCall(PetscArraycpy(pout+l,out+j*bs,bs));
153:       PetscCall(VecRestoreArray(V[j],&pout));
154:     }
155:   }
156:   PetscCall(PetscLogFlops(2.0*n*n*k));
157:   PetscFunctionReturn(PETSC_SUCCESS);
158: }

160: /*
161:     B := alpha*A + beta*B

163:     A,B are nxk
164: */
165: PetscErrorCode BVAXPY_BLAS_Private(BV bv,PetscInt n_,PetscInt k_,PetscScalar alpha,const PetscScalar *A,PetscInt lda_,PetscScalar beta,PetscScalar *B,PetscInt ldb_)
166: {
167:   PetscBLASInt   m,one=1;
168:   PetscInt       i,j;

170:   PetscFunctionBegin;
171:   if (lda_==n_ && ldb_==n_) {
172:     PetscCall(PetscBLASIntCast(n_*k_,&m));
173:     if (beta!=(PetscScalar)1.0) PetscCallBLAS("BLASscal",BLASscal_(&m,&beta,B,&one));
174:     PetscCallBLAS("BLASaxpy",BLASaxpy_(&m,&alpha,A,&one,B,&one));
175:   } else {
176:     if (beta!=(PetscScalar)1.0) {
177:       for (j=0;j<k_;j++) {
178:         for (i=0;i<n_;i++) {
179:           B[i+j*ldb_] = alpha*A[i+j*lda_] + beta*B[i+j*ldb_];
180:         }
181:       }
182:     } else {
183:       for (j=0;j<k_;j++) {
184:         for (i=0;i<n_;i++) {
185:           B[i+j*ldb_] += alpha*A[i+j*lda_];
186:         }
187:       }
188:     }
189:   }
190:   PetscCall(PetscLogFlops((beta==(PetscScalar)1.0)?2.0*n_*k_:3.0*n_*k_));
191:   PetscFunctionReturn(PETSC_SUCCESS);
192: }

194: /*
195:     C := A'*B

197:     A' is mxk (ld=lda), B is kxn (ld=ldb), C is mxn (ld=ldc)
198: */
199: PetscErrorCode BVDot_BLAS_Private(BV bv,PetscInt m_,PetscInt n_,PetscInt k_,const PetscScalar *A,PetscInt lda_,const PetscScalar *B,PetscInt ldb_,PetscScalar *C,PetscInt ldc_,PetscBool mpi)
200: {
201:   PetscScalar    zero=0.0,one=1.0,*CC;
202:   PetscBLASInt   m,n,k,lda,ldb,ldc,j;
203:   PetscMPIInt    len;

205:   PetscFunctionBegin;
206:   PetscCall(PetscBLASIntCast(m_,&m));
207:   PetscCall(PetscBLASIntCast(n_,&n));
208:   PetscCall(PetscBLASIntCast(k_,&k));
209:   PetscCall(PetscBLASIntCast(lda_,&lda));
210:   PetscCall(PetscBLASIntCast(ldb_,&ldb));
211:   PetscCall(PetscBLASIntCast(ldc_,&ldc));
212:   if (mpi) {
213:     if (ldc==m) {
214:       PetscCall(BVAllocateWork_Private(bv,m*n));
215:       if (k) PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&m,&n,&k,&one,(PetscScalar*)A,&lda,(PetscScalar*)B,&ldb,&zero,bv->work,&ldc));
216:       else PetscCall(PetscArrayzero(bv->work,m*n));
217:       PetscCall(PetscMPIIntCast(m*n,&len));
218:       PetscCall(MPIU_Allreduce(bv->work,C,len,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)bv)));
219:     } else {
220:       PetscCall(BVAllocateWork_Private(bv,2*m*n));
221:       CC = bv->work+m*n;
222:       if (k) PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&m,&n,&k,&one,(PetscScalar*)A,&lda,(PetscScalar*)B,&ldb,&zero,bv->work,&m));
223:       else PetscCall(PetscArrayzero(bv->work,m*n));
224:       PetscCall(PetscMPIIntCast(m*n,&len));
225:       PetscCall(MPIU_Allreduce(bv->work,CC,len,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)bv)));
226:       for (j=0;j<n;j++) PetscCall(PetscArraycpy(C+j*ldc,CC+j*m,m));
227:     }
228:   } else {
229:     if (k) PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&m,&n,&k,&one,(PetscScalar*)A,&lda,(PetscScalar*)B,&ldb,&zero,C,&ldc));
230:   }
231:   PetscCall(PetscLogFlops(2.0*m*n*k));
232:   PetscFunctionReturn(PETSC_SUCCESS);
233: }

235: /*
236:     y := A'*x

238:     A is nxk (ld=lda)
239: */
240: PetscErrorCode BVDotVec_BLAS_Private(BV bv,PetscInt n_,PetscInt k_,const PetscScalar *A,PetscInt lda_,const PetscScalar *x,PetscScalar *y,PetscBool mpi)
241: {
242:   PetscScalar    zero=0.0,done=1.0;
243:   PetscBLASInt   n,k,lda,one=1;
244:   PetscMPIInt    len;

246:   PetscFunctionBegin;
247:   PetscCall(PetscBLASIntCast(n_,&n));
248:   PetscCall(PetscBLASIntCast(k_,&k));
249:   PetscCall(PetscBLASIntCast(lda_,&lda));
250:   if (mpi) {
251:     PetscCall(BVAllocateWork_Private(bv,k));
252:     if (n) PetscCallBLAS("BLASgemv",BLASgemv_("C",&n,&k,&done,A,&lda,x,&one,&zero,bv->work,&one));
253:     else PetscCall(PetscArrayzero(bv->work,k));
254:     PetscCall(PetscMPIIntCast(k,&len));
255:     PetscCall(MPIU_Allreduce(bv->work,y,len,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)bv)));
256:   } else {
257:     if (n) PetscCallBLAS("BLASgemv",BLASgemv_("C",&n,&k,&done,A,&lda,x,&one,&zero,y,&one));
258:   }
259:   PetscCall(PetscLogFlops(2.0*n*k));
260:   PetscFunctionReturn(PETSC_SUCCESS);
261: }

263: /*
264:     Scale n scalars
265: */
266: PetscErrorCode BVScale_BLAS_Private(BV bv,PetscInt n_,PetscScalar *A,PetscScalar alpha)
267: {
268:   PetscBLASInt   n,one=1;

270:   PetscFunctionBegin;
271:   if (PetscUnlikely(alpha == (PetscScalar)0.0)) PetscCall(PetscArrayzero(A,n_));
272:   else if (alpha!=(PetscScalar)1.0) {
273:     PetscCall(PetscBLASIntCast(n_,&n));
274:     PetscCallBLAS("BLASscal",BLASscal_(&n,&alpha,A,&one));
275:     PetscCall(PetscLogFlops(1.0*n));
276:   }
277:   PetscFunctionReturn(PETSC_SUCCESS);
278: }