Merge pull request #297146 from GaetanLepage/chex

python3Packages.jax: towards fixing dependencies
This commit is contained in:
Someone 2024-04-02 00:51:34 +00:00 committed by GitHub
commit 4c2f2f1a53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 22 additions and 8 deletions

View File

@ -82,5 +82,8 @@ buildPythonPackage rec {
homepage = "https://github.com/deepmind/distrax";
license = licenses.asl20;
maintainers = with maintainers; [ onny ];
# Several tests fail with:
# AssertionError: [Chex] Assertion assert_type failed: Error in type compatibility check
broken = true;
};
}

View File

@ -47,6 +47,11 @@ buildPythonPackage rec {
pythonImportsCheck = [ "equinox" ];
disabledTests = [
# Failed: DID NOT WARN. No warnings of type (<class 'UserWarning'>,) were emitted.
"test_tracetime"
];
meta = with lib; {
description = "A JAX library based around a simple idea: represent parameterised functions (such as neural networks) as PyTrees";
changelog = "https://github.com/patrick-kidger/equinox/releases/tag/v${version}";

View File

@ -25,16 +25,16 @@
buildPythonPackage rec {
pname = "flax";
version = "0.8.1";
version = "0.8.2";
pyproject = true;
disabled = pythonOlder "3.8";
disabled = pythonOlder "3.9";
src = fetchFromGitHub {
owner = "google";
repo = "flax";
rev = "refs/tags/v${version}";
hash = "sha256-3UzMSJoKw+V1WLBJ+Zf7aF7CDNBsvWnRUfNgb3K4v1A=";
hash = "sha256-UABgJGe1grUSkwOJpjeIoFqhXsqG//HlC1YyYPxXV+g=";
};
nativeBuildInputs = [
@ -87,6 +87,7 @@ buildPythonPackage rec {
# `tensorflow_datasets`, `vocabulary`) so the benefits of trying to run them
# would be limited anyway.
"examples/*"
"flax/experimental/nnx/examples/*"
# See https://github.com/google/flax/issues/3232.
"tests/jax_utils_test.py"
# Requires tree

View File

@ -1,6 +1,7 @@
{ lib
, absl-py
, buildPythonPackage
, flit-core
, chex
, fetchFromGitHub
, jaxlib
@ -11,16 +12,16 @@
buildPythonPackage rec {
pname = "optax";
version = "0.2.1";
format = "setuptools";
version = "0.2.2";
pyproject = true;
disabled = pythonOlder "3.7";
disabled = pythonOlder "3.9";
src = fetchFromGitHub {
owner = "deepmind";
repo = pname;
repo = "optax";
rev = "refs/tags/v${version}";
hash = "sha256-vimsVZV5Z11euLxsu998pMQZ0hG3xl96D3h9iONtl/E=";
hash = "sha256-sBiKUuQR89mttc9Njrh1aeUJOYdlcF7Nlj3/+Y7OMb4=";
};
outputs = [
@ -28,6 +29,10 @@ buildPythonPackage rec {
"testsout"
];
nativeBuildInputs = [
flit-core
];
buildInputs = [
jaxlib
];