Change case of typing.Dtype -> typing.DType

This follows the convention used in numpy.typing.DType.
This commit is contained in:
Jake VanderPlas 2022-09-14 15:03:55 -07:00
parent 1338864c1f
commit 5829c6ae9d
4 changed files with 41 additions and 41 deletions

View File

@ -165,11 +165,11 @@ JAX type annotations should in general indicate the **intent** of APIs, rather t
Inputs to JAX functions and methods should be typed as permissively as is reasonable: for example, while shapes are typically tuples, functions that accept a shape should accept arbitrary sequences. Similarly, functions that accept a dtype need not require an instance of class `np.dtype`, but rather any dtype-convertible object. This might include strings, built-in scalar types, or scalar object constructors such as `np.float64` and `jnp.float64`. In order to make this as uniform as possible across the package, we will add a {mod}`jax.typing` module with common type specifications, starting with broad categories such as:
- `ArrayLike` would be a union of anything that can be implicitly converted into an array: for example, jax arrays, numpy arrays, JAX tracers, and python or numpy scalars
- `DtypeLike` would be a union of anything that can be implicitly converted into a dtype: for example, numpy dtypes, numpy dtype objects, jax dtype objects, strings, and built-in types.
- `DTypeLike` would be a union of anything that can be implicitly converted into a dtype: for example, numpy dtypes, numpy dtype objects, jax dtype objects, strings, and built-in types.
- `ShapeLike` would be a union of anything that could be converted into a shape: for example, sequences of integer or integer-like objecs.
- etc.
Note that these will in general be simpler than the equivalent protocols used in {mod}`numpy.typing`. For example, in the case of `DtypeLike`, JAX does not support structured dtypes, so JAX can use a simpler implementation. Similarly, in `ArrayLike`, JAX generally does not support list or tuple inputs in place of arrays, so the type definition will be simpler than the NumPy analog.
Note that these will in general be simpler than the equivalent protocols used in {mod}`numpy.typing`. For example, in the case of `DTypeLike`, JAX does not support structured dtypes, so JAX can use a simpler implementation. Similarly, in `ArrayLike`, JAX generally does not support list or tuple inputs in place of arrays, so the type definition will be simpler than the NumPy analog.
#### Outputs should be strictly-typed
@ -313,7 +313,7 @@ To move forward with type annotations, we will do the following:
- Alias `Array = Any` for the time being, as this will take a bit more thought.
- `ArrayLike`: a Union of types valid as inputs to normal `jax.numpy` functions
- `Dtype` / `DtypeLike` (Check on capitalization of the `t`: what do other projects use?)
- `DType` / `DTypeLike` (Note: numpy uses camel-cased `DType`; we should follow this convention for ease of use)
- `Shape` / `NamedShape` / `ShapeLike`
The beginnings of this are done in {jax-issue}`#12300`.

View File

@ -66,7 +66,7 @@ from jax._src.lax.utils import (
standard_translate,
)
from jax._src.lax import slicing
from jax._src.typing import Array, ArrayLike, DtypeLike, Shape
from jax._src.typing import Array, ArrayLike, DTypeLike, Shape
xb = xla_bridge
xc = xla_client
@ -532,7 +532,7 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array:
r"""Elementwise less-than: :math:`x < y`."""
return lt_p.bind(x, y)
def convert_element_type(operand: ArrayLike, new_dtype: DtypeLike) -> Array:
def convert_element_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
"""Elementwise cast.
Wraps XLA's `ConvertElementType
@ -551,7 +551,7 @@ def convert_element_type(operand: ArrayLike, new_dtype: DtypeLike) -> Array:
operand = operand.__jax_array__() # type: ignore
return _convert_element_type(operand, new_dtype, weak_type=False)
def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DtypeLike] = None,
def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DTypeLike] = None,
weak_type: bool = False):
# Don't canonicalize old_dtype because x64 context might cause
# un-canonicalized operands to be passed in.
@ -585,7 +585,7 @@ def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DtypeLike] = N
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
weak_type=new_weak_type)
def bitcast_convert_type(operand: ArrayLike, new_dtype: DtypeLike) -> Array:
def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
"""Elementwise bitcast.
Wraps XLA's `BitcastConvertType
@ -693,7 +693,7 @@ PrecisionLike = Union[None, str, PrecisionType, Tuple[str, str],
Tuple[PrecisionType, PrecisionType]]
def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
preferred_element_type: Optional[DtypeLike] = None) -> Array:
preferred_element_type: Optional[DTypeLike] = None) -> Array:
"""Vector/vector, matrix/vector, and matrix/matrix multiplication.
Wraps XLA's `Dot
@ -730,7 +730,7 @@ DotDimensionNumbers = Tuple[Tuple[Sequence[int], Sequence[int]],
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
precision: PrecisionLike = None,
preferred_element_type: Optional[DtypeLike] = None) -> Array:
preferred_element_type: Optional[DTypeLike] = None) -> Array:
"""More general contraction operator.
Wraps XLA's `DotGeneral
@ -953,13 +953,13 @@ def transpose(operand: ArrayLike, permutation: Sequence[int]) -> Array:
return transpose_p.bind(operand, permutation=permutation)
def argmin(operand: ArrayLike, axis: int,
index_dtype: DtypeLike) -> Tuple[Array, Array]:
index_dtype: DTypeLike) -> Tuple[Array, Array]:
"""Computes the index of the minimum element along ``axis``."""
return argmin_p.bind(operand, axes=(axis,),
index_dtype=dtypes.canonicalize_dtype(index_dtype))
def argmax(operand: ArrayLike, axis: int,
index_dtype: DtypeLike) -> Tuple[Array, Array]:
index_dtype: DTypeLike) -> Tuple[Array, Array]:
"""Computes the index of the maximum element along ``axis``."""
return argmax_p.bind(operand, axes=(axis,),
index_dtype=dtypes.canonicalize_dtype(index_dtype))
@ -1058,13 +1058,13 @@ def _get_monoid_reducer(monoid_op: Callable,
return _reduce_min if np.equal(aval.val, _get_min_identity(dtype)) else None
return None
def _get_bitwise_and_identity(dtype: DtypeLike) -> np.ndarray:
def _get_bitwise_and_identity(dtype: DTypeLike) -> np.ndarray:
return np.array(-1, dtype)
def _get_bitwise_or_identity(dtype: DtypeLike) -> np.ndarray:
def _get_bitwise_or_identity(dtype: DTypeLike) -> np.ndarray:
return np.array(0, dtype)
def _get_max_identity(dtype: DtypeLike) -> np.ndarray:
def _get_max_identity(dtype: DTypeLike) -> np.ndarray:
if dtypes.issubdtype(dtype, np.inexact):
return np.array(-np.inf, dtype)
elif dtypes.issubdtype(dtype, np.integer):
@ -1074,7 +1074,7 @@ def _get_max_identity(dtype: DtypeLike) -> np.ndarray:
else:
raise ValueError(f"Unsupported dtype for max: {dtype}")
def _get_min_identity(dtype: DtypeLike) -> np.ndarray:
def _get_min_identity(dtype: DTypeLike) -> np.ndarray:
if dtypes.issubdtype(dtype, np.inexact):
return np.array(np.inf, dtype)
elif dtypes.issubdtype(dtype, np.integer):
@ -1159,7 +1159,7 @@ def tie_in(x: Any, y: T) -> T:
"""Deprecated. Ignores ``x`` and returns ``y``."""
return y
def full(shape: Shape, fill_value: ArrayLike, dtype: Optional[DtypeLike] = None) -> Array:
def full(shape: Shape, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None) -> Array:
"""Returns an array of `shape` filled with `fill_value`.
Args:
@ -1187,14 +1187,14 @@ def zeros_like_shaped_array(aval: ArrayLike) -> Array:
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
def iota(dtype: DtypeLike, size: int) -> Array:
def iota(dtype: DTypeLike, size: int) -> Array:
"""Wraps XLA's `Iota
<https://www.tensorflow.org/xla/operation_semantics#iota>`_
operator.
"""
return broadcasted_iota(dtype, (size,), 0)
def broadcasted_iota(dtype: DtypeLike, shape: Shape, dimension: int) -> Array:
def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int) -> Array:
"""Convenience wrapper around ``iota``."""
dtype = dtypes.canonicalize_dtype(dtype)
shape = canonicalize_shape(shape)
@ -1205,7 +1205,7 @@ def broadcasted_iota(dtype: DtypeLike, shape: Shape, dimension: int) -> Array:
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
dimension=dimension)
def _eye(dtype: DtypeLike, shape: Shape, offset: int) -> Array:
def _eye(dtype: DTypeLike, shape: Shape, offset: int) -> Array:
"""Like numpy.eye, create a 2D array with ones on a diagonal."""
offset = int(offset)
dtype = dtypes.canonicalize_dtype(dtype)
@ -1213,7 +1213,7 @@ def _eye(dtype: DtypeLike, shape: Shape, offset: int) -> Array:
broadcasted_iota(np.int32, shape, 1))
return convert_element_type_p.bind(bool_eye, new_dtype=dtype, weak_type=False)
def _delta(dtype: DtypeLike, shape: Shape, axes: Sequence[int]) -> Array:
def _delta(dtype: DTypeLike, shape: Shape, axes: Sequence[int]) -> Array:
"""This utility function exists for creating Kronecker delta arrays."""
axes = map(int, axes)
dtype = dtypes.canonicalize_dtype(dtype)
@ -1225,7 +1225,7 @@ def _delta(dtype: DtypeLike, shape: Shape, axes: Sequence[int]) -> Array:
new_dtype=dtype, weak_type=False)
return broadcast_in_dim(result, shape, axes)
def _tri(dtype: DtypeLike, shape: Shape, offset: int) -> Array:
def _tri(dtype: DTypeLike, shape: Shape, offset: int) -> Array:
"""Like numpy.tri, create a 2D array with ones below a diagonal."""
offset = int(offset)
dtype = dtypes.canonicalize_dtype(dtype)
@ -1303,7 +1303,7 @@ def expand_dims(array: ArrayLike, dimensions: Sequence[int]) -> Array:
### convenience wrappers around traceables
def full_like(x: Array, fill_value: ArrayLike, dtype: Optional[DtypeLike] = None,
def full_like(x: Array, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None,
shape: Optional[Shape] = None) -> Array:
"""Create a full array like np.full based on the example array `x`.
@ -2466,7 +2466,7 @@ def _masked(padded_value, logical_shape, dimensions, value=0):
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
preferred_element_type: Optional[DtypeLike]):
preferred_element_type: Optional[DTypeLike]):
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
for d in (lhs_contracting, lhs_batch)):
@ -2542,7 +2542,7 @@ def tuple_delete(tup, idx):
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
preferred_element_type: Optional[DtypeLike]):
preferred_element_type: Optional[DTypeLike]):
input_dtype = naryop_dtype_rule(_input_dtype, [_any, _any], 'dot_general', lhs, rhs)
if preferred_element_type is None:
return input_dtype
@ -2550,7 +2550,7 @@ def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
return preferred_element_type
def _dot_general_transpose_lhs(g, y, *, dimension_numbers, precision,
preferred_element_type: Optional[DtypeLike],
preferred_element_type: Optional[DTypeLike],
swap_ans=False):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
x_ndim = g.ndim - y.ndim + len(x_batch) + 2 * len(x_contract)
@ -2567,7 +2567,7 @@ def _dot_general_transpose_lhs(g, y, *, dimension_numbers, precision,
tuple(out_axes))
def _dot_general_transpose_rhs(g, x, *, dimension_numbers, precision,
preferred_element_type: Optional[DtypeLike]):
preferred_element_type: Optional[DTypeLike]):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
return _dot_general_transpose_lhs(
@ -2578,7 +2578,7 @@ def _dot_general_transpose_rhs(g, x, *, dimension_numbers, precision,
def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
precision,
preferred_element_type: Optional[DtypeLike]):
preferred_element_type: Optional[DTypeLike]):
lhs, rhs = batched_args
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
(lhs.ndim, rhs.ndim), batch_dims, dimension_numbers)

View File

@ -30,18 +30,18 @@ from typing import Any, Sequence, Union
from typing_extensions import Protocol
import numpy as np
class HasDtypeAttribute(Protocol):
dtype: Dtype
class HasDTypeAttribute(Protocol):
dtype: DType
Dtype = np.dtype
DType = np.dtype
# DtypeLike is meant to annotate inputs to np.dtype that return
# DTypeLike is meant to annotate inputs to np.dtype that return
# a valid JAX dtype. It's different than numpy.typing.DTypeLike
# because JAX doesn't support objects or structured dtypes.
# It does not include JAX dtype extensions such as KeyType and others.
# For now, we use Any to allow scalar types like np.int32 & jnp.int32.
# TODO(jakevdp) specify these more strictly.
DtypeLike = Union[Any, str, np.dtype, HasDtypeAttribute]
DTypeLike = Union[Any, str, np.dtype, HasDTypeAttribute]
# Shapes are tuples of dimension sizes, which are normally integers. We allow
# modules to extend the set of dimension sizes to contain other types, e.g.,

View File

@ -29,9 +29,9 @@ from absl.testing import absltest
import numpy as np
# DtypeLike is meant to annotate inputs to np.dtype that return
# DTypeLike is meant to annotate inputs to np.dtype that return
# a valid JAX dtype, so we test with np.dtype.
def dtypelike_to_dtype(x: typing.DtypeLike) -> typing.Dtype:
def dtypelike_to_dtype(x: typing.DTypeLike) -> typing.DType:
return np.dtype(x)
@ -42,7 +42,7 @@ def arraylike_to_array(x: typing.ArrayLike) -> typing.Array:
return lax.convert_element_type(x, np.result_type(x))
class HasDtype:
class HasDType:
dtype: np.dtype
def __init__(self, dt):
self.dtype = np.dtype(dt)
@ -53,20 +53,20 @@ float32_dtype = np.dtype("float32")
# Avoid test parameterization because we want to statically check these annotations.
class TypingTest(jtu.JaxTestCase):
def testDtypeLike(self) -> None:
out1: typing.Dtype = dtypelike_to_dtype("float32")
def testDTypeLike(self) -> None:
out1: typing.DType = dtypelike_to_dtype("float32")
self.assertEqual(out1, float32_dtype)
out2: typing.Dtype = dtypelike_to_dtype(np.float32)
out2: typing.DType = dtypelike_to_dtype(np.float32)
self.assertEqual(out2, float32_dtype)
out3: typing.Dtype = dtypelike_to_dtype(jnp.float32)
out3: typing.DType = dtypelike_to_dtype(jnp.float32)
self.assertEqual(out3, float32_dtype)
out4: typing.Dtype = dtypelike_to_dtype(np.dtype('float32'))
out4: typing.DType = dtypelike_to_dtype(np.dtype('float32'))
self.assertEqual(out4, float32_dtype)
out5: typing.Dtype = dtypelike_to_dtype(HasDtype("float32"))
out5: typing.DType = dtypelike_to_dtype(HasDType("float32"))
self.assertEqual(out5, float32_dtype)
def testArrayLike(self) -> None: