Actual source code: cupmblasinterface.hpp
1: #ifndef PETSCCUPMBLASINTERFACE_HPP
2: #define PETSCCUPMBLASINTERFACE_HPP
4: #if defined(__cplusplus)
5: #include <petsc/private/cupminterface.hpp>
6: #include <petsc/private/petscadvancedmacros.h>
8: #include <limits> // std::numeric_limits
10: namespace Petsc
11: {
13: namespace device
14: {
16: namespace cupm
17: {
19: namespace impl
20: {
22: #define PetscCallCUPMBLAS(...) \
23: do { \
24: const cupmBlasError_t cberr_p_ = __VA_ARGS__; \
25: if (PetscUnlikely(cberr_p_ != CUPMBLAS_STATUS_SUCCESS)) { \
26: if (((cberr_p_ == CUPMBLAS_STATUS_NOT_INITIALIZED) || (cberr_p_ == CUPMBLAS_STATUS_ALLOC_FAILED)) && PetscDeviceInitialized(PETSC_DEVICE_CUPM())) { \
27: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, \
28: "%s error %d (%s). Reports not initialized or alloc failed; " \
29: "this indicates the GPU may have run out resources", \
30: cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
31: } \
32: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "%s error %d (%s)", cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
33: } \
34: } while (0)
36: #define PetscCallCUPMBLASAbort(comm, ...) \
37: do { \
38: const cupmBlasError_t cberr_abort_p_ = __VA_ARGS__; \
39: if (PetscUnlikely(cberr_abort_p_ != CUPMBLAS_STATUS_SUCCESS)) { \
40: if (((cberr_abort_p_ == CUPMBLAS_STATUS_NOT_INITIALIZED) || (cberr_abort_p_ == CUPMBLAS_STATUS_ALLOC_FAILED)) && PetscDeviceInitialized(PETSC_DEVICE_CUPM())) { \
41: SETERRABORT(comm, PETSC_ERR_GPU_RESOURCE, \
42: "%s error %d (%s). Reports not initialized or alloc failed; " \
43: "this indicates the GPU may have run out resources", \
44: cupmBlasName(), static_cast<PetscErrorCode>(cberr_abort_p_), cupmBlasGetErrorName(cberr_abort_p_)); \
45: } \
46: SETERRABORT(comm, PETSC_ERR_GPU, "%s error %d (%s)", cupmBlasName(), static_cast<PetscErrorCode>(cberr_abort_p_), cupmBlasGetErrorName(cberr_abort_p_)); \
47: } \
48: } while (0)
50: // given cupmBlas<T>axpy() then
51: // T = PETSC_CUPBLAS_FP_TYPE
52: // given cupmBlas<T><u>nrm2() then
53: // T = PETSC_CUPMBLAS_FP_INPUT_TYPE
54: // u = PETSC_CUPMBLAS_FP_RETURN_TYPE
55: #if PetscDefined(USE_COMPLEX)
56: #if PetscDefined(USE_REAL_SINGLE)
57: #define PETSC_CUPMBLAS_FP_TYPE_U C
58: #define PETSC_CUPMBLAS_FP_TYPE_L c
59: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U S
60: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L s
61: #elif PetscDefined(USE_REAL_DOUBLE)
62: #define PETSC_CUPMBLAS_FP_TYPE_U Z
63: #define PETSC_CUPMBLAS_FP_TYPE_L z
64: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U D
65: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L d
66: #endif
67: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
68: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
69: #else
70: #if PetscDefined(USE_REAL_SINGLE)
71: #define PETSC_CUPMBLAS_FP_TYPE_U S
72: #define PETSC_CUPMBLAS_FP_TYPE_L s
73: #elif PetscDefined(USE_REAL_DOUBLE)
74: #define PETSC_CUPMBLAS_FP_TYPE_U D
75: #define PETSC_CUPMBLAS_FP_TYPE_L d
76: #endif
77: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
78: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
79: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U
80: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L
81: #endif // USE_COMPLEX
83: #if !defined(PETSC_CUPMBLAS_FP_TYPE_U) && !PetscDefined(USE_REAL___FLOAT128)
84: #error "Unsupported floating-point type for CUDA/HIP BLAS"
85: #endif
87: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED() - Helper macro to build a "modified"
88: // blas function whose return type does not match the input type
89: //
90: // input param:
91: // func - base suffix of the blas function, e.g. nrm2
92: //
93: // notes:
94: // requires PETSC_CUPMBLAS_FP_INPUT_TYPE to be defined as the blas floating point input type
95: // letter ("S" for real/complex single, "D" for real/complex double).
96: //
97: // requires PETSC_CUPMBLAS_FP_RETURN_TYPE to be defined as the blas floating point output type
98: // letter ("c" for complex single, "z" for complex double and <absolutely nothing> for real
99: // single/double).
100: //
101: // In their infinite wisdom nvidia/amd have made the upper-case vs lower-case scheme
102: // infuriatingly inconsistent...
103: //
104: // example usage:
105: // #define PETSC_CUPMBLAS_FP_INPUT_TYPE S
106: // #define PETSC_CUPMBLAS_FP_RETURN_TYPE
107: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Snrm2
108: //
109: // #define PETSC_CUPMBLAS_FP_INPUT_TYPE D
110: // #define PETSC_CUPMBLAS_FP_RETURN_TYPE z
111: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Dznrm2
112: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(func) PetscConcat(PetscConcat(PETSC_CUPMBLAS_FP_INPUT_TYPE, PETSC_CUPMBLAS_FP_RETURN_TYPE), func)
114: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE() - Helper macro to build Iamax and Iamin
115: // because they are both extra special
116: //
117: // input param:
118: // func - base suffix of the blas function, either amax or amin
119: //
120: // notes:
121: // The macro name literally stands for "I" ## "floating point type" because shockingly enough,
122: // that's what it does.
123: //
124: // requires PETSC_CUPMBLAS_FP_TYPE_L to be defined as the lower-case blas floating point input type
125: // letter ("s" for complex single, "z" for complex double, "s" for real single, and "d" for
126: // real double).
127: //
128: // example usage:
129: // #define PETSC_CUPMBLAS_FP_TYPE_L s
130: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amax) -> Isamax
131: //
132: // #define PETSC_CUPMBLAS_FP_TYPE_L z
133: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amin) -> Izamin
134: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(func) PetscConcat(I, PetscConcat(PETSC_CUPMBLAS_FP_TYPE_L, func))
136: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD() - Helper macro to build a "standard"
137: // blas function name
138: //
139: // input param:
140: // func - base suffix of the blas function, e.g. axpy, scal
141: //
142: // notes:
143: // requires PETSC_CUPMBLAS_FP_TYPE to be defined as the blas floating-point letter ("C" for
144: // complex single, "Z" for complex double, "S" for real single, "D" for real double).
145: //
146: // example usage:
147: // #define PETSC_CUPMBLAS_FP_TYPE S
148: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Saxpy
149: //
150: // #define PETSC_CUPMBLAS_FP_TYPE Z
151: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Zaxpy
152: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(func) PetscConcat(PETSC_CUPMBLAS_FP_TYPE, func)
154: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT() - In case CUDA/HIP don't agree with our suffix
155: // one can provide both here
156: //
157: // input params:
158: // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
159: // IFPTYPE
160: // our_suffix - the suffix of the alias function
161: // their_suffix - the suffix of the function being aliased
162: //
163: // notes:
164: // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas function
165: // prefix. requires any other specific definitions required by the specific builder macro to
166: // also be defined. See PETSC_CUPM_ALIAS_FUNCTION_EXACT() for the exact expansion of the
167: // function alias.
168: //
169: // example usage:
170: // #define PETSC_CUPMBLAS_PREFIX cublas
171: // #define PETSC_CUPMBLAS_FP_TYPE C
172: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD,dot,dotc) ->
173: // template <typename... T>
174: // static constexpr auto cupmBlasXdot(T&&... args) *noexcept and returntype detection*
175: // {
176: // return cublasCdotc(std::forward<T>(args)...);
177: // }
178: #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, our_suffix, their_suffix) \
179: PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlasX, our_suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, PetscConcat(PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_, MACRO_SUFFIX)(their_suffix)))
181: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION() - Alias a CUDA/HIP blas function
182: //
183: // input params:
184: // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
185: // IFPTYPE
186: // suffix - the common suffix between CUDA and HIP of the alias function
187: //
188: // notes:
189: // see PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(), this macro just calls that one with "suffix" as
190: // "our_prefix" and "their_prefix"
191: #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MACRO_SUFFIX, suffix) PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, suffix, suffix)
193: // PETSC_CUPMBLAS_ALIAS_FUNCTION() - Alias a CUDA/HIP library function
194: //
195: // input params:
196: // suffix - the common suffix between CUDA and HIP of the alias function
197: //
198: // notes:
199: // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas library
200: // prefix. see PETSC_CUPMM_ALIAS_FUNCTION_EXACT() for the precise expansion of this macro.
201: //
202: // example usage:
203: // #define PETSC_CUPMBLAS_PREFIX hipblas
204: // PETSC_CUPMBLAS_ALIAS_FUNCTION(Create) ->
205: // template <typename... T>
206: // static constexpr auto cupmBlasCreate(T&&... args) *noexcept and returntype detection*
207: // {
208: // return hipblasCreate(std::forward<T>(args)...);
209: // }
210: #define PETSC_CUPMBLAS_ALIAS_FUNCTION(suffix) PETSC_CUPM_ALIAS_FUNCTION(PetscConcat(cupmBlas, suffix), PetscConcat(PETSC_CUPMBLAS_PREFIX, suffix))
212: template <DeviceType T>
213: struct BlasInterfaceBase : Interface<T> {
214: PETSC_NODISCARD static constexpr const char *cupmBlasName() noexcept { return T == DeviceType::CUDA ? "cuBLAS" : "hipBLAS"; }
215: };
217: #define PETSC_CUPMBLAS_BASE_CLASS_HEADER(DEV_TYPE) \
218: using base_type = ::Petsc::device::cupm::impl::BlasInterfaceBase<DEV_TYPE>; \
219: using base_type::cupmBlasName; \
220: PETSC_CUPM_ALIAS_FUNCTION(cupmBlasGetErrorName, PetscConcat(PetscConcat(Petsc, PETSC_CUPMBLAS_PREFIX_U), GetErrorName)) \
221: PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(interface_type, DEV_TYPE)
223: template <DeviceType>
224: struct BlasInterfaceImpl;
226: #if PetscDefined(HAVE_CUDA)
227: #define PETSC_CUPMBLAS_PREFIX cublas
228: #define PETSC_CUPMBLAS_PREFIX_U CUBLAS
229: #define PETSC_CUPMBLAS_FP_TYPE PETSC_CUPMBLAS_FP_TYPE_U
230: #define PETSC_CUPMBLAS_FP_INPUT_TYPE PETSC_CUPMBLAS_FP_INPUT_TYPE_U
231: #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
232: template <>
233: struct BlasInterfaceImpl<DeviceType::CUDA> : BlasInterfaceBase<DeviceType::CUDA> {
234: PETSC_CUPMBLAS_BASE_CLASS_HEADER(DeviceType::CUDA);
236: // typedefs
237: using cupmBlasHandle_t = cublasHandle_t;
238: using cupmBlasError_t = cublasStatus_t;
239: using cupmBlasInt_t = int;
240: using cupmSolverHandle_t = cusolverDnHandle_t;
241: using cupmSolverError_t = cusolverStatus_t;
242: using cupmBlasPointerMode_t = cublasPointerMode_t;
244: // values
245: static const auto CUPMBLAS_STATUS_SUCCESS = CUBLAS_STATUS_SUCCESS;
246: static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = CUBLAS_STATUS_NOT_INITIALIZED;
247: static const auto CUPMBLAS_STATUS_ALLOC_FAILED = CUBLAS_STATUS_ALLOC_FAILED;
248: static const auto CUPMBLAS_POINTER_MODE_HOST = CUBLAS_POINTER_MODE_HOST;
249: static const auto CUPMBLAS_POINTER_MODE_DEVICE = CUBLAS_POINTER_MODE_DEVICE;
251: // utility functions
252: PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
253: PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
254: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
255: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
256: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
257: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)
259: // level 1 BLAS
260: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
261: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
262: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
263: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
264: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
265: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
266: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
267: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)
269: // level 2 BLAS
270: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)
272: // level 3 BLAS
273: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)
275: // BLAS extensions
276: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)
278: static PetscErrorCode InitializeHandle(cupmSolverHandle_t &handle) noexcept
279: {
280: PetscFunctionBegin;
281: if (handle) PetscFunctionReturn(PETSC_SUCCESS);
282: for (auto i = 0; i < 3; ++i) {
283: const auto cerr = cusolverDnCreate(&handle);
284: if (PetscLikely(cerr == CUSOLVER_STATUS_SUCCESS)) break;
285: if ((cerr != CUSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUSOLVER(cerr);
286: if (i < 2) {
287: PetscCall(PetscSleep(3));
288: continue;
289: }
290: PetscCheck(cerr == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize cuSolverDn");
291: }
292: PetscFunctionReturn(PETSC_SUCCESS);
293: }
295: static PetscErrorCode SetHandleStream(const cupmSolverHandle_t &handle, const cupmStream_t &stream) noexcept
296: {
297: cupmStream_t cupmStream;
299: PetscFunctionBegin;
300: PetscCallCUSOLVER(cusolverDnGetStream(handle, &cupmStream));
301: if (cupmStream != stream) PetscCallCUSOLVER(cusolverDnSetStream(handle, stream));
302: PetscFunctionReturn(PETSC_SUCCESS);
303: }
305: static PetscErrorCode DestroyHandle(cupmSolverHandle_t &handle) noexcept
306: {
307: PetscFunctionBegin;
308: if (handle) {
309: PetscCallCUSOLVER(cusolverDnDestroy(handle));
310: handle = nullptr;
311: }
312: PetscFunctionReturn(PETSC_SUCCESS);
313: }
314: };
315: #undef PETSC_CUPMBLAS_PREFIX
316: #undef PETSC_CUPMBLAS_PREFIX_U
317: #undef PETSC_CUPMBLAS_FP_TYPE
318: #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
319: #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
320: #endif // PetscDefined(HAVE_CUDA)
322: #if PetscDefined(HAVE_HIP)
323: #define PETSC_CUPMBLAS_PREFIX hipblas
324: #define PETSC_CUPMBLAS_PREFIX_U HIPBLAS
325: #define PETSC_CUPMBLAS_FP_TYPE PETSC_CUPMBLAS_FP_TYPE_U
326: #define PETSC_CUPMBLAS_FP_INPUT_TYPE PETSC_CUPMBLAS_FP_INPUT_TYPE_U
327: #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
328: template <>
329: struct BlasInterfaceImpl<DeviceType::HIP> : BlasInterfaceBase<DeviceType::HIP> {
330: PETSC_CUPMBLAS_BASE_CLASS_HEADER(DeviceType::HIP);
332: // typedefs
333: using cupmBlasHandle_t = hipblasHandle_t;
334: using cupmBlasError_t = hipblasStatus_t;
335: using cupmBlasInt_t = int; // rocblas will have its own
336: using cupmSolverHandle_t = hipsolverHandle_t;
337: using cupmSolverError_t = hipsolverStatus_t;
338: using cupmBlasPointerMode_t = hipblasPointerMode_t;
340: // values
341: static const auto CUPMBLAS_STATUS_SUCCESS = HIPBLAS_STATUS_SUCCESS;
342: static const auto CUPMBLAS_STATUS_NOT_INITIALIZED = HIPBLAS_STATUS_NOT_INITIALIZED;
343: static const auto CUPMBLAS_STATUS_ALLOC_FAILED = HIPBLAS_STATUS_ALLOC_FAILED;
344: static const auto CUPMBLAS_POINTER_MODE_HOST = HIPBLAS_POINTER_MODE_HOST;
345: static const auto CUPMBLAS_POINTER_MODE_DEVICE = HIPBLAS_POINTER_MODE_DEVICE;
347: // utility functions
348: PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
349: PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
350: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
351: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
352: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
353: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)
355: // level 1 BLAS
356: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
357: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
358: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
359: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
360: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
361: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
362: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
363: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)
365: // level 2 BLAS
366: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)
368: // level 3 BLAS
369: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)
371: // BLAS extensions
372: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)
374: static PetscErrorCode InitializeHandle(cupmSolverHandle_t &handle) noexcept
375: {
376: PetscFunctionBegin;
377: if (!handle) PetscCallHIPSOLVER(hipsolverCreate(&handle));
378: PetscFunctionReturn(PETSC_SUCCESS);
379: }
381: static PetscErrorCode SetHandleStream(cupmSolverHandle_t handle, cupmStream_t stream) noexcept
382: {
383: PetscFunctionBegin;
384: PetscCallHIPSOLVER(hipsolverSetStream(handle, stream));
385: PetscFunctionReturn(PETSC_SUCCESS);
386: }
388: static PetscErrorCode DestroyHandle(cupmSolverHandle_t &handle) noexcept
389: {
390: PetscFunctionBegin;
391: if (handle) {
392: PetscCallHIPSOLVER(hipsolverDestroy(handle));
393: handle = nullptr;
394: }
395: PetscFunctionReturn(PETSC_SUCCESS);
396: }
397: };
398: #undef PETSC_CUPMBLAS_PREFIX
399: #undef PETSC_CUPMBLAS_PREFIX_U
400: #undef PETSC_CUPMBLAS_FP_TYPE
401: #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
402: #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
403: #endif // PetscDefined(HAVE_HIP)
405: #undef PETSC_CUPMBLAS_BASE_CLASS_HEADER
407: #define PETSC_CUPMBLAS_IMPL_CLASS_HEADER(base_name, T) \
408: PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(cupmInterface_t, T); \
409: using base_name = ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>; \
410: /* introspection */ \
411: using base_name::cupmBlasName; \
412: using base_name::cupmBlasGetErrorName; \
413: /* types */ \
414: using cupmBlasHandle_t = typename base_name::cupmBlasHandle_t; \
415: using cupmBlasError_t = typename base_name::cupmBlasError_t; \
416: using cupmBlasInt_t = typename base_name::cupmBlasInt_t; \
417: using cupmSolverHandle_t = typename base_name::cupmSolverHandle_t; \
418: using cupmSolverError_t = typename base_name::cupmSolverError_t; \
419: using cupmBlasPointerMode_t = typename base_name::cupmBlasPointerMode_t; \
420: /* values */ \
421: using base_name::CUPMBLAS_STATUS_SUCCESS; \
422: using base_name::CUPMBLAS_STATUS_NOT_INITIALIZED; \
423: using base_name::CUPMBLAS_STATUS_ALLOC_FAILED; \
424: using base_name::CUPMBLAS_POINTER_MODE_HOST; \
425: using base_name::CUPMBLAS_POINTER_MODE_DEVICE; \
426: /* utility functions */ \
427: using base_name::cupmBlasCreate; \
428: using base_name::cupmBlasDestroy; \
429: using base_name::cupmBlasGetStream; \
430: using base_name::cupmBlasSetStream; \
431: using base_name::cupmBlasGetPointerMode; \
432: using base_name::cupmBlasSetPointerMode; \
433: /* level 1 BLAS */ \
434: using base_name::cupmBlasXaxpy; \
435: using base_name::cupmBlasXscal; \
436: using base_name::cupmBlasXdot; \
437: using base_name::cupmBlasXdotu; \
438: using base_name::cupmBlasXswap; \
439: using base_name::cupmBlasXnrm2; \
440: using base_name::cupmBlasXamax; \
441: using base_name::cupmBlasXasum; \
442: /* level 2 BLAS */ \
443: using base_name::cupmBlasXgemv; \
444: /* level 3 BLAS */ \
445: using base_name::cupmBlasXgemm; \
446: /* BLAS extensions */ \
447: using base_name::cupmBlasXgeam
449: // The actual interface class
450: template <DeviceType T>
451: struct BlasInterface : BlasInterfaceImpl<T> {
452: PETSC_CUPMBLAS_IMPL_CLASS_HEADER(blasinterface_type, T);
454: static PetscErrorCode PetscCUPMBlasSetPointerModeFromPointer(cupmBlasHandle_t handle, const void *ptr) noexcept
455: {
456: auto mtype = PETSC_MEMTYPE_HOST;
458: PetscFunctionBegin;
459: PetscCall(PetscCUPMGetMemType(ptr, &mtype));
460: PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, PetscMemTypeDevice(mtype) ? CUPMBLAS_POINTER_MODE_DEVICE : CUPMBLAS_POINTER_MODE_HOST));
461: PetscFunctionReturn(PETSC_SUCCESS);
462: }
464: static PetscErrorCode checkCupmBlasIntCast(PetscInt x) noexcept
465: {
466: PetscFunctionBegin;
467: PetscCheck((std::is_same<PetscInt, cupmBlasInt_t>::value) || (x <= std::numeric_limits<cupmBlasInt_t>::max()), PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "%" PetscInt_FMT " is too big for %s, which may be restricted to 32 bit integers", x, cupmBlasName());
468: PetscCheck(x >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "Passing negative integer to %s routine: %" PetscInt_FMT, cupmBlasName(), x);
469: PetscFunctionReturn(PETSC_SUCCESS);
470: }
472: PETSC_NODISCARD static cupmBlasInt_t cupmBlasIntCast(PetscInt x) noexcept
473: {
474: PetscFunctionBegin;
475: PetscCallAbort(PETSC_COMM_SELF, checkCupmBlasIntCast(x));
476: PetscFunctionReturn(static_cast<cupmBlasInt_t>(x));
477: }
478: };
480: #define PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(base_name, T) \
481: PETSC_CUPMBLAS_IMPL_CLASS_HEADER(PetscConcat(base_name, _impl), T); \
482: using base_name = ::Petsc::device::cupm::impl::BlasInterface<T>; \
483: using base_name::PetscCUPMBlasSetPointerModeFromPointer; \
484: using base_name::checkCupmBlasIntCast; \
485: using base_name::cupmBlasIntCast
487: #if PetscDefined(HAVE_CUDA)
488: extern template struct BlasInterface<DeviceType::CUDA>;
489: #endif
491: #if PetscDefined(HAVE_HIP)
492: extern template struct BlasInterface<DeviceType::HIP>;
493: #endif
495: } // namespace impl
497: } // namespace cupm
499: } // namespace device
501: } // namespace Petsc
503: #endif // defined(__cplusplus)
505: #endif // PETSCCUPMBLASINTERFACE_HPP