Merge pull request #304069 from autrimpo/python-onnxruntime-cuda
onnxruntime: CUDA fixes
This commit is contained in:
commit
c708923629
|
@ -133,7 +133,6 @@ effectiveStdenv.mkDerivation rec {
|
|||
nlohmann_json
|
||||
microsoft-gsl
|
||||
] ++ lib.optionals pythonSupport (with python3Packages; [
|
||||
gtest
|
||||
numpy
|
||||
pybind11
|
||||
packaging
|
||||
|
@ -150,7 +149,9 @@ effectiveStdenv.mkDerivation rec {
|
|||
cuda_cudart
|
||||
]);
|
||||
|
||||
nativeCheckInputs = lib.optionals pythonSupport (with python3Packages; [
|
||||
nativeCheckInputs = [
|
||||
gtest
|
||||
] ++ lib.optionals pythonSupport (with python3Packages; [
|
||||
pytest
|
||||
sympy
|
||||
onnx
|
||||
|
@ -179,7 +180,7 @@ effectiveStdenv.mkDerivation rec {
|
|||
"-DFETCHCONTENT_SOURCE_DIR_SAFEINT=${safeint}"
|
||||
"-DFETCHCONTENT_TRY_FIND_PACKAGE_MODE=ALWAYS"
|
||||
"-Donnxruntime_BUILD_SHARED_LIB=ON"
|
||||
"-Donnxruntime_BUILD_UNIT_TESTS=ON"
|
||||
(lib.cmakeBool "onnxruntime_BUILD_UNIT_TESTS" doCheck)
|
||||
"-Donnxruntime_ENABLE_LTO=ON"
|
||||
"-Donnxruntime_USE_FULL_PROTOBUF=OFF"
|
||||
(lib.cmakeBool "onnxruntime_USE_CUDA" cudaSupport)
|
||||
|
@ -190,6 +191,7 @@ effectiveStdenv.mkDerivation rec {
|
|||
(lib.cmakeFeature "FETCHCONTENT_SOURCE_DIR_CUTLASS" "${cutlass}")
|
||||
(lib.cmakeFeature "onnxruntime_CUDNN_HOME" "${cudaPackages.cudnn}")
|
||||
(lib.cmakeFeature "CMAKE_CUDA_ARCHITECTURES" cudaArchitecturesString)
|
||||
(lib.cmakeFeature "onnxruntime_NVCC_THREADS" "1")
|
||||
];
|
||||
|
||||
env = lib.optionalAttrs effectiveStdenv.cc.isClang {
|
||||
|
@ -224,6 +226,7 @@ effectiveStdenv.mkDerivation rec {
|
|||
'';
|
||||
|
||||
passthru = {
|
||||
inherit cudaSupport cudaPackages; # for the python module
|
||||
protobuf = protobuf_21;
|
||||
tests = lib.optionalAttrs pythonSupport {
|
||||
python = python3Packages.onnxruntime;
|
||||
|
|
|
@ -53,7 +53,13 @@ buildPythonPackage {
|
|||
oneDNN
|
||||
re2
|
||||
onnxruntime.protobuf
|
||||
];
|
||||
] ++ lib.optionals onnxruntime.passthru.cudaSupport (with onnxruntime.passthru.cudaPackages; [
|
||||
libcublas # libcublasLt.so.XX libcublas.so.XX
|
||||
libcurand # libcurand.so.XX
|
||||
libcufft # libcufft.so.XX
|
||||
cudnn # libcudnn.soXX
|
||||
cuda_cudart # libcudart.so.XX
|
||||
]);
|
||||
|
||||
propagatedBuildInputs = [
|
||||
coloredlogs
|
||||
|
|
Loading…
Reference in New Issue
Block a user