45 lines
1.2 KiB
Diff
45 lines
1.2 KiB
Diff
From d448321436e8314d3e2a6a09d4017c4bc10f612d Mon Sep 17 00:00:00 2001
|
|
From: Gaetan Lepage <gaetan@glepage.com>
|
|
Date: Sat, 17 Feb 2024 17:37:22 +0100
|
|
Subject: [PATCH] fix-dlopen-cuda
|
|
|
|
---
|
|
gpuctypes/cuda.py | 20 ++++++++++++++++++--
|
|
1 file changed, 18 insertions(+), 2 deletions(-)
|
|
|
|
diff --git a/gpuctypes/cuda.py b/gpuctypes/cuda.py
|
|
index acba81c..091f7f7 100644
|
|
--- a/gpuctypes/cuda.py
|
|
+++ b/gpuctypes/cuda.py
|
|
@@ -143,9 +143,25 @@ def char_pointer_cast(string, encoding='utf-8'):
|
|
|
|
|
|
|
|
+NAME_TO_PATHS = {
|
|
+ "libcuda.so": ["@driverLink@/lib/libcuda.so"],
|
|
+ "libnvrtc.so": ["@libnvrtc@"],
|
|
+}
|
|
+def _try_dlopen(name):
|
|
+ try:
|
|
+ return ctypes.CDLL(name)
|
|
+ except OSError:
|
|
+ pass
|
|
+ for candidate in NAME_TO_PATHS.get(name, []):
|
|
+ try:
|
|
+ return ctypes.CDLL(candidate)
|
|
+ except OSError:
|
|
+ pass
|
|
+ raise RuntimeError(f"{name} not found")
|
|
+
|
|
_libraries = {}
|
|
-_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda'))
|
|
-_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc'))
|
|
+_libraries['libcuda.so'] = _try_dlopen('libcuda.so')
|
|
+_libraries['libnvrtc.so'] = _try_dlopen('libnvrtc.so')
|
|
|
|
|
|
cuuint32_t = ctypes.c_uint32
|
|
--
|
|
2.43.0
|
|
|