Actual source code: cupmthrustutility.hpp

  1: #ifndef PETSC_CUPM_THRUST_UTILITY_HPP
  2: #define PETSC_CUPM_THRUST_UTILITY_HPP

  4: #include <petsc/private/deviceimpl.h>
  5: #include <petsc/private/cupminterface.hpp>

  7: #if defined(__cplusplus)
  8:   #include <thrust/device_ptr.h>
  9:   #include <thrust/transform.h>

 11: namespace Petsc
 12: {

 14: namespace device
 15: {

 17: namespace cupm
 18: {

 20: namespace impl
 21: {

 23:   #if PetscDefined(USING_NVCC)
 24:     #if !defined(THRUST_VERSION)
 25:       #error "THRUST_VERSION not defined!"
 26:     #endif
 27:     #if !PetscDefined(USE_DEBUG) && (THRUST_VERSION >= 101600)
 28:       #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par_nosync.on(s), __VA_ARGS__)
 29:     #else
 30:       #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par.on(s), __VA_ARGS__)
 31:     #endif
 32:   #elif PetscDefined(USING_HCC) // rocThrust has no par_nosync
 33:     #define thrust_call_par_on(func, s, ...) func(thrust::hip::par.on(s), __VA_ARGS__)
 34:   #else
 35:     #define thrust_call_par_on(func, s, ...) func(__VA_ARGS__)
 36:   #endif

 38: namespace detail
 39: {

 41: struct PetscLogGpuTimer {
 42:   PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeBegin()); }
 43:   ~PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeEnd()); }
 44: };

 46: struct private_tag { };

 48: } // namespace detail

 50:   #define THRUST_CALL(...) \
 51:     [&] { \
 52:       const auto timer = ::Petsc::device::cupm::impl::detail::PetscLogGpuTimer{}; \
 53:       return thrust_call_par_on(__VA_ARGS__); \
 54:     }()

 56:   #define PetscCallThrust(...) \
 57:     do { \
 58:       try { \
 59:         __VA_ARGS__; \
 60:       } catch (const thrust::system_error &ex) { \
 61:         SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Thrust error: %s", ex.what()); \
 62:       } \
 63:     } while (0)


 67: // actual implementation that calls thrust, 2 argument version
 68: template <DeviceType DT, typename FunctorType, typename T>
 69: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xinout, T *yin = nullptr))
 70: {
 71:   const auto xptr   = thrust::device_pointer_cast(xinout);
 72:   const auto retptr = (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr;

 74:   PetscFunctionBegin;
 76:   PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, retptr, std::forward<FunctorType>(functor)));
 77:   PetscFunctionReturn(PETSC_SUCCESS);
 78: }

 80: // actual implementation that calls thrust, 3 argument version
 81: template <DeviceType DT, typename FunctorType, typename T>
 82: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xin, T *yin, T *zin))
 83: {
 84:   const auto xptr = thrust::device_pointer_cast(xin);

 86:   PetscFunctionBegin;
 90:   PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, thrust::device_pointer_cast(yin), thrust::device_pointer_cast(zin), std::forward<FunctorType>(functor)));
 91:   PetscFunctionReturn(PETSC_SUCCESS);
 92: }

 94: // one last intermediate function to check n, and log flops for everything
 95: template <DeviceType DT, typename F, typename... T>
 96: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(typename Interface<DT>::cupmStream_t stream, F &&functor, PetscInt n, T &&...rest))
 97: {
 98:   PetscFunctionBegin;
 99:   PetscAssert(n >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "n %" PetscInt_FMT " must be >= 0", n);
100:   if (PetscLikely(n)) {
101:     PetscCall(ThrustApplyPointwise<DT>(detail::private_tag{}, stream, std::forward<F>(functor), n, std::forward<T>(rest)...));
102:     PetscCall(PetscLogGpuFlops(n));
103:   }
104:   PetscFunctionReturn(PETSC_SUCCESS);
105: }

107: // serves as setup to the real implementation above
108: template <DeviceType T, typename F, typename... Args>
109: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(PetscDeviceContext dctx, F &&functor, PetscInt n, Args &&...rest))
110: {
111:   typename Interface<T>::cupmStream_t stream;

113:   PetscFunctionBegin;
114:   static_assert(sizeof...(Args) <= 3, "");
116:   PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream));
117:   PetscCall(ThrustApplyPointwise<T>(stream, std::forward<F>(functor), n, std::forward<Args>(rest)...));
118:   PetscFunctionReturn(PETSC_SUCCESS);
119: }

121:   #define PetscCallCUPM_(...) \
122:     do { \
123:       using interface               = Interface<DT>; \
124:       using cupmError_t             = typename interface::cupmError_t; \
125:       const auto cupmName           = []() { return interface::cupmName(); }; \
126:       const auto cupmGetErrorName   = [](cupmError_t e) { return interface::cupmGetErrorName(e); }; \
127:       const auto cupmGetErrorString = [](cupmError_t e) { return interface::cupmGetErrorString(e); }; \
128:       const auto cupmSuccess        = interface::cupmSuccess; \
129:       PetscCallCUPM(__VA_ARGS__); \
130:     } while (0)

132: template <DeviceType DT, typename T>
133: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(typename Interface<DT>::cupmStream_t stream, PetscInt n, T *ptr, const T *val))
134: {
135:   PetscFunctionBegin;
137:   if (n) {
138:     const auto size = n * sizeof(T);

141:     if (*val == T{0}) {
142:       PetscCallCUPM_(Interface<DT>::cupmMemsetAsync(ptr, 0, size, stream));
143:     } else {
144:       const auto xptr = thrust::device_pointer_cast(ptr);

146:       PetscCallThrust(THRUST_CALL(thrust::fill, stream, xptr, xptr + n, *val));
147:       if (std::is_same<util::remove_cv_t<T>, PetscScalar>::value) {
148:         PetscCall(PetscLogCpuToGpuScalar(size));
149:       } else {
150:         PetscCall(PetscLogCpuToGpu(size));
151:       }
152:     }
153:   }
154:   PetscFunctionReturn(PETSC_SUCCESS);
155: }

157:   #undef PetscCallCUPM_

160: template <DeviceType DT, typename T>
161: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(PetscDeviceContext dctx, PetscInt n, T *ptr, const T *val))
162: {
163:   typename Interface<DT>::cupmStream_t stream;

165:   PetscFunctionBegin;
167:   PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream));
168:   PetscCall(ThrustSet(stream, n, ptr, val));
169:   PetscFunctionReturn(PETSC_SUCCESS);
170: }

172: } // namespace impl

174: } // namespace cupm

176: } // namespace device

178: } // namespace Petsc

180: #endif // __cplusplus

182: #endif // PETSC_CUPM_THRUST_UTILITY_HPP