Actual source code: vecseqcupm.hpp

  1: #ifndef PETSCVECSEQCUPM_HPP
  2: #define PETSCVECSEQCUPM_HPP

  4: #include <petsc/private/veccupmimpl.h>
  5: #include <petsc/private/randomimpl.h>
  6: #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
  7: #include "../src/sys/objects/device/impls/cupm/cupmallocator.hpp"
  8: #include "../src/sys/objects/device/impls/cupm/kernels.hpp"

 10: #if defined(__cplusplus)
 11:   #include <thrust/transform_reduce.h>
 12:   #include <thrust/reduce.h>
 13:   #include <thrust/functional.h>
 14:   #include <thrust/iterator/counting_iterator.h>
 15:   #include <thrust/inner_product.h>

 17: namespace Petsc
 18: {

 20: namespace vec
 21: {

 23: namespace cupm
 24: {

 26: namespace impl
 27: {

 29: // ==========================================================================================
 30: // VecSeq_CUPM
 31: // ==========================================================================================

 33: template <device::cupm::DeviceType T>
 34: class VecSeq_CUPM : Vec_CUPMBase<T, VecSeq_CUPM<T>> {
 35: public:
 36:   PETSC_VEC_CUPM_BASE_CLASS_HEADER(base_type, T, VecSeq_CUPM<T>);

 38: private:
 39:   PETSC_NODISCARD static Vec_Seq          *VecIMPLCast_(Vec) noexcept;
 40:   PETSC_NODISCARD static constexpr VecType VECIMPLCUPM_() noexcept;

 42:   static PetscErrorCode VecDestroy_IMPL_(Vec) noexcept;
 43:   static PetscErrorCode VecResetArray_IMPL_(Vec) noexcept;
 44:   static PetscErrorCode VecPlaceArray_IMPL_(Vec, const PetscScalar *) noexcept;
 45:   static PetscErrorCode VecCreate_IMPL_Private_(Vec, PetscBool *, PetscInt, PetscScalar *) noexcept;

 47:   static PetscErrorCode MaybeIncrementEmptyLocalVec(Vec) noexcept;

 49:   // common core for min and max
 50:   template <typename TupleFuncT, typename UnaryFuncT>
 51:   static PetscErrorCode minmax_(TupleFuncT &&, UnaryFuncT &&, Vec, PetscInt *, PetscReal *) noexcept;
 52:   // common core for pointwise binary and pointwise unary thrust functions
 53:   template <typename BinaryFuncT>
 54:   static PetscErrorCode pointwisebinary_(BinaryFuncT &&, Vec, Vec, Vec) noexcept;
 55:   template <typename UnaryFuncT>
 56:   static PetscErrorCode pointwiseunary_(UnaryFuncT &&, Vec, Vec /*out*/ = nullptr) noexcept;
 57:   // mdot dispatchers
 58:   static PetscErrorCode mdot_(/* use complex = */ std::true_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
 59:   static PetscErrorCode mdot_(/* use complex = */ std::false_type, Vec, PetscInt, const Vec[], PetscScalar *, PetscDeviceContext) noexcept;
 60:   template <std::size_t... Idx>
 61:   static PetscErrorCode mdot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, util::index_sequence<Idx...>) noexcept;
 62:   template <int>
 63:   static PetscErrorCode mdot_kernel_dispatch_(PetscDeviceContext, cupmStream_t, const PetscScalar *, const Vec[], PetscInt, PetscScalar *, PetscInt &) noexcept;
 64:   template <std::size_t... Idx>
 65:   static PetscErrorCode maxpy_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, util::index_sequence<Idx...>) noexcept;
 66:   template <int>
 67:   static PetscErrorCode maxpy_kernel_dispatch_(PetscDeviceContext, cupmStream_t, PetscScalar *, const PetscScalar *, const Vec *, PetscInt, PetscInt &) noexcept;
 68:   // common core for the various create routines
 69:   static PetscErrorCode createseqcupm_(Vec, PetscDeviceContext, PetscScalar * /*host_ptr*/ = nullptr, PetscScalar * /*device_ptr*/ = nullptr) noexcept;

 71: public:
 72:   // callable directly via a bespoke function
 73:   static PetscErrorCode createseqcupm(MPI_Comm, PetscInt, PetscInt, Vec *, PetscBool) noexcept;
 74:   static PetscErrorCode createseqcupmwithbotharrays(MPI_Comm, PetscInt, PetscInt, const PetscScalar[], const PetscScalar[], Vec *) noexcept;

 76:   // callable indirectly via function pointers
 77:   static PetscErrorCode duplicate(Vec, Vec *) noexcept;
 78:   static PetscErrorCode aypx(Vec, PetscScalar, Vec) noexcept;
 79:   static PetscErrorCode axpy(Vec, PetscScalar, Vec) noexcept;
 80:   static PetscErrorCode pointwisedivide(Vec, Vec, Vec) noexcept;
 81:   static PetscErrorCode pointwisemult(Vec, Vec, Vec) noexcept;
 82:   static PetscErrorCode reciprocal(Vec) noexcept;
 83:   static PetscErrorCode waxpy(Vec, PetscScalar, Vec, Vec) noexcept;
 84:   static PetscErrorCode maxpy(Vec, PetscInt, const PetscScalar[], Vec *) noexcept;
 85:   static PetscErrorCode dot(Vec, Vec, PetscScalar *) noexcept;
 86:   static PetscErrorCode mdot(Vec, PetscInt, const Vec[], PetscScalar *) noexcept;
 87:   static PetscErrorCode set(Vec, PetscScalar) noexcept;
 88:   static PetscErrorCode scale(Vec, PetscScalar) noexcept;
 89:   static PetscErrorCode tdot(Vec, Vec, PetscScalar *) noexcept;
 90:   static PetscErrorCode copy(Vec, Vec) noexcept;
 91:   static PetscErrorCode swap(Vec, Vec) noexcept;
 92:   static PetscErrorCode axpby(Vec, PetscScalar, PetscScalar, Vec) noexcept;
 93:   static PetscErrorCode axpbypcz(Vec, PetscScalar, PetscScalar, PetscScalar, Vec, Vec) noexcept;
 94:   static PetscErrorCode norm(Vec, NormType, PetscReal *) noexcept;
 95:   static PetscErrorCode dotnorm2(Vec, Vec, PetscScalar *, PetscScalar *) noexcept;
 96:   static PetscErrorCode destroy(Vec) noexcept;
 97:   static PetscErrorCode conjugate(Vec) noexcept;
 98:   template <PetscMemoryAccessMode>
 99:   static PetscErrorCode getlocalvector(Vec, Vec) noexcept;
100:   template <PetscMemoryAccessMode>
101:   static PetscErrorCode restorelocalvector(Vec, Vec) noexcept;
102:   static PetscErrorCode max(Vec, PetscInt *, PetscReal *) noexcept;
103:   static PetscErrorCode min(Vec, PetscInt *, PetscReal *) noexcept;
104:   static PetscErrorCode sum(Vec, PetscScalar *) noexcept;
105:   static PetscErrorCode shift(Vec, PetscScalar) noexcept;
106:   static PetscErrorCode setrandom(Vec, PetscRandom) noexcept;
107:   static PetscErrorCode bindtocpu(Vec, PetscBool) noexcept;
108:   static PetscErrorCode setpreallocationcoo(Vec, PetscCount, const PetscInt[]) noexcept;
109:   static PetscErrorCode setvaluescoo(Vec, const PetscScalar[], InsertMode) noexcept;
110: };

112: // ==========================================================================================
113: // VecSeq_CUPM - Private API
114: // ==========================================================================================

116: template <device::cupm::DeviceType T>
117: inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept
118: {
119:   return static_cast<Vec_Seq *>(v->data);
120: }

122: template <device::cupm::DeviceType T>
123: inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept
124: {
125:   return VECSEQCUPM();
126: }

128: template <device::cupm::DeviceType T>
129: inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
130: {
131:   return VecDestroy_Seq(v);
132: }

134: template <device::cupm::DeviceType T>
135: inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
136: {
137:   return VecResetArray_Seq(v);
138: }

140: template <device::cupm::DeviceType T>
141: inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
142: {
143:   return VecPlaceArray_Seq(v, a);
144: }

146: template <device::cupm::DeviceType T>
147: inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept
148: {
149:   PetscMPIInt size;

151:   PetscFunctionBegin;
152:   if (alloc_missing) *alloc_missing = PETSC_FALSE;
153:   PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size));
154:   PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size);
155:   PetscCall(VecCreate_Seq_Private(v, host_array));
156:   PetscFunctionReturn(PETSC_SUCCESS);
157: }

159: // for functions with an early return based one vec size we still need to artificially bump the
160: // object state. This is to prevent the following:
161: //
162: // 0. Suppose you have a Vec {
163: //   rank 0: [0],
164: //   rank 1: [<empty>]
165: // }
166: // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0
167: // 2. Vec enters e.g. VecSet(10)
168: // 3. rank 1 has local size 0 and bails immediately
169: // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite()
170: // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1
171: // 6. Vec enters VecNorm(), and calls VecNormAvailable()
172: // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0
173: // 8. rank 0 has object state = 1, not equal to stash, continues to impl function
174: // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early
175: template <device::cupm::DeviceType T>
176: inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept
177: {
178:   PetscFunctionBegin;
179:   if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
180:   PetscFunctionReturn(PETSC_SUCCESS);
181: }

183: template <device::cupm::DeviceType T>
184: inline PetscErrorCode VecSeq_CUPM<T>::createseqcupm_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept
185: {
186:   PetscFunctionBegin;
187:   PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array));
188:   PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx));
189:   PetscFunctionReturn(PETSC_SUCCESS);
190: }

192: template <device::cupm::DeviceType T>
193: template <typename BinaryFuncT>
194: inline PetscErrorCode VecSeq_CUPM<T>::pointwisebinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout) noexcept
195: {
196:   PetscFunctionBegin;
197:   if (const auto n = zout->map->n) {
198:     PetscDeviceContext dctx;

200:     PetscCall(GetHandles_(&dctx));
201:     PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<BinaryFuncT>(binary), n, DeviceArrayRead(dctx, xin).data(), DeviceArrayRead(dctx, yin).data(), DeviceArrayWrite(dctx, zout).data()));
202:     PetscCall(PetscDeviceContextSynchronize(dctx));
203:   } else {
204:     PetscCall(MaybeIncrementEmptyLocalVec(zout));
205:   }
206:   PetscFunctionReturn(PETSC_SUCCESS);
207: }

209: template <device::cupm::DeviceType T>
210: template <typename UnaryFuncT>
211: inline PetscErrorCode VecSeq_CUPM<T>::pointwiseunary_(UnaryFuncT &&unary, Vec xinout, Vec yin) noexcept
212: {
213:   const auto inplace = !yin || (xinout == yin);

215:   PetscFunctionBegin;
216:   if (const auto n = xinout->map->n) {
217:     PetscDeviceContext dctx;

219:     PetscCall(GetHandles_(&dctx));
220:     if (inplace) {
221:       PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<UnaryFuncT>(unary), n, DeviceArrayReadWrite(dctx, xinout).data()));
222:     } else {
223:       PetscCall(device::cupm::impl::ThrustApplyPointwise<T>(dctx, std::forward<UnaryFuncT>(unary), n, DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yin).data()));
224:     }
225:     PetscCall(PetscDeviceContextSynchronize(dctx));
226:   } else {
227:     if (inplace) {
228:       PetscCall(MaybeIncrementEmptyLocalVec(xinout));
229:     } else {
230:       PetscCall(MaybeIncrementEmptyLocalVec(yin));
231:     }
232:   }
233:   PetscFunctionReturn(PETSC_SUCCESS);
234: }

236: // ==========================================================================================
237: // VecSeq_CUPM - Public API - Constructors
238: // ==========================================================================================

240: // VecCreateSeqCUPM()
241: template <device::cupm::DeviceType T>
242: inline PetscErrorCode VecSeq_CUPM<T>::createseqcupm(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept
243: {
244:   PetscFunctionBegin;
245:   PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type));
246:   PetscFunctionReturn(PETSC_SUCCESS);
247: }

249: // VecCreateSeqCUPMWithArrays()
250: template <device::cupm::DeviceType T>
251: inline PetscErrorCode VecSeq_CUPM<T>::createseqcupmwithbotharrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
252: {
253:   PetscDeviceContext dctx;

255:   PetscFunctionBegin;
256:   PetscCall(GetHandles_(&dctx));
257:   // do NOT call VecSetType(), otherwise ops->create() -> create() ->
258:   // createseqcupm_() is called!
259:   PetscCall(createseqcupm(comm, bs, n, v, PETSC_FALSE));
260:   PetscCall(createseqcupm_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
261:   PetscFunctionReturn(PETSC_SUCCESS);
262: }

264: // v->ops->duplicate
265: template <device::cupm::DeviceType T>
266: inline PetscErrorCode VecSeq_CUPM<T>::duplicate(Vec v, Vec *y) noexcept
267: {
268:   PetscDeviceContext dctx;

270:   PetscFunctionBegin;
271:   PetscCall(GetHandles_(&dctx));
272:   PetscCall(Duplicate_CUPMBase(v, y, dctx));
273:   PetscFunctionReturn(PETSC_SUCCESS);
274: }

276: // ==========================================================================================
277: // VecSeq_CUPM - Public API - Utility
278: // ==========================================================================================

280: // v->ops->bindtocpu
281: template <device::cupm::DeviceType T>
282: inline PetscErrorCode VecSeq_CUPM<T>::bindtocpu(Vec v, PetscBool usehost) noexcept
283: {
284:   PetscDeviceContext dctx;

286:   PetscFunctionBegin;
287:   PetscCall(GetHandles_(&dctx));
288:   PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));

290:   // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess
291:   VecSetOp_CUPM(dot, VecDot_Seq, dot);
292:   VecSetOp_CUPM(norm, VecNorm_Seq, norm);
293:   VecSetOp_CUPM(tdot, VecTDot_Seq, tdot);
294:   VecSetOp_CUPM(mdot, VecMDot_Seq, mdot);
295:   VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template resetarray<PETSC_MEMTYPE_HOST>);
296:   VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template placearray<PETSC_MEMTYPE_HOST>);
297:   v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq;
298:   VecSetOp_CUPM(conjugate, VecConjugate_Seq, conjugate);
299:   VecSetOp_CUPM(max, VecMax_Seq, max);
300:   VecSetOp_CUPM(min, VecMin_Seq, min);
301:   VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, setpreallocationcoo);
302:   VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, setvaluescoo);
303:   PetscFunctionReturn(PETSC_SUCCESS);
304: }

306: // ==========================================================================================
307: // VecSeq_CUPM - Public API - Mutators
308: // ==========================================================================================

310: // v->ops->getlocalvector or v->ops->getlocalvectorread
311: template <device::cupm::DeviceType T>
312: template <PetscMemoryAccessMode access>
313: inline PetscErrorCode VecSeq_CUPM<T>::getlocalvector(Vec v, Vec w) noexcept
314: {
315:   PetscBool wisseqcupm;

317:   PetscFunctionBegin;
318:   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
319:   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
320:   if (wisseqcupm) {
321:     if (const auto wseq = VecIMPLCast(w)) {
322:       if (auto &alloced = wseq->array_allocated) {
323:         const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE));

325:         PetscCall(PetscFree(alloced));
326:       }
327:       wseq->array         = nullptr;
328:       wseq->unplacedarray = nullptr;
329:     }
330:     if (const auto wcu = VecCUPMCast(w)) {
331:       if (auto &device_array = wcu->array_d) {
332:         cupmStream_t stream;

334:         PetscCall(GetHandles_(&stream));
335:         PetscCallCUPM(cupmFreeAsync(device_array, stream));
336:       }
337:       PetscCall(PetscFree(w->spptr /* wcu */));
338:     }
339:   }
340:   if (v->petscnative && wisseqcupm) {
341:     PetscCall(PetscFree(w->data));
342:     w->data          = v->data;
343:     w->offloadmask   = v->offloadmask;
344:     w->pinned_memory = v->pinned_memory;
345:     w->spptr         = v->spptr;
346:     PetscCall(PetscObjectStateIncrease(PetscObjectCast(w)));
347:   } else {
348:     const auto array = &VecIMPLCast(w)->array;

350:     if (access == PETSC_MEMORY_ACCESS_READ) {
351:       PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array)));
352:     } else {
353:       PetscCall(VecGetArray(v, array));
354:     }
355:     w->offloadmask = PETSC_OFFLOAD_CPU;
356:     if (wisseqcupm) {
357:       PetscDeviceContext dctx;

359:       PetscCall(GetHandles_(&dctx));
360:       PetscCall(DeviceAllocateCheck_(dctx, w));
361:     }
362:   }
363:   PetscFunctionReturn(PETSC_SUCCESS);
364: }

366: // v->ops->restorelocalvector or v->ops->restorelocalvectorread
367: template <device::cupm::DeviceType T>
368: template <PetscMemoryAccessMode access>
369: inline PetscErrorCode VecSeq_CUPM<T>::restorelocalvector(Vec v, Vec w) noexcept
370: {
371:   PetscBool wisseqcupm;

373:   PetscFunctionBegin;
374:   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
375:   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
376:   if (v->petscnative && wisseqcupm) {
377:     // the assignments to nullptr are __critical__, as w may persist after this call returns
378:     // and shouldn't share data with v!
379:     v->pinned_memory = w->pinned_memory;
380:     v->offloadmask   = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED);
381:     v->data          = util::exchange(w->data, nullptr);
382:     v->spptr         = util::exchange(w->spptr, nullptr);
383:   } else {
384:     const auto array = &VecIMPLCast(w)->array;

386:     if (access == PETSC_MEMORY_ACCESS_READ) {
387:       PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array)));
388:     } else {
389:       PetscCall(VecRestoreArray(v, array));
390:     }
391:     if (w->spptr && wisseqcupm) {
392:       cupmStream_t stream;

394:       PetscCall(GetHandles_(&stream));
395:       PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream));
396:       PetscCall(PetscFree(w->spptr));
397:     }
398:   }
399:   PetscFunctionReturn(PETSC_SUCCESS);
400: }

402: // ==========================================================================================
403: // VecSeq_CUPM - Public API - Compute Methods
404: // ==========================================================================================

406: // v->ops->aypx
407: template <device::cupm::DeviceType T>
408: inline PetscErrorCode VecSeq_CUPM<T>::aypx(Vec yin, PetscScalar alpha, Vec xin) noexcept
409: {
410:   const auto         n    = static_cast<cupmBlasInt_t>(yin->map->n);
411:   const auto         sync = n != 0;
412:   PetscDeviceContext dctx;

414:   PetscFunctionBegin;
415:   PetscCall(GetHandles_(&dctx));
416:   if (alpha == PetscScalar(0.0)) {
417:     cupmStream_t stream;

419:     PetscCall(GetHandlesFrom_(dctx, &stream));
420:     PetscCall(PetscLogGpuTimeBegin());
421:     PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream));
422:     PetscCall(PetscLogGpuTimeEnd());
423:   } else if (n) {
424:     const auto       alphaIsOne = alpha == PetscScalar(1.0);
425:     const auto       calpha     = cupmScalarPtrCast(&alpha);
426:     cupmBlasHandle_t cupmBlasHandle;

428:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
429:     {
430:       const auto yptr = DeviceArrayReadWrite(dctx, yin);
431:       const auto xptr = DeviceArrayRead(dctx, xin);

433:       PetscCall(PetscLogGpuTimeBegin());
434:       if (alphaIsOne) {
435:         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
436:       } else {
437:         const auto one = cupmScalarCast(1.0);

439:         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1));
440:         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
441:       }
442:       PetscCall(PetscLogGpuTimeEnd());
443:     }
444:     PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n));
445:   }
446:   if (sync) PetscCall(PetscDeviceContextSynchronize(dctx));
447:   PetscFunctionReturn(PETSC_SUCCESS);
448: }

450: // v->ops->axpy
451: template <device::cupm::DeviceType T>
452: inline PetscErrorCode VecSeq_CUPM<T>::axpy(Vec yin, PetscScalar alpha, Vec xin) noexcept
453: {
454:   PetscBool xiscupm;

456:   PetscFunctionBegin;
457:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
458:   if (xiscupm) {
459:     const auto         n = static_cast<cupmBlasInt_t>(yin->map->n);
460:     PetscDeviceContext dctx;
461:     cupmBlasHandle_t   cupmBlasHandle;

463:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
464:     PetscCall(PetscLogGpuTimeBegin());
465:     PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
466:     PetscCall(PetscLogGpuTimeEnd());
467:     PetscCall(PetscLogGpuFlops(2 * n));
468:     PetscCall(PetscDeviceContextSynchronize(dctx));
469:   } else {
470:     PetscCall(VecAXPY_Seq(yin, alpha, xin));
471:   }
472:   PetscFunctionReturn(PETSC_SUCCESS);
473: }

475: // v->ops->pointwisedivide
476: template <device::cupm::DeviceType T>
477: inline PetscErrorCode VecSeq_CUPM<T>::pointwisedivide(Vec win, Vec xin, Vec yin) noexcept
478: {
479:   PetscFunctionBegin;
480:   if (xin->boundtocpu || yin->boundtocpu) {
481:     PetscCall(VecPointwiseDivide_Seq(win, xin, yin));
482:   } else {
483:     // note order of arguments! xin and yin are read, win is written!
484:     PetscCall(pointwisebinary_(thrust::divides<PetscScalar>{}, xin, yin, win));
485:   }
486:   PetscFunctionReturn(PETSC_SUCCESS);
487: }

489: // v->ops->pointwisemult
490: template <device::cupm::DeviceType T>
491: inline PetscErrorCode VecSeq_CUPM<T>::pointwisemult(Vec win, Vec xin, Vec yin) noexcept
492: {
493:   PetscFunctionBegin;
494:   if (xin->boundtocpu || yin->boundtocpu) {
495:     PetscCall(VecPointwiseMult_Seq(win, xin, yin));
496:   } else {
497:     // note order of arguments! xin and yin are read, win is written!
498:     PetscCall(pointwisebinary_(thrust::multiplies<PetscScalar>{}, xin, yin, win));
499:   }
500:   PetscFunctionReturn(PETSC_SUCCESS);
501: }

503: namespace detail
504: {

506: struct reciprocal {
507:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar s) const noexcept
508:   {
509:     // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex
510:     // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap
511:     // everything in PetscScalar...
512:     return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s;
513:   }
514: };

516: } // namespace detail

518: // v->ops->reciprocal
519: template <device::cupm::DeviceType T>
520: inline PetscErrorCode VecSeq_CUPM<T>::reciprocal(Vec xin) noexcept
521: {
522:   PetscFunctionBegin;
523:   PetscCall(pointwiseunary_(detail::reciprocal{}, xin));
524:   PetscFunctionReturn(PETSC_SUCCESS);
525: }

527: // v->ops->waxpy
528: template <device::cupm::DeviceType T>
529: inline PetscErrorCode VecSeq_CUPM<T>::waxpy(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept
530: {
531:   PetscFunctionBegin;
532:   if (alpha == PetscScalar(0.0)) {
533:     PetscCall(copy(yin, win));
534:   } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) {
535:     PetscDeviceContext dctx;
536:     cupmBlasHandle_t   cupmBlasHandle;
537:     cupmStream_t       stream;

539:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle, &stream));
540:     {
541:       const auto wptr = DeviceArrayWrite(dctx, win);

543:       PetscCall(PetscLogGpuTimeBegin());
544:       PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true));
545:       PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1));
546:       PetscCall(PetscLogGpuTimeEnd());
547:     }
548:     PetscCall(PetscLogGpuFlops(2 * n));
549:     PetscCall(PetscDeviceContextSynchronize(dctx));
550:   }
551:   PetscFunctionReturn(PETSC_SUCCESS);
552: }

554: namespace kernels
555: {

557: template <typename... Args>
558: PETSC_KERNEL_DECL static void maxpy_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr)
559: {
560:   constexpr int      N        = sizeof...(Args);
561:   const auto         tx       = threadIdx.x;
562:   const PetscScalar *yptr_p[] = {yptr...};

564:   PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N];

566:   // load a to shared memory
567:   if (tx < N) aptr_shmem[tx] = aptr[tx];
568:   __syncthreads();

570:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
571:     // these may look the same but give different results!
572:   #if 0
573:     PetscScalar sum = 0.0;

575:     #pragma unroll
576:     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
577:     xptr[i] += sum;
578:   #else
579:     auto sum = xptr[i];

581:     #pragma unroll
582:     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
583:     xptr[i] = sum;
584:   #endif
585:   });
586:   return;
587: }

589: } // namespace kernels

591: namespace detail
592: {

594: // a helper-struct to gobble the size_t input, it is used with template parameter pack
595: // expansion such that
596: // typename repeat_type<MyType, IdxParamPack>...
597: // expands to
598: // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times]
599: template <typename T, std::size_t>
600: struct repeat_type {
601:   using type = T;
602: };

604: } // namespace detail

606: template <device::cupm::DeviceType T>
607: template <std::size_t... Idx>
608: inline PetscErrorCode VecSeq_CUPM<T>::maxpy_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept
609: {
610:   PetscFunctionBegin;
611:   // clang-format off
612:   PetscCall(
613:     PetscCUPMLaunchKernel1D(
614:       size, 0, stream,
615:       kernels::maxpy_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
616:       size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()...
617:     )
618:   );
619:   // clang-format on
620:   PetscFunctionReturn(PETSC_SUCCESS);
621: }

623: template <device::cupm::DeviceType T>
624: template <int N>
625: inline PetscErrorCode VecSeq_CUPM<T>::maxpy_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept
626: {
627:   PetscFunctionBegin;
628:   PetscCall(maxpy_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{}));
629:   yidx += N;
630:   PetscFunctionReturn(PETSC_SUCCESS);
631: }

633: // v->ops->maxpy
634: template <device::cupm::DeviceType T>
635: inline PetscErrorCode VecSeq_CUPM<T>::maxpy(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept
636: {
637:   const auto         n = xin->map->n;
638:   PetscDeviceContext dctx;
639:   cupmStream_t       stream;

641:   PetscFunctionBegin;
642:   PetscCall(GetHandles_(&dctx, &stream));
643:   {
644:     const auto   xptr    = DeviceArrayReadWrite(dctx, xin);
645:     PetscScalar *d_alpha = nullptr;
646:     PetscInt     yidx    = 0;

648:     // placement of early-return is deliberate, we would like to capture the
649:     // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail
650:     if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS);
651:     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha));
652:     PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream));
653:     PetscCall(PetscLogGpuTimeBegin());
654:     do {
655:       switch (nv - yidx) {
656:       case 7:
657:         PetscCall(maxpy_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
658:         break;
659:       case 6:
660:         PetscCall(maxpy_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
661:         break;
662:       case 5:
663:         PetscCall(maxpy_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
664:         break;
665:       case 4:
666:         PetscCall(maxpy_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
667:         break;
668:       case 3:
669:         PetscCall(maxpy_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
670:         break;
671:       case 2:
672:         PetscCall(maxpy_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
673:         break;
674:       case 1:
675:         PetscCall(maxpy_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
676:         break;
677:       default: // 8 or more
678:         PetscCall(maxpy_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
679:         break;
680:       }
681:     } while (yidx < nv);
682:     PetscCall(PetscLogGpuTimeEnd());
683:     PetscCall(PetscDeviceFree(dctx, d_alpha));
684:   }
685:   PetscCall(PetscLogGpuFlops(nv * 2 * n));
686:   PetscCall(PetscDeviceContextSynchronize(dctx));
687:   PetscFunctionReturn(PETSC_SUCCESS);
688: }

690: template <device::cupm::DeviceType T>
691: inline PetscErrorCode VecSeq_CUPM<T>::dot(Vec xin, Vec yin, PetscScalar *z) noexcept
692: {
693:   PetscFunctionBegin;
694:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
695:     PetscDeviceContext dctx;
696:     cupmBlasHandle_t   cupmBlasHandle;

698:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
699:     // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the
700:     // second
701:     PetscCall(PetscLogGpuTimeBegin());
702:     PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z)));
703:     PetscCall(PetscLogGpuTimeEnd());
704:     PetscCall(PetscLogGpuFlops(2 * n - 1));
705:   } else {
706:     *z = 0.0;
707:   }
708:   PetscFunctionReturn(PETSC_SUCCESS);
709: }

711:   #define MDOT_WORKGROUP_NUM  128
712:   #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM

714: namespace kernels
715: {

717: PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept
718: {
719:   const auto group_entries = (size - 1) / gridDim.x + 1;
720:   // for very small vectors, a group should still do some work
721:   return group_entries ? group_entries : 1;
722: }

724: template <typename... ConstPetscScalarPointer>
725: PETSC_KERNEL_DECL static void mdot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y)
726: {
727:   constexpr int      N        = sizeof...(ConstPetscScalarPointer);
728:   const PetscScalar *ylocal[] = {y...};
729:   PetscScalar        sumlocal[N];

731:   PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE];

733:   // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate
734:   // types, so each of these go on separate lines...
735:   const auto tx       = threadIdx.x;
736:   const auto bx       = blockIdx.x;
737:   const auto bdx      = blockDim.x;
738:   const auto gdx      = gridDim.x;
739:   const auto worksize = EntriesPerGroup(size);
740:   const auto begin    = tx + bx * worksize;
741:   const auto end      = min((bx + 1) * worksize, size);

743:   #pragma unroll
744:   for (auto i = 0; i < N; ++i) sumlocal[i] = 0;

746:   for (auto i = begin; i < end; i += bdx) {
747:     const auto xi = x[i]; // load only once from global memory!

749:   #pragma unroll
750:     for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi;
751:   }

753:   #pragma unroll
754:   for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i];

756:   // parallel reduction
757:   for (auto stride = bdx / 2; stride > 0; stride /= 2) {
758:     __syncthreads();
759:     if (tx < stride) {
760:   #pragma unroll
761:       for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE];
762:     }
763:   }
764:   // bottom N threads per block write to global memory
765:   // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread
766:   // writes to the same sections in the above loop that it is about to read from below, but
767:   // running this under the racecheck tool of cuda-memcheck reports a write-after-write hazard.
768:   __syncthreads();
769:   if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE];
770:   return;
771: }

773: namespace
774: {

776: PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results)
777: {
778:   int         local_i = 0;
779:   PetscScalar local_results[8];

781:   // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer
782:   //
783:   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
784:   // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ...
785:   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
786:   //  |  ______________________________________________________/
787:   //  | /            <- MDOT_WORKGROUP_NUM ->
788:   //  |/
789:   //  +
790:   //  v
791:   // *-*-*
792:   // | | | ...
793:   // *-*-*
794:   //
795:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
796:     PetscScalar z_sum = 0;

798:     for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j];
799:     local_results[local_i++] = z_sum;
800:   });
801:   // if we needed more than 1 workgroup to handle the vector we should sync since other threads
802:   // may currently be reading from results
803:   if (size >= MDOT_WORKGROUP_SIZE) __syncthreads();
804:   // Local buffer is now written to global memory
805:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
806:     const auto j = --local_i;

808:     if (j >= 0) results[i] = local_results[j];
809:   });
810:   return;
811: }

813: } // namespace

815: } // namespace kernels

817: template <device::cupm::DeviceType T>
818: template <std::size_t... Idx>
819: inline PetscErrorCode VecSeq_CUPM<T>::mdot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept
820: {
821:   PetscFunctionBegin;
822:   // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches
823:   // 128 blocks of 128 threads every time which may be wasteful
824:   // clang-format off
825:   PetscCallCUPM(
826:     cupmLaunchKernel(
827:       kernels::mdot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
828:       MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream,
829:       xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()...
830:     )
831:   );
832:   // clang-format on
833:   PetscFunctionReturn(PETSC_SUCCESS);
834: }

836: template <device::cupm::DeviceType T>
837: template <int N>
838: inline PetscErrorCode VecSeq_CUPM<T>::mdot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept
839: {
840:   PetscFunctionBegin;
841:   PetscCall(mdot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{}));
842:   yidx += N;
843:   PetscFunctionReturn(PETSC_SUCCESS);
844: }

846: template <device::cupm::DeviceType T>
847: inline PetscErrorCode VecSeq_CUPM<T>::mdot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
848: {
849:   // the largest possible size of a batch
850:   constexpr PetscInt batchsize = 8;
851:   // how many sub streams to create, if nv <= batchsize we can do this without looping, so we
852:   // do not create substreams. Note we don't create more than 8 streams, in practice we could
853:   // not get more parallelism with higher numbers.
854:   const auto num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0;
855:   const auto n               = xin->map->n;
856:   // number of vectors that we handle via the batches. note any singletons are handled by
857:   // cublas, hence the nv-1.
858:   const auto   nvbatch = ((nv % batchsize) == 1) ? nv - 1 : nv;
859:   const auto   nwork   = nvbatch * MDOT_WORKGROUP_NUM;
860:   PetscScalar *d_results;
861:   cupmStream_t stream;

863:   PetscFunctionBegin;
864:   PetscCall(GetHandlesFrom_(dctx, &stream));
865:   // allocate scratchpad memory for the results of individual work groups
866:   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results));
867:   {
868:     const auto          xptr       = DeviceArrayRead(dctx, xin);
869:     PetscInt            yidx       = 0;
870:     auto                subidx     = 0;
871:     auto                cur_stream = stream;
872:     auto                cur_ctx    = dctx;
873:     PetscDeviceContext *sub        = nullptr;
874:     PetscStreamType     stype;

876:     // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of
877:     // sub. Ideally the parent context should also join in on the fork, but it is extremely
878:     // fiddly to do so presently
879:     PetscCall(PetscDeviceContextGetStreamType(dctx, &stype));
880:     if (stype == PETSC_STREAM_GLOBAL_BLOCKING) stype = PETSC_STREAM_DEFAULT_BLOCKING;
881:     // If we have a globally blocking stream create nonblocking streams instead (as we can
882:     // locally exploit the parallelism). Otherwise use the prescribed stream type.
883:     PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub));
884:     PetscCall(PetscLogGpuTimeBegin());
885:     do {
886:       if (num_sub_streams) {
887:         cur_ctx = sub[subidx++ % num_sub_streams];
888:         PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream));
889:       }
890:       // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9;
891:       // it is very likely better to do 4+5 rather than 8+1
892:       switch (nv - yidx) {
893:       case 7:
894:         PetscCall(mdot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
895:         break;
896:       case 6:
897:         PetscCall(mdot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
898:         break;
899:       case 5:
900:         PetscCall(mdot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
901:         break;
902:       case 4:
903:         PetscCall(mdot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
904:         break;
905:       case 3:
906:         PetscCall(mdot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
907:         break;
908:       case 2:
909:         PetscCall(mdot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
910:         break;
911:       case 1: {
912:         cupmBlasHandle_t cupmBlasHandle;

914:         PetscCall(GetHandlesFrom_(cur_ctx, &cupmBlasHandle));
915:         PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, static_cast<cupmBlasInt_t>(n), DeviceArrayRead(cur_ctx, yin[yidx]).cupmdata(), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(z + yidx)));
916:         ++yidx;
917:       } break;
918:       default: // 8 or more
919:         PetscCall(mdot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
920:         break;
921:       }
922:     } while (yidx < nv);
923:     PetscCall(PetscLogGpuTimeEnd());
924:     PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
925:   }

927:   PetscCall(PetscCUPMLaunchKernel1D(nvbatch, 0, stream, kernels::sum_kernel, nvbatch, d_results));
928:   // copy result of device reduction to host
929:   PetscCall(PetscCUPMMemcpyAsync(z, d_results, nvbatch, cupmMemcpyDeviceToHost, stream));
930:   // do these now while final reduction is in flight
931:   PetscCall(PetscLogFlops(nwork));
932:   PetscCall(PetscDeviceFree(dctx, d_results));
933:   PetscFunctionReturn(PETSC_SUCCESS);
934: }

936:   #undef MDOT_WORKGROUP_NUM
937:   #undef MDOT_WORKGROUP_SIZE

939: template <device::cupm::DeviceType T>
940: inline PetscErrorCode VecSeq_CUPM<T>::mdot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
941: {
942:   // probably not worth it to run more than 8 of these at a time?
943:   const auto          n_sub = PetscMin(nv, 8);
944:   const auto          n     = static_cast<cupmBlasInt_t>(xin->map->n);
945:   const auto          xptr  = DeviceArrayRead(dctx, xin);
946:   PetscScalar        *d_z;
947:   PetscDeviceContext *subctx;
948:   cupmStream_t        stream;

950:   PetscFunctionBegin;
951:   PetscCall(GetHandlesFrom_(dctx, &stream));
952:   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z));
953:   PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx));
954:   PetscCall(PetscLogGpuTimeBegin());
955:   for (PetscInt i = 0; i < nv; ++i) {
956:     const auto            sub = subctx[i % n_sub];
957:     cupmBlasHandle_t      handle;
958:     cupmBlasPointerMode_t old_mode;

960:     PetscCall(GetHandlesFrom_(sub, &handle));
961:     PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode));
962:     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE));
963:     PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i)));
964:     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode));
965:   }
966:   PetscCall(PetscLogGpuTimeEnd());
967:   PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx));
968:   PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream));
969:   PetscCall(PetscDeviceFree(dctx, d_z));
970:   // REVIEW ME: flops?????
971:   PetscFunctionReturn(PETSC_SUCCESS);
972: }

974: // v->ops->mdot
975: template <device::cupm::DeviceType T>
976: inline PetscErrorCode VecSeq_CUPM<T>::mdot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept
977: {
978:   PetscFunctionBegin;
979:   if (PetscUnlikely(nv == 1)) {
980:     // dot handles nv = 0 correctly
981:     PetscCall(dot(xin, const_cast<Vec>(yin[0]), z));
982:   } else if (const auto n = xin->map->n) {
983:     PetscDeviceContext dctx;

985:     PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv);
986:     PetscCall(GetHandles_(&dctx));
987:     PetscCall(mdot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx));
988:     // REVIEW ME: double count of flops??
989:     PetscCall(PetscLogGpuFlops(nv * (2 * n - 1)));
990:     PetscCall(PetscDeviceContextSynchronize(dctx));
991:   } else {
992:     PetscCall(PetscArrayzero(z, nv));
993:   }
994:   PetscFunctionReturn(PETSC_SUCCESS);
995: }

997: // v->ops->set
998: template <device::cupm::DeviceType T>
999: inline PetscErrorCode VecSeq_CUPM<T>::set(Vec xin, PetscScalar alpha) noexcept
1000: {
1001:   const auto         n = xin->map->n;
1002:   PetscDeviceContext dctx;
1003:   cupmStream_t       stream;

1005:   PetscFunctionBegin;
1006:   PetscCall(GetHandles_(&dctx, &stream));
1007:   {
1008:     const auto xptr = DeviceArrayWrite(dctx, xin);

1010:     if (alpha == PetscScalar(0.0)) {
1011:       PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream));
1012:     } else {
1013:       PetscCall(device::cupm::impl::ThrustSet<T>(stream, n, xptr.data(), &alpha));
1014:     }
1015:     if (n) PetscCall(PetscDeviceContextSynchronize(dctx)); // don't sync if we did nothing
1016:   }
1017:   PetscFunctionReturn(PETSC_SUCCESS);
1018: }

1020: // v->ops->scale
1021: template <device::cupm::DeviceType T>
1022: inline PetscErrorCode VecSeq_CUPM<T>::scale(Vec xin, PetscScalar alpha) noexcept
1023: {
1024:   PetscFunctionBegin;
1025:   if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS);
1026:   if (PetscUnlikely(alpha == PetscScalar(0.0))) {
1027:     PetscCall(set(xin, alpha));
1028:   } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1029:     PetscDeviceContext dctx;
1030:     cupmBlasHandle_t   cupmBlasHandle;

1032:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1033:     PetscCall(PetscLogGpuTimeBegin());
1034:     PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1));
1035:     PetscCall(PetscLogGpuTimeEnd());
1036:     PetscCall(PetscLogGpuFlops(n));
1037:     PetscCall(PetscDeviceContextSynchronize(dctx));
1038:   } else {
1039:     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1040:   }
1041:   PetscFunctionReturn(PETSC_SUCCESS);
1042: }

1044: // v->ops->tdot
1045: template <device::cupm::DeviceType T>
1046: inline PetscErrorCode VecSeq_CUPM<T>::tdot(Vec xin, Vec yin, PetscScalar *z) noexcept
1047: {
1048:   PetscFunctionBegin;
1049:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1050:     PetscDeviceContext dctx;
1051:     cupmBlasHandle_t   cupmBlasHandle;

1053:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1054:     PetscCall(PetscLogGpuTimeBegin());
1055:     PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z)));
1056:     PetscCall(PetscLogGpuTimeEnd());
1057:     PetscCall(PetscLogGpuFlops(2 * n - 1));
1058:   } else {
1059:     *z = 0.0;
1060:   }
1061:   PetscFunctionReturn(PETSC_SUCCESS);
1062: }

1064: // v->ops->copy
1065: template <device::cupm::DeviceType T>
1066: inline PetscErrorCode VecSeq_CUPM<T>::copy(Vec xin, Vec yout) noexcept
1067: {
1068:   PetscFunctionBegin;
1069:   if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS);
1070:   if (const auto n = xin->map->n) {
1071:     const auto xmask = xin->offloadmask;
1072:     // silence buggy gcc warning: mode may be used uninitialized in this function
1073:     auto               mode = cupmMemcpyDeviceToDevice;
1074:     PetscDeviceContext dctx;
1075:     cupmStream_t       stream;

1077:     // translate from PetscOffloadMask to cupmMemcpyKind
1078:     switch (const auto ymask = yout->offloadmask) {
1079:     case PETSC_OFFLOAD_UNALLOCATED: {
1080:       PetscBool yiscupm;

1082:       PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1083:       if (yiscupm) {
1084:         mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost;
1085:         break;
1086:       }
1087:     } // fall-through if unallocated and not cupm
1088:   #if PETSC_CPP_VERSION >= 17
1089:       [[fallthrough]];
1090:   #endif
1091:     case PETSC_OFFLOAD_CPU:
1092:       mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost;
1093:       break;
1094:     case PETSC_OFFLOAD_BOTH:
1095:     case PETSC_OFFLOAD_GPU:
1096:       mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
1097:       break;
1098:     default:
1099:       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask));
1100:     }

1102:     PetscCall(GetHandles_(&dctx, &stream));
1103:     switch (mode) {
1104:     case cupmMemcpyDeviceToDevice: // the best case
1105:     case cupmMemcpyHostToDevice: { // not terrible
1106:       const auto yptr = DeviceArrayWrite(dctx, yout);
1107:       const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();

1109:       PetscCall(PetscLogGpuTimeBegin());
1110:       PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream));
1111:       PetscCall(PetscLogGpuTimeEnd());
1112:     } break;
1113:     case cupmMemcpyDeviceToHost: // not great
1114:     case cupmMemcpyHostToHost: { // worst case
1115:       const auto   xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1116:       PetscScalar *yptr;

1118:       PetscCall(VecGetArrayWrite(yout, &yptr));
1119:       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin());
1120:       PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true));
1121:       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd());
1122:       PetscCall(VecRestoreArrayWrite(yout, &yptr));
1123:     } break;
1124:     default:
1125:       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode));
1126:     }
1127:     PetscCall(PetscDeviceContextSynchronize(dctx));
1128:   } else {
1129:     PetscCall(MaybeIncrementEmptyLocalVec(yout));
1130:   }
1131:   PetscFunctionReturn(PETSC_SUCCESS);
1132: }

1134: // v->ops->swap
1135: template <device::cupm::DeviceType T>
1136: inline PetscErrorCode VecSeq_CUPM<T>::swap(Vec xin, Vec yin) noexcept
1137: {
1138:   PetscFunctionBegin;
1139:   if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS);
1140:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1141:     PetscDeviceContext dctx;
1142:     cupmBlasHandle_t   cupmBlasHandle;

1144:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1145:     PetscCall(PetscLogGpuTimeBegin());
1146:     PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
1147:     PetscCall(PetscLogGpuTimeEnd());
1148:     PetscCall(PetscDeviceContextSynchronize(dctx));
1149:   } else {
1150:     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1151:     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1152:   }
1153:   PetscFunctionReturn(PETSC_SUCCESS);
1154: }

1156: // v->ops->axpby
1157: template <device::cupm::DeviceType T>
1158: inline PetscErrorCode VecSeq_CUPM<T>::axpby(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept
1159: {
1160:   PetscFunctionBegin;
1161:   if (alpha == PetscScalar(0.0)) {
1162:     PetscCall(scale(yin, beta));
1163:   } else if (beta == PetscScalar(1.0)) {
1164:     PetscCall(axpy(yin, alpha, xin));
1165:   } else if (alpha == PetscScalar(1.0)) {
1166:     PetscCall(aypx(yin, beta, xin));
1167:   } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) {
1168:     const auto         betaIsZero = beta == PetscScalar(0.0);
1169:     const auto         aptr       = cupmScalarPtrCast(&alpha);
1170:     PetscDeviceContext dctx;
1171:     cupmBlasHandle_t   cupmBlasHandle;

1173:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1174:     {
1175:       const auto xptr = DeviceArrayRead(dctx, xin);

1177:       if (betaIsZero /* beta = 0 */) {
1178:         // here we can get away with purely write-only as we memcpy into it first
1179:         const auto   yptr = DeviceArrayWrite(dctx, yin);
1180:         cupmStream_t stream;

1182:         PetscCall(GetHandlesFrom_(dctx, &stream));
1183:         PetscCall(PetscLogGpuTimeBegin());
1184:         PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream));
1185:         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1));
1186:       } else {
1187:         const auto yptr = DeviceArrayReadWrite(dctx, yin);

1189:         PetscCall(PetscLogGpuTimeBegin());
1190:         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1));
1191:         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
1192:       }
1193:     }
1194:     PetscCall(PetscLogGpuTimeEnd());
1195:     PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n));
1196:     PetscCall(PetscDeviceContextSynchronize(dctx));
1197:   } else {
1198:     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1199:   }
1200:   PetscFunctionReturn(PETSC_SUCCESS);
1201: }

1203: // v->ops->axpbypcz
1204: template <device::cupm::DeviceType T>
1205: inline PetscErrorCode VecSeq_CUPM<T>::axpbypcz(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept
1206: {
1207:   PetscFunctionBegin;
1208:   if (gamma != PetscScalar(1.0)) PetscCall(scale(zin, gamma));
1209:   PetscCall(axpy(zin, alpha, xin));
1210:   PetscCall(axpy(zin, beta, yin));
1211:   PetscFunctionReturn(PETSC_SUCCESS);
1212: }

1214: // v->ops->norm
1215: template <device::cupm::DeviceType T>
1216: inline PetscErrorCode VecSeq_CUPM<T>::norm(Vec xin, NormType type, PetscReal *z) noexcept
1217: {
1218:   PetscDeviceContext dctx;
1219:   cupmBlasHandle_t   cupmBlasHandle;

1221:   PetscFunctionBegin;
1222:   PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1223:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1224:     const auto xptr      = DeviceArrayRead(dctx, xin);
1225:     PetscInt   flopCount = 0;

1227:     PetscCall(PetscLogGpuTimeBegin());
1228:     switch (type) {
1229:     case NORM_1_AND_2:
1230:     case NORM_1:
1231:       PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1232:       flopCount = std::max(n - 1, 0);
1233:       if (type == NORM_1) break;
1234:       ++z; // fall-through
1235:   #if PETSC_CPP_VERSION >= 17
1236:       [[fallthrough]];
1237:   #endif
1238:     case NORM_2:
1239:     case NORM_FROBENIUS:
1240:       PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1241:       flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2
1242:       break;
1243:     case NORM_INFINITY: {
1244:       cupmBlasInt_t max_loc = 0;
1245:       PetscScalar   xv      = 0.;
1246:       cupmStream_t  stream;

1248:       PetscCall(GetHandlesFrom_(dctx, &stream));
1249:       PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc));
1250:       PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream));
1251:       *z = PetscAbsScalar(xv);
1252:       // REVIEW ME: flopCount = ???
1253:     } break;
1254:     }
1255:     PetscCall(PetscLogGpuTimeEnd());
1256:     PetscCall(PetscLogGpuFlops(flopCount));
1257:   } else {
1258:     z[0]                    = 0.0;
1259:     z[type == NORM_1_AND_2] = 0.0;
1260:   }
1261:   PetscFunctionReturn(PETSC_SUCCESS);
1262: }

1264: namespace detail
1265: {

1267: struct dotnorm2_mult {
1268:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept
1269:   {
1270:     const auto conjt = PetscConj(t);

1272:     return {s * conjt, t * conjt};
1273:   }
1274: };

1276: // it is positively __bananas__ that thrust does not define default operator+ for tuples... I
1277: // would do it myself but now I am worried that they do so on purpose...
1278: struct dotnorm2_tuple_plus {
1279:   using value_type = thrust::tuple<PetscScalar, PetscScalar>;

1281:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {lhs.get<0>() + rhs.get<0>(), lhs.get<1>() + rhs.get<1>()}; }
1282: };

1284: } // namespace detail

1286: // v->ops->dotnorm2
1287: template <device::cupm::DeviceType T>
1288: inline PetscErrorCode VecSeq_CUPM<T>::dotnorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept
1289: {
1290:   PetscDeviceContext dctx;
1291:   cupmStream_t       stream;

1293:   PetscFunctionBegin;
1294:   PetscCall(GetHandles_(&dctx, &stream));
1295:   {
1296:     PetscScalar dpt = 0.0, nmt = 0.0;
1297:     const auto  sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data());

1299:     // clang-format off
1300:     PetscCallThrust(
1301:       thrust::tie(*dp, *nm) = THRUST_CALL(
1302:         thrust::inner_product,
1303:         stream,
1304:         sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()),
1305:         thrust::make_tuple(dpt, nmt),
1306:         detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{}
1307:       );
1308:     );
1309:     // clang-format on
1310:   }
1311:   PetscFunctionReturn(PETSC_SUCCESS);
1312: }

1314: namespace detail
1315: {

1317: struct conjugate {
1318:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(PetscScalar x) const noexcept { return PetscConj(x); }
1319: };

1321: } // namespace detail

1323: // v->ops->conjugate
1324: template <device::cupm::DeviceType T>
1325: inline PetscErrorCode VecSeq_CUPM<T>::conjugate(Vec xin) noexcept
1326: {
1327:   PetscFunctionBegin;
1328:   if (PetscDefined(USE_COMPLEX)) PetscCall(pointwiseunary_(detail::conjugate{}, xin));
1329:   PetscFunctionReturn(PETSC_SUCCESS);
1330: }

1332: namespace detail
1333: {

1335: struct real_part {
1336:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const { return {PetscRealPart(x.get<0>()), x.get<1>()}; }

1338:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(PetscScalar x) const { return PetscRealPart(x); }
1339: };

1341: // deriving from Operator allows us to "store" an instance of the operator in the class but
1342: // also take advantage of empty base class optimization if the operator is stateless
1343: template <typename Operator>
1344: class tuple_compare : Operator {
1345: public:
1346:   using tuple_type    = thrust::tuple<PetscReal, PetscInt>;
1347:   using operator_type = Operator;

1349:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept
1350:   {
1351:     if (op_()(y.get<0>(), x.get<0>())) {
1352:       // if y is strictly greater/less than x, return y
1353:       return y;
1354:     } else if (y.get<0>() == x.get<0>()) {
1355:       // if equal, prefer lower index
1356:       return y.get<1>() < x.get<1>() ? y : x;
1357:     }
1358:     // otherwise return x
1359:     return x;
1360:   }

1362: private:
1363:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; }
1364: };

1366: } // namespace detail

1368: template <device::cupm::DeviceType T>
1369: template <typename TupleFuncT, typename UnaryFuncT>
1370: inline PetscErrorCode VecSeq_CUPM<T>::minmax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept
1371: {
1372:   PetscFunctionBegin;
1373:   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
1374:   if (p) *p = -1;
1375:   if (const auto n = v->map->n) {
1376:     PetscDeviceContext dctx;
1377:     cupmStream_t       stream;

1379:     PetscCall(GetHandles_(&dctx, &stream));
1380:       // needed to:
1381:       // 1. switch between transform_reduce and reduce
1382:       // 2. strip the real_part functor from the arguments
1383:   #if PetscDefined(USE_COMPLEX)
1384:     #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__)
1385:   #else
1386:     #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__)
1387:   #endif
1388:     {
1389:       const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());

1391:       if (p) {
1392:         // clang-format off
1393:         const auto zip = thrust::make_zip_iterator(
1394:           thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0}))
1395:         );
1396:         // clang-format on
1397:         // need to use preprocessor conditionals since otherwise thrust complains about not being
1398:         // able to convert a thrust::device_reference<PetscScalar> to a PetscReal on complex
1399:         // builds...
1400:         // clang-format off
1401:         PetscCallThrust(
1402:           thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE(
1403:             stream, zip, zip + n, detail::real_part{},
1404:             thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr)
1405:           );
1406:         );
1407:         // clang-format on
1408:       } else {
1409:         // clang-format off
1410:         PetscCallThrust(
1411:           *m = THRUST_MINMAX_REDUCE(
1412:             stream, vptr, vptr + n, detail::real_part{},
1413:             *m, std::forward<UnaryFuncT>(unary_ftr)
1414:           );
1415:         );
1416:         // clang-format on
1417:       }
1418:     }
1419:   #undef THRUST_MINMAX_REDUCE
1420:   }
1421:   // REVIEW ME: flops?
1422:   PetscFunctionReturn(PETSC_SUCCESS);
1423: }

1425: // v->ops->max
1426: template <device::cupm::DeviceType T>
1427: inline PetscErrorCode VecSeq_CUPM<T>::max(Vec v, PetscInt *p, PetscReal *m) noexcept
1428: {
1429:   using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>;
1430:   using unary_functor = thrust::maximum<PetscReal>;

1432:   PetscFunctionBegin;
1433:   *m = PETSC_MIN_REAL;
1434:   // use {} constructor syntax otherwise most vexing parse
1435:   PetscCall(minmax_(tuple_functor{}, unary_functor{}, v, p, m));
1436:   PetscFunctionReturn(PETSC_SUCCESS);
1437: }

1439: // v->ops->min
1440: template <device::cupm::DeviceType T>
1441: inline PetscErrorCode VecSeq_CUPM<T>::min(Vec v, PetscInt *p, PetscReal *m) noexcept
1442: {
1443:   using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>;
1444:   using unary_functor = thrust::minimum<PetscReal>;

1446:   PetscFunctionBegin;
1447:   *m = PETSC_MAX_REAL;
1448:   // use {} constructor syntax otherwise most vexing parse
1449:   PetscCall(minmax_(tuple_functor{}, unary_functor{}, v, p, m));
1450:   PetscFunctionReturn(PETSC_SUCCESS);
1451: }

1453: // v->ops->sum
1454: template <device::cupm::DeviceType T>
1455: inline PetscErrorCode VecSeq_CUPM<T>::sum(Vec v, PetscScalar *sum) noexcept
1456: {
1457:   PetscFunctionBegin;
1458:   if (const auto n = v->map->n) {
1459:     PetscDeviceContext dctx;
1460:     cupmStream_t       stream;

1462:     PetscCall(GetHandles_(&dctx, &stream));
1463:     const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1464:     // REVIEW ME: why not cupmBlasXasum()?
1465:     PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0}););
1466:     // REVIEW ME: must be at least n additions
1467:     PetscCall(PetscLogGpuFlops(n));
1468:   } else {
1469:     *sum = 0.0;
1470:   }
1471:   PetscFunctionReturn(PETSC_SUCCESS);
1472: }

1474: namespace detail
1475: {

1477: template <typename T>
1478: class plus_equals {
1479: public:
1480:   using value_type = T;

1482:   PETSC_HOSTDEVICE_DECL constexpr explicit plus_equals(value_type v = value_type{}) noexcept : v_(std::move(v)) { }

1484:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL constexpr T operator()(const T &val) const noexcept { return val + v_; }

1486: private:
1487:   value_type v_;
1488: };

1490: } // namespace detail

1492: template <device::cupm::DeviceType T>
1493: inline PetscErrorCode VecSeq_CUPM<T>::shift(Vec v, PetscScalar shift) noexcept
1494: {
1495:   PetscFunctionBegin;
1496:   PetscCall(pointwiseunary_(detail::plus_equals<PetscScalar>{shift}, v));
1497:   PetscFunctionReturn(PETSC_SUCCESS);
1498: }

1500: template <device::cupm::DeviceType T>
1501: inline PetscErrorCode VecSeq_CUPM<T>::setrandom(Vec v, PetscRandom rand) noexcept
1502: {
1503:   PetscFunctionBegin;
1504:   if (const auto n = v->map->n) {
1505:     PetscBool          iscurand;
1506:     PetscDeviceContext dctx;

1508:     PetscCall(GetHandles_(&dctx));
1509:     PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand));
1510:     if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v)));
1511:     else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v)));
1512:   } else {
1513:     PetscCall(MaybeIncrementEmptyLocalVec(v));
1514:   }
1515:   // REVIEW ME: flops????
1516:   // REVIEW ME: Timing???
1517:   PetscFunctionReturn(PETSC_SUCCESS);
1518: }

1520: // v->ops->setpreallocation
1521: template <device::cupm::DeviceType T>
1522: inline PetscErrorCode VecSeq_CUPM<T>::setpreallocationcoo(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept
1523: {
1524:   PetscDeviceContext dctx;

1526:   PetscFunctionBegin;
1527:   PetscCall(GetHandles_(&dctx));
1528:   PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i));
1529:   PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx));
1530:   PetscFunctionReturn(PETSC_SUCCESS);
1531: }

1533: namespace kernels
1534: {

1536: template <typename F>
1537: PETSC_DEVICE_INLINE_DECL void add_coo_values_impl(const PetscScalar *PETSC_RESTRICT vv, PetscCount n, const PetscCount *PETSC_RESTRICT jmap, const PetscCount *PETSC_RESTRICT perm, InsertMode imode, PetscScalar *PETSC_RESTRICT xv, F &&xvindex)
1538: {
1539:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(n, [=](PetscCount i) {
1540:     const auto  end = jmap[i + 1];
1541:     const auto  idx = xvindex(i);
1542:     PetscScalar sum = 0.0;

1544:     for (auto k = jmap[i]; k < end; ++k) sum += vv[perm[k]];

1546:     if (imode == INSERT_VALUES) {
1547:       xv[idx] = sum;
1548:     } else {
1549:       xv[idx] += sum;
1550:     }
1551:   });
1552:   return;
1553: }

1555: namespace
1556: {

1558: PETSC_KERNEL_DECL void add_coo_values(const PetscScalar *PETSC_RESTRICT v, PetscCount n, const PetscCount *PETSC_RESTRICT jmap1, const PetscCount *PETSC_RESTRICT perm1, InsertMode imode, PetscScalar *PETSC_RESTRICT xv)
1559: {
1560:   add_coo_values_impl(v, n, jmap1, perm1, imode, xv, [](PetscCount i) { return i; });
1561:   return;
1562: }

1564: } // namespace

1566:   #if PetscDefined(USING_HCC)
1567: namespace do_not_use
1568: {

1570: // Needed to silence clang warning:
1571: //
1572: // warning: function 'FUNCTION NAME' is not needed and will not be emitted
1573: //
1574: // The warning is silly, since the function *is* used, however the host compiler does not
1575: // appear see this. Likely because the function using it is in a template.
1576: //
1577: // This warning appeared in clang-11, and still persists until clang-15 (21/02/2023)
1578: inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted()
1579: {
1580:   (void)sum_kernel;
1581: }

1583: inline void silence_warning_function_add_coo_values_is_not_needed_and_will_not_be_emitted()
1584: {
1585:   (void)add_coo_values;
1586: }

1588: } // namespace do_not_use
1589:   #endif

1591: } // namespace kernels

1593: // v->ops->setvaluescoo
1594: template <device::cupm::DeviceType T>
1595: inline PetscErrorCode VecSeq_CUPM<T>::setvaluescoo(Vec x, const PetscScalar v[], InsertMode imode) noexcept
1596: {
1597:   auto               vv = const_cast<PetscScalar *>(v);
1598:   PetscMemType       memtype;
1599:   PetscDeviceContext dctx;
1600:   cupmStream_t       stream;

1602:   PetscFunctionBegin;
1603:   PetscCall(GetHandles_(&dctx, &stream));
1604:   PetscCall(PetscGetMemType(v, &memtype));
1605:   if (PetscMemTypeHost(memtype)) {
1606:     const auto size = VecIMPLCast(x)->coo_n;

1608:     // If user gave v[] in host, we might need to copy it to device if any
1609:     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
1610:     PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
1611:   }

1613:   if (const auto n = x->map->n) {
1614:     const auto vcu = VecCUPMCast(x);

1616:     PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data()));
1617:   } else {
1618:     PetscCall(MaybeIncrementEmptyLocalVec(x));
1619:   }

1621:   if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv));
1622:   PetscCall(PetscDeviceContextSynchronize(dctx));
1623:   PetscFunctionReturn(PETSC_SUCCESS);
1624: }

1626: // ==========================================================================================
1627: // VecSeq_CUPM - Implementations
1628: // ==========================================================================================

1630: namespace
1631: {

1633: template <typename T>
1634: inline PetscErrorCode VecCreateSeqCUPMAsync(T &&VecSeq_CUPM_Impls, MPI_Comm comm, PetscInt n, Vec *v) noexcept
1635: {
1636:   PetscFunctionBegin;
1638:   PetscCall(VecSeq_CUPM_Impls.createseqcupm(comm, 0, n, v, PETSC_TRUE));
1639:   PetscFunctionReturn(PETSC_SUCCESS);
1640: }

1642: template <typename T>
1643: inline PetscErrorCode VecCreateSeqCUPMWithArraysAsync(T &&VecSeq_CUPM_Impls, MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar cpuarray[], const PetscScalar gpuarray[], Vec *v) noexcept
1644: {
1645:   PetscFunctionBegin;
1648:   PetscCall(VecSeq_CUPM_Impls.createseqcupmwithbotharrays(comm, bs, n, cpuarray, gpuarray, v));
1649:   PetscFunctionReturn(PETSC_SUCCESS);
1650: }

1652: template <PetscMemoryAccessMode mode, typename T>
1653: inline PetscErrorCode VecCUPMGetArrayAsync_Private(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1654: {
1655:   PetscFunctionBegin;
1658:   PetscCall(VecSeq_CUPM_Impls.template getarray<PETSC_MEMTYPE_DEVICE, mode>(v, a));
1659:   PetscFunctionReturn(PETSC_SUCCESS);
1660: }

1662: template <PetscMemoryAccessMode mode, typename T>
1663: inline PetscErrorCode VecCUPMRestoreArrayAsync_Private(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1664: {
1665:   PetscFunctionBegin;
1667:   PetscCall(VecSeq_CUPM_Impls.template restorearray<PETSC_MEMTYPE_DEVICE, mode>(v, a));
1668:   PetscFunctionReturn(PETSC_SUCCESS);
1669: }

1671: template <typename T>
1672: inline PetscErrorCode VecCUPMGetArrayAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1673: {
1674:   PetscFunctionBegin;
1675:   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1676:   PetscFunctionReturn(PETSC_SUCCESS);
1677: }

1679: template <typename T>
1680: inline PetscErrorCode VecCUPMRestoreArrayAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1681: {
1682:   PetscFunctionBegin;
1683:   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1684:   PetscFunctionReturn(PETSC_SUCCESS);
1685: }

1687: template <typename T>
1688: inline PetscErrorCode VecCUPMGetArrayReadAsync(T &&VecSeq_CUPM_Impls, Vec v, const PetscScalar **a) noexcept
1689: {
1690:   PetscFunctionBegin;
1691:   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_READ>(std::forward<T>(VecSeq_CUPM_Impls), v, const_cast<PetscScalar **>(a)));
1692:   PetscFunctionReturn(PETSC_SUCCESS);
1693: }

1695: template <typename T>
1696: inline PetscErrorCode VecCUPMRestoreArrayReadAsync(T &&VecSeq_CUPM_Impls, Vec v, const PetscScalar **a) noexcept
1697: {
1698:   PetscFunctionBegin;
1699:   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_READ>(std::forward<T>(VecSeq_CUPM_Impls), v, const_cast<PetscScalar **>(a)));
1700:   PetscFunctionReturn(PETSC_SUCCESS);
1701: }

1703: template <typename T>
1704: inline PetscErrorCode VecCUPMGetArrayWriteAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1705: {
1706:   PetscFunctionBegin;
1707:   PetscCall(VecCUPMGetArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1708:   PetscFunctionReturn(PETSC_SUCCESS);
1709: }

1711: template <typename T>
1712: inline PetscErrorCode VecCUPMRestoreArrayWriteAsync(T &&VecSeq_CUPM_Impls, Vec v, PetscScalar **a) noexcept
1713: {
1714:   PetscFunctionBegin;
1715:   PetscCall(VecCUPMRestoreArrayAsync_Private<PETSC_MEMORY_ACCESS_WRITE>(std::forward<T>(VecSeq_CUPM_Impls), v, a));
1716:   PetscFunctionReturn(PETSC_SUCCESS);
1717: }

1719: template <typename T>
1720: inline PetscErrorCode VecCUPMPlaceArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin, const PetscScalar a[]) noexcept
1721: {
1722:   PetscFunctionBegin;
1724:   PetscCall(VecSeq_CUPM_Impls.template placearray<PETSC_MEMTYPE_DEVICE>(vin, a));
1725:   PetscFunctionReturn(PETSC_SUCCESS);
1726: }

1728: template <typename T>
1729: inline PetscErrorCode VecCUPMReplaceArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin, const PetscScalar a[]) noexcept
1730: {
1731:   PetscFunctionBegin;
1733:   PetscCall(VecSeq_CUPM_Impls.template replacearray<PETSC_MEMTYPE_DEVICE>(vin, a));
1734:   PetscFunctionReturn(PETSC_SUCCESS);
1735: }

1737: template <typename T>
1738: inline PetscErrorCode VecCUPMResetArrayAsync(T &&VecSeq_CUPM_Impls, Vec vin) noexcept
1739: {
1740:   PetscFunctionBegin;
1742:   PetscCall(VecSeq_CUPM_Impls.template resetarray<PETSC_MEMTYPE_DEVICE>(vin));
1743:   PetscFunctionReturn(PETSC_SUCCESS);
1744: }

1746: } // anonymous namespace

1748: } // namespace impl

1750: } // namespace cupm

1752: } // namespace vec

1754: } // namespace Petsc

1756: #endif // __cplusplus

1758: #endif // PETSCVECSEQCUPM_HPP