Actual source code: cupmblasinterface.hpp

  1: #ifndef PETSCCUPMBLASINTERFACE_HPP
  2: #define PETSCCUPMBLASINTERFACE_HPP

  4: #if defined(__cplusplus)
  5: #include <petsc/private/cupminterface.hpp>
  6: #include <petsc/private/petscadvancedmacros.h>

  8:   #include <limits> // std::numeric_limits

 10: namespace Petsc
 11: {

 13: namespace device
 14: {

 16: namespace cupm
 17: {

 19: namespace impl
 20: {

 22:   #define PetscCallCUPMBLAS(...) \
 23:     do { \
 24:       const cupmBlasError_t cberr_p_ = __VA_ARGS__; \
 25:       if (PetscUnlikely(cberr_p_ != CUPMBLAS_STATUS_SUCCESS)) { \
 26:         if (((cberr_p_ == CUPMBLAS_STATUS_NOT_INITIALIZED) || (cberr_p_ == CUPMBLAS_STATUS_ALLOC_FAILED)) && PetscDeviceInitialized(PETSC_DEVICE_CUPM())) { \
 27:           SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, \
 28:                   "%s error %d (%s). Reports not initialized or alloc failed; " \
 29:                   "this indicates the GPU may have run out resources", \
 30:                   cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
 31:         } \
 32:         SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "%s error %d (%s)", cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
 33:       } \
 34:     } while (0)

 36:   #define PetscCallCUPMBLASAbort(comm, ...) \
 37:     do { \
 38:       const cupmBlasError_t cberr_abort_p_ = __VA_ARGS__; \
 39:       if (PetscUnlikely(cberr_abort_p_ != CUPMBLAS_STATUS_SUCCESS)) { \
 40:         if (((cberr_abort_p_ == CUPMBLAS_STATUS_NOT_INITIALIZED) || (cberr_abort_p_ == CUPMBLAS_STATUS_ALLOC_FAILED)) && PetscDeviceInitialized(PETSC_DEVICE_CUPM())) { \
 41:           SETERRABORT(comm, PETSC_ERR_GPU_RESOURCE, \
 42:                       "%s error %d (%s). Reports not initialized or alloc failed; " \
 43:                       "this indicates the GPU may have run out resources", \
 44:                       cupmBlasName(), static_cast<PetscErrorCode>(cberr_abort_p_), cupmBlasGetErrorName(cberr_abort_p_)); \
 45:         } \
 46:         SETERRABORT(comm, PETSC_ERR_GPU, "%s error %d (%s)", cupmBlasName(), static_cast<PetscErrorCode>(cberr_abort_p_), cupmBlasGetErrorName(cberr_abort_p_)); \
 47:       } \
 48:     } while (0)

 50:   // given cupmBlas<T>axpy() then
 51:   // T = PETSC_CUPBLAS_FP_TYPE
 52:   // given cupmBlas<T><u>nrm2() then
 53:   // T = PETSC_CUPMBLAS_FP_INPUT_TYPE
 54:   // u = PETSC_CUPMBLAS_FP_RETURN_TYPE
 55:   #if PetscDefined(USE_COMPLEX)
 56:     #if PetscDefined(USE_REAL_SINGLE)
 57:       #define PETSC_CUPMBLAS_FP_TYPE_U       C
 58:       #define PETSC_CUPMBLAS_FP_TYPE_L       c
 59:       #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U S
 60:       #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L s
 61:     #elif PetscDefined(USE_REAL_DOUBLE)
 62:       #define PETSC_CUPMBLAS_FP_TYPE_U       Z
 63:       #define PETSC_CUPMBLAS_FP_TYPE_L       z
 64:       #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U D
 65:       #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L d
 66:     #endif
 67:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
 68:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
 69:   #else
 70:     #if PetscDefined(USE_REAL_SINGLE)
 71:       #define PETSC_CUPMBLAS_FP_TYPE_U S
 72:       #define PETSC_CUPMBLAS_FP_TYPE_L s
 73:     #elif PetscDefined(USE_REAL_DOUBLE)
 74:       #define PETSC_CUPMBLAS_FP_TYPE_U D
 75:       #define PETSC_CUPMBLAS_FP_TYPE_L d
 76:     #endif
 77:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
 78:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
 79:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U
 80:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L
 81:   #endif // USE_COMPLEX

 83:   #if !defined(PETSC_CUPMBLAS_FP_TYPE_U) && !PetscDefined(USE_REAL___FLOAT128)
 84:     #error "Unsupported floating-point type for CUDA/HIP BLAS"
 85:   #endif

 87:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED() - Helper macro to build a "modified"
 88:   // blas function whose return type does not match the input type
 89:   //
 90:   // input param:
 91:   // func - base suffix of the blas function, e.g. nrm2
 92:   //
 93:   // notes:
 94:   // requires PETSC_CUPMBLAS_FP_INPUT_TYPE to be defined as the blas floating point input type
 95:   // letter ("S" for real/complex single, "D" for real/complex double).
 96:   //
 97:   // requires PETSC_CUPMBLAS_FP_RETURN_TYPE to be defined as the blas floating point output type
 98:   // letter ("c" for complex single, "z" for complex double and <absolutely nothing> for real
 99:   // single/double).
100:   //
101:   // In their infinite wisdom nvidia/amd have made the upper-case vs lower-case scheme
102:   // infuriatingly inconsistent...
103:   //
104:   // example usage:
105:   // #define PETSC_CUPMBLAS_FP_INPUT_TYPE  S
106:   // #define PETSC_CUPMBLAS_FP_RETURN_TYPE
107:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Snrm2
108:   //
109:   // #define PETSC_CUPMBLAS_FP_INPUT_TYPE  D
110:   // #define PETSC_CUPMBLAS_FP_RETURN_TYPE z
111:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Dznrm2
112:   #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(func) PetscConcat(PetscConcat(PETSC_CUPMBLAS_FP_INPUT_TYPE, PETSC_CUPMBLAS_FP_RETURN_TYPE), func)

114:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE() - Helper macro to build Iamax and Iamin
115:   // because they are both extra special
116:   //
117:   // input param:
118:   // func - base suffix of the blas function, either amax or amin
119:   //
120:   // notes:
121:   // The macro name literally stands for "I" ## "floating point type" because shockingly enough,
122:   // that's what it does.
123:   //
124:   // requires PETSC_CUPMBLAS_FP_TYPE_L to be defined as the lower-case blas floating point input type
125:   // letter ("s" for complex single, "z" for complex double, "s" for real single, and "d" for
126:   // real double).
127:   //
128:   // example usage:
129:   // #define PETSC_CUPMBLAS_FP_TYPE_L s
130:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amax) -> Isamax
131:   //
132:   // #define PETSC_CUPMBLAS_FP_TYPE_L z
133:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amin) -> Izamin
134:   #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(func) PetscConcat(I, PetscConcat(PETSC_CUPMBLAS_FP_TYPE_L, func))

136:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD() - Helper macro to build a "standard"
137:   // blas function name
138:   //
139:   // input param:
140:   // func - base suffix of the blas function, e.g. axpy, scal
141:   //
142:   // notes:
143:   // requires PETSC_CUPMBLAS_FP_TYPE to be defined as the blas floating-point letter ("C" for
144:   // complex single, "Z" for complex double, "S" for real single, "D" for real double).
145:   //
146:   // example usage:
147:   // #define PETSC_CUPMBLAS_FP_TYPE S
148:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Saxpy
149:   //
150:   // #define PETSC_CUPMBLAS_FP_TYPE Z
151:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Zaxpy
152:   #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(func) PetscConcat(PETSC_CUPMBLAS_FP_TYPE, func)

154:   // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT() - In case CUDA/HIP don't agree with our suffix
155:   // one can provide both here
156:   //
157:   // input params:
158:   // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
159:   // IFPTYPE
160:   // our_suffix   - the suffix of the alias function
161:   // their_suffix - the suffix of the function being aliased
162:   //
163:   // notes:
164:   // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas function
165:   // prefix. requires any other specific definitions required by the specific builder macro to
166:   // also be defined. See PETSC_CUPM_ALIAS_FUNCTION_EXACT() for the exact expansion of the
167:   // function alias.
168:   //
169:   // example usage:
170:   // #define PETSC_CUPMBLAS_PREFIX  cublas
171:   // #define PETSC_CUPMBLAS_FP_TYPE C
172:   // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD,dot,dotc) ->
173:   // template <typename... T>
174:   // static constexpr auto cupmBlasXdot(T&&... args) *noexcept and returntype detection*
175:   // {
176:   //   return cublasCdotc(std::forward<T>(args)...);
177:   // }
178:   #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, our_suffix, their_suffix) \
179:     PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlasX, our_suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, PetscConcat(PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_, MACRO_SUFFIX)(their_suffix)))

181:   // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION() - Alias a CUDA/HIP blas function
182:   //
183:   // input params:
184:   // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
185:   // IFPTYPE
186:   // suffix       - the common suffix between CUDA and HIP of the alias function
187:   //
188:   // notes:
189:   // see PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(), this macro just calls that one with "suffix" as
190:   // "our_prefix" and "their_prefix"
191:   #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MACRO_SUFFIX, suffix) PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, suffix, suffix)

193:   // PETSC_CUPMBLAS_ALIAS_FUNCTION() - Alias a CUDA/HIP library function
194:   //
195:   // input params:
196:   // suffix - the common suffix between CUDA and HIP of the alias function
197:   //
198:   // notes:
199:   // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas library
200:   // prefix. see PETSC_CUPMM_ALIAS_FUNCTION_EXACT() for the precise expansion of this macro.
201:   //
202:   // example usage:
203:   // #define PETSC_CUPMBLAS_PREFIX hipblas
204:   // PETSC_CUPMBLAS_ALIAS_FUNCTION(Create) ->
205:   // template <typename... T>
206:   // static constexpr auto cupmBlasCreate(T&&... args) *noexcept and returntype detection*
207:   // {
208:   //   return hipblasCreate(std::forward<T>(args)...);
209:   // }
210:   #define PETSC_CUPMBLAS_ALIAS_FUNCTION(suffix) PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlas, suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, suffix))

212: template <DeviceType T>
213: struct BlasInterfaceBase : Interface<T> {
214:   PETSC_NODISCARD static constexpr const char *cupmBlasName() noexcept { return T == DeviceType::CUDA ? "cuBLAS" : "hipBLAS"; }
215: };

217:   #define PETSC_CUPMBLAS_BASE_CLASS_HEADER(DEV_TYPE) \
218:     using base_type = ::Petsc::device::cupm::impl::BlasInterfaceBase<DEV_TYPE>; \
219:     using base_type::cupmBlasName; \
220:     PETSC_CUPM_ALIAS_FUNCTION(cupmBlasGetErrorName, PetscConcat(PetscConcat(Petsc, PETSC_CUPMBLAS_PREFIX_U), GetErrorName)) \
221:     PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(interface_type, DEV_TYPE)

223: template <DeviceType>
224: struct BlasInterfaceImpl;

226:   #if PetscDefined(HAVE_CUDA)
227:     #define PETSC_CUPMBLAS_PREFIX         cublas
228:     #define PETSC_CUPMBLAS_PREFIX_U       CUBLAS
229:     #define PETSC_CUPMBLAS_FP_TYPE        PETSC_CUPMBLAS_FP_TYPE_U
230:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE  PETSC_CUPMBLAS_FP_INPUT_TYPE_U
231:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
232: template <>
233: struct BlasInterfaceImpl<DeviceType::CUDA> : BlasInterfaceBase<DeviceType::CUDA> {
234:   PETSC_CUPMBLAS_BASE_CLASS_HEADER(DeviceType::CUDA);

236:   // typedefs
237:   using cupmBlasHandle_t      = cublasHandle_t;
238:   using cupmBlasError_t       = cublasStatus_t;
239:   using cupmBlasInt_t         = int;
240:   using cupmSolverHandle_t    = cusolverDnHandle_t;
241:   using cupmSolverError_t     = cusolverStatus_t;
242:   using cupmBlasPointerMode_t = cublasPointerMode_t;

244:   // values
245:   static const auto CUPMBLAS_STATUS_SUCCESS         = CUBLAS_STATUS_SUCCESS;
246:   static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = CUBLAS_STATUS_NOT_INITIALIZED;
247:   static const auto CUPMBLAS_STATUS_ALLOC_FAILED    = CUBLAS_STATUS_ALLOC_FAILED;
248:   static const auto CUPMBLAS_POINTER_MODE_HOST      = CUBLAS_POINTER_MODE_HOST;
249:   static const auto CUPMBLAS_POINTER_MODE_DEVICE    = CUBLAS_POINTER_MODE_DEVICE;

251:   // utility functions
252:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
253:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
254:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
255:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
256:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
257:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)

259:   // level 1 BLAS
260:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
261:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
262:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
263:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
264:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
265:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
266:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
267:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)

269:   // level 2 BLAS
270:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)

272:   // level 3 BLAS
273:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)

275:   // BLAS extensions
276:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)

278:   static PetscErrorCode InitializeHandle(cupmSolverHandle_t &handle) noexcept
279:   {
280:     PetscFunctionBegin;
281:     if (handle) PetscFunctionReturn(PETSC_SUCCESS);
282:     for (auto i = 0; i < 3; ++i) {
283:       const auto cerr = cusolverDnCreate(&handle);
284:       if (PetscLikely(cerr == CUSOLVER_STATUS_SUCCESS)) break;
285:       if ((cerr != CUSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUSOLVER(cerr);
286:       if (i < 2) {
287:         PetscCall(PetscSleep(3));
288:         continue;
289:       }
290:       PetscCheck(cerr == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize cuSolverDn");
291:     }
292:     PetscFunctionReturn(PETSC_SUCCESS);
293:   }

295:   static PetscErrorCode SetHandleStream(const cupmSolverHandle_t &handle, const cupmStream_t &stream) noexcept
296:   {
297:     cupmStream_t cupmStream;

299:     PetscFunctionBegin;
300:     PetscCallCUSOLVER(cusolverDnGetStream(handle, &cupmStream));
301:     if (cupmStream != stream) PetscCallCUSOLVER(cusolverDnSetStream(handle, stream));
302:     PetscFunctionReturn(PETSC_SUCCESS);
303:   }

305:   static PetscErrorCode DestroyHandle(cupmSolverHandle_t &handle) noexcept
306:   {
307:     PetscFunctionBegin;
308:     if (handle) {
309:       PetscCallCUSOLVER(cusolverDnDestroy(handle));
310:       handle = nullptr;
311:     }
312:     PetscFunctionReturn(PETSC_SUCCESS);
313:   }
314: };
315:     #undef PETSC_CUPMBLAS_PREFIX
316:     #undef PETSC_CUPMBLAS_PREFIX_U
317:     #undef PETSC_CUPMBLAS_FP_TYPE
318:     #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
319:     #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
320:   #endif // PetscDefined(HAVE_CUDA)

322:   #if PetscDefined(HAVE_HIP)
323:     #define PETSC_CUPMBLAS_PREFIX         hipblas
324:     #define PETSC_CUPMBLAS_PREFIX_U       HIPBLAS
325:     #define PETSC_CUPMBLAS_FP_TYPE        PETSC_CUPMBLAS_FP_TYPE_U
326:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE  PETSC_CUPMBLAS_FP_INPUT_TYPE_U
327:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
328: template <>
329: struct BlasInterfaceImpl<DeviceType::HIP> : BlasInterfaceBase<DeviceType::HIP> {
330:   PETSC_CUPMBLAS_BASE_CLASS_HEADER(DeviceType::HIP);

332:   // typedefs
333:   using cupmBlasHandle_t      = hipblasHandle_t;
334:   using cupmBlasError_t       = hipblasStatus_t;
335:   using cupmBlasInt_t         = int; // rocblas will have its own
336:   using cupmSolverHandle_t    = hipsolverHandle_t;
337:   using cupmSolverError_t     = hipsolverStatus_t;
338:   using cupmBlasPointerMode_t = hipblasPointerMode_t;

340:   // values
341:   static const auto CUPMBLAS_STATUS_SUCCESS         = HIPBLAS_STATUS_SUCCESS;
342:   static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = HIPBLAS_STATUS_NOT_INITIALIZED;
343:   static const auto CUPMBLAS_STATUS_ALLOC_FAILED    = HIPBLAS_STATUS_ALLOC_FAILED;
344:   static const auto CUPMBLAS_POINTER_MODE_HOST      = HIPBLAS_POINTER_MODE_HOST;
345:   static const auto CUPMBLAS_POINTER_MODE_DEVICE    = HIPBLAS_POINTER_MODE_DEVICE;

347:   // utility functions
348:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
349:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
350:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
351:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
352:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
353:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)

355:   // level 1 BLAS
356:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
357:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
358:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
359:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
360:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
361:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
362:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
363:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)

365:   // level 2 BLAS
366:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)

368:   // level 3 BLAS
369:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)

371:   // BLAS extensions
372:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)

374:   static PetscErrorCode InitializeHandle(cupmSolverHandle_t &handle) noexcept
375:   {
376:     PetscFunctionBegin;
377:     if (!handle) PetscCallHIPSOLVER(hipsolverCreate(&handle));
378:     PetscFunctionReturn(PETSC_SUCCESS);
379:   }

381:   static PetscErrorCode SetHandleStream(cupmSolverHandle_t handle, cupmStream_t stream) noexcept
382:   {
383:     PetscFunctionBegin;
384:     PetscCallHIPSOLVER(hipsolverSetStream(handle, stream));
385:     PetscFunctionReturn(PETSC_SUCCESS);
386:   }

388:   static PetscErrorCode DestroyHandle(cupmSolverHandle_t &handle) noexcept
389:   {
390:     PetscFunctionBegin;
391:     if (handle) {
392:       PetscCallHIPSOLVER(hipsolverDestroy(handle));
393:       handle = nullptr;
394:     }
395:     PetscFunctionReturn(PETSC_SUCCESS);
396:   }
397: };
398:     #undef PETSC_CUPMBLAS_PREFIX
399:     #undef PETSC_CUPMBLAS_PREFIX_U
400:     #undef PETSC_CUPMBLAS_FP_TYPE
401:     #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
402:     #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
403:   #endif // PetscDefined(HAVE_HIP)

405:   #undef PETSC_CUPMBLAS_BASE_CLASS_HEADER

407:   #define PETSC_CUPMBLAS_IMPL_CLASS_HEADER(base_name, T) \
408:     PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(cupmInterface_t, T); \
409:     using base_name = ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>; \
410:     /* introspection */ \
411:     using base_name::cupmBlasName; \
412:     using base_name::cupmBlasGetErrorName; \
413:     /* types */ \
414:     using cupmBlasHandle_t      = typename base_name::cupmBlasHandle_t; \
415:     using cupmBlasError_t       = typename base_name::cupmBlasError_t; \
416:     using cupmBlasInt_t         = typename base_name::cupmBlasInt_t; \
417:     using cupmSolverHandle_t    = typename base_name::cupmSolverHandle_t; \
418:     using cupmSolverError_t     = typename base_name::cupmSolverError_t; \
419:     using cupmBlasPointerMode_t = typename base_name::cupmBlasPointerMode_t; \
420:     /* values */ \
421:     using base_name::CUPMBLAS_STATUS_SUCCESS; \
422:     using base_name::CUPMBLAS_STATUS_NOT_INITIALIZED; \
423:     using base_name::CUPMBLAS_STATUS_ALLOC_FAILED; \
424:     using base_name::CUPMBLAS_POINTER_MODE_HOST; \
425:     using base_name::CUPMBLAS_POINTER_MODE_DEVICE; \
426:     /* utility functions */ \
427:     using base_name::cupmBlasCreate; \
428:     using base_name::cupmBlasDestroy; \
429:     using base_name::cupmBlasGetStream; \
430:     using base_name::cupmBlasSetStream; \
431:     using base_name::cupmBlasGetPointerMode; \
432:     using base_name::cupmBlasSetPointerMode; \
433:     /* level 1 BLAS */ \
434:     using base_name::cupmBlasXaxpy; \
435:     using base_name::cupmBlasXscal; \
436:     using base_name::cupmBlasXdot; \
437:     using base_name::cupmBlasXdotu; \
438:     using base_name::cupmBlasXswap; \
439:     using base_name::cupmBlasXnrm2; \
440:     using base_name::cupmBlasXamax; \
441:     using base_name::cupmBlasXasum; \
442:     /* level 2 BLAS */ \
443:     using base_name::cupmBlasXgemv; \
444:     /* level 3 BLAS */ \
445:     using base_name::cupmBlasXgemm; \
446:     /* BLAS extensions */ \
447:     using base_name::cupmBlasXgeam

449: // The actual interface class
450: template <DeviceType T>
451: struct BlasInterface : BlasInterfaceImpl<T> {
452:   PETSC_CUPMBLAS_IMPL_CLASS_HEADER(blasinterface_type, T);

454:   static PetscErrorCode PetscCUPMBlasSetPointerModeFromPointer(cupmBlasHandle_t handle, const void *ptr) noexcept
455:   {
456:     auto mtype = PETSC_MEMTYPE_HOST;

458:     PetscFunctionBegin;
459:     PetscCall(PetscCUPMGetMemType(ptr, &mtype));
460:     PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, PetscMemTypeDevice(mtype) ? CUPMBLAS_POINTER_MODE_DEVICE : CUPMBLAS_POINTER_MODE_HOST));
461:     PetscFunctionReturn(PETSC_SUCCESS);
462:   }

464:   static PetscErrorCode checkCupmBlasIntCast(PetscInt x) noexcept
465:   {
466:     PetscFunctionBegin;
467:     PetscCheck((std::is_same<PetscInt, cupmBlasInt_t>::value) || (x <= std::numeric_limits<cupmBlasInt_t>::max()), PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "%" PetscInt_FMT " is too big for %s, which may be restricted to 32 bit integers", x, cupmBlasName());
468:     PetscCheck(x >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Passing negative integer to %s routine: %" PetscInt_FMT, cupmBlasName(), x);
469:     PetscFunctionReturn(PETSC_SUCCESS);
470:   }

472:   PETSC_NODISCARD static cupmBlasInt_t cupmBlasIntCast(PetscInt x) noexcept
473:   {
474:     PetscFunctionBegin;
475:     PetscCallAbort(PETSC_COMM_SELF, checkCupmBlasIntCast(x));
476:     PetscFunctionReturn(static_cast<cupmBlasInt_t>(x));
477:   }
478: };

480:   #define PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(base_name, T) \
481:     PETSC_CUPMBLAS_IMPL_CLASS_HEADER(PetscConcat(base_name, _impl), T); \
482:     using base_name = ::Petsc::device::cupm::impl::BlasInterface<T>; \
483:     using base_name::PetscCUPMBlasSetPointerModeFromPointer; \
484:     using base_name::checkCupmBlasIntCast; \
485:     using base_name::cupmBlasIntCast

487:   #if PetscDefined(HAVE_CUDA)
488: extern template struct BlasInterface<DeviceType::CUDA>;
489:   #endif

491:   #if PetscDefined(HAVE_HIP)
492: extern template struct BlasInterface<DeviceType::HIP>;
493:   #endif

495: } // namespace impl

497: } // namespace cupm

499: } // namespace device

501: } // namespace Petsc

503: #endif // defined(__cplusplus)

505: #endif // PETSCCUPMBLASINTERFACE_HPP