mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
cleanup now that we depend on ml_dtypes>=0.5
This commit is contained in:
parent
e679811c4a
commit
431c2c0807
@ -90,19 +90,18 @@ class ExtendedDType(StrictABC):
|
||||
|
||||
|
||||
# fp8 support
|
||||
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
|
||||
float8_e3m4: type[np.generic] | None = None
|
||||
float8_e4m3: type[np.generic] | None = None
|
||||
float8_e8m0fnu: type[np.generic] | None = None
|
||||
float8_e3m4: type[np.generic] = ml_dtypes.float8_e3m4
|
||||
float8_e4m3: type[np.generic] = ml_dtypes.float8_e4m3
|
||||
float8_e8m0fnu: type[np.generic] = ml_dtypes.float8_e8m0fnu
|
||||
float8_e4m3b11fnuz: type[np.generic] = ml_dtypes.float8_e4m3b11fnuz
|
||||
float8_e4m3fn: type[np.generic] = ml_dtypes.float8_e4m3fn
|
||||
float8_e4m3fnuz: type[np.generic] = ml_dtypes.float8_e4m3fnuz
|
||||
float8_e5m2: type[np.generic] = ml_dtypes.float8_e5m2
|
||||
float8_e5m2fnuz: type[np.generic] = ml_dtypes.float8_e5m2fnuz
|
||||
|
||||
_float8_e3m4_dtype: np.dtype | None = None
|
||||
_float8_e4m3_dtype: np.dtype | None = None
|
||||
_float8_e8m0fnu_dtype: np.dtype | None = None
|
||||
_float8_e3m4_dtype: np.dtype = np.dtype(float8_e3m4)
|
||||
_float8_e4m3_dtype: np.dtype = np.dtype(float8_e4m3)
|
||||
_float8_e8m0fnu_dtype: np.dtype = np.dtype(float8_e8m0fnu)
|
||||
_float8_e4m3b11fnuz_dtype: np.dtype = np.dtype(float8_e4m3b11fnuz)
|
||||
_float8_e4m3fn_dtype: np.dtype = np.dtype(float8_e4m3fn)
|
||||
_float8_e4m3fnuz_dtype: np.dtype = np.dtype(float8_e4m3fnuz)
|
||||
@ -111,9 +110,9 @@ _float8_e5m2fnuz_dtype: np.dtype = np.dtype(float8_e5m2fnuz)
|
||||
|
||||
# fp4 support
|
||||
# TODO: remove Optional when minimum ml_dtypes version >= 0.5.0
|
||||
float4_e2m1fn: type[np.generic] | None = None
|
||||
float4_e2m1fn: type[np.generic] = ml_dtypes.float4_e2m1fn
|
||||
|
||||
_float4_e2m1fn_dtype: np.dtype | None = None
|
||||
_float4_e2m1fn_dtype: np.dtype = np.dtype(float4_e2m1fn)
|
||||
|
||||
def supports_inf(dtype: DTypeLike) -> bool:
|
||||
"""Return true if the dtype supports infinity, else return False."""
|
||||
@ -127,6 +126,10 @@ bfloat16: type[np.generic] = ml_dtypes.bfloat16
|
||||
_bfloat16_dtype: np.dtype = np.dtype(bfloat16)
|
||||
|
||||
_custom_float_scalar_types = [
|
||||
float4_e2m1fn,
|
||||
float8_e3m4,
|
||||
float8_e4m3,
|
||||
float8_e8m0fnu,
|
||||
float8_e4m3b11fnuz,
|
||||
float8_e4m3fn,
|
||||
float8_e4m3fnuz,
|
||||
@ -135,6 +138,10 @@ _custom_float_scalar_types = [
|
||||
bfloat16,
|
||||
]
|
||||
_custom_float_dtypes = [
|
||||
_float4_e2m1fn_dtype,
|
||||
_float8_e3m4_dtype,
|
||||
_float8_e4m3_dtype,
|
||||
_float8_e8m0fnu_dtype,
|
||||
_float8_e4m3b11fnuz_dtype,
|
||||
_float8_e4m3fn_dtype,
|
||||
_float8_e4m3fnuz_dtype,
|
||||
@ -143,6 +150,9 @@ _custom_float_dtypes = [
|
||||
_bfloat16_dtype,
|
||||
]
|
||||
_float8_dtypes = [
|
||||
_float8_e3m4_dtype,
|
||||
_float8_e4m3_dtype,
|
||||
_float8_e8m0fnu_dtype,
|
||||
_float8_e4m3b11fnuz_dtype,
|
||||
_float8_e4m3fn_dtype,
|
||||
_float8_e4m3fnuz_dtype,
|
||||
@ -150,58 +160,28 @@ _float8_dtypes = [
|
||||
_float8_e5m2fnuz_dtype,
|
||||
]
|
||||
|
||||
_float4_dtypes: list[np.dtype] = []
|
||||
_float4_dtypes: list[np.dtype] = [
|
||||
_float4_e2m1fn_dtype,
|
||||
]
|
||||
|
||||
# TODO: remove the if statements below when minimum ml_dtypes version >= 0.5.0
|
||||
if hasattr(ml_dtypes, "float8_e4m3"):
|
||||
float8_e4m3 = ml_dtypes.float8_e4m3
|
||||
_float8_e4m3_dtype = np.dtype(float8_e4m3)
|
||||
_custom_float_scalar_types.insert(0, float8_e4m3) # type: ignore[arg-type]
|
||||
_custom_float_dtypes.insert(0, _float8_e4m3_dtype)
|
||||
_float8_dtypes.insert(0, _float8_e4m3_dtype)
|
||||
if hasattr(ml_dtypes, "float8_e3m4"):
|
||||
float8_e3m4 = ml_dtypes.float8_e3m4
|
||||
_float8_e3m4_dtype = np.dtype(float8_e3m4)
|
||||
_custom_float_scalar_types.insert(0, float8_e3m4) # type: ignore[arg-type]
|
||||
_custom_float_dtypes.insert(0, _float8_e3m4_dtype)
|
||||
_float8_dtypes.insert(0, _float8_e3m4_dtype)
|
||||
if hasattr(ml_dtypes, "float8_e8m0fnu"):
|
||||
float8_e8m0fnu = ml_dtypes.float8_e8m0fnu
|
||||
_float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu)
|
||||
_custom_float_scalar_types.insert(0, float8_e8m0fnu) # type: ignore[arg-type]
|
||||
_custom_float_dtypes.insert(0, _float8_e8m0fnu_dtype)
|
||||
_float8_dtypes.insert(0, _float8_e8m0fnu_dtype)
|
||||
if hasattr(ml_dtypes, "float4_e2m1fn"):
|
||||
float4_e2m1fn = ml_dtypes.float4_e2m1fn
|
||||
_float4_e2m1fn_dtype = np.dtype(float4_e2m1fn)
|
||||
_custom_float_scalar_types.insert(0, float4_e2m1fn) # type: ignore[arg-type]
|
||||
_custom_float_dtypes.insert(0, _float4_e2m1fn_dtype)
|
||||
_float4_dtypes.insert(0, _float4_e2m1fn_dtype)
|
||||
int2: type[np.generic] = ml_dtypes.int2
|
||||
uint2: type[np.generic] = ml_dtypes.uint2
|
||||
|
||||
# 2-bit integer support
|
||||
int2: type[np.generic] | None = None
|
||||
uint2: type[np.generic] | None = None
|
||||
|
||||
_int2_dtype: np.dtype | None = None
|
||||
_uint2_dtype: np.dtype | None = None
|
||||
|
||||
_intn_dtypes = []
|
||||
|
||||
# Remove the condition once the minimum ml_dtypes version required by JAX
|
||||
# contains https://github.com/jax-ml/ml_dtypes/pull/154.
|
||||
if hasattr(ml_dtypes, 'int2'):
|
||||
int2 = ml_dtypes.int2
|
||||
uint2 = ml_dtypes.uint2
|
||||
_int2_dtype = np.dtype(int2)
|
||||
_uint2_dtype = np.dtype(uint2)
|
||||
_intn_dtypes.extend([_int2_dtype, _uint2_dtype])
|
||||
_int2_dtype: np.dtype = np.dtype(int2)
|
||||
_uint2_dtype: np.dtype = np.dtype(uint2)
|
||||
|
||||
# 4-bit integer support
|
||||
int4: type[np.generic] = ml_dtypes.int4
|
||||
uint4: type[np.generic] = ml_dtypes.uint4
|
||||
_int4_dtype = np.dtype(int4)
|
||||
_uint4_dtype = np.dtype(uint4)
|
||||
_intn_dtypes.extend([_int4_dtype, _uint4_dtype])
|
||||
|
||||
_intn_dtypes = [
|
||||
_int2_dtype,
|
||||
_uint2_dtype,
|
||||
_int4_dtype,
|
||||
_uint4_dtype,
|
||||
]
|
||||
|
||||
# Default types.
|
||||
bool_ = np.bool_
|
||||
@ -472,9 +452,9 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
|
||||
# to the normal scalar type hierarchy.
|
||||
if a_sctype in _custom_float_scalar_types:
|
||||
return b_sctype in {a_sctype, np.floating, np.inexact, np.number, np.generic}
|
||||
if (int2 is not None and a_sctype == int2) or a_sctype == int4:
|
||||
if a_sctype in [int2, int4]:
|
||||
return b_sctype in {a_sctype, np.signedinteger, np.integer, np.number, np.generic}
|
||||
if (uint2 is not None and a_sctype == uint2) or a_sctype == uint4:
|
||||
if a_sctype in [uint2, uint4]:
|
||||
return b_sctype in {a_sctype, np.unsignedinteger, np.integer, np.number, np.generic}
|
||||
|
||||
# Otherwise, fall back to numpy.issubdtype
|
||||
@ -491,6 +471,7 @@ _signed_types: list[JAXType]
|
||||
_unsigned_types: list[JAXType]
|
||||
_int_types: list[JAXType]
|
||||
_unsigned_types = [
|
||||
np.dtype(uint2),
|
||||
np.dtype(uint4),
|
||||
np.dtype('uint8'),
|
||||
np.dtype('uint16'),
|
||||
@ -498,6 +479,7 @@ _unsigned_types = [
|
||||
np.dtype('uint64'),
|
||||
]
|
||||
_signed_types = [
|
||||
np.dtype(int2),
|
||||
np.dtype(int4),
|
||||
np.dtype('int8'),
|
||||
np.dtype('int16'),
|
||||
@ -505,11 +487,6 @@ _signed_types = [
|
||||
np.dtype('int64'),
|
||||
]
|
||||
|
||||
if _int2_dtype is not None:
|
||||
_signed_types.insert(0, _int2_dtype)
|
||||
if _uint2_dtype is not None:
|
||||
_unsigned_types.insert(0, _uint2_dtype)
|
||||
|
||||
_int_types = _unsigned_types + _signed_types
|
||||
|
||||
_float_types: list[JAXType] = [
|
||||
@ -622,11 +599,7 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis
|
||||
This DAG maps each type to its immediately higher type on the lattice.
|
||||
"""
|
||||
b1, = _bool_types
|
||||
if _int2_dtype is not None:
|
||||
assert _uint2_dtype is not None
|
||||
_uint2, uint4, u1, u2, u4, u8, _int2, int4, i1, i2, i4, i8 = _int_types
|
||||
else:
|
||||
uint4, u1, u2, u4, u8, int4, i1, i2, i4, i8 = _int_types
|
||||
uint2, uint4, u1, u2, u4, u8, int2, int4, i1, i2, i4, i8 = _int_types
|
||||
*f1_types, bf, f2, f4, f8 = _float_types
|
||||
c4, c8 = _complex_types
|
||||
i_, f_, c_ = _weak_types
|
||||
@ -634,19 +607,13 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis
|
||||
out: dict[JAXType, list[JAXType]]
|
||||
out = {
|
||||
b1: [i_],
|
||||
i_: [u1, uint4, i1, int4],
|
||||
uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
|
||||
int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
|
||||
i_: [u1, uint2, uint4, i1, int2, int4],
|
||||
uint2: [], uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
|
||||
int2: [], int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
|
||||
f_: [*f1_types, bf, f2, c_],
|
||||
**{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
|
||||
c_: [c4], c4: [c8], c8: [],
|
||||
}
|
||||
if _int2_dtype is not None:
|
||||
out[i_].append(_int2_dtype)
|
||||
out[_int2_dtype] = []
|
||||
if _uint2_dtype is not None:
|
||||
out[i_].append(_uint2_dtype)
|
||||
out[_uint2_dtype] = []
|
||||
return out
|
||||
elif jax_numpy_dtype_promotion == 'strict':
|
||||
return {
|
||||
|
@ -357,16 +357,12 @@ _dtype_to_dtype_kind = {
|
||||
dtypes._float8_e4m3fnuz_dtype: ser_flatbuf.DType.f8_e4m3fnuz,
|
||||
dtypes._float8_e5m2_dtype: ser_flatbuf.DType.f8_e5m2,
|
||||
dtypes._float8_e5m2fnuz_dtype: ser_flatbuf.DType.f8_e5m2fnuz,
|
||||
dtypes._float8_e3m4_dtype: ser_flatbuf.DType.f8_e3m4,
|
||||
dtypes._float8_e4m3_dtype: ser_flatbuf.DType.f8_e4m3,
|
||||
dtypes._float8_e8m0fnu_dtype: ser_flatbuf.DType.f8_e8m0fnu,
|
||||
dtypes._float4_e2m1fn_dtype: ser_flatbuf.DType.f4_e2m1fn,
|
||||
}
|
||||
|
||||
if dtypes._float8_e3m4_dtype is not None:
|
||||
_dtype_to_dtype_kind[dtypes._float8_e3m4_dtype] = ser_flatbuf.DType.f8_e3m4
|
||||
if dtypes._float8_e4m3_dtype is not None:
|
||||
_dtype_to_dtype_kind[dtypes._float8_e4m3_dtype] = ser_flatbuf.DType.f8_e4m3
|
||||
if dtypes._float8_e8m0fnu_dtype is not None:
|
||||
_dtype_to_dtype_kind[dtypes._float8_e8m0fnu_dtype] = ser_flatbuf.DType.f8_e8m0fnu
|
||||
if dtypes._float4_e2m1fn_dtype is not None:
|
||||
_dtype_to_dtype_kind[dtypes._float4_e2m1fn_dtype] = ser_flatbuf.DType.f4_e2m1fn
|
||||
_dtype_kind_to_dtype = {
|
||||
kind: dtype for dtype, kind in _dtype_to_dtype_kind.items()
|
||||
}
|
||||
|
@ -185,24 +185,14 @@ _dtype_to_ir_type : dict[np.dtype, Callable[[], ir.Type]] = {
|
||||
np.dtype(np.float64): ir.F64Type.get,
|
||||
np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
|
||||
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
|
||||
np.dtype(dtypes.int2): partial(ir.IntegerType.get_signless, 2),
|
||||
np.dtype(dtypes.uint2): partial(ir.IntegerType.get_unsigned, 2),
|
||||
np.dtype(dtypes.float8_e3m4): ir.Float8E3M4Type.get,
|
||||
np.dtype(dtypes.float8_e4m3): ir.Float8E4M3Type.get,
|
||||
np.dtype(dtypes.float8_e8m0fnu): ir.Float8E8M0FNUType.get,
|
||||
np.dtype(dtypes.float4_e2m1fn): ir.Float4E2M1FNType.get,
|
||||
}
|
||||
|
||||
|
||||
if dtypes.int2 is not None:
|
||||
assert dtypes.uint2 is not None
|
||||
_dtype_to_ir_type[np.dtype(dtypes.int2)] = partial(ir.IntegerType.get_signless, 2)
|
||||
_dtype_to_ir_type[np.dtype(dtypes.uint2)] = partial(ir.IntegerType.get_unsigned, 2)
|
||||
|
||||
if dtypes.float8_e3m4 is not None:
|
||||
_dtype_to_ir_type[np.dtype(dtypes.float8_e3m4)] = ir.Float8E3M4Type.get
|
||||
if dtypes.float8_e4m3 is not None:
|
||||
_dtype_to_ir_type[np.dtype(dtypes.float8_e4m3)] = ir.Float8E4M3Type.get
|
||||
if dtypes.float8_e8m0fnu is not None:
|
||||
_dtype_to_ir_type[np.dtype(dtypes.float8_e8m0fnu)] = ir.Float8E8M0FNUType.get
|
||||
|
||||
if dtypes.float4_e2m1fn is not None:
|
||||
_dtype_to_ir_type[np.dtype(dtypes.float4_e2m1fn)] = ir.Float4E2M1FNType.get
|
||||
|
||||
def dtype_to_ir_type(dtype: core.bint | np.dtype | np.generic) -> ir.Type:
|
||||
if isinstance(dtype, core.bint):
|
||||
# TODO Support different-size underlying dtypes to take advantage of the
|
||||
|
@ -2346,13 +2346,10 @@ class DotAlgorithmPreset(enum.Enum):
|
||||
np.dtype(dtypes.float8_e4m3fnuz),
|
||||
np.dtype(dtypes.float8_e5m2),
|
||||
np.dtype(dtypes.float8_e5m2fnuz),
|
||||
np.dtype(dtypes.float8_e3m4),
|
||||
np.dtype(dtypes.float8_e4m3),
|
||||
np.dtype(dtypes.float8_e8m0fnu),
|
||||
]
|
||||
if dtypes.float8_e3m4 is not None:
|
||||
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
|
||||
if dtypes.float8_e4m3 is not None:
|
||||
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
|
||||
if dtypes.float8_e8m0fnu is not None:
|
||||
fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)]
|
||||
if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes:
|
||||
raise ValueError(
|
||||
f"The dot algorithm '{self}' requires both inputs to have float8 "
|
||||
@ -5602,13 +5599,9 @@ def accuracy_attr(accuracy) -> hlo.ResultAccuracyAttr:
|
||||
def _handle_dot_precision(ctx, lhs, rhs, precision, platform):
|
||||
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
|
||||
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
|
||||
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
|
||||
if dtypes.float8_e3m4 is not None:
|
||||
fp8_dtypes += (dtypes.float8_e3m4,)
|
||||
if dtypes.float8_e4m3 is not None:
|
||||
fp8_dtypes += (dtypes.float8_e4m3,)
|
||||
if dtypes.float8_e8m0fnu is not None:
|
||||
fp8_dtypes += (dtypes.float8_e8m0fnu,)
|
||||
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz,
|
||||
dtypes.float8_e3m4, dtypes.float8_e4m3,
|
||||
dtypes.float8_e8m0fnu)
|
||||
return _lhs_dtypes in fp8_dtypes and _rhs_dtypes in fp8_dtypes
|
||||
|
||||
# The *_ lets us reuse this for ragged_dot_general, which has group_sizes.
|
||||
|
@ -68,33 +68,27 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
|
||||
return meta
|
||||
|
||||
bool_ = _make_scalar_type(np.bool_)
|
||||
if dtypes.uint2 is not None:
|
||||
uint2 = _make_scalar_type(dtypes.uint2)
|
||||
uint2 = _make_scalar_type(dtypes.uint2)
|
||||
uint4 = _make_scalar_type(dtypes.uint4)
|
||||
uint8 = _make_scalar_type(np.uint8)
|
||||
uint16 = _make_scalar_type(np.uint16)
|
||||
uint32 = _make_scalar_type(np.uint32)
|
||||
uint64 = _make_scalar_type(np.uint64)
|
||||
if dtypes.int2 is not None:
|
||||
int2 = _make_scalar_type(dtypes.int2)
|
||||
int2 = _make_scalar_type(dtypes.int2)
|
||||
int4 = _make_scalar_type(dtypes.int4)
|
||||
int8 = _make_scalar_type(np.int8)
|
||||
int16 = _make_scalar_type(np.int16)
|
||||
int32 = _make_scalar_type(np.int32)
|
||||
int64 = _make_scalar_type(np.int64)
|
||||
if dtypes.float8_e3m4 is not None:
|
||||
float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4)
|
||||
if dtypes.float8_e4m3 is not None:
|
||||
float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3)
|
||||
if dtypes.float8_e8m0fnu is not None:
|
||||
float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu)
|
||||
float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn)
|
||||
float8_e3m4 = _make_scalar_type(dtypes.float8_e3m4)
|
||||
float8_e4m3 = _make_scalar_type(dtypes.float8_e4m3)
|
||||
float8_e8m0fnu = _make_scalar_type(dtypes.float8_e8m0fnu)
|
||||
float8_e4m3fn = _make_scalar_type(dtypes.float8_e4m3fn)
|
||||
float8_e4m3fnuz = _make_scalar_type(dtypes.float8_e4m3fnuz)
|
||||
float8_e5m2 = _make_scalar_type(dtypes.float8_e5m2)
|
||||
float8_e5m2fnuz = _make_scalar_type(dtypes.float8_e5m2fnuz)
|
||||
float8_e4m3b11fnuz = _make_scalar_type(dtypes.float8_e4m3b11fnuz)
|
||||
if dtypes.float4_e2m1fn is not None:
|
||||
float4_e2m1fn = _make_scalar_type(dtypes.float4_e2m1fn)
|
||||
bfloat16 = _make_scalar_type(dtypes.bfloat16)
|
||||
float16 = _make_scalar_type(np.float16)
|
||||
float32 = single = _make_scalar_type(np.float32)
|
||||
|
@ -46,16 +46,22 @@ ToleranceDict: TypeAlias = dict[np.dtype, int | float]
|
||||
_default_tolerance: ToleranceDict = {
|
||||
_dtypes.float0: 0,
|
||||
np.dtype(np.bool_): 0,
|
||||
np.dtype(_dtypes.int2): 0,
|
||||
np.dtype(_dtypes.int4): 0,
|
||||
np.dtype(np.int8): 0,
|
||||
np.dtype(np.int16): 0,
|
||||
np.dtype(np.int32): 0,
|
||||
np.dtype(np.int64): 0,
|
||||
np.dtype(_dtypes.uint2): 0,
|
||||
np.dtype(_dtypes.uint4): 0,
|
||||
np.dtype(np.uint8): 0,
|
||||
np.dtype(np.uint16): 0,
|
||||
np.dtype(np.uint32): 0,
|
||||
np.dtype(np.uint64): 0,
|
||||
np.dtype(_dtypes.float4_e2m1fn): 1e0,
|
||||
np.dtype(_dtypes.float8_e3m4): 1e-1,
|
||||
np.dtype(_dtypes.float8_e4m3): 1e-1,
|
||||
np.dtype(_dtypes.float8_e8m0fnu): 1e0,
|
||||
np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1,
|
||||
np.dtype(_dtypes.float8_e4m3fn): 1e-1,
|
||||
np.dtype(_dtypes.float8_e4m3fnuz): 1e-1,
|
||||
@ -69,16 +75,15 @@ _default_tolerance: ToleranceDict = {
|
||||
np.dtype(np.complex128): 1e-15,
|
||||
}
|
||||
|
||||
if _dtypes.int2 is not None:
|
||||
assert _dtypes.uint2 is not None
|
||||
_default_tolerance[np.dtype(_dtypes.int2)] = 0
|
||||
_default_tolerance[np.dtype(_dtypes.uint2)] = 0
|
||||
|
||||
def default_tolerance():
|
||||
return _default_tolerance
|
||||
|
||||
|
||||
default_gradient_tolerance: ToleranceDict = {
|
||||
np.dtype(_dtypes.float4_e2m1fn): 1e0,
|
||||
np.dtype(_dtypes.float8_e3m4): 1e-1,
|
||||
np.dtype(_dtypes.float8_e4m3): 1e-1,
|
||||
np.dtype(_dtypes.float8_e8m0fnu): 1e0,
|
||||
np.dtype(_dtypes.float8_e4m3b11fnuz): 1e-1,
|
||||
np.dtype(_dtypes.float8_e4m3fn): 1e-1,
|
||||
np.dtype(_dtypes.float8_e4m3fnuz): 1e-1,
|
||||
@ -92,19 +97,6 @@ default_gradient_tolerance: ToleranceDict = {
|
||||
np.dtype(np.complex128): 1e-5,
|
||||
}
|
||||
|
||||
# TODO: make this unconditional when ml_dtypes>=0.5.0 is required
|
||||
if _dtypes.float8_e3m4 is not None:
|
||||
_default_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1
|
||||
default_gradient_tolerance[np.dtype(_dtypes.float8_e3m4)] = 1e-1
|
||||
if _dtypes.float8_e4m3 is not None:
|
||||
_default_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
|
||||
default_gradient_tolerance[np.dtype(_dtypes.float8_e4m3)] = 1e-1
|
||||
if _dtypes.float8_e8m0fnu is not None:
|
||||
_default_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
|
||||
default_gradient_tolerance[np.dtype(_dtypes.float8_e8m0fnu)] = 1e0
|
||||
if _dtypes.float4_e2m1fn is not None:
|
||||
_default_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0
|
||||
default_gradient_tolerance[np.dtype(_dtypes.float4_e2m1fn)] = 1e0
|
||||
|
||||
def is_python_scalar(val: Any) -> bool:
|
||||
return not isinstance(val, np.generic) and isinstance(val, (bool, int, float, complex))
|
||||
@ -115,6 +107,10 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
|
||||
return
|
||||
|
||||
custom_float_dtypes = [
|
||||
_dtypes.float4_e2m1fn,
|
||||
_dtypes.float8_e8m0fnu,
|
||||
_dtypes.float8_e3m4,
|
||||
_dtypes.float8_e4m3,
|
||||
_dtypes.float8_e4m3b11fnuz,
|
||||
_dtypes.float8_e4m3fn,
|
||||
_dtypes.float8_e4m3fnuz,
|
||||
@ -123,15 +119,6 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
|
||||
_dtypes.bfloat16,
|
||||
]
|
||||
|
||||
if _dtypes.float8_e4m3 is not None:
|
||||
custom_float_dtypes.insert(0, _dtypes.float8_e4m3)
|
||||
if _dtypes.float8_e3m4 is not None:
|
||||
custom_float_dtypes.insert(0, _dtypes.float8_e3m4)
|
||||
if _dtypes.float8_e8m0fnu is not None:
|
||||
custom_float_dtypes.insert(0, _dtypes.float8_e8m0fnu)
|
||||
if _dtypes.float4_e2m1fn is not None:
|
||||
custom_float_dtypes.insert(0, _dtypes.float4_e2m1fn)
|
||||
|
||||
def maybe_upcast(x):
|
||||
if x.dtype in custom_float_dtypes:
|
||||
return x.astype(np.float32)
|
||||
|
@ -1632,15 +1632,11 @@ class _LazyDtypes:
|
||||
_dtypes.float8_e4m3fnuz,
|
||||
_dtypes.float8_e5m2,
|
||||
_dtypes.float8_e5m2fnuz,
|
||||
_dtypes.float8_e3m4,
|
||||
_dtypes.float8_e4m3,
|
||||
_dtypes.float8_e8m0fnu,
|
||||
_dtypes.float4_e2m1fn,
|
||||
]
|
||||
if _dtypes.float8_e3m4 is not None:
|
||||
float_dtypes += [_dtypes.float8_e3m4]
|
||||
if _dtypes.float8_e4m3 is not None:
|
||||
float_dtypes += [_dtypes.float8_e4m3]
|
||||
if _dtypes.float8_e8m0fnu is not None:
|
||||
float_dtypes += [_dtypes.float8_e8m0fnu]
|
||||
if _dtypes.float4_e2m1fn is not None:
|
||||
float_dtypes += [_dtypes.float4_e2m1fn]
|
||||
return self.supported(float_dtypes)
|
||||
|
||||
@_cached_property
|
||||
|
@ -211,13 +211,18 @@ from jax._src.numpy.scalar_types import (
|
||||
double as double,
|
||||
float16 as float16,
|
||||
float32 as float32,
|
||||
float4_e2m1fn as float4_e2m1fn,
|
||||
float64 as float64,
|
||||
float8_e3m4 as float8_e3m4,
|
||||
float8_e4m3 as float8_e4m3,
|
||||
float8_e4m3b11fnuz as float8_e4m3b11fnuz,
|
||||
float8_e4m3fn as float8_e4m3fn,
|
||||
float8_e4m3fnuz as float8_e4m3fnuz,
|
||||
float8_e5m2 as float8_e5m2,
|
||||
float8_e5m2fnuz as float8_e5m2fnuz,
|
||||
float8_e8m0fnu as float8_e8m0fnu,
|
||||
float_ as float_,
|
||||
int2 as int2,
|
||||
int4 as int4,
|
||||
int8 as int8,
|
||||
int16 as int16,
|
||||
@ -226,6 +231,7 @@ from jax._src.numpy.scalar_types import (
|
||||
int_ as int_,
|
||||
single as single,
|
||||
uint as uint,
|
||||
uint2 as uint2,
|
||||
uint4 as uint4,
|
||||
uint8 as uint8,
|
||||
uint16 as uint16,
|
||||
@ -295,26 +301,6 @@ from numpy import (
|
||||
unsignedinteger as unsignedinteger,
|
||||
)
|
||||
|
||||
# TODO(slebedev): Remove the try-except once we upgrade to ml_dtypes 0.4.1.
|
||||
try:
|
||||
from jax._src.numpy.scalar_types import (
|
||||
int2 as int2,
|
||||
uint2 as uint2,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# TODO: Remove the try-except once we upgrade to ml_dtypes 0.5.0
|
||||
try:
|
||||
from jax._src.numpy.scalar_types import (
|
||||
float8_e3m4 as float8_e3m4,
|
||||
float8_e4m3 as float8_e4m3,
|
||||
float8_e8m0fnu as float8_e8m0fnu,
|
||||
float4_e2m1fn as float4_e2m1fn,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from jax._src.numpy.array_api_metadata import (
|
||||
__array_api_version__ as __array_api_version__,
|
||||
__array_namespace_info__ as __array_namespace_info__,
|
||||
|
@ -240,16 +240,12 @@ def parse_shape_str(s):
|
||||
|
||||
_DT = {
|
||||
'pred': jnp.bool_,
|
||||
'u4': jnp.uint4, 'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64,
|
||||
's4': jnp.int4, 's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64,
|
||||
'u2': jnp.uint2, 'u4': jnp.uint4, 'u8': jnp.uint8, 'u16': jnp.uint16, 'u32': jnp.uint32, 'u64': jnp.uint64,
|
||||
's2': jnp.int2, 's4': jnp.int4, 's8': jnp.int8, 's16': jnp.int16, 's32': jnp.int32, 's64': jnp.int64,
|
||||
'bf16': jnp.bfloat16,
|
||||
'f16': jnp.float16, 'f32': jnp.float32, 'f64': jnp.float64,
|
||||
'c64': jnp.complex64, 'c128': jnp.complex128
|
||||
}
|
||||
if hasattr(jnp, 'int2'):
|
||||
_DT['s2'] = jnp.int2
|
||||
if hasattr(jnp, 'uint2'):
|
||||
_DT['u2'] = jnp.uint2
|
||||
|
||||
_SHAPE_RE = re.compile(f"^({'|'.join(_DT)})\\[\\s*(\\d*[\\s*,\\d+]*)\\s*\\]$")
|
||||
|
||||
|
@ -238,13 +238,10 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
|
||||
PrimitiveType = _xla.PrimitiveType
|
||||
|
||||
bfloat16 = ml_dtypes.bfloat16
|
||||
# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.
|
||||
# Also, it would be better to conditionally import these based on whether they
|
||||
# are in the current version of ml_dtypes.
|
||||
# float4_e2m1fn = ml_dtypes.float4_e2m1fn
|
||||
# float8_e3m4 = ml_dtypes.float8_e3m4
|
||||
# float8_e4m3 = ml_dtypes.float8_e4m3
|
||||
# float8_e8m0fnu = ml_dtypes.float8_e8m0fnu
|
||||
float4_e2m1fn = ml_dtypes.float4_e2m1fn
|
||||
float8_e3m4 = ml_dtypes.float8_e3m4
|
||||
float8_e4m3 = ml_dtypes.float8_e4m3
|
||||
float8_e8m0fnu = ml_dtypes.float8_e8m0fnu
|
||||
float8_e4m3fn = ml_dtypes.float8_e4m3fn
|
||||
float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz
|
||||
float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz
|
||||
|
@ -46,30 +46,19 @@ np_unsigned_dtypes = [np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'),
|
||||
np.dtype('uint64')]
|
||||
unsigned_dtypes = list(np_unsigned_dtypes)
|
||||
|
||||
intn_dtypes = [np.dtype('int4'), np.dtype('uint4')]
|
||||
signed_dtypes += [np.dtype('int4')]
|
||||
unsigned_dtypes += [np.dtype('uint4')]
|
||||
if dtypes.int2 is not None:
|
||||
assert dtypes.uint2 is not None
|
||||
intn_dtypes[:0] = [np.dtype('int2'), np.dtype('uint2')]
|
||||
signed_dtypes[:0] = [np.dtype('int2')]
|
||||
unsigned_dtypes[:0] = [np.dtype('uint2')]
|
||||
intn_dtypes = [np.dtype('int2'), np.dtype('uint2'), np.dtype('int4'), np.dtype('uint4')]
|
||||
signed_dtypes += [np.dtype('int2'), np.dtype('int4')]
|
||||
unsigned_dtypes += [np.dtype('uint2'), np.dtype('uint4')]
|
||||
|
||||
np_float_dtypes = [np.dtype('float16'), np.dtype('float32'),
|
||||
np.dtype('float64')]
|
||||
np_float_dtypes = [np.dtype('float16'), np.dtype('float32'), np.dtype('float64')]
|
||||
|
||||
float_dtypes = [np.dtype(dtypes.bfloat16)] + np_float_dtypes
|
||||
custom_float_dtypes = [np.dtype(dtypes.bfloat16)]
|
||||
|
||||
fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz), np.dtype(dtypes.float8_e4m3fn),
|
||||
np.dtype(dtypes.float8_e4m3fnuz), np.dtype(dtypes.float8_e5m2),
|
||||
np.dtype(dtypes.float8_e5m2fnuz)]
|
||||
if dtypes.float8_e3m4 is not None:
|
||||
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
|
||||
if dtypes.float8_e4m3 is not None:
|
||||
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
|
||||
if dtypes.float8_e8m0fnu is not None:
|
||||
fp8_dtypes += [np.dtype(dtypes.float8_e8m0fnu)]
|
||||
np.dtype(dtypes.float8_e5m2fnuz), np.dtype(dtypes.float8_e3m4),
|
||||
np.dtype(dtypes.float8_e4m3), np.dtype(dtypes.float8_e8m0fnu)]
|
||||
float_dtypes += fp8_dtypes
|
||||
custom_float_dtypes += fp8_dtypes
|
||||
|
||||
|
@ -114,15 +114,13 @@ class JaxToIRTest(absltest.TestCase):
|
||||
self.assertParsedShape('f32[]', [], jnp.float32)
|
||||
self.assertParsedShape('f32[1,2,3]', [1, 2, 3], jnp.float32)
|
||||
self.assertParsedShape('pred[1]', [1], jnp.bool_)
|
||||
if hasattr(jnp, 'int2'):
|
||||
self.assertParsedShape('s2[1]', [1], jnp.int2)
|
||||
self.assertParsedShape('s2[1]', [1], jnp.int2)
|
||||
self.assertParsedShape('s4[1]', [1], jnp.int4)
|
||||
self.assertParsedShape('s8[1]', [1], jnp.int8)
|
||||
self.assertParsedShape('s16[1]', [1], jnp.int16)
|
||||
self.assertParsedShape('s32[1]', [1], jnp.int32)
|
||||
self.assertParsedShape('s64[1]', [1], jnp.int64)
|
||||
if hasattr(jnp, 'uint2'):
|
||||
self.assertParsedShape('u2[1]', [1], jnp.uint2)
|
||||
self.assertParsedShape('u2[1]', [1], jnp.uint2)
|
||||
self.assertParsedShape('u4[1]', [1], jnp.uint4)
|
||||
self.assertParsedShape('u8[1]', [1], jnp.uint8)
|
||||
self.assertParsedShape('u16[1]', [1], jnp.uint16)
|
||||
|
Loading…
x
Reference in New Issue
Block a user