mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Bumped the minimum ml_dtypes version to 0.4.0
This commit is contained in:
parent
ee79d7d12b
commit
0a694a1b42
@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
|
|
||||||
## jax 0.4.29
|
## jax 0.4.29
|
||||||
|
|
||||||
|
* Breaking changes
|
||||||
|
* JAX now requires ml_dtypes version 0.4.0 or newer.
|
||||||
|
|
||||||
* Deprecations
|
* Deprecations
|
||||||
* Removed a number of previously-deprecated APIs:
|
* Removed a number of previously-deprecated APIs:
|
||||||
* from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape`
|
* from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape`
|
||||||
|
@ -42,8 +42,8 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if _ml_dtypes_version < (0, 2, 0):
|
if _ml_dtypes_version < (0, 4, 0):
|
||||||
raise ValueError("JAX requires ml_dtypes version 0.2.0 or newer; "
|
raise ValueError("JAX requires ml_dtypes version 0.4.0 or newer; "
|
||||||
f"installed version is {ml_dtypes.__version__}.")
|
f"installed version is {ml_dtypes.__version__}.")
|
||||||
|
|
||||||
export = set_module('jax.dtypes')
|
export = set_module('jax.dtypes')
|
||||||
@ -500,7 +500,7 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis
|
|||||||
This DAG maps each type to its immediately higher type on the lattice.
|
This DAG maps each type to its immediately higher type on the lattice.
|
||||||
"""
|
"""
|
||||||
b1, = _bool_types
|
b1, = _bool_types
|
||||||
_uint4, u1, u2, u4, u8, _int4, i1, i2, i4, i8 = _int_types
|
uint4, u1, u2, u4, u8, int4, i1, i2, i4, i8 = _int_types
|
||||||
*f1_types, bf, f2, f4, f8 = _float_types
|
*f1_types, bf, f2, f4, f8 = _float_types
|
||||||
c4, c8 = _complex_types
|
c4, c8 = _complex_types
|
||||||
i_, f_, c_ = _weak_types
|
i_, f_, c_ = _weak_types
|
||||||
@ -508,18 +508,13 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis
|
|||||||
out: dict[JAXType, list[JAXType]]
|
out: dict[JAXType, list[JAXType]]
|
||||||
out = {
|
out = {
|
||||||
b1: [i_],
|
b1: [i_],
|
||||||
u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
|
uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
|
||||||
i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
|
i_: [uint4, int4, u1, i1],
|
||||||
|
int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
|
||||||
f_: [*f1_types, bf, f2, c_],
|
f_: [*f1_types, bf, f2, c_],
|
||||||
**{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
|
**{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
|
||||||
c_: [c4], c4: [c8], c8: [],
|
c_: [c4], c4: [c8], c8: [],
|
||||||
}
|
}
|
||||||
if _int4_dtype is not None:
|
|
||||||
out[i_].append(_int4_dtype)
|
|
||||||
out[_int4_dtype] = []
|
|
||||||
if _uint4_dtype is not None:
|
|
||||||
out[i_].append(_uint4_dtype)
|
|
||||||
out[_uint4_dtype] = []
|
|
||||||
return out
|
return out
|
||||||
elif jax_numpy_dtype_promotion == 'strict':
|
elif jax_numpy_dtype_promotion == 'strict':
|
||||||
return {
|
return {
|
||||||
|
@ -104,15 +104,6 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
|
|||||||
def maybe_upcast(x):
|
def maybe_upcast(x):
|
||||||
if x.dtype in custom_float_dtypes:
|
if x.dtype in custom_float_dtypes:
|
||||||
return x.astype(np.float32)
|
return x.astype(np.float32)
|
||||||
# TODO(reedwm): Upcasting int4 to int8 will no longer be neccessary once
|
|
||||||
# ml_dtypes has a stable release with commit
|
|
||||||
# https://github.com/jax-ml/ml_dtypes/commit/348fd3704306cae97f617c38045cee6bc416bf10.
|
|
||||||
# Remove these checks once JAX depends on a version on ml_dtypes with that
|
|
||||||
# commit.
|
|
||||||
if x.dtype == _dtypes.int4:
|
|
||||||
return x.astype(np.int8)
|
|
||||||
if x.dtype == _dtypes.uint4:
|
|
||||||
return x.astype(np.uint8)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
a = maybe_upcast(a)
|
a = maybe_upcast(a)
|
||||||
|
@ -64,7 +64,7 @@ setup(
|
|||||||
'scipy>=1.9',
|
'scipy>=1.9',
|
||||||
"scipy>=1.11.1; python_version>='3.12'",
|
"scipy>=1.11.1; python_version>='3.12'",
|
||||||
'numpy>=1.22',
|
'numpy>=1.22',
|
||||||
'ml_dtypes>=0.2.0',
|
'ml_dtypes>=0.4.0',
|
||||||
],
|
],
|
||||||
extras_require={
|
extras_require={
|
||||||
'cuda12_pip': [
|
'cuda12_pip': [
|
||||||
|
2
setup.py
2
setup.py
@ -72,7 +72,7 @@ setup(
|
|||||||
package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]},
|
package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]},
|
||||||
python_requires='>=3.9',
|
python_requires='>=3.9',
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'ml_dtypes>=0.2.0',
|
'ml_dtypes>=0.4.0',
|
||||||
'numpy>=1.22',
|
'numpy>=1.22',
|
||||||
"numpy>=1.23.2; python_version>='3.11'",
|
"numpy>=1.23.2; python_version>='3.11'",
|
||||||
"numpy>=1.26.0; python_version>='3.12'",
|
"numpy>=1.26.0; python_version>='3.12'",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user