From b1f7627c71fb79e2c901946e880a6916bb461eaa Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 12 Jun 2024 14:39:11 -0700 Subject: [PATCH] [Rollback] Bumped the minimum ml_dtypes version to 0.4.0 Reverts e86c436e7f8e4e0546eff8bc2d3756a7c49dc83b PiperOrigin-RevId: 642741832 --- CHANGELOG.md | 6 ++++++ jax/_src/dtypes.py | 17 +++++++++++------ jax/_src/public_test_util.py | 9 +++++++++ jaxlib/setup.py | 2 +- setup.py | 2 +- 5 files changed, 28 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bf9f953eb..b9bf07e4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,12 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.30 +* Changes + * JAX supports ml_dtypes >= 0.2. In 0.4.29 release, the ml_dtypes version was + bumped to 0.4.0 but this has been rolled back in this release to give users + of both TensorFlow and JAX more time to migrate to a newer TensorFlow + release. + ## jaxlib 0.4.30 ## jax 0.4.29 (June 10, 2024) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index f85e4833e..b4c9cfdf5 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -42,8 +42,8 @@ try: except: pass else: - if _ml_dtypes_version < (0, 4, 0): - raise ValueError("JAX requires ml_dtypes version 0.4.0 or newer; " + if _ml_dtypes_version < (0, 2, 0): + raise ValueError("JAX requires ml_dtypes version 0.2.0 or newer; " f"installed version is {ml_dtypes.__version__}.") export = set_module('jax.dtypes') @@ -500,7 +500,7 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis This DAG maps each type to its immediately higher type on the lattice. """ b1, = _bool_types - uint4, u1, u2, u4, u8, int4, i1, i2, i4, i8 = _int_types + _uint4, u1, u2, u4, u8, _int4, i1, i2, i4, i8 = _int_types *f1_types, bf, f2, f4, f8 = _float_types c4, c8 = _complex_types i_, f_, c_ = _weak_types @@ -508,13 +508,18 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis out: dict[JAXType, list[JAXType]] out = { b1: [i_], - uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], - i_: [uint4, int4, u1, i1], - int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_], + u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], + i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_], f_: [*f1_types, bf, f2, c_], **{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8], c_: [c4], c4: [c8], c8: [], } + if _int4_dtype is not None: + out[i_].append(_int4_dtype) + out[_int4_dtype] = [] + if _uint4_dtype is not None: + out[i_].append(_uint4_dtype) + out[_uint4_dtype] = [] return out elif jax_numpy_dtype_promotion == 'strict': return { diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index c7cb4ee20..bc18226f0 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -104,6 +104,15 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): def maybe_upcast(x): if x.dtype in custom_float_dtypes: return x.astype(np.float32) + # TODO(reedwm): Upcasting int4 to int8 will no longer be neccessary once + # ml_dtypes has a stable release with commit + # https://github.com/jax-ml/ml_dtypes/commit/348fd3704306cae97f617c38045cee6bc416bf10. + # Remove these checks once JAX depends on a version on ml_dtypes with that + # commit. + if x.dtype == _dtypes.int4: + return x.astype(np.int8) + if x.dtype == _dtypes.uint4: + return x.astype(np.uint8) return x a = maybe_upcast(a) diff --git a/jaxlib/setup.py b/jaxlib/setup.py index e36996749..c539b75dd 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -64,7 +64,7 @@ setup( 'scipy>=1.9', "scipy>=1.11.1; python_version>='3.12'", 'numpy>=1.22', - 'ml_dtypes>=0.4.0', + 'ml_dtypes>=0.2.0', ], extras_require={ 'cuda12_pip': [ diff --git a/setup.py b/setup.py index e89bf22fd..3845d56ea 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ setup( package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]}, python_requires='>=3.9', install_requires=[ - 'ml_dtypes>=0.4.0', + 'ml_dtypes>=0.2.0', 'numpy>=1.22', "numpy>=1.23.2; python_version>='3.11'", "numpy>=1.26.0; python_version>='3.12'",