cudaPackages.saxpy: switch to cudaAtLeast, cudaOlder, __structuredAttrs, and enable on Jetson post-11.4

This commit is contained in:
Connor Baker 2024-04-03 22:16:24 +00:00
parent e77b24b159
commit 5ed9f23d21

View File

@ -10,8 +10,9 @@ let
cuda_cccl cuda_cccl
cuda_cudart cuda_cudart
cuda_nvcc cuda_nvcc
cudaAtLeast
cudaOlder
cudatoolkit cudatoolkit
cudaVersion
flags flags
libcublas libcublas
setupCudaHook setupCudaHook
@ -24,6 +25,7 @@ backendStdenv.mkDerivation {
src = ./.; src = ./.;
__structuredAttrs = true;
strictDeps = true; strictDeps = true;
nativeBuildInputs = nativeBuildInputs =
@ -31,24 +33,22 @@ backendStdenv.mkDerivation {
cmake cmake
autoAddDriverRunpath autoAddDriverRunpath
] ]
++ lib.optionals (lib.versionOlder cudaVersion "11.4") [ cudatoolkit ] ++ lib.optionals (cudaOlder "11.4") [ cudatoolkit ]
++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [ cuda_nvcc ]; ++ lib.optionals (cudaAtLeast "11.4") [ cuda_nvcc ];
buildInputs = buildInputs =
lib.optionals (lib.versionOlder cudaVersion "11.4") [ cudatoolkit ] lib.optionals (cudaOlder "11.4") [ cudatoolkit ]
++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [ ++ lib.optionals (cudaAtLeast "11.4") [
(getDev libcublas) (getDev libcublas)
(getLib libcublas) (getLib libcublas)
(getOutput "static" libcublas) (getOutput "static" libcublas)
cuda_cudart cuda_cudart
] ]
++ lib.optionals (lib.versionAtLeast cudaVersion "12.0") [ cuda_cccl ]; ++ lib.optionals (cudaAtLeast "12.0") [ cuda_cccl ];
cmakeFlags = [ cmakeFlagsArray = [
(lib.cmakeBool "CMAKE_VERBOSE_MAKEFILE" true) (lib.cmakeBool "CMAKE_VERBOSE_MAKEFILE" true)
(lib.cmakeFeature "CMAKE_CUDA_ARCHITECTURES" ( (lib.cmakeFeature "CMAKE_CUDA_ARCHITECTURES" flags.cmakeCudaArchitecturesString)
with flags; lib.concatStringsSep ";" (lib.lists.map dropDot cudaCapabilities)
))
]; ];
meta = rec { meta = rec {
@ -56,6 +56,6 @@ backendStdenv.mkDerivation {
license = lib.licenses.mit; license = lib.licenses.mit;
maintainers = lib.teams.cuda.members; maintainers = lib.teams.cuda.members;
platforms = lib.platforms.unix; platforms = lib.platforms.unix;
badPlatforms = lib.optionals flags.isJetsonBuild platforms; badPlatforms = lib.optionals (flags.isJetsonBuild && cudaOlder "11.4") platforms;
}; };
} }