Merge pull request #304069 from autrimpo/python-onnxruntime-cuda

onnxruntime: CUDA fixes
This commit is contained in:
Someone 2024-04-15 08:04:15 +00:00 committed by GitHub
commit c708923629
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 4 deletions

View File

@ -133,7 +133,6 @@ effectiveStdenv.mkDerivation rec {
nlohmann_json
microsoft-gsl
] ++ lib.optionals pythonSupport (with python3Packages; [
gtest
numpy
pybind11
packaging
@ -150,7 +149,9 @@ effectiveStdenv.mkDerivation rec {
cuda_cudart
]);
nativeCheckInputs = lib.optionals pythonSupport (with python3Packages; [
nativeCheckInputs = [
gtest
] ++ lib.optionals pythonSupport (with python3Packages; [
pytest
sympy
onnx
@ -179,7 +180,7 @@ effectiveStdenv.mkDerivation rec {
"-DFETCHCONTENT_SOURCE_DIR_SAFEINT=${safeint}"
"-DFETCHCONTENT_TRY_FIND_PACKAGE_MODE=ALWAYS"
"-Donnxruntime_BUILD_SHARED_LIB=ON"
"-Donnxruntime_BUILD_UNIT_TESTS=ON"
(lib.cmakeBool "onnxruntime_BUILD_UNIT_TESTS" doCheck)
"-Donnxruntime_ENABLE_LTO=ON"
"-Donnxruntime_USE_FULL_PROTOBUF=OFF"
(lib.cmakeBool "onnxruntime_USE_CUDA" cudaSupport)
@ -190,6 +191,7 @@ effectiveStdenv.mkDerivation rec {
(lib.cmakeFeature "FETCHCONTENT_SOURCE_DIR_CUTLASS" "${cutlass}")
(lib.cmakeFeature "onnxruntime_CUDNN_HOME" "${cudaPackages.cudnn}")
(lib.cmakeFeature "CMAKE_CUDA_ARCHITECTURES" cudaArchitecturesString)
(lib.cmakeFeature "onnxruntime_NVCC_THREADS" "1")
];
env = lib.optionalAttrs effectiveStdenv.cc.isClang {
@ -224,6 +226,7 @@ effectiveStdenv.mkDerivation rec {
'';
passthru = {
inherit cudaSupport cudaPackages; # for the python module
protobuf = protobuf_21;
tests = lib.optionalAttrs pythonSupport {
python = python3Packages.onnxruntime;

View File

@ -53,7 +53,13 @@ buildPythonPackage {
oneDNN
re2
onnxruntime.protobuf
];
] ++ lib.optionals onnxruntime.passthru.cudaSupport (with onnxruntime.passthru.cudaPackages; [
libcublas # libcublasLt.so.XX libcublas.so.XX
libcurand # libcurand.so.XX
libcufft # libcufft.so.XX
cudnn # libcudnn.soXX
cuda_cudart # libcudart.so.XX
]);
propagatedBuildInputs = [
coloredlogs