Actual source code: cupmblasinterface.hpp

  1: #ifndef PETSCCUPMBLASINTERFACE_HPP
  2: #define PETSCCUPMBLASINTERFACE_HPP

  4: #if defined(__cplusplus)
  5: #include <petsc/private/cupminterface.hpp>
  6: #include <petsc/private/petscadvancedmacros.h>

  8: namespace Petsc
  9: {

 11: namespace device
 12: {

 14: namespace cupm
 15: {

 17: namespace impl
 18: {

 20:   #define PetscCallCUPMBLAS(...) \
 21:     do { \
 22:       const cupmBlasError_t cberr_p_ = __VA_ARGS__; \
 23:       if (PetscUnlikely(cberr_p_ != CUPMBLAS_STATUS_SUCCESS)) { \
 24:         if (((cberr_p_ == CUPMBLAS_STATUS_NOT_INITIALIZED) || (cberr_p_ == CUPMBLAS_STATUS_ALLOC_FAILED)) && PetscDeviceInitialized(PETSC_DEVICE_CUPM())) { \
 25:           SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, \
 26:                   "%s error %d (%s). Reports not initialized or alloc failed; " \
 27:                   "this indicates the GPU may have run out resources", \
 28:                   cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
 29:         } \
 30:         SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "%s error %d (%s)", cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
 31:       } \
 32:     } while (0)

 34:   // given cupmBlas<T>axpy() then
 35:   // T = PETSC_CUPBLAS_FP_TYPE
 36:   // given cupmBlas<T><u>nrm2() then
 37:   // T = PETSC_CUPMBLAS_FP_INPUT_TYPE
 38:   // u = PETSC_CUPMBLAS_FP_RETURN_TYPE
 39:   #if PetscDefined(USE_COMPLEX)
 40:     #if PetscDefined(USE_REAL_SINGLE)
 41:       #define PETSC_CUPMBLAS_FP_TYPE_U       C
 42:       #define PETSC_CUPMBLAS_FP_TYPE_L       c
 43:       #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U S
 44:       #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L s
 45:     #elif PetscDefined(USE_REAL_DOUBLE)
 46:       #define PETSC_CUPMBLAS_FP_TYPE_U       Z
 47:       #define PETSC_CUPMBLAS_FP_TYPE_L       z
 48:       #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U D
 49:       #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L d
 50:     #endif
 51:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
 52:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
 53:   #else
 54:     #if PetscDefined(USE_REAL_SINGLE)
 55:       #define PETSC_CUPMBLAS_FP_TYPE_U S
 56:       #define PETSC_CUPMBLAS_FP_TYPE_L s
 57:     #elif PetscDefined(USE_REAL_DOUBLE)
 58:       #define PETSC_CUPMBLAS_FP_TYPE_U D
 59:       #define PETSC_CUPMBLAS_FP_TYPE_L d
 60:     #endif
 61:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
 62:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
 63:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U
 64:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L
 65:   #endif // USE_COMPLEX

 67:   #if !defined(PETSC_CUPMBLAS_FP_TYPE_U) && !PetscDefined(USE_REAL___FLOAT128)
 68:     #error "Unsupported floating-point type for CUDA/HIP BLAS"
 69:   #endif

 71:   // PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE_EXACT() - declaration to alias a CUDA/HIP BLAS integral
 72:   // constant value
 73:   //
 74:   // input params:
 75:   // OUR_PREFIX   - prefix of the alias
 76:   // OUR_SUFFIX   - suffix of the alias
 77:   // THEIR_PREFIX - prefix of the variable being aliased
 78:   // THEIR_SUFFIX - suffix of the variable being aliased
 79:   //
 80:   // example usage:
 81:   // PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE_EXACT(CUPMBLAS,_STATUS_SUCCESS,CUBLAS,_STATUS_SUCCESS) ->
 82:   // static const auto CUPMBLAS_STATUS_SUCCESS = CUBLAS_STATUS_SUCCESS
 83:   #define PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE_EXACT(OUR_PREFIX, OUR_SUFFIX, THEIR_PREFIX, THEIR_SUFFIX) PETSC_CUPM_ALIAS_INTEGRAL_VALUE_EXACT(OUR_PREFIX, OUR_SUFFIX, THEIR_PREFIX, THEIR_SUFFIX)

 85:   // PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE_COMMON() - declaration to alias a CUDA/HIP BLAS integral
 86:   // constant value
 87:   //
 88:   // input param:
 89:   // COMMON - common suffix of the CUDA/HIP blas variable being aliased
 90:   //
 91:   // notes:
 92:   // requires PETSC_CUPMBLAS_PREFIX_U to be defined as the specific UPPERCASE prefix of the
 93:   // variable being aliased
 94:   //
 95:   // example usage:
 96:   // #define PETSC_CUPMBLAS_PREFIX_U CUBLAS
 97:   // PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE_COMMON(_STATUS_SUCCESS) ->
 98:   // static const auto CUPMBLAS_STATUS_SUCCESS = CUBLAS_STATUS_SUCCESS
 99:   //
100:   // #define PETSC_CUPMBLAS_PREFIX_U HIPBLAS
101:   // PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE_COMMON(_STATUS_SUCCESS) ->
102:   // static const auto CUPMBLAS_STATUS_SUCCESS = HIPBLAS_STATUS_SUCCESS
103:   #define PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(COMMON) PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE_EXACT(CUPMBLAS, COMMON, PETSC_CUPMBLAS_PREFIX_U, COMMON)

105:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED() - Helper macro to build a "modified"
106:   // blas function whose return type does not match the input type
107:   //
108:   // input param:
109:   // func - base suffix of the blas function, e.g. nrm2
110:   //
111:   // notes:
112:   // requires PETSC_CUPMBLAS_FP_INPUT_TYPE to be defined as the blas floating point input type
113:   // letter ("S" for real/complex single, "D" for real/complex double).
114:   //
115:   // requires PETSC_CUPMBLAS_FP_RETURN_TYPE to be defined as the blas floating point output type
116:   // letter ("c" for complex single, "z" for complex double and <absolutely nothing> for real
117:   // single/double).
118:   //
119:   // In their infinite wisdom nvidia/amd have made the upper-case vs lower-case scheme
120:   // infuriatingly inconsistent...
121:   //
122:   // example usage:
123:   // #define PETSC_CUPMBLAS_FP_INPUT_TYPE  S
124:   // #define PETSC_CUPMBLAS_FP_RETURN_TYPE
125:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Snrm2
126:   //
127:   // #define PETSC_CUPMBLAS_FP_INPUT_TYPE  D
128:   // #define PETSC_CUPMBLAS_FP_RETURN_TYPE z
129:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Dznrm2
130:   #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(func) PetscConcat(PetscConcat(PETSC_CUPMBLAS_FP_INPUT_TYPE, PETSC_CUPMBLAS_FP_RETURN_TYPE), func)

132:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE() - Helper macro to build Iamax and Iamin
133:   // because they are both extra special
134:   //
135:   // input param:
136:   // func - base suffix of the blas function, either amax or amin
137:   //
138:   // notes:
139:   // The macro name literally stands for "I" ## "floating point type" because shockingly enough,
140:   // that's what it does.
141:   //
142:   // requires PETSC_CUPMBLAS_FP_TYPE_L to be defined as the lower-case blas floating point input type
143:   // letter ("s" for complex single, "z" for complex double, "s" for real single, and "d" for
144:   // real double).
145:   //
146:   // example usage:
147:   // #define PETSC_CUPMBLAS_FP_TYPE_L s
148:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amax) -> Isamax
149:   //
150:   // #define PETSC_CUPMBLAS_FP_TYPE_L z
151:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amin) -> Izamin
152:   #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(func) PetscConcat(I, PetscConcat(PETSC_CUPMBLAS_FP_TYPE_L, func))

154:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD() - Helper macro to build a "standard"
155:   // blas function name
156:   //
157:   // input param:
158:   // func - base suffix of the blas function, e.g. axpy, scal
159:   //
160:   // notes:
161:   // requires PETSC_CUPMBLAS_FP_TYPE to be defined as the blas floating-point letter ("C" for
162:   // complex single, "Z" for complex double, "S" for real single, "D" for real double).
163:   //
164:   // example usage:
165:   // #define PETSC_CUPMBLAS_FP_TYPE S
166:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Saxpy
167:   //
168:   // #define PETSC_CUPMBLAS_FP_TYPE Z
169:   // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Zaxpy
170:   #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(func) PetscConcat(PETSC_CUPMBLAS_FP_TYPE, func)

172:   // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT() - In case CUDA/HIP don't agree with our suffix
173:   // one can provide both here
174:   //
175:   // input params:
176:   // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
177:   // IFPTYPE
178:   // our_suffix   - the suffix of the alias function
179:   // their_suffix - the suffix of the funciton being aliased
180:   //
181:   // notes:
182:   // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas function
183:   // prefix. requires any other specific definitions required by the specific builder macro to
184:   // also be defined. See PETSC_CUPM_ALIAS_FUNCTION_EXACT() for the exact expansion of the
185:   // function alias.
186:   //
187:   // example usage:
188:   // #define PETSC_CUPMBLAS_PREFIX  cublas
189:   // #define PETSC_CUPMBLAS_FP_TYPE C
190:   // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD,dot,dotc) ->
191:   // template <typename... T>
192:   // static constexpr auto cupmBlasXdot(T&&... args) *noexcept and returntype detection*
193:   // {
194:   //   return cublasCdotc(std::forward<T>(args)...);
195:   // }
196:   #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, our_suffix, their_suffix) \
197:     PETSC_CUPM_ALIAS_FUNCTION_EXACT(cupmBlasX, our_suffix, PETSC_CUPMBLAS_PREFIX, PetscConcat(PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_, MACRO_SUFFIX)(their_suffix))

199:   // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION() - Alias a CUDA/HIP blas function
200:   //
201:   // input params:
202:   // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
203:   // IFPTYPE
204:   // suffix       - the common suffix between CUDA and HIP of the alias function
205:   //
206:   // notes:
207:   // see PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(), this macro just calls that one with "suffix" as
208:   // "our_prefix" and "their_prefix"
209:   #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MACRO_SUFFIX, suffix) PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, suffix, suffix)

211:   // PETSC_CUPMBLAS_ALIAS_FUNCTION() - Alias a CUDA/HIP library function
212:   //
213:   // input params:
214:   // suffix - the common suffix between CUDA and HIP of the alias function
215:   //
216:   // notes:
217:   // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas library
218:   // prefix. see PETSC_CUPMM_ALIAS_FUNCTION_EXACT() for the precise expansion of this macro.
219:   //
220:   // example usage:
221:   // #define PETSC_CUPMBLAS_PREFIX hipblas
222:   // PETSC_CUPMBLAS_ALIAS_FUNCTION(Create) ->
223:   // template <typename... T>
224:   // static constexpr auto cupmBlasCreate(T&&... args) *noexcept and returntype detection*
225:   // {
226:   //   return hipblasCreate(std::forward<T>(args)...);
227:   // }
228:   #define PETSC_CUPMBLAS_ALIAS_FUNCTION(suffix) PETSC_CUPM_ALIAS_FUNCTION_EXACT(cupmBlas, suffix, PETSC_CUPMBLAS_PREFIX, suffix)

230: template <DeviceType T>
231: struct BlasInterfaceBase : Interface<T> {
232:   PETSC_NODISCARD static constexpr const char *cupmBlasName() noexcept { return T == DeviceType::CUDA ? "cuBLAS" : "hipBLAS"; }
233: };

235:   #define PETSC_CUPMBLAS_BASE_CLASS_HEADER(DEV_TYPE) \
236:     using base_type = ::Petsc::device::cupm::impl::BlasInterfaceBase<DEV_TYPE>; \
237:     using base_type::cupmBlasName; \
238:     PETSC_CUPM_ALIAS_FUNCTION_EXACT(cupmBlas, GetErrorName, PetscConcat(Petsc, PETSC_CUPMBLAS_PREFIX_U), GetErrorName) \
239:     PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(interface_type, DEV_TYPE)

241: template <DeviceType>
242: struct BlasInterfaceImpl;

244:   #if PetscDefined(HAVE_CUDA)
245:     #define PETSC_CUPMBLAS_PREFIX         cublas
246:     #define PETSC_CUPMBLAS_PREFIX_U       CUBLAS
247:     #define PETSC_CUPMBLAS_FP_TYPE        PETSC_CUPMBLAS_FP_TYPE_U
248:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE  PETSC_CUPMBLAS_FP_INPUT_TYPE_U
249:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
250: template <>
251: struct BlasInterfaceImpl<DeviceType::CUDA> : BlasInterfaceBase<DeviceType::CUDA> {
252:   PETSC_CUPMBLAS_BASE_CLASS_HEADER(DeviceType::CUDA);

254:   // typedefs
255:   using cupmBlasHandle_t      = cublasHandle_t;
256:   using cupmBlasError_t       = cublasStatus_t;
257:   using cupmBlasInt_t         = int;
258:   using cupmSolverHandle_t    = cusolverDnHandle_t;
259:   using cupmSolverError_t     = cusolverStatus_t;
260:   using cupmBlasPointerMode_t = cublasPointerMode_t;

262:   // values
263:   PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_STATUS_SUCCESS);
264:   PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_STATUS_NOT_INITIALIZED);
265:   PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_STATUS_ALLOC_FAILED);
266:   PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_POINTER_MODE_HOST);
267:   PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_POINTER_MODE_DEVICE);

269:   // utility functions
270:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
271:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
272:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
273:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
274:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
275:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)

277:   // level 1 BLAS
278:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
279:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
280:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
281:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
282:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
283:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
284:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
285:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)

287:   // level 2 BLAS
288:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)

290:   // level 3 BLAS
291:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)

293:   // BLAS extensions
294:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)

296:   PETSC_NODISCARD static PetscErrorCode InitializeHandle(cupmSolverHandle_t &handle) noexcept
297:   {
298:     if (handle) return 0;
299:     for (auto i = 0; i < 3; ++i) {
300:       const auto cerr = cusolverDnCreate(&handle);
301:       if (PetscLikely(cerr == CUSOLVER_STATUS_SUCCESS)) break;
302:       if ((cerr != CUSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUSOLVER_STATUS_ALLOC_FAILED)) cerr;
303:       if (i < 2) {
304:         PetscSleep(3);
305:         continue;
306:       }
308:     }
309:     return 0;
310:   }

312:   PETSC_NODISCARD static PetscErrorCode SetHandleStream(const cupmSolverHandle_t &handle, const cupmStream_t &stream) noexcept
313:   {
314:     cupmStream_t cupmStream;

316:     cusolverDnGetStream(handle, &cupmStream);
317:     if (cupmStream != stream) cusolverDnSetStream(handle, stream);
318:     return 0;
319:   }

321:   PETSC_NODISCARD static PetscErrorCode DestroyHandle(cupmSolverHandle_t &handle) noexcept
322:   {
323:     if (handle) {
324:       cusolverDnDestroy(handle);
325:       handle = nullptr;
326:     }
327:     return 0;
328:   }
329: };
330:     #undef PETSC_CUPMBLAS_PREFIX
331:     #undef PETSC_CUPMBLAS_PREFIX_U
332:     #undef PETSC_CUPMBLAS_FP_TYPE
333:     #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
334:     #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
335:   #endif // PetscDefined(HAVE_CUDA)

337:   #if PetscDefined(HAVE_HIP)
338:     #define PETSC_CUPMBLAS_PREFIX         hipblas
339:     #define PETSC_CUPMBLAS_PREFIX_U       HIPBLAS
340:     #define PETSC_CUPMBLAS_FP_TYPE        PETSC_CUPMBLAS_FP_TYPE_U
341:     #define PETSC_CUPMBLAS_FP_INPUT_TYPE  PETSC_CUPMBLAS_FP_INPUT_TYPE_U
342:     #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
343: template <>
344: struct BlasInterfaceImpl<DeviceType::HIP> : BlasInterfaceBase<DeviceType::HIP> {
345:   PETSC_CUPMBLAS_BASE_CLASS_HEADER(DeviceType::HIP);

347:   // typedefs
348:   using cupmBlasHandle_t      = hipblasHandle_t;
349:   using cupmBlasError_t       = hipblasStatus_t;
350:   using cupmBlasInt_t         = int; // rocblas will have its own
351:   using cupmSolverHandle_t    = hipsolverHandle_t;
352:   using cupmSolverError_t     = hipsolverStatus_t;
353:   using cupmBlasPointerMode_t = hipblasPointerMode_t;

355:   // values
356:   PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_STATUS_SUCCESS);
357:   PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_STATUS_NOT_INITIALIZED);
358:   PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_STATUS_ALLOC_FAILED);
359:   PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_POINTER_MODE_HOST);
360:   PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_POINTER_MODE_DEVICE);

362:   // utility functions
363:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
364:   PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
365:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
366:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
367:   PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
368:   PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)

370:   // level 1 BLAS
371:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
372:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
373:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
374:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
375:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
376:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
377:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
378:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)

380:   // level 2 BLAS
381:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)

383:   // level 3 BLAS
384:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)

386:   // BLAS extensions
387:   PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)

389:   PETSC_NODISCARD static PetscErrorCode InitializeHandle(cupmSolverHandle_t &handle) noexcept
390:   {
391:     if (!handle) hipsolverCreate(&handle);
392:     return 0;
393:   }

395:   PETSC_NODISCARD static PetscErrorCode SetHandleStream(cupmSolverHandle_t handle, cupmStream_t stream) noexcept
396:   {
397:     hipsolverSetStream(handle, stream);
398:     return 0;
399:   }

401:   PETSC_NODISCARD static PetscErrorCode DestroyHandle(cupmSolverHandle_t &handle) noexcept
402:   {
403:     if (handle) {
404:       hipsolverDestroy(handle);
405:       handle = nullptr;
406:     }
407:     return 0;
408:   }
409: };
410:     #undef PETSC_CUPMBLAS_PREFIX
411:     #undef PETSC_CUPMBLAS_PREFIX_U
412:     #undef PETSC_CUPMBLAS_FP_TYPE
413:     #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
414:     #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
415:   #endif // PetscDefined(HAVE_HIP)

417:   #undef PETSC_CUPMBLAS_BASE_CLASS_HEADER

419:   #define PETSC_CUPMBLAS_IMPL_CLASS_HEADER(base_name, T) \
420:     PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(cupmInterface_t, T); \
421:     using base_name = ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>; \
422:     /* introspection */ \
423:     using base_name::cupmBlasName; \
424:     using base_name::cupmBlasGetErrorName; \
425:     /* types */ \
426:     using typename base_name::cupmBlasHandle_t; \
427:     using typename base_name::cupmBlasError_t; \
428:     using typename base_name::cupmBlasInt_t; \
429:     using typename base_name::cupmSolverHandle_t; \
430:     using typename base_name::cupmSolverError_t; \
431:     using typename base_name::cupmBlasPointerMode_t; \
432:     /* values */ \
433:     using base_name::CUPMBLAS_STATUS_SUCCESS; \
434:     using base_name::CUPMBLAS_STATUS_NOT_INITIALIZED; \
435:     using base_name::CUPMBLAS_STATUS_ALLOC_FAILED; \
436:     using base_name::CUPMBLAS_POINTER_MODE_HOST; \
437:     using base_name::CUPMBLAS_POINTER_MODE_DEVICE; \
438:     /* utility functions */ \
439:     using base_name::cupmBlasCreate; \
440:     using base_name::cupmBlasDestroy; \
441:     using base_name::cupmBlasGetStream; \
442:     using base_name::cupmBlasSetStream; \
443:     using base_name::cupmBlasGetPointerMode; \
444:     using base_name::cupmBlasSetPointerMode; \
445:     /* level 1 BLAS */ \
446:     using base_name::cupmBlasXaxpy; \
447:     using base_name::cupmBlasXscal; \
448:     using base_name::cupmBlasXdot; \
449:     using base_name::cupmBlasXdotu; \
450:     using base_name::cupmBlasXswap; \
451:     using base_name::cupmBlasXnrm2; \
452:     using base_name::cupmBlasXamax; \
453:     using base_name::cupmBlasXasum; \
454:     /* level 2 BLAS */ \
455:     using base_name::cupmBlasXgemv; \
456:     /* level 3 BLAS */ \
457:     using base_name::cupmBlasXgemm; \
458:     /* BLAS extensions */ \
459:     using base_name::cupmBlasXgeam

461: // The actual interface class
462: template <DeviceType T>
463: struct BlasInterface : BlasInterfaceImpl<T> {
464:   PETSC_CUPMBLAS_IMPL_CLASS_HEADER(blasinterface_type, T);

466:   PETSC_NODISCARD static PetscErrorCode PetscCUPMBlasSetPointerModeFromPointer(cupmBlasHandle_t handle, const void *ptr) noexcept
467:   {
468:     auto mtype = PETSC_MEMTYPE_HOST;

470:     PetscCUPMGetMemType(ptr, &mtype);
471:     cupmBlasSetPointerMode(handle, PetscMemTypeDevice(mtype) ? CUPMBLAS_POINTER_MODE_DEVICE : CUPMBLAS_POINTER_MODE_HOST);
472:     return 0;
473:   }
474: };

476:   #define PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(base_name, T) \
477:     PETSC_CUPMBLAS_IMPL_CLASS_HEADER(PetscConcat(base_name, _impl), T); \
478:     using base_name = ::Petsc::device::cupm::impl::BlasInterface<T>; \
479:     using base_name::PetscCUPMBlasSetPointerModeFromPointer

481:   #if PetscDefined(HAVE_CUDA)
482: extern template struct BlasInterface<DeviceType::CUDA>;
483:   #endif

485:   #if PetscDefined(HAVE_HIP)
486: extern template struct BlasInterface<DeviceType::HIP>;
487:   #endif

489: } // namespace impl

491: } // namespace cupm

493: } // namespace device

495: } // namespace Petsc

497: #endif // defined(__cplusplus)

499: #endif // PETSCCUPMBLASINTERFACE_HPP