Merge pull request #288857 from samuela/samuela/jaxlib
python3Packages.jaxlib: fix #282184 and migrate to cuda redist packages
This commit is contained in:
commit
cdd38b2e49
@ -12,6 +12,7 @@
|
||||
, curl
|
||||
, cython
|
||||
, fetchFromGitHub
|
||||
, fetchpatch
|
||||
, git
|
||||
, IOKit
|
||||
, jsoncpp
|
||||
@ -47,14 +48,19 @@
|
||||
|
||||
# MKL:
|
||||
, mklSupport ? true
|
||||
}:
|
||||
}@inputs:
|
||||
|
||||
let
|
||||
inherit (cudaPackagesGoogle) backendStdenv cudatoolkit cudaFlags cudnn nccl;
|
||||
inherit (cudaPackagesGoogle) autoAddOpenGLRunpathHook cudaFlags cudaVersion cudnn nccl;
|
||||
|
||||
pname = "jaxlib";
|
||||
version = "0.4.24";
|
||||
|
||||
# It's necessary to consistently use backendStdenv when building with CUDA
|
||||
# support, otherwise we get libstdc++ errors downstream
|
||||
stdenv = throw "Use effectiveStdenv instead";
|
||||
effectiveStdenv = if cudaSupport then cudaPackagesGoogle.backendStdenv else inputs.stdenv;
|
||||
|
||||
meta = with lib; {
|
||||
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
|
||||
homepage = "https://github.com/google/jax";
|
||||
@ -65,25 +71,51 @@ let
|
||||
# however even with that fix applied, it doesn't work for everyone:
|
||||
# https://github.com/NixOS/nixpkgs/pull/184395#issuecomment-1207287129
|
||||
# NOTE: We always build with NCCL; if it is unsupported, then our build is broken.
|
||||
broken = stdenv.isDarwin || nccl.meta.unsupported;
|
||||
broken = effectiveStdenv.isDarwin || nccl.meta.unsupported;
|
||||
};
|
||||
|
||||
cudatoolkit_joined = symlinkJoin {
|
||||
name = "${cudatoolkit.name}-merged";
|
||||
paths = [
|
||||
cudatoolkit.lib
|
||||
cudatoolkit.out
|
||||
] ++ lib.optionals (lib.versionOlder cudatoolkit.version "11") [
|
||||
# for some reason some of the required libs are in the targets/x86_64-linux
|
||||
# directory; not sure why but this works around it
|
||||
"${cudatoolkit}/targets/${stdenv.system}"
|
||||
# These are necessary at build time and run time.
|
||||
cuda_libs_joined = symlinkJoin {
|
||||
name = "cuda-joined";
|
||||
paths = with cudaPackagesGoogle; [
|
||||
cuda_cudart.lib # libcudart.so
|
||||
cuda_cudart.static # libcudart_static.a
|
||||
cuda_cupti.lib # libcupti.so
|
||||
libcublas.lib # libcublas.so
|
||||
libcufft.lib # libcufft.so
|
||||
libcurand.lib # libcurand.so
|
||||
libcusolver.lib # libcusolver.so
|
||||
libcusparse.lib # libcusparse.so
|
||||
];
|
||||
};
|
||||
# These are only necessary at build time.
|
||||
cuda_build_deps_joined = symlinkJoin {
|
||||
name = "cuda-build-deps-joined";
|
||||
paths = with cudaPackagesGoogle; [
|
||||
cuda_libs_joined
|
||||
|
||||
# Binaries
|
||||
cudaPackagesGoogle.cuda_nvcc.bin # nvcc
|
||||
|
||||
# Headers
|
||||
cuda_cccl.dev # block_load.cuh
|
||||
cuda_cudart.dev # cuda.h
|
||||
cuda_cupti.dev # cupti.h
|
||||
cuda_nvcc.dev # See https://github.com/google/jax/issues/19811
|
||||
cuda_nvml_dev # nvml.h
|
||||
cuda_nvtx.dev # nvToolsExt.h
|
||||
libcublas.dev # cublas_api.h
|
||||
libcufft.dev # cufft.h
|
||||
libcurand.dev # curand.h
|
||||
libcusolver.dev # cusolver_common.h
|
||||
libcusparse.dev # cusparse.h
|
||||
];
|
||||
};
|
||||
|
||||
cudatoolkit_cc_joined = symlinkJoin {
|
||||
name = "${cudatoolkit.cc.name}-merged";
|
||||
backend_cc_joined = symlinkJoin {
|
||||
name = "cuda-cc-joined";
|
||||
paths = [
|
||||
backendStdenv.cc
|
||||
effectiveStdenv.cc
|
||||
binutils.bintools # for ar, dwp, nm, objcopy, objdump, strip
|
||||
];
|
||||
};
|
||||
@ -137,8 +169,44 @@ let
|
||||
|
||||
arch =
|
||||
# KeyError: ('Linux', 'arm64')
|
||||
if stdenv.hostPlatform.isLinux && stdenv.hostPlatform.linuxArch == "arm64" then "aarch64"
|
||||
else stdenv.hostPlatform.linuxArch;
|
||||
if effectiveStdenv.hostPlatform.isLinux && effectiveStdenv.hostPlatform.linuxArch == "arm64" then "aarch64"
|
||||
else effectiveStdenv.hostPlatform.linuxArch;
|
||||
|
||||
xla = effectiveStdenv.mkDerivation {
|
||||
pname = "xla-src";
|
||||
version = "unstable";
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "openxla";
|
||||
repo = "xla";
|
||||
# Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl.
|
||||
rev = "12eee889e1f2ad41e27d7b0e970cb92d282d3ec5";
|
||||
hash = "sha256-68kjjgwYjRlcT0TVJo9BN6s+WTkdu5UMJqQcfHpBT90=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
# Resolves "could not convert ‘result’ from ‘SmallVector<[...],6>’ to
|
||||
# ‘SmallVector<[...],4>’" compilation error. See https://github.com/google/jax/issues/19814#issuecomment-1945141259.
|
||||
(fetchpatch {
|
||||
url = "https://github.com/openxla/xla/commit/7a614cd346594fc7ea2fe75570c9c53a4a444f60.patch";
|
||||
hash = "sha256-RtuQTH8wzNiJcOtISLhf+gMlH1gg8hekvxEB+4wX6BM=";
|
||||
})
|
||||
];
|
||||
|
||||
dontBuild = true;
|
||||
|
||||
# This is necessary for patchShebangs to know the right path to use.
|
||||
nativeBuildInputs = [ python ];
|
||||
|
||||
# Main culprits we're targeting are third_party/tsl/third_party/gpus/crosstool/clang/bin/*.tpl
|
||||
postPatch = ''
|
||||
patchShebangs .
|
||||
'';
|
||||
|
||||
installPhase = ''
|
||||
cp -r . $out
|
||||
'';
|
||||
};
|
||||
|
||||
bazel-build = buildBazelPackage rec {
|
||||
name = "bazel-build-${pname}-${version}";
|
||||
@ -162,7 +230,7 @@ let
|
||||
wheel
|
||||
build
|
||||
which
|
||||
] ++ lib.optionals stdenv.isDarwin [
|
||||
] ++ lib.optionals effectiveStdenv.isDarwin [
|
||||
cctools
|
||||
];
|
||||
|
||||
@ -181,15 +249,13 @@ let
|
||||
six
|
||||
snappy
|
||||
zlib
|
||||
] ++ lib.optionals cudaSupport [
|
||||
cudatoolkit
|
||||
cudnn
|
||||
] ++ lib.optionals stdenv.isDarwin [
|
||||
] ++ lib.optionals effectiveStdenv.isDarwin [
|
||||
IOKit
|
||||
] ++ lib.optionals (!stdenv.isDarwin) [
|
||||
] ++ lib.optionals (!effectiveStdenv.isDarwin) [
|
||||
nsync
|
||||
];
|
||||
|
||||
# We don't want to be quite so picky regarding bazel version
|
||||
postPatch = ''
|
||||
rm -f .bazelversion
|
||||
'';
|
||||
@ -204,50 +270,80 @@ let
|
||||
|
||||
removeRulesCC = false;
|
||||
|
||||
GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin";
|
||||
GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${cudatoolkit_cc_joined}/bin/gcc";
|
||||
GCC_HOST_COMPILER_PREFIX = lib.optionalString cudaSupport "${backend_cc_joined}/bin";
|
||||
GCC_HOST_COMPILER_PATH = lib.optionalString cudaSupport "${backend_cc_joined}/bin/gcc";
|
||||
|
||||
# The version is automatically set to ".dev" if this variable is not set.
|
||||
# https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
|
||||
JAXLIB_RELEASE = "1";
|
||||
|
||||
preConfigure = ''
|
||||
# dummy ldconfig
|
||||
mkdir dummy-ldconfig
|
||||
echo "#!${stdenv.shell}" > dummy-ldconfig/ldconfig
|
||||
chmod +x dummy-ldconfig/ldconfig
|
||||
export PATH="$PWD/dummy-ldconfig:$PATH"
|
||||
cat <<CFG > ./.jax_configure.bazelrc
|
||||
build --strategy=Genrule=standalone
|
||||
build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
|
||||
build --action_env=PYENV_ROOT
|
||||
build --python_path="${python}/bin/python"
|
||||
build --distinct_host_configuration=false
|
||||
build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
|
||||
'' + lib.optionalString (stdenv.hostPlatform.avxSupport && stdenv.hostPlatform.isUnix) ''
|
||||
build --config=avx_posix
|
||||
'' + lib.optionalString mklSupport ''
|
||||
build --config=mkl_open_source_only
|
||||
'' + lib.optionalString cudaSupport ''
|
||||
build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
|
||||
build --action_env CUDNN_INSTALL_PATH="${cudnn}"
|
||||
build --action_env TF_CUDA_PATHS="${cudatoolkit_joined},${cudnn},${nccl}"
|
||||
build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudatoolkit.version}"
|
||||
build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}"
|
||||
build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}"
|
||||
'' + ''
|
||||
CFG
|
||||
'';
|
||||
preConfigure =
|
||||
# Dummy ldconfig to work around "Can't open cache file /nix/store/<hash>-glibc-2.38-44/etc/ld.so.cache" error
|
||||
''
|
||||
mkdir dummy-ldconfig
|
||||
echo "#!${effectiveStdenv.shell}" > dummy-ldconfig/ldconfig
|
||||
chmod +x dummy-ldconfig/ldconfig
|
||||
export PATH="$PWD/dummy-ldconfig:$PATH"
|
||||
'' +
|
||||
|
||||
# Construct .jax_configure.bazelrc. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L259-L345
|
||||
# for more info. We assume
|
||||
# * `cpu = None`
|
||||
# * `enable_nccl = True`
|
||||
# * `target_cpu_features = "release"`
|
||||
# * `rocm_amdgpu_targets = None`
|
||||
# * `enable_rocm = False`
|
||||
# * `build_gpu_plugin = False`
|
||||
# * `use_clang = False` (Should we use `effectiveStdenv.cc.isClang` instead?)
|
||||
#
|
||||
# Note: We should try just running https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L259-L266
|
||||
# instead of duplicating the logic here. Perhaps we can leverage the
|
||||
# `--configure_only` flag (https://github.com/google/jax/blob/ceb198582b62b9e6f6bdf20ab74839b0cf1db16e/build/build.py#L544-L548)?
|
||||
''
|
||||
cat <<CFG > ./.jax_configure.bazelrc
|
||||
build --strategy=Genrule=standalone
|
||||
build --repo_env PYTHON_BIN_PATH="${python}/bin/python"
|
||||
build --action_env=PYENV_ROOT
|
||||
build --python_path="${python}/bin/python"
|
||||
build --distinct_host_configuration=false
|
||||
build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
|
||||
'' + lib.optionalString cudaSupport ''
|
||||
build --config=cuda
|
||||
build --action_env CUDA_TOOLKIT_PATH="${cuda_build_deps_joined}"
|
||||
build --action_env CUDNN_INSTALL_PATH="${cudnn}"
|
||||
build --action_env TF_CUDA_PATHS="${cuda_build_deps_joined},${cudnn},${nccl}"
|
||||
build --action_env TF_CUDA_VERSION="${lib.versions.majorMinor cudaVersion}"
|
||||
build --action_env TF_CUDNN_VERSION="${lib.versions.major cudnn.version}"
|
||||
build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="${builtins.concatStringsSep "," cudaFlags.realArches}"
|
||||
'' +
|
||||
# Note that upstream conditions this on `wheel_cpu == "x86_64"`. We just
|
||||
# rely on `effectiveStdenv.hostPlatform.avxSupport` instead. So far so
|
||||
# good. See https://github.com/google/jax/blob/b9824d7de3cb30f1df738cc42e486db3e9d915ff/build/build.py#L322
|
||||
# for upstream's version.
|
||||
lib.optionalString (effectiveStdenv.hostPlatform.avxSupport && effectiveStdenv.hostPlatform.isUnix) ''
|
||||
build --config=avx_posix
|
||||
'' + lib.optionalString mklSupport ''
|
||||
build --config=mkl_open_source_only
|
||||
'' +
|
||||
''
|
||||
CFG
|
||||
'';
|
||||
|
||||
# Make sure Bazel knows about our configuration flags during fetching so that the
|
||||
# relevant dependencies can be downloaded.
|
||||
bazelFlags = [
|
||||
"-c opt"
|
||||
] ++ lib.optionals stdenv.cc.isClang [
|
||||
# See https://bazel.build/external/advanced#overriding-repositories for
|
||||
# information on --override_repository flag.
|
||||
"--override_repository=xla=${xla}"
|
||||
] ++ lib.optionals effectiveStdenv.cc.isClang [
|
||||
# bazel depends on the compiler frontend automatically selecting these flags based on file
|
||||
# extension but our clang doesn't.
|
||||
# https://github.com/NixOS/nixpkgs/issues/150655
|
||||
"--cxxopt=-x" "--cxxopt=c++" "--host_cxxopt=-x" "--host_cxxopt=c++"
|
||||
"--cxxopt=-x"
|
||||
"--cxxopt=c++"
|
||||
"--host_cxxopt=-x"
|
||||
"--host_cxxopt=c++"
|
||||
];
|
||||
|
||||
# We intentionally overfetch so we can share the fetch derivation across all the different configurations
|
||||
@ -257,40 +353,34 @@ let
|
||||
bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ];
|
||||
bazelFlags = bazelFlags ++ [
|
||||
"--config=avx_posix"
|
||||
"--config=mkl_open_source_only"
|
||||
] ++ lib.optionals cudaSupport [
|
||||
# ideally we'd add this unconditionally too, but it doesn't work on darwin
|
||||
# we make this conditional on `cudaSupport` instead of the system, so that the hash for both
|
||||
# the cuda and the non-cuda deps can be computed on linux, since a lot of contributors don't
|
||||
# have access to darwin machines
|
||||
"--config=cuda"
|
||||
] ++ [
|
||||
"--config=mkl_open_source_only"
|
||||
];
|
||||
|
||||
sha256 = (if cudaSupport then {
|
||||
x86_64-linux = "sha256-c0avcURLAYNiLASjIeu5phXX3ze5TR812SW5SCG/iwk=";
|
||||
x86_64-linux = "sha256-IEKoHjCOtKZKvU/DUUjbvXldORFJuyO1R3F6CZZDXxM=";
|
||||
} else {
|
||||
x86_64-linux = "sha256-1hrQ9ehFy3vBJxKNUzi/T0l+eZxo26Th7i5VRd/9U+0=";
|
||||
aarch64-linux = "sha256-3QVYJOj1lNHgYVV9rOzVdfhq5q6GDwpcWCjKNrSZ4aU=";
|
||||
}).${stdenv.system} or (throw "jaxlib: unsupported system: ${stdenv.system}");
|
||||
x86_64-linux = "sha256-IE4+Tk4llo85u3NjakvY04tPw4R1bidyecPpQ4gknR8=";
|
||||
aarch64-linux = "sha256-NehnpA4m+Fynvh0S6WKy/v9ab81487NE9ahvbS70wjY=";
|
||||
}).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}");
|
||||
};
|
||||
|
||||
buildAttrs = {
|
||||
outputs = [ "out" ];
|
||||
|
||||
TF_SYSTEM_LIBS = lib.concatStringsSep "," (tf_system_libs ++ lib.optionals (!stdenv.isDarwin) [
|
||||
TF_SYSTEM_LIBS = lib.concatStringsSep "," (tf_system_libs ++ lib.optionals (!effectiveStdenv.isDarwin) [
|
||||
"nsync" # fails to build on darwin
|
||||
]);
|
||||
|
||||
# Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
|
||||
# 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
|
||||
# loading multiple extensions in the same python program due to duplicate protobuf DBs.
|
||||
# 2) Patch python path in the compiler driver.
|
||||
preBuild = lib.optionalString cudaSupport ''
|
||||
patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
'' + lib.optionalString stdenv.isDarwin ''
|
||||
# Framework search paths aren't added by bintools hook
|
||||
# https://github.com/NixOS/nixpkgs/pull/41914
|
||||
# Note: we cannot do most of this patching at `patch` phase as the deps
|
||||
# are not available yet. Framework search paths aren't added by bintools
|
||||
# hook. See https://github.com/NixOS/nixpkgs/pull/41914.
|
||||
preBuild = lib.optionalString effectiveStdenv.isDarwin ''
|
||||
export NIX_LDFLAGS+=" -F${IOKit}/Library/Frameworks"
|
||||
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/osx_cc_wrapper.sh.tpl \
|
||||
--replace "/usr/bin/install_name_tool" "${cctools}/bin/install_name_tool"
|
||||
@ -302,13 +392,13 @@ let
|
||||
inherit meta;
|
||||
};
|
||||
platformTag =
|
||||
if stdenv.hostPlatform.isLinux then
|
||||
if effectiveStdenv.hostPlatform.isLinux then
|
||||
"manylinux2014_${arch}"
|
||||
else if stdenv.system == "x86_64-darwin" then
|
||||
else if effectiveStdenv.system == "x86_64-darwin" then
|
||||
"macosx_10_9_${arch}"
|
||||
else if stdenv.system == "aarch64-darwin" then
|
||||
else if effectiveStdenv.system == "aarch64-darwin" then
|
||||
"macosx_11_0_${arch}"
|
||||
else throw "Unsupported target platform: ${stdenv.hostPlatform}";
|
||||
else throw "Unsupported target platform: ${effectiveStdenv.hostPlatform}";
|
||||
|
||||
in
|
||||
buildPythonPackage {
|
||||
@ -319,20 +409,18 @@ buildPythonPackage {
|
||||
let cp = "cp${builtins.replaceStrings ["."] [""] python.pythonVersion}";
|
||||
in "${bazel-build}/jaxlib-${version}-${cp}-${cp}-${platformTag}.whl";
|
||||
|
||||
# Note that cudatoolkit is necessary since jaxlib looks for "ptxas" in $PATH.
|
||||
# See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
|
||||
# more info.
|
||||
# Note that jaxlib looks for "ptxas" in $PATH. See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
|
||||
# for more info.
|
||||
postInstall = lib.optionalString cudaSupport ''
|
||||
mkdir -p $out/bin
|
||||
ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
|
||||
ln -s ${cudaPackagesGoogle.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas
|
||||
|
||||
find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
|
||||
addOpenGLRunpath "$lib"
|
||||
patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"
|
||||
patchelf --add-rpath "${lib.makeLibraryPath [cuda_libs_joined cudnn nccl]}" "$lib"
|
||||
done
|
||||
'';
|
||||
|
||||
nativeBuildInputs = lib.optional cudaSupport addOpenGLRunpath;
|
||||
nativeBuildInputs = lib.optionals cudaSupport [ autoAddOpenGLRunpathHook ];
|
||||
|
||||
propagatedBuildInputs = [
|
||||
absl-py
|
||||
|
Loading…
Reference in New Issue
Block a user