Merge pull request #291705 from GaetanLepage/jax

python311Packages.{jax,jaxlib,jaxlib-bin}: 0.4.24 -> 0.4.28
This commit is contained in:
Samuel Ainsworth 2024-05-13 12:06:07 -04:00 committed by GitHub
commit 3a993d3244
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 128 additions and 100 deletions

View File

@ -16,7 +16,7 @@
buildPythonPackage rec {
pname = "blackjax";
version = "1.2.0";
version = "1.2.1";
pyproject = true;
disabled = pythonOlder "3.9";
@ -25,7 +25,7 @@ buildPythonPackage rec {
owner = "blackjax-devs";
repo = "blackjax";
rev = "refs/tags/${version}";
hash = "sha256-vXyxK3xALKG61YGK7fmoqQNGfOiagHFrvnU02WKZThw=";
hash = "sha256-VoWBCjFMyE5LVJyf7du/pKlnvDHj22lguiP6ZUzH9ak=";
};
build-system = [
@ -56,6 +56,10 @@ buildPythonPackage rec {
disabledTests = [
# too slow
"test_adaptive_tempered_smc"
] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
# Numerical test (AssertionError)
# https://github.com/blackjax-devs/blackjax/issues/668
"test_chees_adaptation"
];
pythonImportsCheck = [

View File

@ -48,8 +48,21 @@ buildPythonPackage rec {
pythonImportsCheck = [ "equinox" ];
disabledTests = [
# Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted.
"test_tracetime"
# For simplicity, JAX has removed its internal frames from the traceback of the following exception.
# https://github.com/patrick-kidger/equinox/issues/716
"test_abstract"
"test_complicated"
"test_grad"
"test_jvp"
"test_mlp"
"test_num_traces"
"test_pytree_in"
"test_simple"
"test_vmap"
# AssertionError: assert 'foo:\n pri...pe=float32)\n' == 'foo:\n pri...pe=float32)\n'
# Also reported in patrick-kidger/equinox#716
"test_backward_nan"
];
meta = with lib; {

View File

@ -25,7 +25,7 @@
buildPythonPackage rec {
pname = "flax";
version = "0.8.2";
version = "0.8.3";
pyproject = true;
disabled = pythonOlder "3.9";
@ -34,16 +34,16 @@ buildPythonPackage rec {
owner = "google";
repo = "flax";
rev = "refs/tags/v${version}";
hash = "sha256-UABgJGe1grUSkwOJpjeIoFqhXsqG//HlC1YyYPxXV+g=";
hash = "sha256-uDGTyksUZTTL6FiTJP+qteFLOjr75dcTj9yRJ6Jm8xU=";
};
nativeBuildInputs = [
build-system = [
jaxlib
pythonRelaxDepsHook
setuptools-scm
];
propagatedBuildInputs = [
dependencies = [
jax
msgpack
numpy

View File

@ -29,7 +29,7 @@ let
in
buildPythonPackage rec {
pname = "jax";
version = "0.4.25";
version = "0.4.28";
pyproject = true;
disabled = pythonOlder "3.9";
@ -39,7 +39,7 @@ buildPythonPackage rec {
repo = "jax";
# google/jax contains tags for jax and jaxlib. Only use jax tags!
rev = "refs/tags/jax-v${version}";
hash = "sha256-poQQo2ZgEhPYzK3aCs+BjaHTNZbezJAECd+HOdY1Yok=";
hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
};
nativeBuildInputs = [
@ -81,6 +81,14 @@ buildPythonPackage rec {
"tests/"
];
# Prevents `tests/export_back_compat_test.py::CompatTest::test_*` tests from failing on darwin with
# PermissionError: [Errno 13] Permission denied: '/tmp/back_compat_testdata/test_*.py'
# See https://github.com/google/jax/blob/jaxlib-v0.4.27/jax/_src/internal_test_util/export_back_compat_test_util.py#L240-L241
# NOTE: this doesn't seem to be an issue on linux
preCheck = lib.optionalString stdenv.isDarwin ''
export TEST_UNDECLARED_OUTPUTS_DIR=$(mktemp -d)
'';
disabledTests = [
# Exceeds tolerance when the machine is busy
"test_custom_linear_solve_aux"

View File

@ -20,17 +20,17 @@
, stdenv
# Options:
, cudaSupport ? config.cudaSupport
, cudaPackagesGoogle
, cudaPackages
}:
let
inherit (cudaPackagesGoogle) cudaVersion;
inherit (cudaPackages) cudaVersion;
version = "0.4.24";
version = "0.4.28";
inherit (python) pythonVersion;
cudaLibPath = lib.makeLibraryPath (with cudaPackagesGoogle; [
cudaLibPath = lib.makeLibraryPath (with cudaPackages; [
cuda_cudart.lib # libcudart.so
cuda_cupti.lib # libcupti.so
cudnn.lib # libcudnn.so
@ -56,65 +56,65 @@ let
"3.9-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp39";
hash = "sha256-6P5ArMoLZiUkHUoQ/mJccbNj5/7el/op+Qo6cGQ33xE=";
hash = "sha256-Slbr8FtKTBeRaZ2HTgcvP4CPCYa0AQsU+1SaackMqdw=";
};
"3.9-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp39";
hash = "sha256-23JQZRwMLtt7sK/JlCBqqRyfTVIAVJFN2sL+nAkQgvU=";
hash = "sha256-sBVi7IrXVxm30DiXUkiel+trTctMjBE75JFjTVKCrTw=";
};
"3.9-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp39";
hash = "sha256-OgMedn9GHGs5THZf3pkP3Aw/jJ0vL5qK1b+Lzf634Ik=";
hash = "sha256-T5jMg3srbG3P4Kt/+esQkxSSCUYRmqOvn6oTlxj/J4c=";
};
"3.10-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp310";
hash = "sha256-/VwUIIa7mTs/wLz0ArsEfNrz2pGriVVT5GX9XRFRxfY=";
hash = "sha256-47zcb45g+FVPQVwU2TATTmAuPKM8OOVGJ0/VRfh1dps=";
};
"3.10-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp310";
hash = "sha256-LgICOyDGts840SQQJh+yOMobMASb62llvJjpGvhzrSw=";
hash = "sha256-8Djmi9ENGjVUcisLvjbmpEg4RDenWqnSg/aW8O2fjAk=";
};
"3.10-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp310";
hash = "sha256-vhyULw+zBpz1UEi2tqgBMQEzY9a6YBgEIg6A4PPh3bQ=";
hash = "sha256-pCHSN/jCXShQFm0zRgPGc925tsJvUrxJZwS4eCKXvWY=";
};
"3.11-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp311";
hash = "sha256-VJO/VVwBFkOEtq4y/sLVgAV8Cung01JULiuT6W96E/8=";
hash = "sha256-Rc4PPIQM/4I2z/JsN/Jsn/B4aV+T4MFiwyDCgfUEEnU=";
};
"3.11-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp311";
hash = "sha256-VtuwXxurpSp1KI8ty1bizs5cdy8GEBN2MgS227sOCmE=";
hash = "sha256-eThX+vN/Nxyv51L+pfyBH0NeQ7j7S1AgWERKf17M+Ck=";
};
"3.11-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp311";
hash = "sha256-4Dj5dEGKb9hpg3HlVogNO1Gc9UibJhy1eym2mjivxAQ=";
hash = "sha256-L/gpDtx7ksfq5SUX9lSSYz4mey6QZ7rT5MMj0hPnfPU=";
};
"3.12-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp312";
hash = "sha256-TlrGVtb3NTLmhnILWPLJR+jISCZ5SUV4wxNFpSfkCBo=";
hash = "sha256-RqGqhX9P7uikP8upXA4Kti1AwmzJcwtsaWVZCLo1n40=";
};
"3.12-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp312";
hash = "sha256-FIwK5CGykQjteuWzLZnbtAggIxLQeGV96bXlZGEytN0=";
hash = "sha256-jdi//jhTcC9jzZJNoO4lc0pNGc1ckmvgM9dyun0cF10=";
};
"3.12-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp312";
hash = "sha256-9/jw/wr6oUD9pOadVAaMRL086iVMUXwVgnUMcG1UNvE=";
hash = "sha256-1sCaVFMpciRhrwVuc1FG0sjHTCKsdCaoRetp8ya096A=";
};
};
@ -130,35 +130,19 @@ let
gpuSrcs = {
"cuda12.2-3.9" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl";
hash = "sha256-xdJKLPtx+CIza2CrWKM3M0cZJzyNFVTTTsvlgh38bfM=";
hash = "sha256-d8LIl22gIvmWfoyKfXKElZJXicPQIZxdS4HumhwQGCw=";
};
"cuda12.2-3.10" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-QCjrOczD2mp+CDwVXBc0/4rJnAizeV62AK0Dpx9X6TE=";
hash = "sha256-PXtWv+UEcMWF8LhWe6Z1UGkf14PG3dkJ0Iop0LiimnQ=";
};
"cuda12.2-3.11" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl";
hash = "sha256-Ipy3vk1yUplpNzECAFt63aOIhgEWgXG7hkoeTIk9bQQ=";
hash = "sha256-QO2WSOzmJ48VaCha596mELiOfPsAGLpGctmdzcCHE/o=";
};
"cuda12.2-3.12" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl";
hash = "sha256-LSnZHaUga/8Z65iKXWBnZDk4yUpNykFTu3vukCchO6Q=";
};
"cuda11.8-3.9" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl";
hash = "sha256-UmyugL0VjlXkiD7fuDPWgW8XUpr/QaP5ggp6swoZTzU=";
};
"cuda11.8-3.10" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-luKULEiV1t/sO6eckDxddJTiOFa0dtJeDlrvp+WYmHk=";
};
"cuda11.8-3.11" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl";
hash = "sha256-4+uJ8Ij6mFGEmjFEgi3fLnSLZs+v18BRoOt7mZuqydw=";
};
"cuda11.8-3.12" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp312-cp312-manylinux2014_x86_64.whl";
hash = "sha256-bUDFb94Ar/65SzzR9RLIs/SL/HdjaPT1Su5whmjkS00=";
hash = "sha256-ixWMaIChy4Ammsn23/3cCoala0lFibuUxyUr3tjfFKU=";
};
};
@ -213,7 +197,7 @@ buildPythonPackage {
# for more info.
postInstall = lib.optional cudaSupport ''
mkdir -p $out/${python.sitePackages}/jaxlib/cuda/bin
ln -s ${lib.getExe' cudaPackagesGoogle.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jaxlib/cuda/bin/ptxas
'';
inherit (jaxlib-build) pythonImportsCheck;
@ -227,7 +211,7 @@ buildPythonPackage {
platforms = [ "aarch64-darwin" "x86_64-linux" "x86_64-darwin" ];
broken =
!(cudaSupport -> lib.versionAtLeast cudaVersion "11.1")
|| !(cudaSupport -> lib.versionAtLeast cudaPackagesGoogle.cudnn.version "8.2")
|| !(cudaSupport -> lib.versionAtLeast cudaPackages.cudnn.version "8.2")
|| !(cudaSupport -> stdenv.isLinux)
|| !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}"))
# Fails at pythonImportsCheckPhase:

View File

@ -13,7 +13,6 @@
, curl
, cython
, fetchFromGitHub
, fetchpatch
, git
, IOKit
, jsoncpp
@ -45,22 +44,22 @@
, config
# CUDA flags:
, cudaSupport ? config.cudaSupport
, cudaPackagesGoogle
, cudaPackages
# MKL:
, mklSupport ? true
}@inputs:
let
inherit (cudaPackagesGoogle) cudaFlags cudaVersion cudnn nccl;
inherit (cudaPackages) cudaFlags cudaVersion cudnn nccl;
pname = "jaxlib";
version = "0.4.24";
version = "0.4.28";
# It's necessary to consistently use backendStdenv when building with CUDA
# support, otherwise we get libstdc++ errors downstream
stdenv = throw "Use effectiveStdenv instead";
effectiveStdenv = if cudaSupport then cudaPackagesGoogle.backendStdenv else inputs.stdenv;
effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else inputs.stdenv;
meta = with lib; {
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@ -78,7 +77,7 @@ let
# These are necessary at build time and run time.
cuda_libs_joined = symlinkJoin {
name = "cuda-joined";
paths = with cudaPackagesGoogle; [
paths = with cudaPackages; [
cuda_cudart.lib # libcudart.so
cuda_cudart.static # libcudart_static.a
cuda_cupti.lib # libcupti.so
@ -92,11 +91,11 @@ let
# These are only necessary at build time.
cuda_build_deps_joined = symlinkJoin {
name = "cuda-build-deps-joined";
paths = with cudaPackagesGoogle; [
paths = with cudaPackages; [
cuda_libs_joined
# Binaries
cudaPackagesGoogle.cuda_nvcc.bin # nvcc
cudaPackages.cuda_nvcc.bin # nvcc
# Headers
cuda_cccl.dev # block_load.cuh
@ -181,19 +180,10 @@ let
owner = "openxla";
repo = "xla";
# Update this according to https://github.com/google/jax/blob/jaxlib-v${version}/third_party/xla/workspace.bzl.
rev = "12eee889e1f2ad41e27d7b0e970cb92d282d3ec5";
hash = "sha256-68kjjgwYjRlcT0TVJo9BN6s+WTkdu5UMJqQcfHpBT90=";
rev = "e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4";
hash = "sha256-ZhgMIVs3Z4dTrkRWDqaPC/i7yJz2dsYXrZbjzqvPX3E=";
};
patches = [
# Resolves "could not convert result from SmallVector<[...],6> to
# SmallVector<[...],4>" compilation error. See https://github.com/google/jax/issues/19814#issuecomment-1945141259.
(fetchpatch {
url = "https://github.com/openxla/xla/commit/7a614cd346594fc7ea2fe75570c9c53a4a444f60.patch";
hash = "sha256-RtuQTH8wzNiJcOtISLhf+gMlH1gg8hekvxEB+4wX6BM=";
})
];
dontBuild = true;
# This is necessary for patchShebangs to know the right path to use.
@ -220,7 +210,7 @@ let
repo = "jax";
# google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
rev = "refs/tags/${pname}-v${version}";
hash = "sha256-hmx7eo3pephc6BQfoJ3U0QwWBWmhkAc+7S4QmW32qQs=";
hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
};
nativeBuildInputs = [
@ -364,10 +354,10 @@ let
];
sha256 = (if cudaSupport then {
x86_64-linux = "sha256-8JilAoTbqOjOOJa/Zc/n/quaEDcpdcLXCNb34mfB+OM=";
x86_64-linux = "sha256-VGNMf5/DgXbgsu1w5J1Pmrukw+7UO31BNU+crKVsX5k=";
} else {
x86_64-linux = "sha256-iqS+I1FQLNWXNMsA20cJp7YkyGUeshee5b2QfRBNZtk=";
aarch64-linux = "sha256-qmJ0Fm/VGMTmko4PhKs1P8/GLEJmVxb8xg+ss/HsakY==";
x86_64-linux = "sha256-uOoAyMBLHPX6jzdN43b5wZV5eW0yI8sCDD7BSX2h4oQ=";
aarch64-linux = "sha256-+SnGKY9LIT1Qhu/x6Uh7sHRaAEjlc//qyKj1m4t16PA=";
}).${effectiveStdenv.system} or (throw "jaxlib: unsupported system: ${effectiveStdenv.system}");
};
@ -414,7 +404,7 @@ buildPythonPackage {
# for more info.
postInstall = lib.optionalString cudaSupport ''
mkdir -p $out/bin
ln -s ${cudaPackagesGoogle.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas
ln -s ${cudaPackages.cuda_nvcc.bin}/bin/ptxas $out/bin/ptxas
find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
patchelf --add-rpath "${lib.makeLibraryPath [cuda_libs_joined cudnn nccl]}" "$lib"
@ -423,7 +413,7 @@ buildPythonPackage {
nativeBuildInputs = lib.optionals cudaSupport [ autoAddDriverRunpath ];
propagatedBuildInputs = [
dependencies = [
absl-py
curl
double-conversion

View File

@ -6,6 +6,7 @@
, fetchpatch
, pytest-xdist
, pytestCheckHook
, setuptools
, absl-py
, cvxpy
, jax
@ -20,7 +21,7 @@
buildPythonPackage rec {
pname = "jaxopt";
version = "0.8.3";
format = "setuptools";
pyproject = true;
disabled = pythonOlder "3.8";
@ -41,7 +42,11 @@ buildPythonPackage rec {
})
];
propagatedBuildInputs = [
build-system = [
setuptools
];
dependencies = [
absl-py
jax
jaxlib
@ -66,11 +71,20 @@ buildPythonPackage rec {
"jaxopt.tree_util"
];
disabledTests = lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
disabledTests = [
# https://github.com/google/jaxopt/issues/592
"test_solve_sparse"
] ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
# https://github.com/google/jaxopt/issues/577
"test_binary_logit_log_likelihood"
"test_solve_sparse"
"test_logreg_with_intercept_manual_loop3"
# https://github.com/google/jaxopt/issues/593
# Makes the test suite crash
"test_dtype_consistency"
# AssertionError: Array(0.01411963, dtype=float32) not less than or equal to 0.01
"test_multiclass_logreg6"
];
meta = with lib; {

View File

@ -51,8 +51,10 @@ buildPythonPackage rec {
scipy
torch
tensorflow
jax
jaxlib
# Uncomment at next release (1.9.3)
# See https://github.com/wjakob/nanobind/issues/578
# jax
# jaxlib
];
meta = with lib; {

View File

@ -1,7 +1,6 @@
{ lib
, buildPythonPackage
, fetchFromGitHub
, fetchpatch
, jax
, jaxlib
, keras
@ -30,7 +29,12 @@ buildPythonPackage rec {
hash = "sha256-WD+pmR8cEay4iziRXqF3sHUzCMBjmLJ3wZ3iYOD+hzk=";
};
nativeBuildInputs = [
patches = [
# Issue reported upstream: https://github.com/google/objax/issues/270
./replace-deprecated-device_buffers.patch
];
build-system = [
setuptools
];
@ -40,7 +44,7 @@ buildPythonPackage rec {
jaxlib
];
propagatedBuildInputs = [
dependencies = [
jax
numpy
parameterized

View File

@ -0,0 +1,14 @@
diff --git a/objax/util/util.py b/objax/util/util.py
index c31a356..344cf9a 100644
--- a/objax/util/util.py
+++ b/objax/util/util.py
@@ -117,7 +117,8 @@ def get_local_devices():
if _local_devices is None:
x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32)
sharded_x = map_to_device(x)
- _local_devices = [b.device() for b in sharded_x.device_buffers]
+ device_buffers = [buf.data for buf in sharded_x.addressable_shards]
+ _local_devices = [list(b.devices())[0] for b in device_buffers]
return _local_devices

View File

@ -22,7 +22,7 @@
, tensorboard
, config
, cudaSupport ? config.cudaSupport
, cudaPackagesGoogle
, cudaPackages
, zlib
, python
, keras-applications
@ -43,7 +43,7 @@ assert ! (stdenv.isDarwin && cudaSupport);
let
packages = import ./binary-hashes.nix;
inherit (cudaPackagesGoogle) cudatoolkit cudnn;
inherit (cudaPackages) cudatoolkit cudnn;
in buildPythonPackage {
pname = "tensorflow" + lib.optionalString cudaSupport "-gpu";
inherit (packages) version;
@ -199,10 +199,6 @@ in buildPythonPackage {
"tensorflow.python.framework"
];
passthru = {
cudaPackages = cudaPackagesGoogle;
};
meta = with lib; {
description = "Computation using data flow graphs for scalable machine learning";
homepage = "http://tensorflow.org";

View File

@ -19,8 +19,8 @@
# https://groups.google.com/a/tensorflow.org/forum/#!topic/developers/iRCt5m4qUz0
, config
, cudaSupport ? config.cudaSupport
, cudaPackagesGoogle
, cudaCapabilities ? cudaPackagesGoogle.cudaFlags.cudaCapabilities
, cudaPackages
, cudaCapabilities ? cudaPackages.cudaFlags.cudaCapabilities
, mklSupport ? false, mkl
, tensorboardSupport ? true
# XLA without CUDA is broken
@ -50,15 +50,15 @@ let
# __ZN4llvm11SmallPtrSetIPKNS_10AllocaInstELj8EED1Ev in any of the
# translation units, so the build fails at link time
stdenv =
if cudaSupport then cudaPackagesGoogle.backendStdenv
if cudaSupport then cudaPackages.backendStdenv
else if originalStdenv.isDarwin then llvmPackages.stdenv
else originalStdenv;
inherit (cudaPackagesGoogle) cudatoolkit nccl;
inherit (cudaPackages) cudatoolkit nccl;
# use compatible cuDNN (https://www.tensorflow.org/install/source#gpu)
# cudaPackages.cudnn led to this:
# https://github.com/tensorflow/tensorflow/issues/60398
cudnnAttribute = "cudnn_8_6";
cudnn = cudaPackagesGoogle.${cudnnAttribute};
cudnn = cudaPackages.${cudnnAttribute};
gentoo-patches = fetchzip {
url = "https://dev.gentoo.org/~perfinion/patches/tensorflow-patches-2.12.0.tar.bz2";
hash = "sha256-SCRX/5/zML7LmKEPJkcM5Tebez9vv/gmE4xhT/jyqWs=";
@ -490,8 +490,8 @@ let
broken =
stdenv.isDarwin
|| !(xlaSupport -> cudaSupport)
|| !(cudaSupport -> builtins.hasAttr cudnnAttribute cudaPackagesGoogle)
|| !(cudaSupport -> cudaPackagesGoogle ? cudatoolkit);
|| !(cudaSupport -> builtins.hasAttr cudnnAttribute cudaPackages)
|| !(cudaSupport -> cudaPackages ? cudatoolkit);
} // lib.optionalAttrs stdenv.isDarwin {
timeout = 86400; # 24 hours
maxSilent = 14400; # 4h, double the default of 7200s
@ -594,7 +594,6 @@ in buildPythonPackage {
# Regression test for #77626 removed because not more `tensorflow.contrib`.
passthru = {
cudaPackages = cudaPackagesGoogle;
deps = bazel-build.deps;
libtensorflow = bazel-build.out;
};

View File

@ -3,7 +3,6 @@
recurseIntoAttrs,
cudaPackages,
cudaPackagesGoogle,
cudaPackages_10_0,
cudaPackages_10_1,

View File

@ -7125,10 +7125,6 @@ with pkgs;
cudaPackages_12_3 = callPackage ./cuda-packages.nix { cudaVersion = "12.3"; };
cudaPackages_12 = cudaPackages_12_2; # Latest supported by cudnn
# Use the older cudaPackages for tensorflow and jax, as determined by cudnn
# compatibility: https://www.tensorflow.org/install/source#gpu
cudaPackagesGoogle = cudaPackages_11;
cudaPackages = recurseIntoAttrs cudaPackages_12;
# TODO: move to alias

View File

@ -14885,6 +14885,8 @@ self: super: with self; {
tensorflow-bin = callPackage ../development/python-modules/tensorflow/bin.nix {
inherit (pkgs.config) cudaSupport;
# https://www.tensorflow.org/install/source#gpu
cudaPackages = pkgs.cudaPackages_11;
};
tensorflow-build = let
@ -14892,6 +14894,8 @@ self: super: with self; {
protobufTF = pkgs.protobuf_21.override {
abseil-cpp = pkgs.abseil-cpp_202301;
};
# https://www.tensorflow.org/install/source#gpu
cudaPackagesTF = pkgs.cudaPackages_11;
grpcTF = (pkgs.grpc.overrideAttrs (
oldAttrs: rec {
# nvcc fails on recent grpc versions, so we use the latest patch level
@ -14937,6 +14941,7 @@ self: super: with self; {
inherit (pkgs.darwin.apple_sdk.frameworks) Foundation Security;
flatbuffers-core = pkgs.flatbuffers;
flatbuffers-python = self.flatbuffers;
cudaPackages = compat.cudaPackagesTF;
protobuf-core = compat.protobufTF;
protobuf-python = compat.protobuf-pythonTF;
grpc = compat.grpcTF;