cudaPackages.nccl: switch to cudaAtLeast, cudaOlder, and __structuredAttrs

This commit is contained in:
Connor Baker 2024-04-03 22:27:03 +00:00
parent 5ed9f23d21
commit 0494330fad
1 changed files with 16 additions and 14 deletions

View File

@ -17,9 +17,10 @@ let
cuda_cccl
cuda_cudart
cuda_nvcc
cudaAtLeast
cudaFlags
cudaOlder
cudatoolkit
cudaVersion
;
in
backendStdenv.mkDerivation (finalAttrs: {
@ -33,6 +34,7 @@ backendStdenv.mkDerivation (finalAttrs: {
hash = "sha256-ModIjD6RaRD/57a/PA1oTgYhZsAQPrrvhl5sNVXnO6c=";
};
__structuredAttrs = true;
strictDeps = true;
outputs = [
@ -46,12 +48,12 @@ backendStdenv.mkDerivation (finalAttrs: {
autoAddDriverRunpath
python3
]
++ lib.optionals (lib.versionOlder cudaVersion "11.4") [ cudatoolkit ]
++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [ cuda_nvcc ];
++ lib.optionals (cudaOlder "11.4") [ cudatoolkit ]
++ lib.optionals (cudaAtLeast "11.4") [ cuda_nvcc ];
buildInputs =
lib.optionals (lib.versionOlder cudaVersion "11.4") [ cudatoolkit ]
++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [
lib.optionals (cudaOlder "11.4") [ cudatoolkit ]
++ lib.optionals (cudaAtLeast "11.4") [
cuda_nvcc.dev # crt/host_config.h
cuda_cudart
]
@ -59,25 +61,25 @@ backendStdenv.mkDerivation (finalAttrs: {
# against other version, like below, it's important that we use the same format. Otherwise,
# we'll get incorrect results.
# For example, lib.versionAtLeast "12.0" "12.0.0" == false.
++ lib.optionals (lib.versionAtLeast cudaVersion "12.0") [ cuda_cccl ];
++ lib.optionals (cudaAtLeast "12.0") [ cuda_cccl ];
env.NIX_CFLAGS_COMPILE = toString [ "-Wno-unused-function" ];
preConfigure = ''
postPatch = ''
patchShebangs ./src/device/generate.py
makeFlagsArray+=(
"NVCC_GENCODE=${lib.concatStringsSep " " cudaFlags.gencode}"
)
'';
makeFlags =
[ "PREFIX=$(out)" ]
++ lib.optionals (lib.versionOlder cudaVersion "11.4") [
makeFlagsArray =
[
"PREFIX=$(out)"
"NVCC_GENCODE=${cudaFlags.gencodeString}"
]
++ lib.optionals (cudaOlder "11.4") [
"CUDA_HOME=${cudatoolkit}"
"CUDA_LIB=${lib.getLib cudatoolkit}/lib"
"CUDA_INC=${lib.getDev cudatoolkit}/include"
]
++ lib.optionals (lib.versionAtLeast cudaVersion "11.4") [
++ lib.optionals (cudaAtLeast "11.4") [
"CUDA_HOME=${cuda_nvcc}"
"CUDA_LIB=${lib.getLib cuda_cudart}/lib"
"CUDA_INC=${lib.getDev cuda_cudart}/include"