Merge pull request #21069 from mattjj:remove-named-shapes

PiperOrigin-RevId: 655766534
This commit is contained in:
jax authors 2024-07-24 18:20:50 -07:00
commit 086b500da6
24 changed files with 128 additions and 388 deletions

View File

@ -381,9 +381,9 @@ def xla_computation(fun: Callable,
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
wrapped function returns a pair where the first element is the XLA
computation and the second element is a pytree with the same structure as
the output of ``fun`` and where the leaves are objects with ``shape``,
``dtype``, and ``named_shape`` attributes representing the corresponding
types of the output leaves.
the output of ``fun`` and where the leaves are objects with ``shape`` and
``dtype`` attributes representing the corresponding types of the output
leaves.
donate_argnums: Specify which arguments are "donated" to the computation.
It is safe to donate arguments if you no longer need them once the
computation has finished. In some cases XLA can make use of donated
@ -557,8 +557,8 @@ def xla_computation(fun: Callable,
m = mlir.module_to_bytecode(lowering_result.module)
built = xc._xla.mlir.mlir_module_to_xla_computation(
m, use_tuple_args=tuple_args, return_tuple=True)
out_shapes_flat = [
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
out_shape = tree_unflatten(out_tree(), out_shapes_flat)
for out_aval in out_avals:
if not isinstance(out_aval, ShapedArray):
@ -2337,8 +2337,8 @@ def make_jaxpr(fun: Callable,
wrapped function returns a pair where the first element is the
``ClosedJaxpr`` representation of ``fun`` and the second element is a
pytree with the same structure as the output of ``fun`` and where the
leaves are objects with ``shape``, ``dtype``, and ``named_shape``
attributes representing the corresponding types of the output leaves.
leaves are objects with ``shape`` and ``dtype`` attributes representing
the corresponding types of the output leaves.
Returns:
A wrapped version of ``fun`` that when applied to example arguments returns
@ -2400,8 +2400,7 @@ def make_jaxpr(fun: Callable,
else:
jaxpr = traced.jaxpr
if return_shape:
out = [ShapeDtypeStruct(o.shape, o.dtype, getattr(o, 'named_shape', None))
for o in jaxpr.out_avals]
out = [ShapeDtypeStruct(o.shape, o.dtype) for o in jaxpr.out_avals]
return jaxpr, tree_unflatten(tree_structure(traced.out_info), out)
return jaxpr
@ -2691,12 +2690,13 @@ class ShapeDtypeStruct:
Args:
shape: a sequence of integers representing an array shape
dtype: a dtype-like object
named_shape: (optional) a dictionary representing a named shape
sharding: (optional) a :class:`jax.Sharding` object
"""
__slots__ = ["shape", "dtype", "named_shape", "sharding", "_dll"]
__slots__ = ["shape", "dtype", "sharding", "_dll"]
named_shape = {}
def __init__(self, shape, dtype, named_shape=None, sharding=None):
def __init__(self, shape, dtype, sharding=None, named_shape=None):
del named_shape # ignored, vestigial
self.shape = tuple(shape)
if dtype is None:
raise ValueError("ShapeDtypeStruct: dtype must be specified.")
@ -2713,7 +2713,6 @@ class ShapeDtypeStruct:
f" layout in a `ShapeDtypeStruct`. Got {sharding}")
self.sharding = sharding.sharding if isinstance(sharding, Layout) else sharding
self._dll = sharding.device_local_layout if isinstance(sharding, Layout) else None
self.named_shape = {} if named_shape is None else dict(named_shape)
size = property(lambda self: math.prod(self.shape))
ndim = property(lambda self: len(self.shape))
@ -2729,11 +2728,10 @@ class ShapeDtypeStruct:
raise TypeError("len() of unsized object") from e # same as numpy error
def __repr__(self):
ns = f", named_shape={self.named_shape}" if self.named_shape else ""
sh = f", sharding={self.sharding}" if self.sharding is not None else ""
l = f", layout={self.layout}" if self._dll is not None else ""
return (f"{type(self).__name__}(shape={self.shape}, "
f"dtype={self.dtype.name}{ns}{sh}{l})")
f"dtype={self.dtype.name}{sh}{l})")
__str__ = __repr__
@ -2741,19 +2739,17 @@ class ShapeDtypeStruct:
if not isinstance(other, ShapeDtypeStruct):
return False
else:
return ((other.shape, other.dtype, other.named_shape, other.sharding, other.layout) ==
(self.shape, self.dtype, self.named_shape, self.sharding, self.layout))
return ((other.shape, other.dtype, other.sharding, other.layout) ==
(self.shape, self.dtype, self.sharding, self.layout))
def __hash__(self):
# TODO(frostig): avoid the conversion from dict by addressing
# https://github.com/google/jax/issues/8182
named = frozenset(self.named_shape.items())
return hash((self.shape, self.dtype, named, self.sharding, self.layout))
return hash((self.shape, self.dtype, self.sharding, self.layout))
core.pytype_aval_mappings[ShapeDtypeStruct] = (
lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
weak_type=False, named_shape=x.named_shape))
weak_type=False))
@api_boundary

View File

@ -570,15 +570,13 @@ def _shaped_abstractify_slow(x):
pass
weak_type = getattr(x, 'weak_type', False)
named_shape = getattr(x, 'named_shape', {})
if hasattr(x, 'dtype'):
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
else:
raise TypeError(
f"Cannot interpret value of type {type(x)} as an abstract array; it "
"does not have a dtype attribute")
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
named_shape=named_shape)
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type)
# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior
def shaped_abstractify(x):

View File

@ -466,8 +466,7 @@ class ArrayImpl(basearray.Array):
def __reduce__(self):
fun, args, arr_state = self._value.__reduce__()
aval_state = {'weak_type': self.aval.weak_type,
'named_shape': self.aval.named_shape}
aval_state = {'weak_type': self.aval.weak_type}
return (_reconstruct_array, (fun, args, arr_state, aval_state))
@use_cpp_method()

View File

@ -1418,9 +1418,6 @@ class AbstractValue:
def strip_weak_type(self) -> AbstractValue:
return self
def strip_named_shape(self) -> AbstractValue:
return self
def join(self, other):
raise NotImplementedError("must override")
@ -1695,6 +1692,8 @@ def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]:
Returns:
A tuple of canonical dimension values.
"""
if isinstance(shape, int):
shape = shape,
try:
return tuple(unsafe_map(_canonicalize_dimension, shape))
except TypeError:
@ -1733,25 +1732,25 @@ def _invalid_shape_error(shape: Shape, context: str=""):
return TypeError(msg)
class ShapedArray(UnshapedArray):
__slots__ = ['shape', 'named_shape']
__slots__ = ['shape']
array_abstraction_level = 2
named_shape = {}
def __init__(self, shape, dtype, weak_type=False, named_shape=None):
del named_shape # unused, vestigial
self.shape = canonicalize_shape(shape)
self.dtype = _dtype_object(dtype)
self.weak_type = weak_type
self.named_shape = {} if named_shape is None else dict(named_shape)
def update(self, shape=None, dtype=None, weak_type=None, named_shape=None):
del named_shape # unused, vestigial
if shape is None:
shape = self.shape
if dtype is None:
dtype = self.dtype
if weak_type is None:
weak_type = self.weak_type
if named_shape is None:
named_shape = self.named_shape
return ShapedArray(shape, dtype, weak_type, named_shape)
return ShapedArray(shape, dtype, weak_type)
ndim = property(lambda self: len(self.shape))
size = property(lambda self:
@ -1766,25 +1765,22 @@ class ShapedArray(UnshapedArray):
def __eq__(self, other):
return (type(self) is type(other)
and self.dtype == other.dtype and self.shape == other.shape
and self.weak_type == other.weak_type
and self.named_shape == other.named_shape)
and self.weak_type == other.weak_type)
def __hash__(self):
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
# the unique character code via hash(self.dtype.char)
return hash((self.shape, self.dtype, self.weak_type,
tuple(self.named_shape.items())))
return hash((self.shape, self.dtype, self.weak_type))
def at_least_vspace(self):
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type, self.named_shape)
self.weak_type)
def join(self, other):
if definitely_equal_shape(self.shape, other.shape) and self.dtype == other.dtype:
weak_type = self.weak_type and other.weak_type
named_shape = join_named_shapes(self.named_shape, other.named_shape)
return self.update(weak_type=weak_type, named_shape=named_shape)
return self.update(weak_type=weak_type)
elif self.dtype == other.dtype:
return UnshapedArray(self.dtype)
else:
@ -1794,14 +1790,7 @@ class ShapedArray(UnshapedArray):
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
dt_str = dt_str.replace('void', 'float0')
shapestr = ','.join(map(str, self.shape))
if self.named_shape:
named_shapestr = ','.join(f'{k}:{v}' for k, v in self.named_shape.items())
return f'{dt_str}[{shapestr};{named_shapestr}]'
else:
return f'{dt_str}[{shapestr}]'
def strip_named_shape(self):
return self.update(named_shape={})
return f'{dt_str}[{shapestr}]'
def _len(self, ignored_tracer):
try:
@ -1849,12 +1838,9 @@ class ConcreteArray(ShapedArray):
return self
elif self.shape == other.shape and self.dtype == other.dtype:
weak_type = self.weak_type and other.weak_type
named_shape = join_named_shapes(self.named_shape, other.named_shape)
return ShapedArray(
self.shape, self.dtype, weak_type=weak_type, named_shape=named_shape)
return ShapedArray(self.shape, self.dtype, weak_type=weak_type)
elif self.dtype == other.dtype:
return UnshapedArray(self.dtype,
weak_type=self.weak_type and other.weak_type)
return UnshapedArray(self.dtype, weak_type=self.weak_type and other.weak_type)
else:
raise TypeError(self, other)
@ -2090,8 +2076,7 @@ raise_to_shaped_mappings: dict[type, Callable] = {
Bot: lambda aval, _: aval,
UnshapedArray: lambda aval, _: aval,
ShapedArray: lambda aval, weak_type: ShapedArray(
aval.shape, aval.dtype, weak_type, aval.named_shape
),
aval.shape, aval.dtype, weak_type),
DConcreteArray: lambda aval, weak_type: DShapedArray(
aval.shape, aval.dtype, weak_type
),
@ -2282,94 +2267,6 @@ def dim_constant(ct: int):
def dim_value_aval() -> AbstractValue:
return ShapedArray((), dim_value_dtype(), weak_type=True)
# ------------------- Named shapes -------------------
class NamedShape:
def __init__(self, *args, **kwargs):
self.__positional = canonicalize_shape(args)
# TODO: Assert that kwargs match axis env?
self.__named = dict(kwargs)
@property
def rank(self):
return len(self.__positional) + len(self.__named)
@property
def positional_rank(self):
return len(self.__positional)
@property
def named_rank(self):
return len(self.__named)
@property
def positional(self):
return self.__positional
@property
def names(self):
return self.__named.keys()
@property
def named_sizes(self):
return self.__named.values()
@property
def named_items(self):
return self.__named.items()
def __getitem__(self, idx):
try:
idx = operator.index(idx)
return self.__positional[idx]
except TypeError:
pass
return self.__named[idx]
@property
def total(self):
total = 1
for s in self.__positional: total *= s
for s in self.__named.values(): total *= s
return total
def __str__(self):
# TODO(mattjj,frostig): revise not to miss commas
if not self.__named:
return str(self.__positional)
return (f"({', '.join(map(str, self.__positional))}{', ' if self.__named else ''}"
f"{', '.join(f'{k}={v}' for k, v in self.__named.items())})")
def __eq__(self, other):
if isinstance(other, NamedShape):
return (self.__positional, self.__named) == (other.__positional, other.__named)
if isinstance(other, tuple):
return not self.__named and self.__positional == other
return False
def __hash__(self):
named = frozenset(self.__named.items())
return hash((self.__positional, named))
def join_named_shapes(*named_shapes):
result = {}
for named_shape in named_shapes:
for name, size in named_shape.items():
if result.setdefault(name, size) != size:
raise TypeError(
f"Axis name {name} used with inconsistent sizes: {result[name]} != {size}")
return result
# TODO: Make canonicalize_shape return named shapes?
def as_named_shape(shape) -> NamedShape:
if isinstance(shape, int):
shape = (shape,)
if isinstance(shape, NamedShape):
return shape
return NamedShape(*shape)
# ------------------- Call -------------------
class CallPrimitive(Primitive):
@ -2574,17 +2471,15 @@ def _map_shaped_array(
# TODO: Extend the named shape
if axis is None: return aval
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
named_shape=aval.named_shape, weak_type=aval.weak_type)
weak_type=aval.weak_type)
def _unmap_shaped_array(
size: int, axis_name: AxisName, axis: int | None, aval: ShapedArray
) -> ShapedArray:
named_shape = dict(aval.named_shape)
named_shape.pop(axis_name, None) # TODO: make this mandatory
if axis is None: return aval.update(named_shape=named_shape)
if axis is None: return aval
elif type(axis) is int:
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
named_shape=named_shape, weak_type=aval.weak_type)
weak_type=aval.weak_type)
else: raise TypeError(axis)
def _map_dshaped_array(
@ -2780,16 +2675,8 @@ def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: dict[Var, Var]) -> V
# Var identity is load-bearing, so we can't have duplicates!
if isinstance(v, DropVar): return v
assert v not in var_map
if not hasattr(v.aval, 'named_shape'):
var_map[v] = v
return v
names = tuple(it.chain.from_iterable(subst(name) for name in v.aval.named_shape))
named_shape = {name: axis_frame(name).size for name in names}
if len(named_shape) != len(names):
raise DuplicateAxisNameError(v)
new_v = Var(v.suffix, v.aval.update(named_shape=named_shape))
var_map[v] = new_v
return new_v
var_map[v] = v
return v
def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: dict[Var, Var]) -> JaxprEqn:
invars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in eqn.invars]
@ -2857,31 +2744,20 @@ def typecheck(aval: AbstractValue, x) -> bool:
return typecompat(aval, get_aval(x))
def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool:
"""Determine whether `aval` conforms to `aval_ref`.
Ignores weak_type and named_shape, other than to check that an axis name isn't
used with different sizes.
"""
"""Determine whether `aval` conforms to `aval_ref`. Ignores weak_type."""
try:
return typematch(aval_ref, lattice_join(aval_ref, aval))
except TypeError:
return False
def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool:
"""Determine whether `aval1` and `aval2` are equivalent.
Ignores weak_type and named_shape, other than to check that an axis name isn't
used with different sizes.
"""
"""Determine whether `aval1` and `aval2` are equivalent. Ignores weak_type."""
if aval1 == aval2: return True
# unequal avals may still represent the same type, because type is represented
# by avals at the shaped level, and because weak type tags and (for now) named
# shape components aren't considered part of the type
if isinstance(aval1, ShapedArray) and isinstance(aval2, ShapedArray):
# a bonus check for whether any named axes have inconsistent sizes
join_named_shapes(aval1.named_shape, aval2.named_shape)
return (raise_to_shaped(aval1, weak_type=False).strip_named_shape() ==
raise_to_shaped(aval2, weak_type=False).strip_named_shape())
# by avals at the shaped level, and because weak type tags aren't considered
# part of the type
return (raise_to_shaped(aval1, weak_type=False) ==
raise_to_shaped(aval2, weak_type=False))
class JaxprTypeError(TypeError): pass

View File

@ -327,13 +327,11 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
f""" {str(ty_tree_).replace("'", "")}""")
raise TypeError(m)
# TODO(mattjj): compare primals' tangent types to tangent objects' types
primal_avals_out = [
raise_to_shaped(core.get_aval(x), weak_type=False).strip_named_shape()
for x in primals_out]
tangent_avals_out = [
raise_to_shaped(core.get_aval(t), weak_type=False).strip_named_shape()
if type(t) is not SymbolicZero else t.aval.strip_weak_type()
for t in tangents_out]
primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False)
for x in primals_out]
tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False)
if type(t) is not SymbolicZero else t.aval.strip_weak_type()
for t in tangents_out]
if primal_avals_out != tangent_avals_out:
if len(primal_avals_out) == 1:
(av1,), (av2,) = primal_avals_out, tangent_avals_out

View File

@ -1237,8 +1237,7 @@ def _call_exported_abstract_eval(
out_avals = tuple(
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars,
*exported_dim_values),
dtype=out_aval.dtype, weak_type=out_aval.weak_type,
named_shape=out_aval.named_shape)
dtype=out_aval.dtype, weak_type=out_aval.weak_type)
for out_aval in exported.out_avals)
return out_avals, set(exported.ordered_effects + exported.unordered_effects)

View File

@ -200,8 +200,8 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
# TODO(mattjj): add back these checks for dynamic shapes
# if config.enable_checks.value:
# ct_aval = core.get_aval(ct_env[v])
# joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type().strip_named_shape()
# assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (prim, v.aval, ct_aval)
# joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type()
# assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval)
def read_cotangent(v):
return ct_env.pop(v, Zero(v.aval.at_least_vspace()))

View File

@ -1091,7 +1091,7 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args,
input_output_aliases = list(input_output_aliases)
# To match-up in-avals to out-avals we only care about the number of
# bytes, so we strip off unrelated aval metadata (eg. the named shape)
strip_metadata = lambda a: a.strip_named_shape().strip_weak_type()
strip_metadata = lambda a: a.strip_weak_type()
avals_in = map(strip_metadata, avals_in)
avals_out = map(strip_metadata, avals_out)

View File

@ -854,12 +854,10 @@ def lower_parallel_callable(
def _pmap_unmap_shaped_array(
size: int, axis_name: core.AxisName, axis: int | None, aval: ShapedArray
) -> ShapedArray:
named_shape = dict(aval.named_shape)
named_shape.pop(axis_name, None) # TODO: make this mandatory
if axis is None: return aval.update(named_shape=named_shape)
if axis is None: return aval
elif type(axis) is int:
return ShapedArray(tuple_update(aval.shape, axis, size), aval.dtype,
named_shape=named_shape, weak_type=aval.weak_type)
weak_type=aval.weak_type)
else: raise TypeError(axis)
@ -1507,22 +1505,17 @@ mlir.register_lowering(xla_pmap_p, _pmap_lowering)
def tile_aval_nd(axis_sizes, in_axes: ArrayMapping, aval):
assert isinstance(aval, ShapedArray)
shape = list(aval.shape)
named_shape = dict(aval.named_shape)
for name, axis in in_axes.items():
assert shape[axis] % axis_sizes[name] == 0
assert name not in named_shape
named_shape[name] = axis_sizes[name]
shape[axis] //= axis_sizes[name]
return aval.update(shape=tuple(shape), named_shape=named_shape)
return aval.update(shape=tuple(shape))
def untile_aval_nd(axis_sizes, out_axes: ArrayMapping, aval):
assert isinstance(aval, ShapedArray)
shape = list(aval.shape)
named_shape = dict(aval.named_shape)
for name, axis in out_axes.items():
shape[axis] *= axis_sizes[name]
named_shape.pop(name, None) # The name might be missing --- it's a broadcast.
return aval.update(shape=tuple(shape), named_shape=named_shape)
return aval.update(shape=tuple(shape))
def mesh_local_to_global(mesh, axes: ArrayMapping, aval):

View File

@ -1321,7 +1321,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
raise TypeError(msg.format(cond_tree))
pred_aval = cond_jaxpr.out_avals[0]
if (not isinstance(pred_aval, ShapedArray)
or pred_aval.strip_weak_type().strip_named_shape() != ShapedArray((), np.bool_)):
or pred_aval.strip_weak_type() != ShapedArray((), np.bool_)):
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
raise TypeError(msg.format(cond_jaxpr.out_avals))
return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree

View File

@ -59,8 +59,7 @@ from jax._src.interpreters.batching import RaggedAxis
from jax._src.lax import slicing
from jax._src.lax.utils import (
_input_dtype, dtype_to_string, standard_abstract_eval,
standard_multi_result_abstract_eval, standard_named_shape_rule,
standard_primitive)
standard_multi_result_abstract_eval, standard_primitive)
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
@ -2563,7 +2562,7 @@ convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_elemen
convert_element_type_p.def_abstract_eval(
partial(standard_abstract_eval, convert_element_type_p,
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
_convert_element_type_weak_type_rule, standard_named_shape_rule))
_convert_element_type_weak_type_rule))
ad.defjvp(convert_element_type_p, _convert_element_type_jvp_rule)
ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule
batching.defvectorized(convert_element_type_p)
@ -3360,7 +3359,7 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions):
type(core.get_aval(d).dtype) is core.bint for d in shape)):
shape = _broadcast_in_dim_shape_rule( # error checking
x, shape=shape, broadcast_dimensions=broadcast_dimensions)
return core.ShapedArray(shape, x.dtype, x.weak_type, x.named_shape)
return core.ShapedArray(shape, x.dtype, x.weak_type)
# If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray
# (even if x is a ShapedArray)
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
@ -4057,25 +4056,12 @@ def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr, dimensions):
reducer = core.jaxpr_as_fun(jaxpr)
return _reduce_jvp(reducer, init_values, primal_xs, tangent_xs, dimensions)
def _reduce_named_shape_rule(*avals, computation, jaxpr, dimensions):
# TODO(mattjj,frostig): see the TODOs noting limitations/assumptions in
# _reduce_batching_rule. We're making the same assumptions here for now.
num_operands = len(avals) // 2
operand_avals, init_avals = split_list(avals, [num_operands])
if any(a.named_shape for a in init_avals):
raise NotImplementedError
named_shapes = [a.named_shape for a in operand_avals]
join = core.join_named_shapes(*(a.named_shape for a in operand_avals))
return [join] * len(named_shapes)
reduce_p = core.Primitive('reduce')
reduce_p.multiple_results = True
reduce_p.def_impl(partial(dispatch.apply_primitive, reduce_p))
reduce_p.def_abstract_eval(
partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule,
_reduce_dtype_rule, _reduce_weak_type_rule,
_reduce_named_shape_rule))
_reduce_dtype_rule, _reduce_weak_type_rule))
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
ad.primitive_jvps[reduce_p] = _reduce_jvp_rule
@ -4839,9 +4825,6 @@ def _rng_bit_generator_lowering(
return [out_key, out_vals]
def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
return [key.named_shape, key.named_shape]
rng_bit_generator_p = Primitive("rng_bit_generator")
rng_bit_generator_p.multiple_results = True
rng_bit_generator_p.def_impl(
@ -4849,8 +4832,7 @@ rng_bit_generator_p.def_impl(
rng_bit_generator_p.def_abstract_eval(
partial(standard_multi_result_abstract_eval, rng_bit_generator_p,
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
_rng_bit_generator_weak_type_rule,
_rng_bit_generator_named_shape_rule))
_rng_bit_generator_weak_type_rule))
mlir.register_lowering(rng_bit_generator_p,
_rng_bit_generator_lowering)

View File

@ -738,21 +738,15 @@ def _allreduce_impl(pos_reducer, *args, axes, axis_index_groups):
return [pos_reducer(arg, axes) for arg in args]
def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
# TODO(frostig,mattjj,jekbradbury): maybe check aval names here
named_axes = tuple(axis for axis in axes if not isinstance(axis, int))
pos_axes = tuple(axis for axis in axes if isinstance(axis, int))
named_shapes = [arg.named_shape for arg in args]
named_axes = {axis for axis in axes if not isinstance(axis, int)}
if axis_index_groups is None:
named_shapes = [{name: size for name, size in arg.named_shape.items()
if name not in named_axes} for arg in args]
else:
if axis_index_groups is not None:
if len(pos_axes) != 0:
raise ValueError(f"axis_index_groups can only be used with reductions over "
f"named axes, but got: {axes}")
out_avals = [
ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes),
arg.dtype, named_shape=named_shape)
for arg, named_shape in zip(args, named_shapes)]
arg.dtype) for arg in args]
return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes}
def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
@ -1301,12 +1295,7 @@ def _all_gather_effectful_abstract_eval(
new_shape[all_gather_dimension] *= axis_size
else:
new_shape.insert(all_gather_dimension, axis_size)
new_named_shape = {name: size for name, size in x_aval.named_shape.items()
if name not in axis_name}
out_aval = x_aval.update(shape=new_shape, named_shape=new_named_shape)
effects = {*map(core.NamedAxisEffect, axis_name)}
return out_aval, effects
return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)}
def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled):
return (psum_scatter(cts, axis_name=axis_name,
@ -1437,15 +1426,7 @@ def _reduce_scatter_effectful_abstract_eval(
f"{scatter_dim_input_size} must match shard count "
f"{axis_size}")
del new_shape[scatter_dimension]
new_named_shape = {
name: size
for name, size in x_aval.named_shape.items()
if name not in axis_name
}
out_aval = x_aval.update(shape=new_shape, named_shape=new_named_shape)
effects = {*map(core.NamedAxisEffect, axis_name)}
return out_aval, effects
return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)}
def _reduce_scatter_transpose_rule(cts, x, *, axis_name, scatter_dimension,
@ -1633,9 +1614,7 @@ def _axis_index_lowering(ctx, *, axis_name):
def _axis_index_effectful_abstract_eval(*, axis_name):
frame = core.axis_frame(axis_name)
out_aval = ShapedArray((), np.int32, named_shape={axis_name: frame.size})
return out_aval, {core.NamedAxisEffect(axis_name)}
return ShapedArray((), np.int32), {core.NamedAxisEffect(axis_name)}
axis_index_p = core.Primitive('axis_index')
mlir.register_lowering(axis_index_p, _axis_index_lowering)
@ -1691,14 +1670,7 @@ def _pdot_effectful_abstract_eval(
pos_aval = lax.dot_general_p.abstract_eval(
x, y, dimension_numbers=[pos_contract, pos_batch],
precision=precision, preferred_element_type=None)[0]
common_named_shape = core.join_named_shapes(x.named_shape, y.named_shape)
named_shape = {name: size
for name, size in common_named_shape.items()
if name not in axis_name}
out_aval = pos_aval.update(named_shape=named_shape)
effects = {*map(core.NamedAxisEffect, axis_name)}
return out_aval, effects
return pos_aval, {*map(core.NamedAxisEffect, axis_name)}
def _pdot_vmap_collective_rule(axis_size, frame_name, _, vals_in, dims_in, *, axis_name,
pos_contract, pos_batch, precision):

View File

@ -24,6 +24,8 @@ from jax._src import dtypes
from jax._src.util import safe_zip
from jax._src.lib import xla_client
zip, unsafe_zip = safe_zip, zip
import numpy as np
xops = xla_client.ops
@ -35,20 +37,19 @@ def _argnum_weak_type(*argnums):
return lambda *args, **_: all(args[i].weak_type for i in argnums)
def standard_primitive(shape_rule, dtype_rule, name,
weak_type_rule=None, named_shape_rule=None):
weak_type_rule=None):
weak_type_rule = weak_type_rule or _standard_weak_type_rule
named_shape_rule = named_shape_rule or standard_named_shape_rule
prim = core.Primitive(name)
prim.def_impl(partial(dispatch.apply_primitive, prim))
prim.def_abstract_eval(
partial(standard_abstract_eval, prim, shape_rule, dtype_rule,
weak_type_rule, named_shape_rule))
weak_type_rule))
return prim
def _get_array_abstraction_level(a): return a.array_abstraction_level
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
named_shape_rule, *avals, **kwargs):
*avals, **kwargs):
assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
assert not prim.multiple_results
weak_type = weak_type_rule(*avals, **kwargs)
@ -58,8 +59,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
return core.ConcreteArray(out.dtype, out, weak_type=weak_type)
elif least_specialized is core.ShapedArray:
return core.ShapedArray(shape_rule(*avals, **kwargs),
dtype_rule(*avals, **kwargs), weak_type=weak_type,
named_shape=named_shape_rule(*avals, **kwargs))
dtype_rule(*avals, **kwargs), weak_type=weak_type)
elif least_specialized is core.DShapedArray:
shape = shape_rule(*avals, **kwargs)
ty = (core.ShapedArray if all(type(d) is int for d in shape)
@ -71,8 +71,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
raise TypeError(avals, least_specialized)
def standard_multi_result_abstract_eval(
prim, shape_rule, dtype_rule, weak_type_rule,
named_shape_rule, *avals, **kwargs):
prim, shape_rule, dtype_rule, weak_type_rule, *avals, **kwargs):
assert prim.multiple_results
assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
least_specialized = max(map(type, avals), key=_get_array_abstraction_level)
@ -80,18 +79,16 @@ def standard_multi_result_abstract_eval(
if least_specialized is core.ConcreteArray:
out_vals = prim.impl(*[x.val for x in avals], **kwargs)
return [core.ConcreteArray(val.dtype, val, weak_type=weak_type)
for val, weak_type in safe_zip(out_vals, weak_types)]
for val, weak_type in zip(out_vals, weak_types)]
elif least_specialized is core.ShapedArray:
out_shapes = shape_rule(*avals, **kwargs)
out_dtypes = dtype_rule(*avals, **kwargs)
out_named_shapes = named_shape_rule(*avals, **kwargs)
return [core.ShapedArray(s, d, weak_type=weak_type, named_shape=named_shape)
for s, d, weak_type, named_shape
in safe_zip(out_shapes, out_dtypes, weak_types, out_named_shapes)]
return [core.ShapedArray(s, d, weak_type=weak_type)
for s, d, weak_type in zip(out_shapes, out_dtypes, weak_types)]
elif least_specialized is core.UnshapedArray:
out_dtypes = dtype_rule(*avals, **kwargs)
return [core.UnshapedArray(dtype, weak_type=weak_type)
for dtype, weak_type in safe_zip(out_dtypes, weak_types)]
for dtype, weak_type in zip(out_dtypes, weak_types)]
else:
raise TypeError(avals, least_specialized)
@ -103,9 +100,6 @@ def standard_translate(prim):
return [op(*args, **kwargs)]
return translation_rule
def standard_named_shape_rule(*avals, **kwargs):
return core.join_named_shapes(*(a.named_shape for a in avals))
def _standard_weak_type_rule(*avals, **kwargs):
return all(aval.weak_type for aval in avals)

View File

@ -198,19 +198,19 @@ def truncated_normal(stddev: RealNumeric = 1e-2,
return init
@export
def _compute_fans(shape: core.NamedShape,
def _compute_fans(shape: Sequence[int],
in_axis: int | Sequence[int] = -2,
out_axis: int | Sequence[int] = -1,
batch_axis: int | Sequence[int] = ()
) -> tuple[Array, Array]:
) -> tuple[float, float]:
"""
Compute effective input and output sizes for a linear or convolutional layer.
Axes not in in_axis, out_axis, or batch_axis are assumed to constitute the
"receptive field" of a convolution (kernel spatial dimensions).
"""
if shape.rank <= 1:
raise ValueError(f"Can't compute input and output sizes of a {shape.rank}"
if len(shape) <= 1:
raise ValueError(f"Can't compute input and output sizes of a {len(shape)}"
"-dimensional weights tensor. Must be at least 2D.")
if isinstance(in_axis, int):
@ -225,13 +225,13 @@ def _compute_fans(shape: core.NamedShape,
batch_size = shape[batch_axis]
else:
batch_size = math.prod([shape[i] for i in batch_axis])
receptive_field_size = shape.total / in_size / out_size / batch_size
receptive_field_size = math.prod(shape) / in_size / out_size / batch_size
fan_in = in_size * receptive_field_size
fan_out = out_size * receptive_field_size
return fan_in, fan_out
def _complex_uniform(key: KeyArray,
shape: Sequence[int] | core.NamedShape,
shape: Sequence[int],
dtype: DTypeLikeInexact) -> Array:
"""
Sample uniform random values within a disk on the complex plane,
@ -245,7 +245,7 @@ def _complex_uniform(key: KeyArray,
return r * jnp.exp(1j * theta)
def _complex_truncated_normal(key: KeyArray, upper: ArrayLike,
shape: Sequence[int] | core.NamedShape,
shape: Sequence[int],
dtype: DTypeLikeInexact) -> Array:
"""
Sample random values from a centered normal distribution on the complex plane,
@ -317,9 +317,9 @@ def variance_scaling(
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
shape = core.canonicalize_shape(shape)
dtype = dtypes.canonicalize_dtype(dtype)
named_shape = core.as_named_shape(shape)
fan_in, fan_out = _compute_fans(named_shape, in_axis, out_axis, batch_axis)
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis)
if mode == "fan_in": denominator = fan_in
elif mode == "fan_out": denominator = fan_out
elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2
@ -332,18 +332,18 @@ def variance_scaling(
if jnp.issubdtype(dtype, jnp.floating):
# constant is stddev of standard normal truncated to (-2, 2)
stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype)
return random.truncated_normal(key, -2, 2, named_shape, dtype) * stddev
return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
else:
# constant is stddev of complex standard normal truncated to 2
stddev = jnp.sqrt(variance) / jnp.array(.95311164380491208, dtype)
return _complex_truncated_normal(key, 2, named_shape, dtype) * stddev
return _complex_truncated_normal(key, 2, shape, dtype) * stddev
elif distribution == "normal":
return random.normal(key, named_shape, dtype) * jnp.sqrt(variance)
return random.normal(key, shape, dtype) * jnp.sqrt(variance)
elif distribution == "uniform":
if jnp.issubdtype(dtype, jnp.floating):
return random.uniform(key, named_shape, dtype, -1) * jnp.sqrt(3 * variance)
return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance)
else:
return _complex_uniform(key, named_shape, dtype) * jnp.sqrt(variance)
return _complex_uniform(key, shape, dtype) * jnp.sqrt(variance)
else:
raise ValueError(f"invalid distribution for variance scaling initializer: {distribution}")

View File

@ -148,9 +148,7 @@ def _logical_aval_to_interpret_mode_aval(aval):
return aval.update(inner_aval=inner_aval)
if isinstance(aval, jax_core.ShapedArray):
inner_dtype = _logical_to_interpret_mode_dtype(aval.dtype)
return jax_core.ShapedArray(aval.shape,
inner_dtype,
weak_type=aval.weak_type, named_shape=aval.named_shape)
return jax_core.ShapedArray(aval.shape, inner_dtype, weak_type=aval.weak_type)
return aval
def _get_next_indices(grid, indices):

View File

@ -490,7 +490,7 @@ def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
p, _ = _infer_params(fun, jit_info, args, kwargs)
out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']]
# TODO(yashkatariya): Add `Layout` to SDS.
out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape, sharding=s)
out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s)
for x, s in zip(p.params['jaxpr'].out_avals, out_s)]
return tree_unflatten(p.out_tree, out)

View File

@ -44,7 +44,6 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.lax import lax as lax_internal
from jax._src.lax import utils as lax_utils
from jax._src.lib import gpu_prng
from jax._src.lib import xla_client as xc
from jax._src.lib import version as jaxlib_version
@ -611,8 +610,7 @@ batching.defbroadcasting(random_fold_in_p)
def random_fold_in_abstract_eval(keys_aval, msgs_aval):
shape = lax_internal.broadcasting_shape_rule(
'random_fold_in', keys_aval, msgs_aval)
named_shape = lax_utils.standard_named_shape_rule(keys_aval, msgs_aval)
return core.ShapedArray(shape, keys_aval.dtype, named_shape=named_shape)
return core.ShapedArray(shape, keys_aval.dtype)
@random_fold_in_p.def_impl
def random_fold_in_impl(keys, msgs):
@ -640,19 +638,7 @@ mlir.register_lowering(random_fold_in_p, random_fold_in_lowering)
def random_bits(keys, bit_width, shape):
shape = core.as_named_shape(shape)
for name, size in shape.named_items:
# TODO(frostig,mattjj,apaszke): Is this real_size check necessary,
# and is it meant to raise a user-facing ValueError? Should it be
# an `assert` (or RuntimeError) instead? Why do we check it in
# calls to `random_bits` instead of a more common paralleism path?
real_size = lax.psum(1, name)
if real_size != size:
raise ValueError(f"The shape of axis {name} was specified as {size}, "
f"but it really is {real_size}")
axis_index = lax.axis_index(name)
keys = random_fold_in(keys, axis_index)
return random_bits_p.bind(keys, bit_width=bit_width, shape=shape.positional)
return random_bits_p.bind(keys, bit_width=bit_width, shape=shape)
random_bits_p = core.Primitive('random_bits')
ad.defjvp_zero(random_bits_p)
@ -822,8 +808,7 @@ def _threefry2x32_abstract_eval(*args):
.format(args))
if all(isinstance(arg, core.ShapedArray) for arg in args):
shape = lax_internal.broadcasting_shape_rule(*args)
named_shape = core.join_named_shapes(*(a.named_shape for a in args))
aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32), named_shape=named_shape)
aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32))
else:
aval = core.UnshapedArray(jnp.dtype(jnp.uint32))
return (aval,) * 2

View File

@ -35,7 +35,6 @@ from jax._src import dtypes
from jax._src import prng
from jax._src import xla_bridge
from jax._src.api import jit, vmap
from jax._src.core import NamedShape
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -319,12 +318,10 @@ def wrap_key_data(key_bits_array: Array, *,
### random samplers
def _check_shape(name: str, shape: Shape | NamedShape, *param_shapes) -> None:
shape = core.as_named_shape(shape)
def _check_shape(name: str, shape: Shape, *param_shapes) -> None:
if param_shapes:
shape_ = lax.broadcast_shapes(shape.positional, *param_shapes)
if shape.positional != shape_:
shape_ = lax.broadcast_shapes(shape, *param_shapes) # type: ignore
if shape != shape_:
msg = ("{} parameter shapes must be broadcast-compatible with shape "
"argument, and the result of broadcasting the shapes must equal "
"the shape argument, but got result {} for shape argument {}.")
@ -361,7 +358,7 @@ def bits(key: KeyArrayLike,
def uniform(key: KeyArrayLike,
shape: Shape | NamedShape = (),
shape: Shape = (),
dtype: DTypeLikeFloat = float,
minval: RealArray = 0.,
maxval: RealArray = 1.) -> Array:
@ -381,12 +378,12 @@ def uniform(key: KeyArrayLike,
"""
key, _ = _check_prng_key("uniform", key)
dtypes.check_user_dtype_supported(dtype)
shape = core.canonicalize_shape(shape)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `uniform` must be a float dtype, "
f"got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.as_named_shape(shape)
return _uniform(key, shape, dtype, minval, maxval)
@partial(jit, static_argnums=(1, 2))
@ -397,8 +394,8 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array:
minval = lax.convert_element_type(minval, dtype)
maxval = lax.convert_element_type(maxval, dtype)
minval = lax.broadcast_to_rank(minval, shape.positional_rank)
maxval = lax.broadcast_to_rank(maxval, shape.positional_rank)
minval = lax.broadcast_to_rank(minval, len(shape))
maxval = lax.broadcast_to_rank(maxval, len(shape))
finfo = jnp.finfo(dtype)
nbits, nmant = finfo.bits, finfo.nmant
@ -427,7 +424,7 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array:
floats = lax.bitcast_convert_type(float_bits, dtype) - np.array(1., dtype)
return lax.max(
minval,
lax.reshape(floats * (maxval - minval) + minval, shape.positional))
lax.reshape(floats * (maxval - minval) + minval, shape))
def randint(key: KeyArrayLike,
@ -674,7 +671,7 @@ def choice(key: KeyArrayLike,
def normal(key: KeyArrayLike,
shape: Shape | NamedShape = (),
shape: Shape = (),
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample standard normal random values with given shape and float dtype.
@ -696,12 +693,12 @@ def normal(key: KeyArrayLike,
A random array with the specified shape and dtype.
"""
key, _ = _check_prng_key("normal", key)
shape = core.canonicalize_shape(shape)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.inexact):
raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, "
f"got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.as_named_shape(shape)
return _normal(key, shape, dtype)
@partial(jit, static_argnums=(1, 2))
@ -812,7 +809,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array:
def truncated_normal(key: KeyArrayLike,
lower: RealArray,
upper: RealArray,
shape: Shape | NamedShape | None = None,
shape: Shape | None = None,
dtype: DTypeLikeFloat = float) -> Array:
r"""Sample truncated standard normal random values with given shape and dtype.
@ -841,14 +838,14 @@ def truncated_normal(key: KeyArrayLike,
``shape`` is not None, or else by broadcasting ``lower`` and ``upper``.
Returns values in the open interval ``(lower, upper)``.
"""
if shape is not None:
shape = core.canonicalize_shape(shape)
key, _ = _check_prng_key("truncated_normal", key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `truncated_normal` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.as_named_shape(shape)
return _truncated_normal(key, lower, upper, shape, dtype)
@partial(jit, static_argnums=(3, 4))
@ -877,7 +874,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> Array:
def bernoulli(key: KeyArrayLike,
p: RealArray = np.float32(0.5),
shape: Shape | NamedShape | None = None) -> Array:
shape: Shape | None = None) -> Array:
r"""Sample Bernoulli random values with given shape and mean.
The values are distributed according to the probability mass function:
@ -899,10 +896,10 @@ def bernoulli(key: KeyArrayLike,
A random array with boolean dtype and shape given by ``shape`` if ``shape``
is not None, or else ``p.shape``.
"""
if shape is not None:
shape = core.canonicalize_shape(shape)
key, _ = _check_prng_key("bernoulli", key)
dtype = dtypes.canonicalize_dtype(lax.dtype(p))
if shape is not None:
shape = core.as_named_shape(shape)
if not jnp.issubdtype(dtype, np.floating):
msg = "bernoulli probability `p` must have a floating dtype, got {}."
raise TypeError(msg.format(dtype))
@ -1559,7 +1556,7 @@ def categorical(key: KeyArrayLike,
if shape is None:
shape = batch_shape
else:
shape = tuple(shape)
shape = core.canonicalize_shape(shape)
_check_shape("categorical", shape, batch_shape)
shape_prefix = shape[:len(shape)-len(batch_shape)]
@ -2053,6 +2050,7 @@ def orthogonal(
Returns:
A random array of shape `(*shape, n, n)` and specified dtype.
"""
shape = core.canonicalize_shape(shape)
key, _ = _check_prng_key("orthogonal", key)
dtypes.check_user_dtype_supported(dtype)
_check_shape("orthogonal", shape)
@ -2088,6 +2086,7 @@ def generalized_normal(
Returns:
A random array with the specified shape and dtype.
"""
shape = core.canonicalize_shape(shape)
key, _ = _check_prng_key("generalized_normal", key)
dtypes.check_user_dtype_supported(dtype)
_check_shape("generalized_normal", shape)
@ -2118,6 +2117,7 @@ def ball(
Returns:
A random array of shape `(*shape, d)` and specified dtype.
"""
shape = core.canonicalize_shape(shape)
key, _ = _check_prng_key("ball", key)
dtypes.check_user_dtype_supported(dtype)
_check_shape("ball", shape)

View File

@ -218,11 +218,9 @@ def get_ref_state_effects(
if isinstance(eff, (ReadEffect, WriteEffect, AccumEffect))
and eff.input_index == i} for i, _ in enumerate(avals)]
def shaped_array_ref(shape: tuple[int, ...], dtype,
weak_type: bool = False,
named_shape = None) -> AbstractRef:
return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type,
named_shape=named_shape))
def shaped_array_ref(
shape: tuple[int, ...], dtype, weak_type: bool = False) -> AbstractRef:
return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type))
def _shard_ref(mesh, names, ref_aval: AbstractRef):
del mesh

View File

@ -42,7 +42,6 @@ from jax._src.core import (
MainTrace as MainTrace,
MapPrimitive as MapPrimitive,
NameGatheringSubst as NameGatheringSubst,
NamedShape as NamedShape,
OutDBIdx as OutDBIdx,
OutputType as OutputType,
ParamDict as ParamDict,
@ -61,7 +60,6 @@ from jax._src.core import (
Var as Var,
abstract_token as abstract_token,
apply_todos as apply_todos,
as_named_shape as as_named_shape,
aval_mapping_handlers as aval_mapping_handlers,
axis_frame as axis_frame,
call as call,
@ -97,7 +95,6 @@ from jax._src.core import (
jaxpr_uses_outfeed as jaxpr_uses_outfeed,
jaxprs_in_params as jaxprs_in_params,
join_effects as join_effects,
join_named_shapes as join_named_shapes,
lattice_join as lattice_join,
leaked_tracer_error as leaked_tracer_error,
literalable_types as literalable_types,

View File

@ -572,9 +572,7 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames,
aval: core.AbstractValue,) -> core.AbstractValue:
assert isinstance(aval, core.ShapedArray)
return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape)),
named_shape={k: v for k, v in aval.named_shape.items()
if k not in mesh.shape})
for i, sz in enumerate(aval.shape)))
core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array
# Type-checking

View File

@ -2706,6 +2706,8 @@ class APITest(jtu.JaxTestCase):
self.assertEqual(out_shape.shape, (3,))
def test_eval_shape_names(self):
raise unittest.SkipTest("named shape are deprecated")
def fun(x, y):
return lax.psum(x, 'i') + y
@ -6571,6 +6573,7 @@ class JaxprTest(jtu.JaxTestCase):
self.assertIn('psum', str(jaxpr))
def test_make_jaxpr_named(self):
raise unittest.SkipTest("named shape are deprecated")
def f(x):
return x - lax.psum(x, 'i')

View File

@ -550,32 +550,6 @@ class JaxprTypeChecks(jtu.JaxTestCase):
aval = core.raise_to_shaped(core.get_aval(value))
self.assertEqual(aval.weak_type, weak_type)
def test_lattice_join_named_shape(self):
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
self.assertEqual(core.lattice_join(aval1, aval1), aval1)
aval2 = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
expected = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
self.assertEqual(core.lattice_join(aval1, aval2), expected)
aval3 = core.ShapedArray((2, 3), np.float32, False, {'i': 5})
self.assertRaises(TypeError, lambda: core.lattice_join(aval1, aval3))
def test_typecompat_named_shape(self):
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
aval2 = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
self.assertTrue(core.typecompat(aval1, aval2))
aval3 = core.ShapedArray((2, 3), np.float32, False, {'i': 5})
self.assertFalse(core.typecompat(aval1, aval3))
def test_named_shape_comparision(self):
self.assertTrue(core.NamedShape(2, 3) == (2, 3))
self.assertFalse(core.NamedShape(2, i=3) == (2,))
self.assertFalse(core.NamedShape(2, i=3) == (2, 3))
self.assertFalse(core.NamedShape(2, i=3) == None)
self.assertFalse(core.NamedShape() == [])
@jtu.with_config(jax_dynamic_shapes=True)
class DynamicShapesTest(jtu.JaxTestCase):

View File

@ -3287,26 +3287,6 @@ class LazyConstantTest(jtu.JaxTestCase):
expected.astype(np.complex64), lax.log1p(np.complex64(1e-5)))
class LaxNamedShapeTest(jtu.JaxTestCase):
def test_abstract_eval(self):
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
out, _ = lax.sin_p.abstract_eval(aval1)
self.assertEqual(out, aval1)
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
aval2 = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
expected = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
out, _ = lax.add_p.abstract_eval(aval1, aval2)
self.assertEqual(out, expected)
def test_abstract_eval_collective(self):
with core.extend_axis_env('i', 10, None):
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
expected = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
(out,), _ = lax.psum_p.abstract_eval(aval1, axes=('i',), axis_index_groups=None)
self.assertEqual(out, expected)
class FooTyRules:
# handlers