Actual source code: cycliccuda.cu
slepc-3.21.2 2024-09-25
1: /*
2: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
3: SLEPc - Scalable Library for Eigenvalue Problem Computations
4: Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain
6: This file is part of SLEPc.
7: SLEPc is distributed under a 2-clause BSD license (see LICENSE).
8: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
9: */
10: /*
11: SLEPc singular value solver: "cyclic" (CUDA implementation)
12: */
13: #include <slepc/private/svdimpl.h>
14: #include "../src/svd/impls/cyclic/cyclic.h"
16: PetscErrorCode MatMult_Cyclic_CUDA(Mat B,Vec x,Vec y)
17: {
18: SVD_CYCLIC_SHELL *ctx;
19: const PetscScalar *d_px;
20: PetscScalar *d_py;
21: PetscInt m;
23: PetscFunctionBegin;
24: PetscCall(MatShellGetContext(B,&ctx));
25: PetscCall(MatGetLocalSize(ctx->A,&m,NULL));
26: PetscCall(VecCUDAGetArrayRead(x,&d_px));
27: PetscCall(VecCUDAGetArrayWrite(y,&d_py));
28: PetscCall(VecCUDAPlaceArray(ctx->x1,d_px));
29: PetscCall(VecCUDAPlaceArray(ctx->x2,d_px+m));
30: PetscCall(VecCUDAPlaceArray(ctx->y1,d_py));
31: PetscCall(VecCUDAPlaceArray(ctx->y2,d_py+m));
32: if (!ctx->misaligned) {
33: PetscCall(MatMult(ctx->A,ctx->x2,ctx->y1));
34: PetscCall(MatMult(ctx->AT,ctx->x1,ctx->y2));
35: } else { /* prevent CUDA errors when bottom part is misaligned */
36: PetscCall(VecCopy(ctx->x2,ctx->wx2));
37: PetscCall(MatMult(ctx->A,ctx->wx2,ctx->y1));
38: PetscCall(MatMult(ctx->AT,ctx->x1,ctx->wy2));
39: PetscCall(VecCopy(ctx->wy2,ctx->y2));
40: }
41: PetscCall(VecCUDAResetArray(ctx->x1));
42: PetscCall(VecCUDAResetArray(ctx->x2));
43: PetscCall(VecCUDAResetArray(ctx->y1));
44: PetscCall(VecCUDAResetArray(ctx->y2));
45: PetscCall(VecCUDARestoreArrayRead(x,&d_px));
46: PetscCall(VecCUDARestoreArrayWrite(y,&d_py));
47: PetscFunctionReturn(PETSC_SUCCESS);
48: }
50: PetscErrorCode MatMult_ECross_CUDA(Mat B,Vec x,Vec y)
51: {
52: SVD_CYCLIC_SHELL *ctx;
53: const PetscScalar *d_px;
54: PetscScalar *d_py;
55: PetscInt mn,m,n;
57: PetscFunctionBegin;
58: PetscCall(MatShellGetContext(B,&ctx));
59: PetscCall(MatGetLocalSize(ctx->A,NULL,&n));
60: PetscCall(VecGetLocalSize(y,&mn));
61: m = mn-n;
62: PetscCall(VecCUDAGetArrayRead(x,&d_px));
63: PetscCall(VecCUDAGetArrayWrite(y,&d_py));
64: PetscCall(VecCUDAPlaceArray(ctx->x1,d_px));
65: PetscCall(VecCUDAPlaceArray(ctx->x2,d_px+m));
66: PetscCall(VecCUDAPlaceArray(ctx->y1,d_py));
67: PetscCall(VecCUDAPlaceArray(ctx->y2,d_py+m));
68: PetscCall(VecCopy(ctx->x1,ctx->y1));
69: if (!ctx->misaligned) {
70: PetscCall(MatMult(ctx->A,ctx->x2,ctx->w));
71: PetscCall(MatMult(ctx->AT,ctx->w,ctx->y2));
72: } else { /* prevent CUDA errors when bottom part is misaligned */
73: PetscCall(VecCopy(ctx->x2,ctx->wx2));
74: PetscCall(MatMult(ctx->A,ctx->wx2,ctx->w));
75: PetscCall(MatMult(ctx->AT,ctx->w,ctx->wy2));
76: PetscCall(VecCopy(ctx->wy2,ctx->y2));
77: }
78: PetscCall(VecCUDAResetArray(ctx->x1));
79: PetscCall(VecCUDAResetArray(ctx->x2));
80: PetscCall(VecCUDAResetArray(ctx->y1));
81: PetscCall(VecCUDAResetArray(ctx->y2));
82: PetscCall(VecCUDARestoreArrayRead(x,&d_px));
83: PetscCall(VecCUDARestoreArrayWrite(y,&d_py));
84: PetscFunctionReturn(PETSC_SUCCESS);
85: }