Actual source code: cupmthrustutility.hpp
1: #ifndef PETSC_CUPM_THRUST_UTILITY_HPP
2: #define PETSC_CUPM_THRUST_UTILITY_HPP
4: #include <petsc/private/deviceimpl.h>
5: #include <petsc/private/cupminterface.hpp>
7: #if defined(__cplusplus)
8: #include <thrust/device_ptr.h>
9: #include <thrust/transform.h>
11: namespace Petsc
12: {
14: namespace device
15: {
17: namespace cupm
18: {
20: namespace impl
21: {
23: #if PetscDefined(USING_NVCC)
24: #if !defined(THRUST_VERSION)
25: #error "THRUST_VERSION not defined!"
26: #endif
27: #if !PetscDefined(USE_DEBUG) && (THRUST_VERSION >= 101600)
28: #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par_nosync.on(s), __VA_ARGS__)
29: #else
30: #define thrust_call_par_on(func, s, ...) func(thrust::cuda::par.on(s), __VA_ARGS__)
31: #endif
32: #elif PetscDefined(USING_HCC) // rocThrust has no par_nosync
33: #define thrust_call_par_on(func, s, ...) func(thrust::hip::par.on(s), __VA_ARGS__)
34: #else
35: #define thrust_call_par_on(func, s, ...) func(__VA_ARGS__)
36: #endif
38: namespace detail
39: {
41: struct PetscLogGpuTimer {
42: PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeBegin()); }
43: ~PetscLogGpuTimer() noexcept { PetscCallAbort(PETSC_COMM_SELF, PetscLogGpuTimeEnd()); }
44: };
46: struct private_tag { };
48: } // namespace detail
50: #define THRUST_CALL(...) \
51: [&] { \
52: const auto timer = ::Petsc::device::cupm::impl::detail::PetscLogGpuTimer{}; \
53: return thrust_call_par_on(__VA_ARGS__); \
54: }()
56: #define PetscCallThrust(...) \
57: do { \
58: try { \
59: __VA_ARGS__; \
60: } catch (const thrust::system_error &ex) { \
61: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "Thrust error: %s", ex.what()); \
62: } \
63: } while (0)
67: // actual implementation that calls thrust, 2 argument version
68: template <DeviceType DT, typename FunctorType, typename T>
69: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xinout, T *yin = nullptr))
70: {
71: const auto xptr = thrust::device_pointer_cast(xinout);
72: const auto retptr = (yin && (yin != xinout)) ? thrust::device_pointer_cast(yin) : xptr;
74: PetscFunctionBegin;
76: PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, retptr, std::forward<FunctorType>(functor)));
77: PetscFunctionReturn(PETSC_SUCCESS);
78: }
80: // actual implementation that calls thrust, 3 argument version
81: template <DeviceType DT, typename FunctorType, typename T>
82: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(detail::private_tag, typename Interface<DT>::cupmStream_t stream, FunctorType &&functor, PetscInt n, T *xin, T *yin, T *zin))
83: {
84: const auto xptr = thrust::device_pointer_cast(xin);
86: PetscFunctionBegin;
90: PetscCallThrust(THRUST_CALL(thrust::transform, stream, xptr, xptr + n, thrust::device_pointer_cast(yin), thrust::device_pointer_cast(zin), std::forward<FunctorType>(functor)));
91: PetscFunctionReturn(PETSC_SUCCESS);
92: }
94: // one last intermediate function to check n, and log flops for everything
95: template <DeviceType DT, typename F, typename... T>
96: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(typename Interface<DT>::cupmStream_t stream, F &&functor, PetscInt n, T &&...rest))
97: {
98: PetscFunctionBegin;
99: PetscAssert(n >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "n %" PetscInt_FMT " must be >= 0", n);
100: if (PetscLikely(n)) {
101: PetscCall(ThrustApplyPointwise<DT>(detail::private_tag{}, stream, std::forward<F>(functor), n, std::forward<T>(rest)...));
102: PetscCall(PetscLogGpuFlops(n));
103: }
104: PetscFunctionReturn(PETSC_SUCCESS);
105: }
107: // serves as setup to the real implementation above
108: template <DeviceType T, typename F, typename... Args>
109: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustApplyPointwise(PetscDeviceContext dctx, F &&functor, PetscInt n, Args &&...rest))
110: {
111: typename Interface<T>::cupmStream_t stream;
113: PetscFunctionBegin;
114: static_assert(sizeof...(Args) <= 3, "");
116: PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream));
117: PetscCall(ThrustApplyPointwise<T>(stream, std::forward<F>(functor), n, std::forward<Args>(rest)...));
118: PetscFunctionReturn(PETSC_SUCCESS);
119: }
121: #define PetscCallCUPM_(...) \
122: do { \
123: using interface = Interface<DT>; \
124: using cupmError_t = typename interface::cupmError_t; \
125: const auto cupmName = []() { return interface::cupmName(); }; \
126: const auto cupmGetErrorName = [](cupmError_t e) { return interface::cupmGetErrorName(e); }; \
127: const auto cupmGetErrorString = [](cupmError_t e) { return interface::cupmGetErrorString(e); }; \
128: const auto cupmSuccess = interface::cupmSuccess; \
129: PetscCallCUPM(__VA_ARGS__); \
130: } while (0)
132: template <DeviceType DT, typename T>
133: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(typename Interface<DT>::cupmStream_t stream, PetscInt n, T *ptr, const T *val))
134: {
135: PetscFunctionBegin;
137: if (n) {
138: const auto size = n * sizeof(T);
141: if (*val == T{0}) {
142: PetscCallCUPM_(Interface<DT>::cupmMemsetAsync(ptr, 0, size, stream));
143: } else {
144: const auto xptr = thrust::device_pointer_cast(ptr);
146: PetscCallThrust(THRUST_CALL(thrust::fill, stream, xptr, xptr + n, *val));
147: if (std::is_same<util::remove_cv_t<T>, PetscScalar>::value) {
148: PetscCall(PetscLogCpuToGpuScalar(size));
149: } else {
150: PetscCall(PetscLogCpuToGpu(size));
151: }
152: }
153: }
154: PetscFunctionReturn(PETSC_SUCCESS);
155: }
157: #undef PetscCallCUPM_
160: template <DeviceType DT, typename T>
161: PETSC_CXX_COMPAT_DEFN(PetscErrorCode ThrustSet(PetscDeviceContext dctx, PetscInt n, T *ptr, const T *val))
162: {
163: typename Interface<DT>::cupmStream_t stream;
165: PetscFunctionBegin;
167: PetscCall(PetscDeviceContextGetStreamHandle_Internal(dctx, &stream));
168: PetscCall(ThrustSet(stream, n, ptr, val));
169: PetscFunctionReturn(PETSC_SUCCESS);
170: }
172: } // namespace impl
174: } // namespace cupm
176: } // namespace device
178: } // namespace Petsc
180: #endif // __cplusplus
182: #endif // PETSC_CUPM_THRUST_UTILITY_HPP