mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Change case of typing.Dtype -> typing.DType
This follows the convention used in numpy.typing.DType.
This commit is contained in:
parent
1338864c1f
commit
5829c6ae9d
@ -165,11 +165,11 @@ JAX type annotations should in general indicate the **intent** of APIs, rather t
|
||||
Inputs to JAX functions and methods should be typed as permissively as is reasonable: for example, while shapes are typically tuples, functions that accept a shape should accept arbitrary sequences. Similarly, functions that accept a dtype need not require an instance of class `np.dtype`, but rather any dtype-convertible object. This might include strings, built-in scalar types, or scalar object constructors such as `np.float64` and `jnp.float64`. In order to make this as uniform as possible across the package, we will add a {mod}`jax.typing` module with common type specifications, starting with broad categories such as:
|
||||
|
||||
- `ArrayLike` would be a union of anything that can be implicitly converted into an array: for example, jax arrays, numpy arrays, JAX tracers, and python or numpy scalars
|
||||
- `DtypeLike` would be a union of anything that can be implicitly converted into a dtype: for example, numpy dtypes, numpy dtype objects, jax dtype objects, strings, and built-in types.
|
||||
- `DTypeLike` would be a union of anything that can be implicitly converted into a dtype: for example, numpy dtypes, numpy dtype objects, jax dtype objects, strings, and built-in types.
|
||||
- `ShapeLike` would be a union of anything that could be converted into a shape: for example, sequences of integer or integer-like objecs.
|
||||
- etc.
|
||||
|
||||
Note that these will in general be simpler than the equivalent protocols used in {mod}`numpy.typing`. For example, in the case of `DtypeLike`, JAX does not support structured dtypes, so JAX can use a simpler implementation. Similarly, in `ArrayLike`, JAX generally does not support list or tuple inputs in place of arrays, so the type definition will be simpler than the NumPy analog.
|
||||
Note that these will in general be simpler than the equivalent protocols used in {mod}`numpy.typing`. For example, in the case of `DTypeLike`, JAX does not support structured dtypes, so JAX can use a simpler implementation. Similarly, in `ArrayLike`, JAX generally does not support list or tuple inputs in place of arrays, so the type definition will be simpler than the NumPy analog.
|
||||
|
||||
#### Outputs should be strictly-typed
|
||||
|
||||
@ -313,7 +313,7 @@ To move forward with type annotations, we will do the following:
|
||||
|
||||
- Alias `Array = Any` for the time being, as this will take a bit more thought.
|
||||
- `ArrayLike`: a Union of types valid as inputs to normal `jax.numpy` functions
|
||||
- `Dtype` / `DtypeLike` (Check on capitalization of the `t`: what do other projects use?)
|
||||
- `DType` / `DTypeLike` (Note: numpy uses camel-cased `DType`; we should follow this convention for ease of use)
|
||||
- `Shape` / `NamedShape` / `ShapeLike`
|
||||
|
||||
The beginnings of this are done in {jax-issue}`#12300`.
|
||||
|
@ -66,7 +66,7 @@ from jax._src.lax.utils import (
|
||||
standard_translate,
|
||||
)
|
||||
from jax._src.lax import slicing
|
||||
from jax._src.typing import Array, ArrayLike, DtypeLike, Shape
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike, Shape
|
||||
|
||||
xb = xla_bridge
|
||||
xc = xla_client
|
||||
@ -532,7 +532,7 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array:
|
||||
r"""Elementwise less-than: :math:`x < y`."""
|
||||
return lt_p.bind(x, y)
|
||||
|
||||
def convert_element_type(operand: ArrayLike, new_dtype: DtypeLike) -> Array:
|
||||
def convert_element_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
|
||||
"""Elementwise cast.
|
||||
|
||||
Wraps XLA's `ConvertElementType
|
||||
@ -551,7 +551,7 @@ def convert_element_type(operand: ArrayLike, new_dtype: DtypeLike) -> Array:
|
||||
operand = operand.__jax_array__() # type: ignore
|
||||
return _convert_element_type(operand, new_dtype, weak_type=False)
|
||||
|
||||
def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DtypeLike] = None,
|
||||
def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DTypeLike] = None,
|
||||
weak_type: bool = False):
|
||||
# Don't canonicalize old_dtype because x64 context might cause
|
||||
# un-canonicalized operands to be passed in.
|
||||
@ -585,7 +585,7 @@ def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DtypeLike] = N
|
||||
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
|
||||
weak_type=new_weak_type)
|
||||
|
||||
def bitcast_convert_type(operand: ArrayLike, new_dtype: DtypeLike) -> Array:
|
||||
def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
|
||||
"""Elementwise bitcast.
|
||||
|
||||
Wraps XLA's `BitcastConvertType
|
||||
@ -693,7 +693,7 @@ PrecisionLike = Union[None, str, PrecisionType, Tuple[str, str],
|
||||
Tuple[PrecisionType, PrecisionType]]
|
||||
|
||||
def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
|
||||
preferred_element_type: Optional[DtypeLike] = None) -> Array:
|
||||
preferred_element_type: Optional[DTypeLike] = None) -> Array:
|
||||
"""Vector/vector, matrix/vector, and matrix/matrix multiplication.
|
||||
|
||||
Wraps XLA's `Dot
|
||||
@ -730,7 +730,7 @@ DotDimensionNumbers = Tuple[Tuple[Sequence[int], Sequence[int]],
|
||||
|
||||
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
|
||||
precision: PrecisionLike = None,
|
||||
preferred_element_type: Optional[DtypeLike] = None) -> Array:
|
||||
preferred_element_type: Optional[DTypeLike] = None) -> Array:
|
||||
"""More general contraction operator.
|
||||
|
||||
Wraps XLA's `DotGeneral
|
||||
@ -953,13 +953,13 @@ def transpose(operand: ArrayLike, permutation: Sequence[int]) -> Array:
|
||||
return transpose_p.bind(operand, permutation=permutation)
|
||||
|
||||
def argmin(operand: ArrayLike, axis: int,
|
||||
index_dtype: DtypeLike) -> Tuple[Array, Array]:
|
||||
index_dtype: DTypeLike) -> Tuple[Array, Array]:
|
||||
"""Computes the index of the minimum element along ``axis``."""
|
||||
return argmin_p.bind(operand, axes=(axis,),
|
||||
index_dtype=dtypes.canonicalize_dtype(index_dtype))
|
||||
|
||||
def argmax(operand: ArrayLike, axis: int,
|
||||
index_dtype: DtypeLike) -> Tuple[Array, Array]:
|
||||
index_dtype: DTypeLike) -> Tuple[Array, Array]:
|
||||
"""Computes the index of the maximum element along ``axis``."""
|
||||
return argmax_p.bind(operand, axes=(axis,),
|
||||
index_dtype=dtypes.canonicalize_dtype(index_dtype))
|
||||
@ -1058,13 +1058,13 @@ def _get_monoid_reducer(monoid_op: Callable,
|
||||
return _reduce_min if np.equal(aval.val, _get_min_identity(dtype)) else None
|
||||
return None
|
||||
|
||||
def _get_bitwise_and_identity(dtype: DtypeLike) -> np.ndarray:
|
||||
def _get_bitwise_and_identity(dtype: DTypeLike) -> np.ndarray:
|
||||
return np.array(-1, dtype)
|
||||
|
||||
def _get_bitwise_or_identity(dtype: DtypeLike) -> np.ndarray:
|
||||
def _get_bitwise_or_identity(dtype: DTypeLike) -> np.ndarray:
|
||||
return np.array(0, dtype)
|
||||
|
||||
def _get_max_identity(dtype: DtypeLike) -> np.ndarray:
|
||||
def _get_max_identity(dtype: DTypeLike) -> np.ndarray:
|
||||
if dtypes.issubdtype(dtype, np.inexact):
|
||||
return np.array(-np.inf, dtype)
|
||||
elif dtypes.issubdtype(dtype, np.integer):
|
||||
@ -1074,7 +1074,7 @@ def _get_max_identity(dtype: DtypeLike) -> np.ndarray:
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype for max: {dtype}")
|
||||
|
||||
def _get_min_identity(dtype: DtypeLike) -> np.ndarray:
|
||||
def _get_min_identity(dtype: DTypeLike) -> np.ndarray:
|
||||
if dtypes.issubdtype(dtype, np.inexact):
|
||||
return np.array(np.inf, dtype)
|
||||
elif dtypes.issubdtype(dtype, np.integer):
|
||||
@ -1159,7 +1159,7 @@ def tie_in(x: Any, y: T) -> T:
|
||||
"""Deprecated. Ignores ``x`` and returns ``y``."""
|
||||
return y
|
||||
|
||||
def full(shape: Shape, fill_value: ArrayLike, dtype: Optional[DtypeLike] = None) -> Array:
|
||||
def full(shape: Shape, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None) -> Array:
|
||||
"""Returns an array of `shape` filled with `fill_value`.
|
||||
|
||||
Args:
|
||||
@ -1187,14 +1187,14 @@ def zeros_like_shaped_array(aval: ArrayLike) -> Array:
|
||||
|
||||
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
|
||||
|
||||
def iota(dtype: DtypeLike, size: int) -> Array:
|
||||
def iota(dtype: DTypeLike, size: int) -> Array:
|
||||
"""Wraps XLA's `Iota
|
||||
<https://www.tensorflow.org/xla/operation_semantics#iota>`_
|
||||
operator.
|
||||
"""
|
||||
return broadcasted_iota(dtype, (size,), 0)
|
||||
|
||||
def broadcasted_iota(dtype: DtypeLike, shape: Shape, dimension: int) -> Array:
|
||||
def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int) -> Array:
|
||||
"""Convenience wrapper around ``iota``."""
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
shape = canonicalize_shape(shape)
|
||||
@ -1205,7 +1205,7 @@ def broadcasted_iota(dtype: DtypeLike, shape: Shape, dimension: int) -> Array:
|
||||
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
|
||||
dimension=dimension)
|
||||
|
||||
def _eye(dtype: DtypeLike, shape: Shape, offset: int) -> Array:
|
||||
def _eye(dtype: DTypeLike, shape: Shape, offset: int) -> Array:
|
||||
"""Like numpy.eye, create a 2D array with ones on a diagonal."""
|
||||
offset = int(offset)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
@ -1213,7 +1213,7 @@ def _eye(dtype: DtypeLike, shape: Shape, offset: int) -> Array:
|
||||
broadcasted_iota(np.int32, shape, 1))
|
||||
return convert_element_type_p.bind(bool_eye, new_dtype=dtype, weak_type=False)
|
||||
|
||||
def _delta(dtype: DtypeLike, shape: Shape, axes: Sequence[int]) -> Array:
|
||||
def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array:
|
||||
"""This utility function exists for creating Kronecker delta arrays."""
|
||||
axes = map(int, axes)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
@ -1225,7 +1225,7 @@ def _delta(dtype: DtypeLike, shape: Shape, axes: Sequence[int]) -> Array:
|
||||
new_dtype=dtype, weak_type=False)
|
||||
return broadcast_in_dim(result, shape, axes)
|
||||
|
||||
def _tri(dtype: DtypeLike, shape: Shape, offset: int) -> Array:
|
||||
def _tri(dtype: DTypeLike, shape: Shape, offset: int) -> Array:
|
||||
"""Like numpy.tri, create a 2D array with ones below a diagonal."""
|
||||
offset = int(offset)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
@ -1303,7 +1303,7 @@ def expand_dims(array: ArrayLike, dimensions: Sequence[int]) -> Array:
|
||||
|
||||
### convenience wrappers around traceables
|
||||
|
||||
def full_like(x: Array, fill_value: ArrayLike, dtype: Optional[DtypeLike] = None,
|
||||
def full_like(x: Array, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None,
|
||||
shape: Optional[Shape] = None) -> Array:
|
||||
"""Create a full array like np.full based on the example array `x`.
|
||||
|
||||
@ -2466,7 +2466,7 @@ def _masked(padded_value, logical_shape, dimensions, value=0):
|
||||
|
||||
|
||||
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
preferred_element_type: Optional[DtypeLike]):
|
||||
preferred_element_type: Optional[DTypeLike]):
|
||||
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
|
||||
for d in (lhs_contracting, lhs_batch)):
|
||||
@ -2542,7 +2542,7 @@ def tuple_delete(tup, idx):
|
||||
|
||||
|
||||
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
preferred_element_type: Optional[DtypeLike]):
|
||||
preferred_element_type: Optional[DTypeLike]):
|
||||
input_dtype = naryop_dtype_rule(_input_dtype, [_any, _any], 'dot_general', lhs, rhs)
|
||||
if preferred_element_type is None:
|
||||
return input_dtype
|
||||
@ -2550,7 +2550,7 @@ def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
|
||||
return preferred_element_type
|
||||
|
||||
def _dot_general_transpose_lhs(g, y, *, dimension_numbers, precision,
|
||||
preferred_element_type: Optional[DtypeLike],
|
||||
preferred_element_type: Optional[DTypeLike],
|
||||
swap_ans=False):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
x_ndim = g.ndim - y.ndim + len(x_batch) + 2 * len(x_contract)
|
||||
@ -2567,7 +2567,7 @@ def _dot_general_transpose_lhs(g, y, *, dimension_numbers, precision,
|
||||
tuple(out_axes))
|
||||
|
||||
def _dot_general_transpose_rhs(g, x, *, dimension_numbers, precision,
|
||||
preferred_element_type: Optional[DtypeLike]):
|
||||
preferred_element_type: Optional[DTypeLike]):
|
||||
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
|
||||
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
|
||||
return _dot_general_transpose_lhs(
|
||||
@ -2578,7 +2578,7 @@ def _dot_general_transpose_rhs(g, x, *, dimension_numbers, precision,
|
||||
|
||||
def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type: Optional[DtypeLike]):
|
||||
preferred_element_type: Optional[DTypeLike]):
|
||||
lhs, rhs = batched_args
|
||||
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
|
||||
(lhs.ndim, rhs.ndim), batch_dims, dimension_numbers)
|
||||
|
@ -30,18 +30,18 @@ from typing import Any, Sequence, Union
|
||||
from typing_extensions import Protocol
|
||||
import numpy as np
|
||||
|
||||
class HasDtypeAttribute(Protocol):
|
||||
dtype: Dtype
|
||||
class HasDTypeAttribute(Protocol):
|
||||
dtype: DType
|
||||
|
||||
Dtype = np.dtype
|
||||
DType = np.dtype
|
||||
|
||||
# DtypeLike is meant to annotate inputs to np.dtype that return
|
||||
# DTypeLike is meant to annotate inputs to np.dtype that return
|
||||
# a valid JAX dtype. It's different than numpy.typing.DTypeLike
|
||||
# because JAX doesn't support objects or structured dtypes.
|
||||
# It does not include JAX dtype extensions such as KeyType and others.
|
||||
# For now, we use Any to allow scalar types like np.int32 & jnp.int32.
|
||||
# TODO(jakevdp) specify these more strictly.
|
||||
DtypeLike = Union[Any, str, np.dtype, HasDtypeAttribute]
|
||||
DTypeLike = Union[Any, str, np.dtype, HasDTypeAttribute]
|
||||
|
||||
# Shapes are tuples of dimension sizes, which are normally integers. We allow
|
||||
# modules to extend the set of dimension sizes to contain other types, e.g.,
|
||||
|
@ -29,9 +29,9 @@ from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
|
||||
# DtypeLike is meant to annotate inputs to np.dtype that return
|
||||
# DTypeLike is meant to annotate inputs to np.dtype that return
|
||||
# a valid JAX dtype, so we test with np.dtype.
|
||||
def dtypelike_to_dtype(x: typing.DtypeLike) -> typing.Dtype:
|
||||
def dtypelike_to_dtype(x: typing.DTypeLike) -> typing.DType:
|
||||
return np.dtype(x)
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ def arraylike_to_array(x: typing.ArrayLike) -> typing.Array:
|
||||
return lax.convert_element_type(x, np.result_type(x))
|
||||
|
||||
|
||||
class HasDtype:
|
||||
class HasDType:
|
||||
dtype: np.dtype
|
||||
def __init__(self, dt):
|
||||
self.dtype = np.dtype(dt)
|
||||
@ -53,20 +53,20 @@ float32_dtype = np.dtype("float32")
|
||||
# Avoid test parameterization because we want to statically check these annotations.
|
||||
class TypingTest(jtu.JaxTestCase):
|
||||
|
||||
def testDtypeLike(self) -> None:
|
||||
out1: typing.Dtype = dtypelike_to_dtype("float32")
|
||||
def testDTypeLike(self) -> None:
|
||||
out1: typing.DType = dtypelike_to_dtype("float32")
|
||||
self.assertEqual(out1, float32_dtype)
|
||||
|
||||
out2: typing.Dtype = dtypelike_to_dtype(np.float32)
|
||||
out2: typing.DType = dtypelike_to_dtype(np.float32)
|
||||
self.assertEqual(out2, float32_dtype)
|
||||
|
||||
out3: typing.Dtype = dtypelike_to_dtype(jnp.float32)
|
||||
out3: typing.DType = dtypelike_to_dtype(jnp.float32)
|
||||
self.assertEqual(out3, float32_dtype)
|
||||
|
||||
out4: typing.Dtype = dtypelike_to_dtype(np.dtype('float32'))
|
||||
out4: typing.DType = dtypelike_to_dtype(np.dtype('float32'))
|
||||
self.assertEqual(out4, float32_dtype)
|
||||
|
||||
out5: typing.Dtype = dtypelike_to_dtype(HasDtype("float32"))
|
||||
out5: typing.DType = dtypelike_to_dtype(HasDType("float32"))
|
||||
self.assertEqual(out5, float32_dtype)
|
||||
|
||||
def testArrayLike(self) -> None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user