Actual source code: slepccupmblas.h

slepc-3.22.0 2024-09-28
Report Typos and Errors
  1: /*
  2:    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  3:    SLEPc - Scalable Library for Eigenvalue Problem Computations
  4:    Copyright (c) 2002-, Universitat Politecnica de Valencia, Spain

  6:    This file is part of SLEPc.
  7:    SLEPc is distributed under a 2-clause BSD license (see LICENSE).
  8:    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
  9: */
 10: /*
 11:    Macro definitions to use cuBLAS and hipBLAS functionality
 12: */

 14: #pragma once

 16: #include <petscdevice.h>
 17: #include <petsc/private/petsclegacycupmblas.h>

 19: #if defined(PETSC_HAVE_CUDA)

 21: /* complex single */
 22: #if defined(PETSC_USE_COMPLEX)
 23: #if defined(PETSC_USE_REAL_SINGLE)
 24: #define cublasXdotc(a,b,c,d,e,f,g) cublasCdotc((a),(b),(const cuComplex *)(c),(d),(const cuComplex *)(e),(f),(cuComplex *)(g))
 25: #define cublasXgetrfBatched(a,b,c,d,e,f,g) cublasCgetrfBatched((a),(b),(cuComplex**)(c),(d),(e),(f),(g))
 26: #define cublasXgetrsBatched(a,b,c,d,e,f,g,h,i,j,k) cublasCgetrsBatched((a),(b),(c),(d),(const cuComplex**)(e),(f),(g),(cuComplex**)(h),(i),(j),(k))
 27: #else /* complex double */
 28: #define cublasXdotc(a,b,c,d,e,f,g) cublasZdotc((a),(b),(const cuDoubleComplex *)(c),(d),(const cuDoubleComplex *)(e),(f),(cuDoubleComplex *)(g))
 29: #define cublasXgetrfBatched(a,b,c,d,e,f,g) cublasZgetrfBatched((a),(b),(cuDoubleComplex**)(c),(d),(e),(f),(g))
 30: #define cublasXgetrsBatched(a,b,c,d,e,f,g,h,i,j,k) cublasZgetrsBatched((a),(b),(c),(d),(const cuDoubleComplex**)(e),(f),(g),(cuDoubleComplex**)(h),(i),(j),(k))
 31: #endif
 32: #else /* real single */
 33: #if defined(PETSC_USE_REAL_SINGLE)
 34: #define cublasXdotc cublasSdot
 35: #define cublasXgetrfBatched cublasSgetrfBatched
 36: #define cublasXgetrsBatched cublasSgetrsBatched
 37: #else /* real double */
 38: #define cublasXdotc cublasDdot
 39: #define cublasXgetrfBatched cublasDgetrfBatched
 40: #define cublasXgetrsBatched cublasDgetrsBatched
 41: #endif
 42: #endif

 44: /* the following ones are used for PetscComplex in both real and complex scalars */
 45: #if defined(PETSC_USE_REAL_SINGLE)
 46: #define cublasXCaxpy(a,b,c,d,e,f,g)                cublasCaxpy((a),(b),(const cuComplex *)(c),(const cuComplex *)(d),(e),(cuComplex *)(f),(g))
 47: #define cublasXCgemm(a,b,c,d,e,f,g,h,i,j,k,l,m,n)  cublasCgemm((a),(b),(c),(d),(e),(f),(const cuComplex *)(g),(const cuComplex *)(h),(i),(const cuComplex *)(j),(k),(const cuComplex *)(l),(cuComplex *)(m),(n))
 48: #define cublasXCscal(a,b,c,d,e)                    cublasCscal((a),(b),(const cuComplex *)(c),(cuComplex *)(d),(e))
 49: #else
 50: #define cublasXCaxpy(a,b,c,d,e,f,g)                cublasZaxpy((a),(b),(const cuDoubleComplex *)(c),(const cuDoubleComplex *)(d),(e),(cuDoubleComplex *)(f),(g))
 51: #define cublasXCgemm(a,b,c,d,e,f,g,h,i,j,k,l,m,n)  cublasZgemm((a),(b),(c),(d),(e),(f),(const cuDoubleComplex *)(g),(const cuDoubleComplex *)(h),(i),(const cuDoubleComplex *)(j),(k),(const cuDoubleComplex *)(l),(cuDoubleComplex *)(m),(n))
 52: #define cublasXCscal(a,b,c,d,e)                    cublasZscal((a),(b),(const cuDoubleComplex *)(c),(cuDoubleComplex *)(d),(e))
 53: #endif

 55: #endif // PETSC_HAVE_CUDA

 57: #if defined(PETSC_HAVE_HIP)

 59: /* complex single */
 60: #if defined(PETSC_USE_COMPLEX)
 61: #if defined(PETSC_USE_REAL_SINGLE)
 62: #define hipblasXdotc(a,b,c,d,e,f,g) hipblasCdotc((a),(b),(const hipComplex *)(c),(d),(const hipComplex *)(e),(f),(hipComplex *)(g))
 63: #else /* complex double */
 64: #define hipblasXdotc(a,b,c,d,e,f,g) hipblasZdotc((a),(b),(const hipDoubleComplex *)(c),(d),(const hipDoubleComplex *)(e),(f),(hipDoubleComplex *)(g))
 65: #endif
 66: #else /* real single */
 67: #if defined(PETSC_USE_REAL_SINGLE)
 68: #define hipblasXdotc hipblasSdot
 69: #else /* real double */
 70: #define hipblasXdotc hipblasDdot
 71: #endif
 72: #endif

 74: #endif // PETSC_HAVE_HIP