Actual source code: cupmcontext.hpp

  1: #ifndef PETSCDEVICECONTEXTCUPM_HPP
  2: #define PETSCDEVICECONTEXTCUPM_HPP

  4: #include <petsc/private/deviceimpl.h>
  5: #include <petsc/private/cupmblasinterface.hpp>
  6: #include <petsc/private/logimpl.h>

  8: #include <petsc/private/cpp/array.hpp>

 10: #include "../segmentedmempool.hpp"
 11: #include "cupmallocator.hpp"
 12: #include "cupmstream.hpp"
 13: #include "cupmevent.hpp"

 15: #if defined(__cplusplus)

 17: namespace Petsc
 18: {

 20: namespace device
 21: {

 23: namespace cupm
 24: {

 26: namespace impl
 27: {

 29: template <DeviceType T>
 30: class DeviceContext : BlasInterface<T> {
 31: public:
 32:   PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(cupmBlasInterface_t, T);

 34: private:
 35:   template <typename H, std::size_t>
 36:   struct HandleTag {
 37:     using type = H;
 38:   };

 40:   using stream_tag = HandleTag<cupmStream_t, 0>;
 41:   using blas_tag   = HandleTag<cupmBlasHandle_t, 1>;
 42:   using solver_tag = HandleTag<cupmSolverHandle_t, 2>;

 44:   using stream_type = CUPMStream<T>;
 45:   using event_type  = CUPMEvent<T>;

 47: public:
 48:   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
 49:   // header, but since we are using the power of templates it must be declared part of
 50:   // this class to have easy access the same typedefs. Technically one can make a
 51:   // templated struct outside the class but it's more code for the same result.
 52:   struct PetscDeviceContext_IMPLS : memory::PoolAllocated<PetscDeviceContext_IMPLS> {
 53:     stream_type stream{};
 54:     cupmEvent_t event{};
 55:     cupmEvent_t begin{}; // timer-only
 56:     cupmEvent_t end{};   // timer-only
 57:   #if PetscDefined(USE_DEBUG)
 58:     PetscBool timerInUse{};
 59:   #endif
 60:     cupmBlasHandle_t   blas{};
 61:     cupmSolverHandle_t solver{};

 63:     constexpr PetscDeviceContext_IMPLS() noexcept = default;

 65:     PETSC_NODISCARD cupmStream_t get(stream_tag) const noexcept { return this->stream.get_stream(); }

 67:     PETSC_NODISCARD cupmBlasHandle_t get(blas_tag) const noexcept { return this->blas; }

 69:     PETSC_NODISCARD cupmSolverHandle_t get(solver_tag) const noexcept { return this->solver; }
 70:   };

 72: private:
 73:   static bool initialized_;

 75:   static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES>   blashandles_;
 76:   static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;

 78:   PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }

 80:   PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }

 82:   PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }

 84:   PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }

 86:   // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
 87:   // handles
 88:   static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; }

 90:   static PetscErrorCode create_handle_(blas_tag, cupmBlasHandle_t &handle) noexcept
 91:   {
 92:     PetscLogEvent event;

 94:     PetscFunctionBegin;
 95:     if (PetscLikely(handle)) PetscFunctionReturn(PETSC_SUCCESS);
 96:     PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
 97:     PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
 98:     for (auto i = 0; i < 3; ++i) {
 99:       auto cberr = cupmBlasCreate(&handle);
100:       if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
101:       if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
102:       if (i != 2) {
103:         PetscCall(PetscSleep(3));
104:         continue;
105:       }
106:       PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
107:     }
108:     PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
109:     PetscCall(PetscLogEventResume_Internal(event));
110:     PetscFunctionReturn(PETSC_SUCCESS);
111:   }

113:   static PetscErrorCode initialize_handle_(blas_tag tag, PetscDeviceContext dctx) noexcept
114:   {
115:     const auto dci    = impls_cast_(dctx);
116:     auto      &handle = blashandles_[dctx->device->deviceId];

118:     PetscFunctionBegin;
119:     PetscCall(create_handle_(tag, handle));
120:     PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
121:     dci->blas = handle;
122:     PetscFunctionReturn(PETSC_SUCCESS);
123:   }

125:   static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept
126:   {
127:     const auto    dci    = impls_cast_(dctx);
128:     auto         &handle = solverhandles_[dctx->device->deviceId];
129:     PetscLogEvent event;

131:     PetscFunctionBegin;
132:     PetscCall(PetscLogPauseCurrentEvent_Internal(&event));
133:     PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
134:     PetscCall(cupmBlasInterface_t::InitializeHandle(handle));
135:     PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
136:     PetscCall(PetscLogEventResume_Internal(event));
137:     PetscCall(cupmBlasInterface_t::SetHandleStream(handle, dci->stream.get_stream()));
138:     dci->solver = handle;
139:     PetscFunctionReturn(PETSC_SUCCESS);
140:   }

142:   static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
143:   {
144:     const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;

146:     PetscFunctionBegin;
147:     PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
148:                PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
149:     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
150:     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
151:     PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
152:     PetscFunctionReturn(PETSC_SUCCESS);
153:   }

155:   static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }

157:   static PetscErrorCode finalize_() noexcept
158:   {
159:     PetscFunctionBegin;
160:     for (auto &&handle : blashandles_) {
161:       if (handle) {
162:         PetscCallCUPMBLAS(cupmBlasDestroy(handle));
163:         handle = nullptr;
164:       }
165:     }

167:     for (auto &&handle : solverhandles_) {
168:       if (handle) {
169:         PetscCall(cupmBlasInterface_t::DestroyHandle(handle));
170:         handle = nullptr;
171:       }
172:     }
173:     initialized_ = false;
174:     PetscFunctionReturn(PETSC_SUCCESS);
175:   }

177:   template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
178:   PETSC_NODISCARD static PoolType &default_pool_() noexcept
179:   {
180:     static PoolType pool;
181:     return pool;
182:   }

184:   static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
185:   {
186:     PetscFunctionBegin;
187:     PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
188:     PetscFunctionReturn(PETSC_SUCCESS);
189:   }

191: public:
192:   // All of these functions MUST be static in order to be callable from C, otherwise they
193:   // get the implicit 'this' pointer tacked on
194:   static PetscErrorCode destroy(PetscDeviceContext) noexcept;
195:   static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept;
196:   static PetscErrorCode setUp(PetscDeviceContext) noexcept;
197:   static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept;
198:   static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept;
199:   static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
200:   template <typename Handle_t>
201:   static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept;
202:   static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
203:   static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept;
204:   static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept;
205:   static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept;
206:   static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept;
207:   static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept;
208:   static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept;
209:   static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept;
210:   static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept;

212:   // not a PetscDeviceContext method, this registers the class
213:   static PetscErrorCode initialize(PetscDevice) noexcept;

215:   // clang-format off
216:   const _DeviceContextOps ops = {
217:     destroy,
218:     changeStreamType,
219:     setUp,
220:     query,
221:     waitForContext,
222:     synchronize,
223:     getHandle<blas_tag>,
224:     getHandle<solver_tag>,
225:     getHandle<stream_tag>,
226:     beginTimer,
227:     endTimer,
228:     memAlloc,
229:     memFree,
230:     memCopy,
231:     memSet,
232:     createEvent,
233:     recordEvent,
234:     waitForEvent
235:   };
236:   // clang-format on
237: };

239: // not a PetscDeviceContext method, this initializes the CLASS
240: template <DeviceType T>
241: inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept
242: {
243:   PetscFunctionBegin;
244:   if (PetscUnlikely(!initialized_)) {
245:     uint64_t      threshold = UINT64_MAX;
246:     cupmMemPool_t mempool;

248:     initialized_ = true;
249:     PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId)));
250:     PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
251:     blashandles_.fill(nullptr);
252:     solverhandles_.fill(nullptr);
253:     PetscCall(PetscRegisterFinalize(finalize_));
254:   }
255:   PetscFunctionReturn(PETSC_SUCCESS);
256: }

258: template <DeviceType T>
259: inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept
260: {
261:   PetscFunctionBegin;
262:   if (const auto dci = impls_cast_(dctx)) {
263:     PetscCall(dci->stream.destroy());
264:     if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event));
265:     if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
266:     if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
267:     delete dci;
268:     dctx->data = nullptr;
269:   }
270:   PetscFunctionReturn(PETSC_SUCCESS);
271: }

273: template <DeviceType T>
274: inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
275: {
276:   const auto dci = impls_cast_(dctx);

278:   PetscFunctionBegin;
279:   PetscCall(dci->stream.destroy());
280:   // set these to null so they aren't usable until setup is called again
281:   dci->blas   = nullptr;
282:   dci->solver = nullptr;
283:   PetscFunctionReturn(PETSC_SUCCESS);
284: }

286: template <DeviceType T>
287: inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept
288: {
289:   const auto dci   = impls_cast_(dctx);
290:   auto      &event = dci->event;

292:   PetscFunctionBegin;
293:   PetscCall(check_current_device_(dctx));
294:   PetscCall(dci->stream.change_type(dctx->streamType));
295:   if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
296:   #if PetscDefined(USE_DEBUG)
297:   dci->timerInUse = PETSC_FALSE;
298:   #endif
299:   PetscFunctionReturn(PETSC_SUCCESS);
300: }

302: template <DeviceType T>
303: inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
304: {
305:   PetscFunctionBegin;
306:   PetscCall(check_current_device_(dctx));
307:   switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
308:   case cupmSuccess:
309:     *idle = PETSC_TRUE;
310:     break;
311:   case cupmErrorNotReady:
312:     *idle = PETSC_FALSE;
313:     // reset the error
314:     cerr = cupmGetLastError();
315:     static_cast<void>(cerr);
316:     break;
317:   default:
318:     PetscCallCUPM(cerr);
319:     PetscUnreachable();
320:   }
321:   PetscFunctionReturn(PETSC_SUCCESS);
322: }

324: template <DeviceType T>
325: inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
326: {
327:   const auto dcib  = impls_cast_(dctxb);
328:   const auto event = dcib->event;

330:   PetscFunctionBegin;
331:   PetscCall(check_current_device_(dctxa, dctxb));
332:   PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
333:   PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
334:   PetscFunctionReturn(PETSC_SUCCESS);
335: }

337: template <DeviceType T>
338: inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept
339: {
340:   auto idle = PETSC_TRUE;

342:   PetscFunctionBegin;
343:   PetscCall(query(dctx, &idle));
344:   if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
345:   PetscFunctionReturn(PETSC_SUCCESS);
346: }

348: template <DeviceType T>
349: template <typename handle_t>
350: inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
351: {
352:   PetscFunctionBegin;
353:   PetscCall(initialize_handle_(handle_t{}, dctx));
354:   *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
355:   PetscFunctionReturn(PETSC_SUCCESS);
356: }

358: template <DeviceType T>
359: inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
360: {
361:   const auto dci = impls_cast_(dctx);

363:   PetscFunctionBegin;
364:   PetscCall(check_current_device_(dctx));
365:   #if PetscDefined(USE_DEBUG)
366:   PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
367:   dci->timerInUse = PETSC_TRUE;
368:   #endif
369:   if (!dci->begin) {
370:     PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
371:     PetscCallCUPM(cupmEventCreate(&dci->begin));
372:     PetscCallCUPM(cupmEventCreate(&dci->end));
373:   }
374:   PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
375:   PetscFunctionReturn(PETSC_SUCCESS);
376: }

378: template <DeviceType T>
379: inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
380: {
381:   float      gtime;
382:   const auto dci = impls_cast_(dctx);
383:   const auto end = dci->end;

385:   PetscFunctionBegin;
386:   PetscCall(check_current_device_(dctx));
387:   #if PetscDefined(USE_DEBUG)
388:   PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
389:   dci->timerInUse = PETSC_FALSE;
390:   #endif
391:   PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
392:   PetscCallCUPM(cupmEventSynchronize(end));
393:   PetscCallCUPM(cupmEventElapsedTime(&gtime, dci->begin, end));
394:   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
395:   PetscFunctionReturn(PETSC_SUCCESS);
396: }

398: template <DeviceType T>
399: inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept
400: {
401:   const auto &stream = impls_cast_(dctx)->stream;

403:   PetscFunctionBegin;
404:   PetscCall(check_current_device_(dctx));
405:   PetscCall(check_memtype_(mtype, "allocating"));
406:   if (PetscMemTypeHost(mtype)) {
407:     PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
408:   } else {
409:     PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
410:   }
411:   if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
412:   PetscFunctionReturn(PETSC_SUCCESS);
413: }

415: template <DeviceType T>
416: inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept
417: {
418:   const auto &stream = impls_cast_(dctx)->stream;

420:   PetscFunctionBegin;
421:   PetscCall(check_current_device_(dctx));
422:   PetscCall(check_memtype_(mtype, "freeing"));
423:   if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS);
424:   if (PetscMemTypeHost(mtype)) {
425:     PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
426:     // if ptr exists still exists the pool didn't own it
427:     if (*ptr) {
428:       auto registered = PETSC_FALSE, managed = PETSC_FALSE;

430:       PetscCall(PetscCUPMGetMemType(*ptr, nullptr, &registered, &managed));
431:       if (registered) {
432:         PetscCallCUPM(cupmFreeHost(*ptr));
433:       } else if (managed) {
434:         PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
435:       }
436:     }
437:   } else {
438:     PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
439:     // if ptr still exists the pool didn't own it
440:     if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
441:   }
442:   PetscFunctionReturn(PETSC_SUCCESS);
443: }

445: template <DeviceType T>
446: inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept
447: {
448:   const auto stream = impls_cast_(dctx)->stream.get_stream();

450:   PetscFunctionBegin;
451:   // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
452:   if (mode == PETSC_DEVICE_COPY_HTOH) {
453:     const auto cerr = cupmStreamQuery(stream);

455:     // yes this is faster
456:     if (cerr == cupmSuccess) {
457:       PetscCall(PetscMemcpy(dest, src, n));
458:       PetscFunctionReturn(PETSC_SUCCESS);
459:     } else if (cerr == cupmErrorNotReady) {
460:       auto PETSC_UNUSED unused = cupmGetLastError();

462:       static_cast<void>(unused);
463:     } else {
464:       PetscCallCUPM(cerr);
465:     }
466:   }
467:   PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
468:   PetscFunctionReturn(PETSC_SUCCESS);
469: }

471: template <DeviceType T>
472: inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept
473: {
474:   PetscFunctionBegin;
475:   PetscCall(check_current_device_(dctx));
476:   PetscCall(check_memtype_(mtype, "zeroing"));
477:   PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
478:   PetscFunctionReturn(PETSC_SUCCESS);
479: }

481: template <DeviceType T>
482: inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept
483: {
484:   PetscFunctionBegin;
485:   PetscCallCXX(event->data = new event_type());
486:   event->destroy = [](PetscEvent event) {
487:     PetscFunctionBegin;
488:     delete event_cast_(event);
489:     event->data = nullptr;
490:     PetscFunctionReturn(PETSC_SUCCESS);
491:   };
492:   PetscFunctionReturn(PETSC_SUCCESS);
493: }

495: template <DeviceType T>
496: inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
497: {
498:   PetscFunctionBegin;
499:   PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
500:   PetscFunctionReturn(PETSC_SUCCESS);
501: }

503: template <DeviceType T>
504: inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
505: {
506:   PetscFunctionBegin;
507:   PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
508:   PetscFunctionReturn(PETSC_SUCCESS);
509: }

511: // initialize the static member variables
512: template <DeviceType T>
513: bool DeviceContext<T>::initialized_ = false;

515: template <DeviceType T>
516: std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};

518: template <DeviceType T>
519: std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};

521: } // namespace impl

523: // shorten this one up a bit (and instantiate the templates)
524: using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
525: using CUPMContextHip  = impl::DeviceContext<DeviceType::HIP>;

527:   // shorthand for what is an EXTREMELY long name
528:   #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS

530: } // namespace cupm

532: } // namespace device

534: } // namespace Petsc

536: #endif // __cplusplus

538: #endif // PETSCDEVICECONTEXTCUDA_HPP