Actual source code: bvblas.c
slepc-3.22.2 2024-12-02
1: /*
2: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
3: SLEPc - Scalable Library for Eigenvalue Problem Computations
4: Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain
6: This file is part of SLEPc.
7: SLEPc is distributed under a 2-clause BSD license (see LICENSE).
8: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
9: */
10: /*
11: BV private kernels that use the BLAS
12: */
14: #include <slepc/private/bvimpl.h>
15: #include <slepcblaslapack.h>
17: #define BLOCKSIZE 64
19: /*
20: C := alpha*A*B + beta*C
22: A is mxk (ld=lda), B is kxn (ld=ldb), C is mxn (ld=ldc)
23: */
24: PetscErrorCode BVMult_BLAS_Private(BV bv,PetscInt m_,PetscInt n_,PetscInt k_,PetscScalar alpha,const PetscScalar *A,PetscInt lda_,const PetscScalar *B,PetscInt ldb_,PetscScalar beta,PetscScalar *C,PetscInt ldc_)
25: {
26: PetscBLASInt m,n,k,lda,ldb,ldc;
27: #if defined(PETSC_HAVE_FBLASLAPACK) || defined(PETSC_HAVE_F2CBLASLAPACK)
28: PetscBLASInt l,bs=BLOCKSIZE;
29: #endif
31: PetscFunctionBegin;
32: PetscCall(PetscBLASIntCast(m_,&m));
33: PetscCall(PetscBLASIntCast(n_,&n));
34: PetscCall(PetscBLASIntCast(k_,&k));
35: PetscCall(PetscBLASIntCast(lda_,&lda));
36: PetscCall(PetscBLASIntCast(ldb_,&ldb));
37: PetscCall(PetscBLASIntCast(ldc_,&ldc));
38: #if defined(PETSC_HAVE_FBLASLAPACK) || defined(PETSC_HAVE_F2CBLASLAPACK)
39: l = m % bs;
40: if (l) PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&l,&n,&k,&alpha,(PetscScalar*)A,&lda,(PetscScalar*)B,&ldb,&beta,C,&ldc));
41: for (;l<m;l+=bs) {
42: PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&bs,&n,&k,&alpha,(PetscScalar*)A+l,&lda,(PetscScalar*)B,&ldb,&beta,C+l,&ldc));
43: }
44: #else
45: if (m) PetscCallBLAS("BLASgemm",BLASgemm_("N","N",&m,&n,&k,&alpha,(PetscScalar*)A,&lda,(PetscScalar*)B,&ldb,&beta,C,&ldc));
46: #endif
47: PetscCall(PetscLogFlops(2.0*m*n*k));
48: PetscFunctionReturn(PETSC_SUCCESS);
49: }
51: /*
52: y := alpha*A*x + beta*y
54: A is nxk (ld=lda)
55: */
56: PetscErrorCode BVMultVec_BLAS_Private(BV bv,PetscInt n_,PetscInt k_,PetscScalar alpha,const PetscScalar *A,PetscInt lda_,const PetscScalar *x,PetscScalar beta,PetscScalar *y)
57: {
58: PetscBLASInt n,k,lda,one=1;
60: PetscFunctionBegin;
61: PetscCall(PetscBLASIntCast(n_,&n));
62: PetscCall(PetscBLASIntCast(k_,&k));
63: PetscCall(PetscBLASIntCast(lda_,&lda));
64: if (n) PetscCallBLAS("BLASgemv",BLASgemv_("N",&n,&k,&alpha,A,&lda,x,&one,&beta,y,&one));
65: PetscCall(PetscLogFlops(2.0*n*k));
66: PetscFunctionReturn(PETSC_SUCCESS);
67: }
69: /*
70: A(:,s:e-1) := A*B(:,s:e-1)
72: A is mxk (ld=lda), B is kxn (ld=ldb), n=e-s
73: */
74: PetscErrorCode BVMultInPlace_BLAS_Private(BV bv,PetscInt m_,PetscInt k_,PetscInt s,PetscInt e,PetscScalar *A,PetscInt lda_,const PetscScalar *B,PetscInt ldb_,PetscBool btrans)
75: {
76: PetscScalar *pb,zero=0.0,one=1.0;
77: PetscBLASInt m,n,k,l,lda,ldb,bs=BLOCKSIZE;
78: PetscInt j,n_=e-s;
79: const char *bt;
81: PetscFunctionBegin;
82: PetscCall(PetscBLASIntCast(m_,&m));
83: PetscCall(PetscBLASIntCast(n_,&n));
84: PetscCall(PetscBLASIntCast(k_,&k));
85: PetscCall(PetscBLASIntCast(lda_,&lda));
86: PetscCall(PetscBLASIntCast(ldb_,&ldb));
87: PetscCall(BVAllocateWork_Private(bv,BLOCKSIZE*n_));
88: if (PetscUnlikely(btrans)) {
89: pb = (PetscScalar*)B+s;
90: bt = "C";
91: } else {
92: pb = (PetscScalar*)B+s*ldb;
93: bt = "N";
94: }
95: l = m % bs;
96: if (l) {
97: PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&l,&n,&k,&one,A,&lda,pb,&ldb,&zero,bv->work,&l));
98: for (j=0;j<n;j++) PetscCall(PetscArraycpy(A+(s+j)*lda,bv->work+j*l,l));
99: }
100: for (;l<m;l+=bs) {
101: PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&bs,&n,&k,&one,A+l,&lda,pb,&ldb,&zero,bv->work,&bs));
102: for (j=0;j<n;j++) PetscCall(PetscArraycpy(A+(s+j)*lda+l,bv->work+j*bs,bs));
103: }
104: PetscCall(PetscLogFlops(2.0*m*n*k));
105: PetscFunctionReturn(PETSC_SUCCESS);
106: }
108: /*
109: V := V*B
111: V is mxn (ld=m), B is nxn (ld=k)
112: */
113: PetscErrorCode BVMultInPlace_Vecs_Private(BV bv,PetscInt m_,PetscInt n_,PetscInt k_,Vec *V,const PetscScalar *B,PetscBool btrans)
114: {
115: PetscScalar zero=0.0,one=1.0,*out,*pout;
116: const PetscScalar *pin;
117: PetscBLASInt m = 0,n,k,l,bs=BLOCKSIZE;
118: PetscInt j;
119: const char *bt;
121: PetscFunctionBegin;
122: PetscCall(PetscBLASIntCast(m_,&m));
123: PetscCall(PetscBLASIntCast(n_,&n));
124: PetscCall(PetscBLASIntCast(k_,&k));
125: PetscCall(BVAllocateWork_Private(bv,2*BLOCKSIZE*n_));
126: out = bv->work+BLOCKSIZE*n_;
127: if (btrans) bt = "C";
128: else bt = "N";
129: l = m % bs;
130: if (l) {
131: for (j=0;j<n;j++) {
132: PetscCall(VecGetArrayRead(V[j],&pin));
133: PetscCall(PetscArraycpy(bv->work+j*l,pin,l));
134: PetscCall(VecRestoreArrayRead(V[j],&pin));
135: }
136: PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&l,&n,&n,&one,bv->work,&l,(PetscScalar*)B,&k,&zero,out,&l));
137: for (j=0;j<n;j++) {
138: PetscCall(VecGetArray(V[j],&pout));
139: PetscCall(PetscArraycpy(pout,out+j*l,l));
140: PetscCall(VecRestoreArray(V[j],&pout));
141: }
142: }
143: for (;l<m;l+=bs) {
144: for (j=0;j<n;j++) {
145: PetscCall(VecGetArrayRead(V[j],&pin));
146: PetscCall(PetscArraycpy(bv->work+j*bs,pin+l,bs));
147: PetscCall(VecRestoreArrayRead(V[j],&pin));
148: }
149: PetscCallBLAS("BLASgemm",BLASgemm_("N",bt,&bs,&n,&n,&one,bv->work,&bs,(PetscScalar*)B,&k,&zero,out,&bs));
150: for (j=0;j<n;j++) {
151: PetscCall(VecGetArray(V[j],&pout));
152: PetscCall(PetscArraycpy(pout+l,out+j*bs,bs));
153: PetscCall(VecRestoreArray(V[j],&pout));
154: }
155: }
156: PetscCall(PetscLogFlops(2.0*n*n*k));
157: PetscFunctionReturn(PETSC_SUCCESS);
158: }
160: /*
161: B := alpha*A + beta*B
163: A,B are nxk
164: */
165: PetscErrorCode BVAXPY_BLAS_Private(BV bv,PetscInt n_,PetscInt k_,PetscScalar alpha,const PetscScalar *A,PetscInt lda_,PetscScalar beta,PetscScalar *B,PetscInt ldb_)
166: {
167: PetscBLASInt m,one=1;
168: PetscInt i,j;
170: PetscFunctionBegin;
171: if (lda_==n_ && ldb_==n_) {
172: PetscCall(PetscBLASIntCast(n_*k_,&m));
173: if (beta!=(PetscScalar)1.0) PetscCallBLAS("BLASscal",BLASscal_(&m,&beta,B,&one));
174: PetscCallBLAS("BLASaxpy",BLASaxpy_(&m,&alpha,A,&one,B,&one));
175: } else {
176: if (beta!=(PetscScalar)1.0) {
177: for (j=0;j<k_;j++) {
178: for (i=0;i<n_;i++) {
179: B[i+j*ldb_] = alpha*A[i+j*lda_] + beta*B[i+j*ldb_];
180: }
181: }
182: } else {
183: for (j=0;j<k_;j++) {
184: for (i=0;i<n_;i++) {
185: B[i+j*ldb_] += alpha*A[i+j*lda_];
186: }
187: }
188: }
189: }
190: PetscCall(PetscLogFlops((beta==(PetscScalar)1.0)?2.0*n_*k_:3.0*n_*k_));
191: PetscFunctionReturn(PETSC_SUCCESS);
192: }
194: /*
195: C := A'*B
197: A' is mxk (ld=lda), B is kxn (ld=ldb), C is mxn (ld=ldc)
198: */
199: PetscErrorCode BVDot_BLAS_Private(BV bv,PetscInt m_,PetscInt n_,PetscInt k_,const PetscScalar *A,PetscInt lda_,const PetscScalar *B,PetscInt ldb_,PetscScalar *C,PetscInt ldc_,PetscBool mpi)
200: {
201: PetscScalar zero=0.0,one=1.0,*CC;
202: PetscBLASInt m,n,k,lda,ldb,ldc,j;
203: PetscMPIInt len;
205: PetscFunctionBegin;
206: PetscCall(PetscBLASIntCast(m_,&m));
207: PetscCall(PetscBLASIntCast(n_,&n));
208: PetscCall(PetscBLASIntCast(k_,&k));
209: PetscCall(PetscBLASIntCast(lda_,&lda));
210: PetscCall(PetscBLASIntCast(ldb_,&ldb));
211: PetscCall(PetscBLASIntCast(ldc_,&ldc));
212: if (mpi) {
213: if (ldc==m) {
214: PetscCall(BVAllocateWork_Private(bv,m*n));
215: if (k) PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&m,&n,&k,&one,(PetscScalar*)A,&lda,(PetscScalar*)B,&ldb,&zero,bv->work,&ldc));
216: else PetscCall(PetscArrayzero(bv->work,m*n));
217: PetscCall(PetscMPIIntCast(m*n,&len));
218: PetscCallMPI(MPIU_Allreduce(bv->work,C,len,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)bv)));
219: } else {
220: PetscCall(BVAllocateWork_Private(bv,2*m*n));
221: CC = bv->work+m*n;
222: if (k) PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&m,&n,&k,&one,(PetscScalar*)A,&lda,(PetscScalar*)B,&ldb,&zero,bv->work,&m));
223: else PetscCall(PetscArrayzero(bv->work,m*n));
224: PetscCall(PetscMPIIntCast(m*n,&len));
225: PetscCallMPI(MPIU_Allreduce(bv->work,CC,len,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)bv)));
226: for (j=0;j<n;j++) PetscCall(PetscArraycpy(C+j*ldc,CC+j*m,m));
227: }
228: } else {
229: if (k) PetscCallBLAS("BLASgemm",BLASgemm_("C","N",&m,&n,&k,&one,(PetscScalar*)A,&lda,(PetscScalar*)B,&ldb,&zero,C,&ldc));
230: }
231: PetscCall(PetscLogFlops(2.0*m*n*k));
232: PetscFunctionReturn(PETSC_SUCCESS);
233: }
235: /*
236: y := A'*x
238: A is nxk (ld=lda)
239: */
240: PetscErrorCode BVDotVec_BLAS_Private(BV bv,PetscInt n_,PetscInt k_,const PetscScalar *A,PetscInt lda_,const PetscScalar *x,PetscScalar *y,PetscBool mpi)
241: {
242: PetscScalar zero=0.0,done=1.0;
243: PetscBLASInt n,k,lda,one=1;
244: PetscMPIInt len;
246: PetscFunctionBegin;
247: PetscCall(PetscBLASIntCast(n_,&n));
248: PetscCall(PetscBLASIntCast(k_,&k));
249: PetscCall(PetscBLASIntCast(lda_,&lda));
250: if (mpi) {
251: PetscCall(BVAllocateWork_Private(bv,k));
252: if (n) PetscCallBLAS("BLASgemv",BLASgemv_("C",&n,&k,&done,A,&lda,x,&one,&zero,bv->work,&one));
253: else PetscCall(PetscArrayzero(bv->work,k));
254: PetscCall(PetscMPIIntCast(k,&len));
255: PetscCallMPI(MPIU_Allreduce(bv->work,y,len,MPIU_SCALAR,MPIU_SUM,PetscObjectComm((PetscObject)bv)));
256: } else {
257: if (n) PetscCallBLAS("BLASgemv",BLASgemv_("C",&n,&k,&done,A,&lda,x,&one,&zero,y,&one));
258: }
259: PetscCall(PetscLogFlops(2.0*n*k));
260: PetscFunctionReturn(PETSC_SUCCESS);
261: }
263: /*
264: Scale n scalars
265: */
266: PetscErrorCode BVScale_BLAS_Private(BV bv,PetscInt n_,PetscScalar *A,PetscScalar alpha)
267: {
268: PetscBLASInt n,one=1;
270: PetscFunctionBegin;
271: if (PetscUnlikely(alpha == (PetscScalar)0.0)) PetscCall(PetscArrayzero(A,n_));
272: else if (alpha!=(PetscScalar)1.0) {
273: PetscCall(PetscBLASIntCast(n_,&n));
274: PetscCallBLAS("BLASscal",BLASscal_(&n,&alpha,A,&one));
275: PetscCall(PetscLogFlops(1.0*n));
276: }
277: PetscFunctionReturn(PETSC_SUCCESS);
278: }