python311Packages.flax: dependencies and tests check up

This commit is contained in:
Gaetan Lepage 2023-10-27 19:18:12 +02:00
parent e556bb0b67
commit 15cd73f569

View File

@ -4,14 +4,17 @@
, jaxlib
, pythonRelaxDepsHook
, setuptools-scm
, cloudpickle
, jax
, matplotlib
, msgpack
, numpy
, optax
, pyyaml
, rich
, tensorstore
, typing-extensions
, matplotlib
, cloudpickle
, einops
, keras
, pytest-xdist
, pytestCheckHook
@ -37,24 +40,27 @@ buildPythonPackage rec {
];
propagatedBuildInputs = [
cloudpickle
jax
matplotlib
msgpack
numpy
optax
pyyaml
rich
tensorstore
typing-extensions
];
# See https://github.com/google/flax/pull/2882.
pythonRemoveDeps = [ "orbax" ];
passthru.optional-dependencies = {
all = [ matplotlib ];
};
pythonImportsCheck = [
"flax"
];
nativeCheckInputs = [
cloudpickle
einops
keras
pytest-xdist
pytestCheckHook
@ -85,22 +91,6 @@ buildPythonPackage rec {
"tests/checkpoints_test.py"
];
disabledTests = [
# See https://github.com/google/flax/issues/2554.
"test_async_save_checkpoints"
"test_jax_array0"
"test_jax_array1"
"test_keep0"
"test_keep1"
"test_optimized_lstm_cell_matches_regular"
"test_overwrite_checkpoints"
"test_save_restore_checkpoints_target_empty"
"test_save_restore_checkpoints_target_none"
"test_save_restore_checkpoints_target_singular"
"test_save_restore_checkpoints_w_float_steps"
"test_save_restore_checkpoints"
];
meta = with lib; {
description = "Neural network library for JAX";
homepage = "https://github.com/google/flax";