cleanup now that we depend on ml_dtypes>=0.5

This commit is contained in:
Jake VanderPlas 2025-03-28 07:44:38 -07:00
parent e679811c4a
commit 431c2c0807
12 changed files with 101 additions and 212 deletions

View File

@ -90,19 +90,18 @@ class ExtendedDType(StrictABC):
# fp8 support
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
float8_e3m4: type[np.generic] | None = None
float8_e4m3: type[np.generic] | None = None
float8_e8m0fnu: type[np.generic] | None = None
float8_e3m4: type[np.generic] = ml_dtypes.float8_e3m4
float8_e4m3: type[np.generic] = ml_dtypes.float8_e4m3
float8_e8m0fnu: type[np.generic] = ml_dtypes.float8_e8m0fnu
float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz
float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn
float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz
float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2
float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz
_float8_e3m4_dtype: np.dtype | None = None
_float8_e4m3_dtype: np.dtype | None = None
_float8_e8m0fnu_dtype: np.dtype | None = None
_float8_e3m4_dtype: np.dtype = np.dtype(float8_e3m4)
_float8_e4m3_dtype: np.dtype = np.dtype(float8_e4m3)
_float8_e8m0fnu_dtype: np.dtype = np.dtype(float8_e8m0fnu)
_float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz)
_float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn)
_float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz)
@ -111,9 +110,9 @@ _float8_e5m2fnuz_dtype: np.dtype = np.dtype(float8_e5m2fnuz)
# fp4 support
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
float4_e2m1fn: type[np.generic] | None = None
float4_e2m1fn: type[np.generic] = ml_dtypes.float4_e2m1fn
_float4_e2m1fn_dtype: np.dtype | None = None
_float4_e2m1fn_dtype: np.dtype = np.dtype(float4_e2m1fn)
def supports_inf(dtype: DTypeLike) -> bool:
"""Return true if the dtype supports infinity, else return False."""
@ -127,6 +126,10 @@ bfloat16: type[np.generic] = ml_dtypes.bfloat16
_bfloat16_dtype: np.dtype = np.dtype(bfloat16)
_custom_float_scalar_types = [
float4_e2m1fn,
float8_e3m4,
float8_e4m3,
float8_e8m0fnu,
float8_e4m3b11fnuz,
float8_e4m3fn,
float8_e4m3fnuz,
@ -135,6 +138,10 @@ _custom_float_scalar_types = [
bfloat16,
]
_custom_float_dtypes = [
_float4_e2m1fn_dtype,
_float8_e3m4_dtype,
_float8_e4m3_dtype,
_float8_e8m0fnu_dtype,
_float8_e4m3b11fnuz_dtype,
_float8_e4m3fn_dtype,
_float8_e4m3fnuz_dtype,
@ -143,6 +150,9 @@ _custom_float_dtypes = [
_bfloat16_dtype,
]
_float8_dtypes = [
_float8_e3m4_dtype,
_float8_e4m3_dtype,
_float8_e8m0fnu_dtype,
_float8_e4m3b11fnuz_dtype,
_float8_e4m3fn_dtype,
_float8_e4m3fnuz_dtype,
@ -150,58 +160,28 @@ _float8_dtypes = [
_float8_e5m2fnuz_dtype,
]
_float4_dtypes: list[np.dtype] = []
_float4_dtypes: list[np.dtype] = [
_float4_e2m1fn_dtype,
]
# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0
if hasattr(ml_dtypes, "float8_e4m3"):
float8_e4m3 = ml_dtypes.float8_e4m3
_float8_e4m3_dtype = np.dtype(float8_e4m3)
_custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e4m3_dtype)
_float8_dtypes.insert(0, _float8_e4m3_dtype)
if hasattr(ml_dtypes, "float8_e3m4"):
float8_e3m4 = ml_dtypes.float8_e3m4
_float8_e3m4_dtype = np.dtype(float8_e3m4)
_custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e3m4_dtype)
_float8_dtypes.insert(0, _float8_e3m4_dtype)
if hasattr(ml_dtypes, "float8_e8m0fnu"):
float8_e8m0fnu = ml_dtypes.float8_e8m0fnu
_float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu)
_custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype)
_float8_dtypes.insert(0, _float8_e8m0fnu_dtype)
if hasattr(ml_dtypes, "float4_e2m1fn"):
float4_e2m1fn = ml_dtypes.float4_e2m1fn
_float4_e2m1fn_dtype = np.dtype(float4_e2m1fn)
_custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type]
_custom_float_dtypes.insert(0, _float4_e2m1fn_dtype)
_float4_dtypes.insert(0, _float4_e2m1fn_dtype)
int2: type[np.generic] = ml_dtypes.int2
uint2: type[np.generic] = ml_dtypes.uint2
# 2-bit integer support
int2: type[np.generic] | None = None
uint2: type[np.generic] | None = None
_int2_dtype: np.dtype | None = None
_uint2_dtype: np.dtype | None = None
_intn_dtypes = []
# Remove the condition once the minimum ml_dtypes version required by JAX
# contains https://github.com/jax-ml/ml_dtypes/pull/154.
if hasattr(ml_dtypes, 'int2'):
int2 = ml_dtypes.int2
uint2 = ml_dtypes.uint2
_int2_dtype = np.dtype(int2)
_uint2_dtype = np.dtype(uint2)
_intn_dtypes.extend([_int2_dtype, _uint2_dtype])
_int2_dtype: np.dtype = np.dtype(int2)
_uint2_dtype: np.dtype = np.dtype(uint2)
# 4-bit integer support
int4: type[np.generic] = ml_dtypes.int4
uint4: type[np.generic] = ml_dtypes.uint4
_int4_dtype = np.dtype(int4)
_uint4_dtype = np.dtype(uint4)
_intn_dtypes.extend([_int4_dtype, _uint4_dtype])
_intn_dtypes = [
_int2_dtype,
_uint2_dtype,
_int4_dtype,
_uint4_dtype,
]
# Default types.
bool_ = np.bool_
@ -472,9 +452,9 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
# to the normal scalar type hierarchy.
if a_sctype in _custom_float_scalar_types:
return b_sctype in {a_sctype, np.floating, np.inexact, np.number, np.generic}
if (int2 is not None and a_sctype == int2) or a_sctype == int4:
if a_sctype in [int2, int4]:
return b_sctype in {a_sctype, np.signedinteger, np.integer, np.number, np.generic}
if (uint2 is not None and a_sctype == uint2) or a_sctype == uint4:
if a_sctype in [uint2, uint4]:
return b_sctype in {a_sctype, np.unsignedinteger, np.integer, np.number, np.generic}
# Otherwise, fall back to numpy.issubdtype
@ -491,6 +471,7 @@ _signed_types: list[JAXType]
_unsigned_types: list[JAXType]
_int_types: list[JAXType]
_unsigned_types = [
np.dtype(uint2),
np.dtype(uint4),
np.dtype('uint8'),
np.dtype('uint16'),
@ -498,6 +479,7 @@ _unsigned_types = [
np.dtype('uint64'),
]
_signed_types = [
np.dtype(int2),
np.dtype(int4),
np.dtype('int8'),
np.dtype('int16'),
@ -505,11 +487,6 @@ _signed_types = [
np.dtype('int64'),
]
if _int2_dtype is not None:
_signed_types.insert(0, _int2_dtype)
if _uint2_dtype is not None:
_unsigned_types.insert(0, _uint2_dtype)
_int_types = _unsigned_types + _signed_types
_float_types: list[JAXType] = [
@ -622,11 +599,7 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis
This DAG maps each type to its immediately higher type on the lattice.
"""
b1, = _bool_types
if _int2_dtype is not None:
assert _uint2_dtype is not None
_uint2, uint4, u1, u2, u4, u8, _int2, int4, i1, i2, i4, i8 = _int_types
else:
uint4, u1, u2, u4, u8, int4, i1, i2, i4, i8 = _int_types
uint2, uint4, u1, u2, u4, u8, int2, int4, i1, i2, i4, i8 = _int_types
*f1_types, bf, f2, f4, f8 = _float_types
c4, c8 = _complex_types
i_, f_, c_ = _weak_types
@ -634,19 +607,13 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis
out: dict[JAXType, list[JAXType]]
out = {
b1: [i_],
i_: [u1, uint4, i1, int4],
uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
i_: [u1, uint2, uint4, i1, int2, int4],
uint2: [], uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
int2: [], int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
f_: [*f1_types, bf, f2, c_],
**{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
c_: [c4], c4: [c8], c8: [],
}
if _int2_dtype is not None:
out[i_].append(_int2_dtype)
out[_int2_dtype] = []
if _uint2_dtype is not None:
out[i_].append(_uint2_dtype)
out[_uint2_dtype] = []
return out
elif jax_numpy_dtype_promotion == 'strict':
return {

View File

@ -357,16 +357,12 @@ _dtype_to_dtype_kind = {
dtypes._float8_e4m3fnuz_dtype: ser_flatbuf.DType.f8_e4m3fnuz,
dtypes._float8_e5m2_dtype: ser_flatbuf.DType.f8_e5m2,
dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz,
dtypes._float8_e3m4_dtype: ser_flatbuf.DType.f8_e3m4,
dtypes._float8_e4m3_dtype: ser_flatbuf.DType.f8_e4m3,
dtypes._float8_e8m0fnu_dtype: ser_flatbuf.DType.f8_e8m0fnu,
dtypes._float4_e2m1fn_dtype: ser_flatbuf.DType.f4_e2m1fn,
}
if dtypes._float8_e3m4_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4
if dtypes._float8_e4m3_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3
if dtypes._float8_e8m0fnu_dtype is not None:
_dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu
if dtypes._float4_e2m1fn_dtype is not None:
_dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn
_dtype_kind_to_dtype = {
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
}

View File

@ -185,24 +185,14 @@ _dtype_to_ir_type : dict[np.dtype, Callable[[], ir.Type]] = {
np.dtype(np.float64): ir.F64Type.get,
np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
np.dtype(dtypes.int2): partial(ir.IntegerType.get_signless, 2),
np.dtype(dtypes.uint2): partial(ir.IntegerType.get_unsigned, 2),
np.dtype(dtypes.float8_e3m4): ir.Float8E3M4Type.get,
np.dtype(dtypes.float8_e4m3): ir.Float8E4M3Type.get,
np.dtype(dtypes.float8_e8m0fnu): ir.Float8E8M0FNUType.get,
np.dtype(dtypes.float4_e2m1fn): ir.Float4E2M1FNType.get,
}
if dtypes.int2 is not None:
assert dtypes.uint2 is not None
_dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2)
_dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2)
if dtypes.float8_e3m4 is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get
if dtypes.float8_e4m3 is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get
if dtypes.float8_e8m0fnu is not None:
_dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get
if dtypes.float4_e2m1fn is not None:
_dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get
def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
if isinstance(dtype, core.bint):
# TODO Support different-size underlying dtypes to take advantage of the

View File

@ -2346,13 +2346,10 @@ class DotAlgorithmPreset(enum.Enum):
np.dtype(dtypes.float8_e4m3fnuz),
np.dtype(dtypes.float8_e5m2),
np.dtype(dtypes.float8_e5m2fnuz),
np.dtype(dtypes.float8_e3m4),
np.dtype(dtypes.float8_e4m3),
np.dtype(dtypes.float8_e8m0fnu),
]
if dtypes.float8_e3m4 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
if dtypes.float8_e4m3 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
if dtypes.float8_e8m0fnu is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)]
if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes:
raise ValueError(
f"The dot algorithm '{self}' requires both inputs to have float8 "
@ -5602,13 +5599,9 @@ def accuracy_attr(accuracy) -> hlo.ResultAccuracyAttr:
def _handle_dot_precision(ctx, lhs, rhs, precision, platform):
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
if dtypes.float8_e3m4 is not None:
fp8_dtypes += (dtypes.float8_e3m4,)
if dtypes.float8_e4m3 is not None:
fp8_dtypes += (dtypes.float8_e4m3,)
if dtypes.float8_e8m0fnu is not None:
fp8_dtypes += (dtypes.float8_e8m0fnu,)
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz,
dtypes.float8_e3m4, dtypes.float8_e4m3,
dtypes.float8_e8m0fnu)
return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes
# The *_ lets us reuse this for ragged_dot_general, which has group_sizes.

View File

@ -68,33 +68,27 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
return meta
bool_ = _make_scalar_type(np.bool_)
if dtypes.uint2 is not None:
uint2 = _make_scalar_type(dtypes.uint2)
uint2 = _make_scalar_type(dtypes.uint2)
uint4 = _make_scalar_type(dtypes.uint4)
uint8 = _make_scalar_type(np.uint8)
uint16 = _make_scalar_type(np.uint16)
uint32 = _make_scalar_type(np.uint32)
uint64 = _make_scalar_type(np.uint64)
if dtypes.int2 is not None:
int2 = _make_scalar_type(dtypes.int2)
int2 = _make_scalar_type(dtypes.int2)
int4 = _make_scalar_type(dtypes.int4)
int8 = _make_scalar_type(np.int8)
int16 = _make_scalar_type(np.int16)
int32 = _make_scalar_type(np.int32)
int64 = _make_scalar_type(np.int64)
if dtypes.float8_e3m4 is not None:
float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4)
if dtypes.float8_e4m3 is not None:
float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3)
if dtypes.float8_e8m0fnu is not None:
float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu)
float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn)
float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4)
float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3)
float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu)
float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn)
float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz)
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)
float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz)
float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz)
if dtypes.float4_e2m1fn is not None:
float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn)
bfloat16 = _make_scalar_type(dtypes.bfloat16)
float16 = _make_scalar_type(np.float16)
float32 = single = _make_scalar_type(np.float32)

View File

@ -46,16 +46,22 @@ ToleranceDict: TypeAlias = dict[np.dtype, int | float]
_default_tolerance: ToleranceDict = {
_dtypes.float0: 0,
np.dtype(np.bool_): 0,
np.dtype(_dtypes.int2): 0,
np.dtype(_dtypes.int4): 0,
np.dtype(np.int8): 0,
np.dtype(np.int16): 0,
np.dtype(np.int32): 0,
np.dtype(np.int64): 0,
np.dtype(_dtypes.uint2): 0,
np.dtype(_dtypes.uint4): 0,
np.dtype(np.uint8): 0,
np.dtype(np.uint16): 0,
np.dtype(np.uint32): 0,
np.dtype(np.uint64): 0,
np.dtype(_dtypes.float4_e2m1fn): 1e0,
np.dtype(_dtypes.float8_e3m4): 1e-1,
np.dtype(_dtypes.float8_e4m3): 1e-1,
np.dtype(_dtypes.float8_e8m0fnu): 1e0,
np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1,
np.dtype(_dtypes.float8_e4m3fn): 1e-1,
np.dtype(_dtypes.float8_e4m3fnuz): 1e-1,
@ -69,16 +75,15 @@ _default_tolerance: ToleranceDict = {
np.dtype(np.complex128): 1e-15,
}
if _dtypes.int2 is not None:
assert _dtypes.uint2 is not None
_default_tolerance[np.dtype(_dtypes.int2)] = 0
_default_tolerance[np.dtype(_dtypes.uint2)] = 0
def default_tolerance():
return _default_tolerance
default_gradient_tolerance: ToleranceDict = {
np.dtype(_dtypes.float4_e2m1fn): 1e0,
np.dtype(_dtypes.float8_e3m4): 1e-1,
np.dtype(_dtypes.float8_e4m3): 1e-1,
np.dtype(_dtypes.float8_e8m0fnu): 1e0,
np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1,
np.dtype(_dtypes.float8_e4m3fn): 1e-1,
np.dtype(_dtypes.float8_e4m3fnuz): 1e-1,
@ -92,19 +97,6 @@ default_gradient_tolerance: ToleranceDict = {
np.dtype(np.complex128): 1e-5,
}
# TODO: make this unconditional when ml_dtypes>=0.5.0 is required
if _dtypes.float8_e3m4 is not None:
_default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1
default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1
if _dtypes.float8_e4m3 is not None:
_default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
if _dtypes.float8_e8m0fnu is not None:
_default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
if _dtypes.float4_e2m1fn is not None:
_default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0
default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0
def is_python_scalar(val: Any) -> bool:
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))
@ -115,6 +107,10 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
return
custom_float_dtypes = [
_dtypes.float4_e2m1fn,
_dtypes.float8_e8m0fnu,
_dtypes.float8_e3m4,
_dtypes.float8_e4m3,
_dtypes.float8_e4m3b11fnuz,
_dtypes.float8_e4m3fn,
_dtypes.float8_e4m3fnuz,
@ -123,15 +119,6 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
_dtypes.bfloat16,
]
if _dtypes.float8_e4m3 is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e4m3)
if _dtypes.float8_e3m4 is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e3m4)
if _dtypes.float8_e8m0fnu is not None:
custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu)
if _dtypes.float4_e2m1fn is not None:
custom_float_dtypes.insert(0, _dtypes.float4_e2m1fn)
def maybe_upcast(x):
if x.dtype in custom_float_dtypes:
return x.astype(np.float32)

View File

@ -1632,15 +1632,11 @@ class _LazyDtypes:
_dtypes.float8_e4m3fnuz,
_dtypes.float8_e5m2,
_dtypes.float8_e5m2fnuz,
_dtypes.float8_e3m4,
_dtypes.float8_e4m3,
_dtypes.float8_e8m0fnu,
_dtypes.float4_e2m1fn,
]
if _dtypes.float8_e3m4 is not None:
float_dtypes += [_dtypes.float8_e3m4]
if _dtypes.float8_e4m3 is not None:
float_dtypes += [_dtypes.float8_e4m3]
if _dtypes.float8_e8m0fnu is not None:
float_dtypes += [_dtypes.float8_e8m0fnu]
if _dtypes.float4_e2m1fn is not None:
float_dtypes += [_dtypes.float4_e2m1fn]
return self.supported(float_dtypes)
@_cached_property

View File

@ -211,13 +211,18 @@ from jax._src.numpy.scalar_types import (
double as double,
float16 as float16,
float32 as float32,
float4_e2m1fn as float4_e2m1fn,
float64 as float64,
float8_e3m4 as float8_e3m4,
float8_e4m3 as float8_e4m3,
float8_e4m3b11fnuz as float8_e4m3b11fnuz,
float8_e4m3fn as float8_e4m3fn,
float8_e4m3fnuz as float8_e4m3fnuz,
float8_e5m2 as float8_e5m2,
float8_e5m2fnuz as float8_e5m2fnuz,
float8_e8m0fnu as float8_e8m0fnu,
float_ as float_,
int2 as int2,
int4 as int4,
int8 as int8,
int16 as int16,
@ -226,6 +231,7 @@ from jax._src.numpy.scalar_types import (
int_ as int_,
single as single,
uint as uint,
uint2 as uint2,
uint4 as uint4,
uint8 as uint8,
uint16 as uint16,
@ -295,26 +301,6 @@ from numpy import (
unsignedinteger as unsignedinteger,
)
# TODO(slebedev): Remove the try-except once we upgrade to ml_dtypes 0.4.1.
try:
from jax._src.numpy.scalar_types import (
int2 as int2,
uint2 as uint2,
)
except ImportError:
pass
# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0
try:
from jax._src.numpy.scalar_types import (
float8_e3m4 as float8_e3m4,
float8_e4m3 as float8_e4m3,
float8_e8m0fnu as float8_e8m0fnu,
float4_e2m1fn as float4_e2m1fn,
)
except ImportError:
pass
from jax._src.numpy.array_api_metadata import (
__array_api_version__ as __array_api_version__,
__array_namespace_info__ as __array_namespace_info__,

View File

@ -240,16 +240,12 @@ def parse_shape_str(s):
_DT = {
'pred': jnp.bool_,
'u4': jnp.uint4, 'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64,
's4': jnp.int4, 's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64,
'u2': jnp.uint2, 'u4': jnp.uint4, 'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64,
's2': jnp.int2, 's4': jnp.int4, 's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64,
'bf16': jnp.bfloat16,
'f16': jnp.float16, 'f32': jnp.float32, 'f64': jnp.float64,
'c64': jnp.complex64, 'c128': jnp.complex128
}
if hasattr(jnp, 'int2'):
_DT['s2'] = jnp.int2
if hasattr(jnp, 'uint2'):
_DT['u2'] = jnp.uint2
_SHAPE_RE = re.compile(f"^({'|'.join(_DT)})\\[\\s*(\\d*[\\s*,\\d+]*)\\s*\\]$")

View File

@ -238,13 +238,10 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
PrimitiveType = _xla.PrimitiveType
bfloat16 = ml_dtypes.bfloat16
# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
# Also, it would be better to conditionally import these based on whether they
# are in the current version of ml_dtypes.
# float4_e2m1fn = ml_dtypes.float4_e2m1fn
# float8_e3m4 = ml_dtypes.float8_e3m4
# float8_e4m3 = ml_dtypes.float8_e4m3
# float8_e8m0fnu = ml_dtypes.float8_e8m0fnu
float4_e2m1fn = ml_dtypes.float4_e2m1fn
float8_e3m4 = ml_dtypes.float8_e3m4
float8_e4m3 = ml_dtypes.float8_e4m3
float8_e8m0fnu = ml_dtypes.float8_e8m0fnu
float8_e4m3fn = ml_dtypes.float8_e4m3fn
float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz
float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz

View File

@ -46,30 +46,19 @@ np_unsigned_dtypes = [np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'),
np.dtype('uint64')]
unsigned_dtypes = list(np_unsigned_dtypes)
intn_dtypes = [np.dtype('int4'), np.dtype('uint4')]
signed_dtypes += [np.dtype('int4')]
unsigned_dtypes += [np.dtype('uint4')]
if dtypes.int2 is not None:
assert dtypes.uint2 is not None
intn_dtypes[:0] = [np.dtype('int2'), np.dtype('uint2')]
signed_dtypes[:0] = [np.dtype('int2')]
unsigned_dtypes[:0] = [np.dtype('uint2')]
intn_dtypes = [np.dtype('int2'), np.dtype('uint2'), np.dtype('int4'), np.dtype('uint4')]
signed_dtypes += [np.dtype('int2'), np.dtype('int4')]
unsigned_dtypes += [np.dtype('uint2'), np.dtype('uint4')]
np_float_dtypes = [np.dtype('float16'), np.dtype('float32'),
np.dtype('float64')]
np_float_dtypes = [np.dtype('float16'), np.dtype('float32'), np.dtype('float64')]
float_dtypes = [np.dtype(dtypes.bfloat16)] + np_float_dtypes
custom_float_dtypes = [np.dtype(dtypes.bfloat16)]
fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn),
np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2),
np.dtype(dtypes.float8_e5m2fnuz)]
if dtypes.float8_e3m4 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
if dtypes.float8_e4m3 is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
if dtypes.float8_e8m0fnu is not None:
fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)]
np.dtype(dtypes.float8_e5m2fnuz), np.dtype(dtypes.float8_e3m4),
np.dtype(dtypes.float8_e4m3), np.dtype(dtypes.float8_e8m0fnu)]
float_dtypes += fp8_dtypes
custom_float_dtypes += fp8_dtypes

View File

@ -114,15 +114,13 @@ class JaxToIRTest(absltest.TestCase):
self.assertParsedShape('f32[]', [], jnp.float32)
self.assertParsedShape('f32[1,2,3]', [1, 2, 3], jnp.float32)
self.assertParsedShape('pred[1]', [1], jnp.bool_)
if hasattr(jnp, 'int2'):
self.assertParsedShape('s2[1]', [1], jnp.int2)
self.assertParsedShape('s2[1]', [1], jnp.int2)
self.assertParsedShape('s4[1]', [1], jnp.int4)
self.assertParsedShape('s8[1]', [1], jnp.int8)
self.assertParsedShape('s16[1]', [1], jnp.int16)
self.assertParsedShape('s32[1]', [1], jnp.int32)
self.assertParsedShape('s64[1]', [1], jnp.int64)
if hasattr(jnp, 'uint2'):
self.assertParsedShape('u2[1]', [1], jnp.uint2)
self.assertParsedShape('u2[1]', [1], jnp.uint2)
self.assertParsedShape('u4[1]', [1], jnp.uint4)
self.assertParsedShape('u8[1]', [1], jnp.uint8)
self.assertParsedShape('u16[1]', [1], jnp.uint16)