mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #21069 from mattjj:remove-named-shapes
PiperOrigin-RevId: 655766534
This commit is contained in:
commit
086b500da6
@ -381,9 +381,9 @@ def xla_computation(fun: Callable,
|
||||
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
|
||||
wrapped function returns a pair where the first element is the XLA
|
||||
computation and the second element is a pytree with the same structure as
|
||||
the output of ``fun`` and where the leaves are objects with ``shape``,
|
||||
``dtype``, and ``named_shape`` attributes representing the corresponding
|
||||
types of the output leaves.
|
||||
the output of ``fun`` and where the leaves are objects with ``shape`` and
|
||||
``dtype`` attributes representing the corresponding types of the output
|
||||
leaves.
|
||||
donate_argnums: Specify which arguments are "donated" to the computation.
|
||||
It is safe to donate arguments if you no longer need them once the
|
||||
computation has finished. In some cases XLA can make use of donated
|
||||
@ -557,8 +557,8 @@ def xla_computation(fun: Callable,
|
||||
m = mlir.module_to_bytecode(lowering_result.module)
|
||||
built = xc._xla.mlir.mlir_module_to_xla_computation(
|
||||
m, use_tuple_args=tuple_args, return_tuple=True)
|
||||
out_shapes_flat = [
|
||||
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
|
||||
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
|
||||
out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
|
||||
out_shape = tree_unflatten(out_tree(), out_shapes_flat)
|
||||
for out_aval in out_avals:
|
||||
if not isinstance(out_aval, ShapedArray):
|
||||
@ -2337,8 +2337,8 @@ def make_jaxpr(fun: Callable,
|
||||
wrapped function returns a pair where the first element is the
|
||||
``ClosedJaxpr`` representation of ``fun`` and the second element is a
|
||||
pytree with the same structure as the output of ``fun`` and where the
|
||||
leaves are objects with ``shape``, ``dtype``, and ``named_shape``
|
||||
attributes representing the corresponding types of the output leaves.
|
||||
leaves are objects with ``shape`` and ``dtype`` attributes representing
|
||||
the corresponding types of the output leaves.
|
||||
|
||||
Returns:
|
||||
A wrapped version of ``fun`` that when applied to example arguments returns
|
||||
@ -2400,8 +2400,7 @@ def make_jaxpr(fun: Callable,
|
||||
else:
|
||||
jaxpr = traced.jaxpr
|
||||
if return_shape:
|
||||
out = [ShapeDtypeStruct(o.shape, o.dtype, getattr(o, 'named_shape', None))
|
||||
for o in jaxpr.out_avals]
|
||||
out = [ShapeDtypeStruct(o.shape, o.dtype) for o in jaxpr.out_avals]
|
||||
return jaxpr, tree_unflatten(tree_structure(traced.out_info), out)
|
||||
return jaxpr
|
||||
|
||||
@ -2691,12 +2690,13 @@ class ShapeDtypeStruct:
|
||||
Args:
|
||||
shape: a sequence of integers representing an array shape
|
||||
dtype: a dtype-like object
|
||||
named_shape: (optional) a dictionary representing a named shape
|
||||
sharding: (optional) a :class:`jax.Sharding` object
|
||||
"""
|
||||
__slots__ = ["shape", "dtype", "named_shape", "sharding", "_dll"]
|
||||
__slots__ = ["shape", "dtype", "sharding", "_dll"]
|
||||
named_shape = {}
|
||||
|
||||
def __init__(self, shape, dtype, named_shape=None, sharding=None):
|
||||
def __init__(self, shape, dtype, sharding=None, named_shape=None):
|
||||
del named_shape # ignored, vestigial
|
||||
self.shape = tuple(shape)
|
||||
if dtype is None:
|
||||
raise ValueError("ShapeDtypeStruct: dtype must be specified.")
|
||||
@ -2713,7 +2713,6 @@ class ShapeDtypeStruct:
|
||||
f" layout in a `ShapeDtypeStruct`. Got {sharding}")
|
||||
self.sharding = sharding.sharding if isinstance(sharding, Layout) else sharding
|
||||
self._dll = sharding.device_local_layout if isinstance(sharding, Layout) else None
|
||||
self.named_shape = {} if named_shape is None else dict(named_shape)
|
||||
|
||||
size = property(lambda self: math.prod(self.shape))
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
@ -2729,11 +2728,10 @@ class ShapeDtypeStruct:
|
||||
raise TypeError("len() of unsized object") from e # same as numpy error
|
||||
|
||||
def __repr__(self):
|
||||
ns = f", named_shape={self.named_shape}" if self.named_shape else ""
|
||||
sh = f", sharding={self.sharding}" if self.sharding is not None else ""
|
||||
l = f", layout={self.layout}" if self._dll is not None else ""
|
||||
return (f"{type(self).__name__}(shape={self.shape}, "
|
||||
f"dtype={self.dtype.name}{ns}{sh}{l})")
|
||||
f"dtype={self.dtype.name}{sh}{l})")
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
@ -2741,19 +2739,17 @@ class ShapeDtypeStruct:
|
||||
if not isinstance(other, ShapeDtypeStruct):
|
||||
return False
|
||||
else:
|
||||
return ((other.shape, other.dtype, other.named_shape, other.sharding, other.layout) ==
|
||||
(self.shape, self.dtype, self.named_shape, self.sharding, self.layout))
|
||||
return ((other.shape, other.dtype, other.sharding, other.layout) ==
|
||||
(self.shape, self.dtype, self.sharding, self.layout))
|
||||
|
||||
def __hash__(self):
|
||||
# TODO(frostig): avoid the conversion from dict by addressing
|
||||
# https://github.com/google/jax/issues/8182
|
||||
named = frozenset(self.named_shape.items())
|
||||
return hash((self.shape, self.dtype, named, self.sharding, self.layout))
|
||||
|
||||
return hash((self.shape, self.dtype, self.sharding, self.layout))
|
||||
|
||||
core.pytype_aval_mappings[ShapeDtypeStruct] = (
|
||||
lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
|
||||
weak_type=False, named_shape=x.named_shape))
|
||||
weak_type=False))
|
||||
|
||||
|
||||
@api_boundary
|
||||
|
@ -570,15 +570,13 @@ def _shaped_abstractify_slow(x):
|
||||
pass
|
||||
|
||||
weak_type = getattr(x, 'weak_type', False)
|
||||
named_shape = getattr(x, 'named_shape', {})
|
||||
if hasattr(x, 'dtype'):
|
||||
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Cannot interpret value of type {type(x)} as an abstract array; it "
|
||||
"does not have a dtype attribute")
|
||||
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
|
||||
named_shape=named_shape)
|
||||
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type)
|
||||
|
||||
# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior
|
||||
def shaped_abstractify(x):
|
||||
|
@ -466,8 +466,7 @@ class ArrayImpl(basearray.Array):
|
||||
|
||||
def __reduce__(self):
|
||||
fun, args, arr_state = self._value.__reduce__()
|
||||
aval_state = {'weak_type': self.aval.weak_type,
|
||||
'named_shape': self.aval.named_shape}
|
||||
aval_state = {'weak_type': self.aval.weak_type}
|
||||
return (_reconstruct_array, (fun, args, arr_state, aval_state))
|
||||
|
||||
@use_cpp_method()
|
||||
|
176
jax/_src/core.py
176
jax/_src/core.py
@ -1418,9 +1418,6 @@ class AbstractValue:
|
||||
def strip_weak_type(self) -> AbstractValue:
|
||||
return self
|
||||
|
||||
def strip_named_shape(self) -> AbstractValue:
|
||||
return self
|
||||
|
||||
def join(self, other):
|
||||
raise NotImplementedError("must override")
|
||||
|
||||
@ -1695,6 +1692,8 @@ def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]:
|
||||
Returns:
|
||||
A tuple of canonical dimension values.
|
||||
"""
|
||||
if isinstance(shape, int):
|
||||
shape = shape,
|
||||
try:
|
||||
return tuple(unsafe_map(_canonicalize_dimension, shape))
|
||||
except TypeError:
|
||||
@ -1733,25 +1732,25 @@ def _invalid_shape_error(shape: Shape, context: str=""):
|
||||
return TypeError(msg)
|
||||
|
||||
class ShapedArray(UnshapedArray):
|
||||
__slots__ = ['shape', 'named_shape']
|
||||
__slots__ = ['shape']
|
||||
array_abstraction_level = 2
|
||||
named_shape = {}
|
||||
|
||||
def __init__(self, shape, dtype, weak_type=False, named_shape=None):
|
||||
del named_shape # unused, vestigial
|
||||
self.shape = canonicalize_shape(shape)
|
||||
self.dtype = _dtype_object(dtype)
|
||||
self.weak_type = weak_type
|
||||
self.named_shape = {} if named_shape is None else dict(named_shape)
|
||||
|
||||
def update(self, shape=None, dtype=None, weak_type=None, named_shape=None):
|
||||
del named_shape # unused, vestigial
|
||||
if shape is None:
|
||||
shape = self.shape
|
||||
if dtype is None:
|
||||
dtype = self.dtype
|
||||
if weak_type is None:
|
||||
weak_type = self.weak_type
|
||||
if named_shape is None:
|
||||
named_shape = self.named_shape
|
||||
return ShapedArray(shape, dtype, weak_type, named_shape)
|
||||
return ShapedArray(shape, dtype, weak_type)
|
||||
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
size = property(lambda self:
|
||||
@ -1766,25 +1765,22 @@ class ShapedArray(UnshapedArray):
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other)
|
||||
and self.dtype == other.dtype and self.shape == other.shape
|
||||
and self.weak_type == other.weak_type
|
||||
and self.named_shape == other.named_shape)
|
||||
and self.weak_type == other.weak_type)
|
||||
|
||||
def __hash__(self):
|
||||
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
|
||||
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
|
||||
# the unique character code via hash(self.dtype.char)
|
||||
return hash((self.shape, self.dtype, self.weak_type,
|
||||
tuple(self.named_shape.items())))
|
||||
return hash((self.shape, self.dtype, self.weak_type))
|
||||
|
||||
def at_least_vspace(self):
|
||||
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type, self.named_shape)
|
||||
self.weak_type)
|
||||
|
||||
def join(self, other):
|
||||
if definitely_equal_shape(self.shape, other.shape) and self.dtype == other.dtype:
|
||||
weak_type = self.weak_type and other.weak_type
|
||||
named_shape = join_named_shapes(self.named_shape, other.named_shape)
|
||||
return self.update(weak_type=weak_type, named_shape=named_shape)
|
||||
return self.update(weak_type=weak_type)
|
||||
elif self.dtype == other.dtype:
|
||||
return UnshapedArray(self.dtype)
|
||||
else:
|
||||
@ -1794,14 +1790,7 @@ class ShapedArray(UnshapedArray):
|
||||
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
|
||||
dt_str = dt_str.replace('void', 'float0')
|
||||
shapestr = ','.join(map(str, self.shape))
|
||||
if self.named_shape:
|
||||
named_shapestr = ','.join(f'{k}:{v}' for k, v in self.named_shape.items())
|
||||
return f'{dt_str}[{shapestr};{named_shapestr}]'
|
||||
else:
|
||||
return f'{dt_str}[{shapestr}]'
|
||||
|
||||
def strip_named_shape(self):
|
||||
return self.update(named_shape={})
|
||||
return f'{dt_str}[{shapestr}]'
|
||||
|
||||
def _len(self, ignored_tracer):
|
||||
try:
|
||||
@ -1849,12 +1838,9 @@ class ConcreteArray(ShapedArray):
|
||||
return self
|
||||
elif self.shape == other.shape and self.dtype == other.dtype:
|
||||
weak_type = self.weak_type and other.weak_type
|
||||
named_shape = join_named_shapes(self.named_shape, other.named_shape)
|
||||
return ShapedArray(
|
||||
self.shape, self.dtype, weak_type=weak_type, named_shape=named_shape)
|
||||
return ShapedArray(self.shape, self.dtype, weak_type=weak_type)
|
||||
elif self.dtype == other.dtype:
|
||||
return UnshapedArray(self.dtype,
|
||||
weak_type=self.weak_type and other.weak_type)
|
||||
return UnshapedArray(self.dtype, weak_type=self.weak_type and other.weak_type)
|
||||
else:
|
||||
raise TypeError(self, other)
|
||||
|
||||
@ -2090,8 +2076,7 @@ raise_to_shaped_mappings: dict[type, Callable] = {
|
||||
Bot: lambda aval, _: aval,
|
||||
UnshapedArray: lambda aval, _: aval,
|
||||
ShapedArray: lambda aval, weak_type: ShapedArray(
|
||||
aval.shape, aval.dtype, weak_type, aval.named_shape
|
||||
),
|
||||
aval.shape, aval.dtype, weak_type),
|
||||
DConcreteArray: lambda aval, weak_type: DShapedArray(
|
||||
aval.shape, aval.dtype, weak_type
|
||||
),
|
||||
@ -2282,94 +2267,6 @@ def dim_constant(ct: int):
|
||||
def dim_value_aval() -> AbstractValue:
|
||||
return ShapedArray((), dim_value_dtype(), weak_type=True)
|
||||
|
||||
# ------------------- Named shapes -------------------
|
||||
|
||||
|
||||
class NamedShape:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.__positional = canonicalize_shape(args)
|
||||
# TODO: Assert that kwargs match axis env?
|
||||
self.__named = dict(kwargs)
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return len(self.__positional) + len(self.__named)
|
||||
|
||||
@property
|
||||
def positional_rank(self):
|
||||
return len(self.__positional)
|
||||
|
||||
@property
|
||||
def named_rank(self):
|
||||
return len(self.__named)
|
||||
|
||||
@property
|
||||
def positional(self):
|
||||
return self.__positional
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
return self.__named.keys()
|
||||
|
||||
@property
|
||||
def named_sizes(self):
|
||||
return self.__named.values()
|
||||
|
||||
@property
|
||||
def named_items(self):
|
||||
return self.__named.items()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
idx = operator.index(idx)
|
||||
return self.__positional[idx]
|
||||
except TypeError:
|
||||
pass
|
||||
return self.__named[idx]
|
||||
|
||||
@property
|
||||
def total(self):
|
||||
total = 1
|
||||
for s in self.__positional: total *= s
|
||||
for s in self.__named.values(): total *= s
|
||||
return total
|
||||
|
||||
def __str__(self):
|
||||
# TODO(mattjj,frostig): revise not to miss commas
|
||||
if not self.__named:
|
||||
return str(self.__positional)
|
||||
return (f"({', '.join(map(str, self.__positional))}{', ' if self.__named else ''}"
|
||||
f"{', '.join(f'{k}={v}' for k, v in self.__named.items())})")
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, NamedShape):
|
||||
return (self.__positional, self.__named) == (other.__positional, other.__named)
|
||||
if isinstance(other, tuple):
|
||||
return not self.__named and self.__positional == other
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
named = frozenset(self.__named.items())
|
||||
return hash((self.__positional, named))
|
||||
|
||||
def join_named_shapes(*named_shapes):
|
||||
result = {}
|
||||
for named_shape in named_shapes:
|
||||
for name, size in named_shape.items():
|
||||
if result.setdefault(name, size) != size:
|
||||
raise TypeError(
|
||||
f"Axis name {name} used with inconsistent sizes: {result[name]} != {size}")
|
||||
return result
|
||||
|
||||
# TODO: Make canonicalize_shape return named shapes?
|
||||
def as_named_shape(shape) -> NamedShape:
|
||||
if isinstance(shape, int):
|
||||
shape = (shape,)
|
||||
if isinstance(shape, NamedShape):
|
||||
return shape
|
||||
return NamedShape(*shape)
|
||||
|
||||
|
||||
# ------------------- Call -------------------
|
||||
|
||||
class CallPrimitive(Primitive):
|
||||
@ -2574,17 +2471,15 @@ def _map_shaped_array(
|
||||
# TODO: Extend the named shape
|
||||
if axis is None: return aval
|
||||
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
|
||||
named_shape=aval.named_shape, weak_type=aval.weak_type)
|
||||
weak_type=aval.weak_type)
|
||||
|
||||
def _unmap_shaped_array(
|
||||
size: int, axis_name: AxisName, axis: int | None, aval: ShapedArray
|
||||
) -> ShapedArray:
|
||||
named_shape = dict(aval.named_shape)
|
||||
named_shape.pop(axis_name, None) # TODO: make this mandatory
|
||||
if axis is None: return aval.update(named_shape=named_shape)
|
||||
if axis is None: return aval
|
||||
elif type(axis) is int:
|
||||
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
|
||||
named_shape=named_shape, weak_type=aval.weak_type)
|
||||
weak_type=aval.weak_type)
|
||||
else: raise TypeError(axis)
|
||||
|
||||
def _map_dshaped_array(
|
||||
@ -2780,16 +2675,8 @@ def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: dict[Var, Var]) -> V
|
||||
# Var identity is load-bearing, so we can't have duplicates!
|
||||
if isinstance(v, DropVar): return v
|
||||
assert v not in var_map
|
||||
if not hasattr(v.aval, 'named_shape'):
|
||||
var_map[v] = v
|
||||
return v
|
||||
names = tuple(it.chain.from_iterable(subst(name) for name in v.aval.named_shape))
|
||||
named_shape = {name: axis_frame(name).size for name in names}
|
||||
if len(named_shape) != len(names):
|
||||
raise DuplicateAxisNameError(v)
|
||||
new_v = Var(v.suffix, v.aval.update(named_shape=named_shape))
|
||||
var_map[v] = new_v
|
||||
return new_v
|
||||
var_map[v] = v
|
||||
return v
|
||||
|
||||
def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: dict[Var, Var]) -> JaxprEqn:
|
||||
invars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in eqn.invars]
|
||||
@ -2857,31 +2744,20 @@ def typecheck(aval: AbstractValue, x) -> bool:
|
||||
return typecompat(aval, get_aval(x))
|
||||
|
||||
def typecompat(aval_ref: AbstractValue, aval: AbstractValue) -> bool:
|
||||
"""Determine whether `aval` conforms to `aval_ref`.
|
||||
|
||||
Ignores weak_type and named_shape, other than to check that an axis name isn't
|
||||
used with different sizes.
|
||||
"""
|
||||
"""Determine whether `aval` conforms to `aval_ref`. Ignores weak_type."""
|
||||
try:
|
||||
return typematch(aval_ref, lattice_join(aval_ref, aval))
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
def typematch(aval1: AbstractValue, aval2: AbstractValue) -> bool:
|
||||
"""Determine whether `aval1` and `aval2` are equivalent.
|
||||
|
||||
Ignores weak_type and named_shape, other than to check that an axis name isn't
|
||||
used with different sizes.
|
||||
"""
|
||||
"""Determine whether `aval1` and `aval2` are equivalent. Ignores weak_type."""
|
||||
if aval1 == aval2: return True
|
||||
# unequal avals may still represent the same type, because type is represented
|
||||
# by avals at the shaped level, and because weak type tags and (for now) named
|
||||
# shape components aren't considered part of the type
|
||||
if isinstance(aval1, ShapedArray) and isinstance(aval2, ShapedArray):
|
||||
# a bonus check for whether any named axes have inconsistent sizes
|
||||
join_named_shapes(aval1.named_shape, aval2.named_shape)
|
||||
return (raise_to_shaped(aval1, weak_type=False).strip_named_shape() ==
|
||||
raise_to_shaped(aval2, weak_type=False).strip_named_shape())
|
||||
# by avals at the shaped level, and because weak type tags aren't considered
|
||||
# part of the type
|
||||
return (raise_to_shaped(aval1, weak_type=False) ==
|
||||
raise_to_shaped(aval2, weak_type=False))
|
||||
|
||||
class JaxprTypeError(TypeError): pass
|
||||
|
||||
|
@ -327,13 +327,11 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
|
||||
f""" {str(ty_tree_).replace("'", "")}""")
|
||||
raise TypeError(m)
|
||||
# TODO(mattjj): compare primals' tangent types to tangent objects' types
|
||||
primal_avals_out = [
|
||||
raise_to_shaped(core.get_aval(x), weak_type=False).strip_named_shape()
|
||||
for x in primals_out]
|
||||
tangent_avals_out = [
|
||||
raise_to_shaped(core.get_aval(t), weak_type=False).strip_named_shape()
|
||||
if type(t) is not SymbolicZero else t.aval.strip_weak_type()
|
||||
for t in tangents_out]
|
||||
primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False)
|
||||
for x in primals_out]
|
||||
tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False)
|
||||
if type(t) is not SymbolicZero else t.aval.strip_weak_type()
|
||||
for t in tangents_out]
|
||||
if primal_avals_out != tangent_avals_out:
|
||||
if len(primal_avals_out) == 1:
|
||||
(av1,), (av2,) = primal_avals_out, tangent_avals_out
|
||||
|
@ -1237,8 +1237,7 @@ def _call_exported_abstract_eval(
|
||||
out_avals = tuple(
|
||||
core.ShapedArray(core.evaluate_shape(out_aval.shape, exported_dim_vars,
|
||||
*exported_dim_values),
|
||||
dtype=out_aval.dtype, weak_type=out_aval.weak_type,
|
||||
named_shape=out_aval.named_shape)
|
||||
dtype=out_aval.dtype, weak_type=out_aval.weak_type)
|
||||
for out_aval in exported.out_avals)
|
||||
return out_avals, set(exported.ordered_effects + exported.unordered_effects)
|
||||
|
||||
|
@ -200,8 +200,8 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
|
||||
# TODO(mattjj): add back these checks for dynamic shapes
|
||||
# if config.enable_checks.value:
|
||||
# ct_aval = core.get_aval(ct_env[v])
|
||||
# joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type().strip_named_shape()
|
||||
# assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (prim, v.aval, ct_aval)
|
||||
# joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type()
|
||||
# assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval)
|
||||
|
||||
def read_cotangent(v):
|
||||
return ct_env.pop(v, Zero(v.aval.at_least_vspace()))
|
||||
|
@ -1091,7 +1091,7 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args,
|
||||
input_output_aliases = list(input_output_aliases)
|
||||
# To match-up in-avals to out-avals we only care about the number of
|
||||
# bytes, so we strip off unrelated aval metadata (eg. the named shape)
|
||||
strip_metadata = lambda a: a.strip_named_shape().strip_weak_type()
|
||||
strip_metadata = lambda a: a.strip_weak_type()
|
||||
avals_in = map(strip_metadata, avals_in)
|
||||
avals_out = map(strip_metadata, avals_out)
|
||||
|
||||
|
@ -854,12 +854,10 @@ def lower_parallel_callable(
|
||||
def _pmap_unmap_shaped_array(
|
||||
size: int, axis_name: core.AxisName, axis: int | None, aval: ShapedArray
|
||||
) -> ShapedArray:
|
||||
named_shape = dict(aval.named_shape)
|
||||
named_shape.pop(axis_name, None) # TODO: make this mandatory
|
||||
if axis is None: return aval.update(named_shape=named_shape)
|
||||
if axis is None: return aval
|
||||
elif type(axis) is int:
|
||||
return ShapedArray(tuple_update(aval.shape, axis, size), aval.dtype,
|
||||
named_shape=named_shape, weak_type=aval.weak_type)
|
||||
weak_type=aval.weak_type)
|
||||
else: raise TypeError(axis)
|
||||
|
||||
|
||||
@ -1507,22 +1505,17 @@ mlir.register_lowering(xla_pmap_p, _pmap_lowering)
|
||||
def tile_aval_nd(axis_sizes, in_axes: ArrayMapping, aval):
|
||||
assert isinstance(aval, ShapedArray)
|
||||
shape = list(aval.shape)
|
||||
named_shape = dict(aval.named_shape)
|
||||
for name, axis in in_axes.items():
|
||||
assert shape[axis] % axis_sizes[name] == 0
|
||||
assert name not in named_shape
|
||||
named_shape[name] = axis_sizes[name]
|
||||
shape[axis] //= axis_sizes[name]
|
||||
return aval.update(shape=tuple(shape), named_shape=named_shape)
|
||||
return aval.update(shape=tuple(shape))
|
||||
|
||||
def untile_aval_nd(axis_sizes, out_axes: ArrayMapping, aval):
|
||||
assert isinstance(aval, ShapedArray)
|
||||
shape = list(aval.shape)
|
||||
named_shape = dict(aval.named_shape)
|
||||
for name, axis in out_axes.items():
|
||||
shape[axis] *= axis_sizes[name]
|
||||
named_shape.pop(name, None) # The name might be missing --- it's a broadcast.
|
||||
return aval.update(shape=tuple(shape), named_shape=named_shape)
|
||||
return aval.update(shape=tuple(shape))
|
||||
|
||||
|
||||
def mesh_local_to_global(mesh, axes: ArrayMapping, aval):
|
||||
|
@ -1321,7 +1321,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
raise TypeError(msg.format(cond_tree))
|
||||
pred_aval = cond_jaxpr.out_avals[0]
|
||||
if (not isinstance(pred_aval, ShapedArray)
|
||||
or pred_aval.strip_weak_type().strip_named_shape() != ShapedArray((), np.bool_)):
|
||||
or pred_aval.strip_weak_type() != ShapedArray((), np.bool_)):
|
||||
msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
|
||||
raise TypeError(msg.format(cond_jaxpr.out_avals))
|
||||
return init_vals, init_avals, body_jaxpr, in_tree, cond_jaxpr, cond_consts, body_consts, body_tree
|
||||
|
@ -59,8 +59,7 @@ from jax._src.interpreters.batching import RaggedAxis
|
||||
from jax._src.lax import slicing
|
||||
from jax._src.lax.utils import (
|
||||
_input_dtype, dtype_to_string, standard_abstract_eval,
|
||||
standard_multi_result_abstract_eval, standard_named_shape_rule,
|
||||
standard_primitive)
|
||||
standard_multi_result_abstract_eval, standard_primitive)
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib.mlir import ir
|
||||
@ -2563,7 +2562,7 @@ convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_elemen
|
||||
convert_element_type_p.def_abstract_eval(
|
||||
partial(standard_abstract_eval, convert_element_type_p,
|
||||
_convert_element_type_shape_rule, _convert_element_type_dtype_rule,
|
||||
_convert_element_type_weak_type_rule, standard_named_shape_rule))
|
||||
_convert_element_type_weak_type_rule))
|
||||
ad.defjvp(convert_element_type_p, _convert_element_type_jvp_rule)
|
||||
ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule
|
||||
batching.defvectorized(convert_element_type_p)
|
||||
@ -3360,7 +3359,7 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions):
|
||||
type(core.get_aval(d).dtype) is core.bint for d in shape)):
|
||||
shape = _broadcast_in_dim_shape_rule( # error checking
|
||||
x, shape=shape, broadcast_dimensions=broadcast_dimensions)
|
||||
return core.ShapedArray(shape, x.dtype, x.weak_type, x.named_shape)
|
||||
return core.ShapedArray(shape, x.dtype, x.weak_type)
|
||||
# If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray
|
||||
# (even if x is a ShapedArray)
|
||||
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
|
||||
@ -4057,25 +4056,12 @@ def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr, dimensions):
|
||||
reducer = core.jaxpr_as_fun(jaxpr)
|
||||
return _reduce_jvp(reducer, init_values, primal_xs, tangent_xs, dimensions)
|
||||
|
||||
def _reduce_named_shape_rule(*avals, computation, jaxpr, dimensions):
|
||||
# TODO(mattjj,frostig): see the TODOs noting limitations/assumptions in
|
||||
# _reduce_batching_rule. We're making the same assumptions here for now.
|
||||
num_operands = len(avals) // 2
|
||||
operand_avals, init_avals = split_list(avals, [num_operands])
|
||||
if any(a.named_shape for a in init_avals):
|
||||
raise NotImplementedError
|
||||
named_shapes = [a.named_shape for a in operand_avals]
|
||||
join = core.join_named_shapes(*(a.named_shape for a in operand_avals))
|
||||
return [join] * len(named_shapes)
|
||||
|
||||
|
||||
reduce_p = core.Primitive('reduce')
|
||||
reduce_p.multiple_results = True
|
||||
reduce_p.def_impl(partial(dispatch.apply_primitive, reduce_p))
|
||||
reduce_p.def_abstract_eval(
|
||||
partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule,
|
||||
_reduce_dtype_rule, _reduce_weak_type_rule,
|
||||
_reduce_named_shape_rule))
|
||||
_reduce_dtype_rule, _reduce_weak_type_rule))
|
||||
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
|
||||
ad.primitive_jvps[reduce_p] = _reduce_jvp_rule
|
||||
|
||||
@ -4839,9 +4825,6 @@ def _rng_bit_generator_lowering(
|
||||
return [out_key, out_vals]
|
||||
|
||||
|
||||
def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
|
||||
return [key.named_shape, key.named_shape]
|
||||
|
||||
rng_bit_generator_p = Primitive("rng_bit_generator")
|
||||
rng_bit_generator_p.multiple_results = True
|
||||
rng_bit_generator_p.def_impl(
|
||||
@ -4849,8 +4832,7 @@ rng_bit_generator_p.def_impl(
|
||||
rng_bit_generator_p.def_abstract_eval(
|
||||
partial(standard_multi_result_abstract_eval, rng_bit_generator_p,
|
||||
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
|
||||
_rng_bit_generator_weak_type_rule,
|
||||
_rng_bit_generator_named_shape_rule))
|
||||
_rng_bit_generator_weak_type_rule))
|
||||
mlir.register_lowering(rng_bit_generator_p,
|
||||
_rng_bit_generator_lowering)
|
||||
|
||||
|
@ -738,21 +738,15 @@ def _allreduce_impl(pos_reducer, *args, axes, axis_index_groups):
|
||||
return [pos_reducer(arg, axes) for arg in args]
|
||||
|
||||
def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
|
||||
# TODO(frostig,mattjj,jekbradbury): maybe check aval names here
|
||||
named_axes = tuple(axis for axis in axes if not isinstance(axis, int))
|
||||
pos_axes = tuple(axis for axis in axes if isinstance(axis, int))
|
||||
named_shapes = [arg.named_shape for arg in args]
|
||||
named_axes = {axis for axis in axes if not isinstance(axis, int)}
|
||||
if axis_index_groups is None:
|
||||
named_shapes = [{name: size for name, size in arg.named_shape.items()
|
||||
if name not in named_axes} for arg in args]
|
||||
else:
|
||||
if axis_index_groups is not None:
|
||||
if len(pos_axes) != 0:
|
||||
raise ValueError(f"axis_index_groups can only be used with reductions over "
|
||||
f"named axes, but got: {axes}")
|
||||
out_avals = [
|
||||
ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes),
|
||||
arg.dtype, named_shape=named_shape)
|
||||
for arg, named_shape in zip(args, named_shapes)]
|
||||
arg.dtype) for arg in args]
|
||||
return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes}
|
||||
|
||||
def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
|
||||
@ -1301,12 +1295,7 @@ def _all_gather_effectful_abstract_eval(
|
||||
new_shape[all_gather_dimension] *= axis_size
|
||||
else:
|
||||
new_shape.insert(all_gather_dimension, axis_size)
|
||||
new_named_shape = {name: size for name, size in x_aval.named_shape.items()
|
||||
if name not in axis_name}
|
||||
out_aval = x_aval.update(shape=new_shape, named_shape=new_named_shape)
|
||||
effects = {*map(core.NamedAxisEffect, axis_name)}
|
||||
return out_aval, effects
|
||||
|
||||
return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)}
|
||||
|
||||
def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled):
|
||||
return (psum_scatter(cts, axis_name=axis_name,
|
||||
@ -1437,15 +1426,7 @@ def _reduce_scatter_effectful_abstract_eval(
|
||||
f"{scatter_dim_input_size} must match shard count "
|
||||
f"{axis_size}")
|
||||
del new_shape[scatter_dimension]
|
||||
|
||||
new_named_shape = {
|
||||
name: size
|
||||
for name, size in x_aval.named_shape.items()
|
||||
if name not in axis_name
|
||||
}
|
||||
out_aval = x_aval.update(shape=new_shape, named_shape=new_named_shape)
|
||||
effects = {*map(core.NamedAxisEffect, axis_name)}
|
||||
return out_aval, effects
|
||||
return x_aval.update(shape=new_shape), {*map(core.NamedAxisEffect, axis_name)}
|
||||
|
||||
|
||||
def _reduce_scatter_transpose_rule(cts, x, *, axis_name, scatter_dimension,
|
||||
@ -1633,9 +1614,7 @@ def _axis_index_lowering(ctx, *, axis_name):
|
||||
|
||||
def _axis_index_effectful_abstract_eval(*, axis_name):
|
||||
frame = core.axis_frame(axis_name)
|
||||
out_aval = ShapedArray((), np.int32, named_shape={axis_name: frame.size})
|
||||
return out_aval, {core.NamedAxisEffect(axis_name)}
|
||||
|
||||
return ShapedArray((), np.int32), {core.NamedAxisEffect(axis_name)}
|
||||
|
||||
axis_index_p = core.Primitive('axis_index')
|
||||
mlir.register_lowering(axis_index_p, _axis_index_lowering)
|
||||
@ -1691,14 +1670,7 @@ def _pdot_effectful_abstract_eval(
|
||||
pos_aval = lax.dot_general_p.abstract_eval(
|
||||
x, y, dimension_numbers=[pos_contract, pos_batch],
|
||||
precision=precision, preferred_element_type=None)[0]
|
||||
common_named_shape = core.join_named_shapes(x.named_shape, y.named_shape)
|
||||
named_shape = {name: size
|
||||
for name, size in common_named_shape.items()
|
||||
if name not in axis_name}
|
||||
out_aval = pos_aval.update(named_shape=named_shape)
|
||||
effects = {*map(core.NamedAxisEffect, axis_name)}
|
||||
return out_aval, effects
|
||||
|
||||
return pos_aval, {*map(core.NamedAxisEffect, axis_name)}
|
||||
|
||||
def _pdot_vmap_collective_rule(axis_size, frame_name, _, vals_in, dims_in, *, axis_name,
|
||||
pos_contract, pos_batch, precision):
|
||||
|
@ -24,6 +24,8 @@ from jax._src import dtypes
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
import numpy as np
|
||||
|
||||
xops = xla_client.ops
|
||||
@ -35,20 +37,19 @@ def _argnum_weak_type(*argnums):
|
||||
return lambda *args, **_: all(args[i].weak_type for i in argnums)
|
||||
|
||||
def standard_primitive(shape_rule, dtype_rule, name,
|
||||
weak_type_rule=None, named_shape_rule=None):
|
||||
weak_type_rule=None):
|
||||
weak_type_rule = weak_type_rule or _standard_weak_type_rule
|
||||
named_shape_rule = named_shape_rule or standard_named_shape_rule
|
||||
prim = core.Primitive(name)
|
||||
prim.def_impl(partial(dispatch.apply_primitive, prim))
|
||||
prim.def_abstract_eval(
|
||||
partial(standard_abstract_eval, prim, shape_rule, dtype_rule,
|
||||
weak_type_rule, named_shape_rule))
|
||||
weak_type_rule))
|
||||
return prim
|
||||
|
||||
def _get_array_abstraction_level(a): return a.array_abstraction_level
|
||||
|
||||
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
|
||||
named_shape_rule, *avals, **kwargs):
|
||||
*avals, **kwargs):
|
||||
assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
|
||||
assert not prim.multiple_results
|
||||
weak_type = weak_type_rule(*avals, **kwargs)
|
||||
@ -58,8 +59,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
|
||||
return core.ConcreteArray(out.dtype, out, weak_type=weak_type)
|
||||
elif least_specialized is core.ShapedArray:
|
||||
return core.ShapedArray(shape_rule(*avals, **kwargs),
|
||||
dtype_rule(*avals, **kwargs), weak_type=weak_type,
|
||||
named_shape=named_shape_rule(*avals, **kwargs))
|
||||
dtype_rule(*avals, **kwargs), weak_type=weak_type)
|
||||
elif least_specialized is core.DShapedArray:
|
||||
shape = shape_rule(*avals, **kwargs)
|
||||
ty = (core.ShapedArray if all(type(d) is int for d in shape)
|
||||
@ -71,8 +71,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
|
||||
raise TypeError(avals, least_specialized)
|
||||
|
||||
def standard_multi_result_abstract_eval(
|
||||
prim, shape_rule, dtype_rule, weak_type_rule,
|
||||
named_shape_rule, *avals, **kwargs):
|
||||
prim, shape_rule, dtype_rule, weak_type_rule, *avals, **kwargs):
|
||||
assert prim.multiple_results
|
||||
assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
|
||||
least_specialized = max(map(type, avals), key=_get_array_abstraction_level)
|
||||
@ -80,18 +79,16 @@ def standard_multi_result_abstract_eval(
|
||||
if least_specialized is core.ConcreteArray:
|
||||
out_vals = prim.impl(*[x.val for x in avals], **kwargs)
|
||||
return [core.ConcreteArray(val.dtype, val, weak_type=weak_type)
|
||||
for val, weak_type in safe_zip(out_vals, weak_types)]
|
||||
for val, weak_type in zip(out_vals, weak_types)]
|
||||
elif least_specialized is core.ShapedArray:
|
||||
out_shapes = shape_rule(*avals, **kwargs)
|
||||
out_dtypes = dtype_rule(*avals, **kwargs)
|
||||
out_named_shapes = named_shape_rule(*avals, **kwargs)
|
||||
return [core.ShapedArray(s, d, weak_type=weak_type, named_shape=named_shape)
|
||||
for s, d, weak_type, named_shape
|
||||
in safe_zip(out_shapes, out_dtypes, weak_types, out_named_shapes)]
|
||||
return [core.ShapedArray(s, d, weak_type=weak_type)
|
||||
for s, d, weak_type in zip(out_shapes, out_dtypes, weak_types)]
|
||||
elif least_specialized is core.UnshapedArray:
|
||||
out_dtypes = dtype_rule(*avals, **kwargs)
|
||||
return [core.UnshapedArray(dtype, weak_type=weak_type)
|
||||
for dtype, weak_type in safe_zip(out_dtypes, weak_types)]
|
||||
for dtype, weak_type in zip(out_dtypes, weak_types)]
|
||||
else:
|
||||
raise TypeError(avals, least_specialized)
|
||||
|
||||
@ -103,9 +100,6 @@ def standard_translate(prim):
|
||||
return [op(*args, **kwargs)]
|
||||
return translation_rule
|
||||
|
||||
def standard_named_shape_rule(*avals, **kwargs):
|
||||
return core.join_named_shapes(*(a.named_shape for a in avals))
|
||||
|
||||
def _standard_weak_type_rule(*avals, **kwargs):
|
||||
return all(aval.weak_type for aval in avals)
|
||||
|
||||
|
@ -198,19 +198,19 @@ def truncated_normal(stddev: RealNumeric = 1e-2,
|
||||
return init
|
||||
|
||||
@export
|
||||
def _compute_fans(shape: core.NamedShape,
|
||||
def _compute_fans(shape: Sequence[int],
|
||||
in_axis: int | Sequence[int] = -2,
|
||||
out_axis: int | Sequence[int] = -1,
|
||||
batch_axis: int | Sequence[int] = ()
|
||||
) -> tuple[Array, Array]:
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Compute effective input and output sizes for a linear or convolutional layer.
|
||||
|
||||
Axes not in in_axis, out_axis, or batch_axis are assumed to constitute the
|
||||
"receptive field" of a convolution (kernel spatial dimensions).
|
||||
"""
|
||||
if shape.rank <= 1:
|
||||
raise ValueError(f"Can't compute input and output sizes of a {shape.rank}"
|
||||
if len(shape) <= 1:
|
||||
raise ValueError(f"Can't compute input and output sizes of a {len(shape)}"
|
||||
"-dimensional weights tensor. Must be at least 2D.")
|
||||
|
||||
if isinstance(in_axis, int):
|
||||
@ -225,13 +225,13 @@ def _compute_fans(shape: core.NamedShape,
|
||||
batch_size = shape[batch_axis]
|
||||
else:
|
||||
batch_size = math.prod([shape[i] for i in batch_axis])
|
||||
receptive_field_size = shape.total / in_size / out_size / batch_size
|
||||
receptive_field_size = math.prod(shape) / in_size / out_size / batch_size
|
||||
fan_in = in_size * receptive_field_size
|
||||
fan_out = out_size * receptive_field_size
|
||||
return fan_in, fan_out
|
||||
|
||||
def _complex_uniform(key: KeyArray,
|
||||
shape: Sequence[int] | core.NamedShape,
|
||||
shape: Sequence[int],
|
||||
dtype: DTypeLikeInexact) -> Array:
|
||||
"""
|
||||
Sample uniform random values within a disk on the complex plane,
|
||||
@ -245,7 +245,7 @@ def _complex_uniform(key: KeyArray,
|
||||
return r * jnp.exp(1j * theta)
|
||||
|
||||
def _complex_truncated_normal(key: KeyArray, upper: ArrayLike,
|
||||
shape: Sequence[int] | core.NamedShape,
|
||||
shape: Sequence[int],
|
||||
dtype: DTypeLikeInexact) -> Array:
|
||||
"""
|
||||
Sample random values from a centered normal distribution on the complex plane,
|
||||
@ -317,9 +317,9 @@ def variance_scaling(
|
||||
def init(key: KeyArray,
|
||||
shape: core.Shape,
|
||||
dtype: DTypeLikeInexact = dtype) -> Array:
|
||||
shape = core.canonicalize_shape(shape)
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
named_shape = core.as_named_shape(shape)
|
||||
fan_in, fan_out = _compute_fans(named_shape, in_axis, out_axis, batch_axis)
|
||||
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis)
|
||||
if mode == "fan_in": denominator = fan_in
|
||||
elif mode == "fan_out": denominator = fan_out
|
||||
elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2
|
||||
@ -332,18 +332,18 @@ def variance_scaling(
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
# constant is stddev of standard normal truncated to (-2, 2)
|
||||
stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype)
|
||||
return random.truncated_normal(key, -2, 2, named_shape, dtype) * stddev
|
||||
return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
|
||||
else:
|
||||
# constant is stddev of complex standard normal truncated to 2
|
||||
stddev = jnp.sqrt(variance) / jnp.array(.95311164380491208, dtype)
|
||||
return _complex_truncated_normal(key, 2, named_shape, dtype) * stddev
|
||||
return _complex_truncated_normal(key, 2, shape, dtype) * stddev
|
||||
elif distribution == "normal":
|
||||
return random.normal(key, named_shape, dtype) * jnp.sqrt(variance)
|
||||
return random.normal(key, shape, dtype) * jnp.sqrt(variance)
|
||||
elif distribution == "uniform":
|
||||
if jnp.issubdtype(dtype, jnp.floating):
|
||||
return random.uniform(key, named_shape, dtype, -1) * jnp.sqrt(3 * variance)
|
||||
return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance)
|
||||
else:
|
||||
return _complex_uniform(key, named_shape, dtype) * jnp.sqrt(variance)
|
||||
return _complex_uniform(key, shape, dtype) * jnp.sqrt(variance)
|
||||
else:
|
||||
raise ValueError(f"invalid distribution for variance scaling initializer: {distribution}")
|
||||
|
||||
|
@ -148,9 +148,7 @@ def _logical_aval_to_interpret_mode_aval(aval):
|
||||
return aval.update(inner_aval=inner_aval)
|
||||
if isinstance(aval, jax_core.ShapedArray):
|
||||
inner_dtype = _logical_to_interpret_mode_dtype(aval.dtype)
|
||||
return jax_core.ShapedArray(aval.shape,
|
||||
inner_dtype,
|
||||
weak_type=aval.weak_type, named_shape=aval.named_shape)
|
||||
return jax_core.ShapedArray(aval.shape, inner_dtype, weak_type=aval.weak_type)
|
||||
return aval
|
||||
|
||||
def _get_next_indices(grid, indices):
|
||||
|
@ -490,7 +490,7 @@ def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
|
||||
p, _ = _infer_params(fun, jit_info, args, kwargs)
|
||||
out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']]
|
||||
# TODO(yashkatariya): Add `Layout` to SDS.
|
||||
out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape, sharding=s)
|
||||
out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s)
|
||||
for x, s in zip(p.params['jaxpr'].out_avals, out_s)]
|
||||
return tree_unflatten(p.out_tree, out)
|
||||
|
||||
|
@ -44,7 +44,6 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import utils as lax_utils
|
||||
from jax._src.lib import gpu_prng
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import version as jaxlib_version
|
||||
@ -611,8 +610,7 @@ batching.defbroadcasting(random_fold_in_p)
|
||||
def random_fold_in_abstract_eval(keys_aval, msgs_aval):
|
||||
shape = lax_internal.broadcasting_shape_rule(
|
||||
'random_fold_in', keys_aval, msgs_aval)
|
||||
named_shape = lax_utils.standard_named_shape_rule(keys_aval, msgs_aval)
|
||||
return core.ShapedArray(shape, keys_aval.dtype, named_shape=named_shape)
|
||||
return core.ShapedArray(shape, keys_aval.dtype)
|
||||
|
||||
@random_fold_in_p.def_impl
|
||||
def random_fold_in_impl(keys, msgs):
|
||||
@ -640,19 +638,7 @@ mlir.register_lowering(random_fold_in_p, random_fold_in_lowering)
|
||||
|
||||
|
||||
def random_bits(keys, bit_width, shape):
|
||||
shape = core.as_named_shape(shape)
|
||||
for name, size in shape.named_items:
|
||||
# TODO(frostig,mattjj,apaszke): Is this real_size check necessary,
|
||||
# and is it meant to raise a user-facing ValueError? Should it be
|
||||
# an `assert` (or RuntimeError) instead? Why do we check it in
|
||||
# calls to `random_bits` instead of a more common paralleism path?
|
||||
real_size = lax.psum(1, name)
|
||||
if real_size != size:
|
||||
raise ValueError(f"The shape of axis {name} was specified as {size}, "
|
||||
f"but it really is {real_size}")
|
||||
axis_index = lax.axis_index(name)
|
||||
keys = random_fold_in(keys, axis_index)
|
||||
return random_bits_p.bind(keys, bit_width=bit_width, shape=shape.positional)
|
||||
return random_bits_p.bind(keys, bit_width=bit_width, shape=shape)
|
||||
|
||||
random_bits_p = core.Primitive('random_bits')
|
||||
ad.defjvp_zero(random_bits_p)
|
||||
@ -822,8 +808,7 @@ def _threefry2x32_abstract_eval(*args):
|
||||
.format(args))
|
||||
if all(isinstance(arg, core.ShapedArray) for arg in args):
|
||||
shape = lax_internal.broadcasting_shape_rule(*args)
|
||||
named_shape = core.join_named_shapes(*(a.named_shape for a in args))
|
||||
aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32), named_shape=named_shape)
|
||||
aval = core.ShapedArray(shape, jnp.dtype(jnp.uint32))
|
||||
else:
|
||||
aval = core.UnshapedArray(jnp.dtype(jnp.uint32))
|
||||
return (aval,) * 2
|
||||
|
@ -35,7 +35,6 @@ from jax._src import dtypes
|
||||
from jax._src import prng
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.core import NamedShape
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -319,12 +318,10 @@ def wrap_key_data(key_bits_array: Array, *,
|
||||
### random samplers
|
||||
|
||||
|
||||
def _check_shape(name: str, shape: Shape | NamedShape, *param_shapes) -> None:
|
||||
shape = core.as_named_shape(shape)
|
||||
|
||||
def _check_shape(name: str, shape: Shape, *param_shapes) -> None:
|
||||
if param_shapes:
|
||||
shape_ = lax.broadcast_shapes(shape.positional, *param_shapes)
|
||||
if shape.positional != shape_:
|
||||
shape_ = lax.broadcast_shapes(shape, *param_shapes) # type: ignore
|
||||
if shape != shape_:
|
||||
msg = ("{} parameter shapes must be broadcast-compatible with shape "
|
||||
"argument, and the result of broadcasting the shapes must equal "
|
||||
"the shape argument, but got result {} for shape argument {}.")
|
||||
@ -361,7 +358,7 @@ def bits(key: KeyArrayLike,
|
||||
|
||||
|
||||
def uniform(key: KeyArrayLike,
|
||||
shape: Shape | NamedShape = (),
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = float,
|
||||
minval: RealArray = 0.,
|
||||
maxval: RealArray = 1.) -> Array:
|
||||
@ -381,12 +378,12 @@ def uniform(key: KeyArrayLike,
|
||||
"""
|
||||
key, _ = _check_prng_key("uniform", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
shape = core.canonicalize_shape(shape)
|
||||
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `uniform` must be a float dtype, "
|
||||
f"got {dtype}")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
shape = core.as_named_shape(shape)
|
||||
return _uniform(key, shape, dtype, minval, maxval)
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
@ -397,8 +394,8 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array:
|
||||
|
||||
minval = lax.convert_element_type(minval, dtype)
|
||||
maxval = lax.convert_element_type(maxval, dtype)
|
||||
minval = lax.broadcast_to_rank(minval, shape.positional_rank)
|
||||
maxval = lax.broadcast_to_rank(maxval, shape.positional_rank)
|
||||
minval = lax.broadcast_to_rank(minval, len(shape))
|
||||
maxval = lax.broadcast_to_rank(maxval, len(shape))
|
||||
|
||||
finfo = jnp.finfo(dtype)
|
||||
nbits, nmant = finfo.bits, finfo.nmant
|
||||
@ -427,7 +424,7 @@ def _uniform(key, shape, dtype, minval, maxval) -> Array:
|
||||
floats = lax.bitcast_convert_type(float_bits, dtype) - np.array(1., dtype)
|
||||
return lax.max(
|
||||
minval,
|
||||
lax.reshape(floats * (maxval - minval) + minval, shape.positional))
|
||||
lax.reshape(floats * (maxval - minval) + minval, shape))
|
||||
|
||||
|
||||
def randint(key: KeyArrayLike,
|
||||
@ -674,7 +671,7 @@ def choice(key: KeyArrayLike,
|
||||
|
||||
|
||||
def normal(key: KeyArrayLike,
|
||||
shape: Shape | NamedShape = (),
|
||||
shape: Shape = (),
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample standard normal random values with given shape and float dtype.
|
||||
|
||||
@ -696,12 +693,12 @@ def normal(key: KeyArrayLike,
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
key, _ = _check_prng_key("normal", key)
|
||||
shape = core.canonicalize_shape(shape)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.inexact):
|
||||
raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, "
|
||||
f"got {dtype}")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
shape = core.as_named_shape(shape)
|
||||
return _normal(key, shape, dtype)
|
||||
|
||||
@partial(jit, static_argnums=(1, 2))
|
||||
@ -812,7 +809,7 @@ def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array:
|
||||
def truncated_normal(key: KeyArrayLike,
|
||||
lower: RealArray,
|
||||
upper: RealArray,
|
||||
shape: Shape | NamedShape | None = None,
|
||||
shape: Shape | None = None,
|
||||
dtype: DTypeLikeFloat = float) -> Array:
|
||||
r"""Sample truncated standard normal random values with given shape and dtype.
|
||||
|
||||
@ -841,14 +838,14 @@ def truncated_normal(key: KeyArrayLike,
|
||||
``shape`` is not None, or else by broadcasting ``lower`` and ``upper``.
|
||||
Returns values in the open interval ``(lower, upper)``.
|
||||
"""
|
||||
if shape is not None:
|
||||
shape = core.canonicalize_shape(shape)
|
||||
key, _ = _check_prng_key("truncated_normal", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
if not dtypes.issubdtype(dtype, np.floating):
|
||||
raise ValueError(f"dtype argument to `truncated_normal` must be a float "
|
||||
f"dtype, got {dtype}")
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if shape is not None:
|
||||
shape = core.as_named_shape(shape)
|
||||
return _truncated_normal(key, lower, upper, shape, dtype)
|
||||
|
||||
@partial(jit, static_argnums=(3, 4))
|
||||
@ -877,7 +874,7 @@ def _truncated_normal(key, lower, upper, shape, dtype) -> Array:
|
||||
|
||||
def bernoulli(key: KeyArrayLike,
|
||||
p: RealArray = np.float32(0.5),
|
||||
shape: Shape | NamedShape | None = None) -> Array:
|
||||
shape: Shape | None = None) -> Array:
|
||||
r"""Sample Bernoulli random values with given shape and mean.
|
||||
|
||||
The values are distributed according to the probability mass function:
|
||||
@ -899,10 +896,10 @@ def bernoulli(key: KeyArrayLike,
|
||||
A random array with boolean dtype and shape given by ``shape`` if ``shape``
|
||||
is not None, or else ``p.shape``.
|
||||
"""
|
||||
if shape is not None:
|
||||
shape = core.canonicalize_shape(shape)
|
||||
key, _ = _check_prng_key("bernoulli", key)
|
||||
dtype = dtypes.canonicalize_dtype(lax.dtype(p))
|
||||
if shape is not None:
|
||||
shape = core.as_named_shape(shape)
|
||||
if not jnp.issubdtype(dtype, np.floating):
|
||||
msg = "bernoulli probability `p` must have a floating dtype, got {}."
|
||||
raise TypeError(msg.format(dtype))
|
||||
@ -1559,7 +1556,7 @@ def categorical(key: KeyArrayLike,
|
||||
if shape is None:
|
||||
shape = batch_shape
|
||||
else:
|
||||
shape = tuple(shape)
|
||||
shape = core.canonicalize_shape(shape)
|
||||
_check_shape("categorical", shape, batch_shape)
|
||||
|
||||
shape_prefix = shape[:len(shape)-len(batch_shape)]
|
||||
@ -2053,6 +2050,7 @@ def orthogonal(
|
||||
Returns:
|
||||
A random array of shape `(*shape, n, n)` and specified dtype.
|
||||
"""
|
||||
shape = core.canonicalize_shape(shape)
|
||||
key, _ = _check_prng_key("orthogonal", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
_check_shape("orthogonal", shape)
|
||||
@ -2088,6 +2086,7 @@ def generalized_normal(
|
||||
Returns:
|
||||
A random array with the specified shape and dtype.
|
||||
"""
|
||||
shape = core.canonicalize_shape(shape)
|
||||
key, _ = _check_prng_key("generalized_normal", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
_check_shape("generalized_normal", shape)
|
||||
@ -2118,6 +2117,7 @@ def ball(
|
||||
Returns:
|
||||
A random array of shape `(*shape, d)` and specified dtype.
|
||||
"""
|
||||
shape = core.canonicalize_shape(shape)
|
||||
key, _ = _check_prng_key("ball", key)
|
||||
dtypes.check_user_dtype_supported(dtype)
|
||||
_check_shape("ball", shape)
|
||||
|
@ -218,11 +218,9 @@ def get_ref_state_effects(
|
||||
if isinstance(eff, (ReadEffect, WriteEffect, AccumEffect))
|
||||
and eff.input_index == i} for i, _ in enumerate(avals)]
|
||||
|
||||
def shaped_array_ref(shape: tuple[int, ...], dtype,
|
||||
weak_type: bool = False,
|
||||
named_shape = None) -> AbstractRef:
|
||||
return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type,
|
||||
named_shape=named_shape))
|
||||
def shaped_array_ref(
|
||||
shape: tuple[int, ...], dtype, weak_type: bool = False) -> AbstractRef:
|
||||
return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type))
|
||||
|
||||
def _shard_ref(mesh, names, ref_aval: AbstractRef):
|
||||
del mesh
|
||||
|
@ -42,7 +42,6 @@ from jax._src.core import (
|
||||
MainTrace as MainTrace,
|
||||
MapPrimitive as MapPrimitive,
|
||||
NameGatheringSubst as NameGatheringSubst,
|
||||
NamedShape as NamedShape,
|
||||
OutDBIdx as OutDBIdx,
|
||||
OutputType as OutputType,
|
||||
ParamDict as ParamDict,
|
||||
@ -61,7 +60,6 @@ from jax._src.core import (
|
||||
Var as Var,
|
||||
abstract_token as abstract_token,
|
||||
apply_todos as apply_todos,
|
||||
as_named_shape as as_named_shape,
|
||||
aval_mapping_handlers as aval_mapping_handlers,
|
||||
axis_frame as axis_frame,
|
||||
call as call,
|
||||
@ -97,7 +95,6 @@ from jax._src.core import (
|
||||
jaxpr_uses_outfeed as jaxpr_uses_outfeed,
|
||||
jaxprs_in_params as jaxprs_in_params,
|
||||
join_effects as join_effects,
|
||||
join_named_shapes as join_named_shapes,
|
||||
lattice_join as lattice_join,
|
||||
leaked_tracer_error as leaked_tracer_error,
|
||||
literalable_types as literalable_types,
|
||||
|
@ -572,9 +572,7 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames,
|
||||
aval: core.AbstractValue,) -> core.AbstractValue:
|
||||
assert isinstance(aval, core.ShapedArray)
|
||||
return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
|
||||
for i, sz in enumerate(aval.shape)),
|
||||
named_shape={k: v for k, v in aval.named_shape.items()
|
||||
if k not in mesh.shape})
|
||||
for i, sz in enumerate(aval.shape)))
|
||||
core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array
|
||||
|
||||
# Type-checking
|
||||
|
@ -2706,6 +2706,8 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertEqual(out_shape.shape, (3,))
|
||||
|
||||
def test_eval_shape_names(self):
|
||||
raise unittest.SkipTest("named shape are deprecated")
|
||||
|
||||
def fun(x, y):
|
||||
return lax.psum(x, 'i') + y
|
||||
|
||||
@ -6571,6 +6573,7 @@ class JaxprTest(jtu.JaxTestCase):
|
||||
self.assertIn('psum', str(jaxpr))
|
||||
|
||||
def test_make_jaxpr_named(self):
|
||||
raise unittest.SkipTest("named shape are deprecated")
|
||||
def f(x):
|
||||
return x - lax.psum(x, 'i')
|
||||
|
||||
|
@ -550,32 +550,6 @@ class JaxprTypeChecks(jtu.JaxTestCase):
|
||||
aval = core.raise_to_shaped(core.get_aval(value))
|
||||
self.assertEqual(aval.weak_type, weak_type)
|
||||
|
||||
def test_lattice_join_named_shape(self):
|
||||
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
|
||||
self.assertEqual(core.lattice_join(aval1, aval1), aval1)
|
||||
|
||||
aval2 = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
|
||||
expected = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
|
||||
self.assertEqual(core.lattice_join(aval1, aval2), expected)
|
||||
|
||||
aval3 = core.ShapedArray((2, 3), np.float32, False, {'i': 5})
|
||||
self.assertRaises(TypeError, lambda: core.lattice_join(aval1, aval3))
|
||||
|
||||
def test_typecompat_named_shape(self):
|
||||
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
|
||||
aval2 = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
|
||||
self.assertTrue(core.typecompat(aval1, aval2))
|
||||
|
||||
aval3 = core.ShapedArray((2, 3), np.float32, False, {'i': 5})
|
||||
self.assertFalse(core.typecompat(aval1, aval3))
|
||||
|
||||
def test_named_shape_comparision(self):
|
||||
self.assertTrue(core.NamedShape(2, 3) == (2, 3))
|
||||
self.assertFalse(core.NamedShape(2, i=3) == (2,))
|
||||
self.assertFalse(core.NamedShape(2, i=3) == (2, 3))
|
||||
self.assertFalse(core.NamedShape(2, i=3) == None)
|
||||
self.assertFalse(core.NamedShape() == [])
|
||||
|
||||
|
||||
@jtu.with_config(jax_dynamic_shapes=True)
|
||||
class DynamicShapesTest(jtu.JaxTestCase):
|
||||
|
@ -3287,26 +3287,6 @@ class LazyConstantTest(jtu.JaxTestCase):
|
||||
expected.astype(np.complex64), lax.log1p(np.complex64(1e-5)))
|
||||
|
||||
|
||||
class LaxNamedShapeTest(jtu.JaxTestCase):
|
||||
|
||||
def test_abstract_eval(self):
|
||||
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
|
||||
out, _ = lax.sin_p.abstract_eval(aval1)
|
||||
self.assertEqual(out, aval1)
|
||||
|
||||
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10})
|
||||
aval2 = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
|
||||
expected = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
|
||||
out, _ = lax.add_p.abstract_eval(aval1, aval2)
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
def test_abstract_eval_collective(self):
|
||||
with core.extend_axis_env('i', 10, None):
|
||||
aval1 = core.ShapedArray((2, 3), np.float32, False, {'i': 10, 'j': 5})
|
||||
expected = core.ShapedArray((2, 3), np.float32, False, {'j': 5})
|
||||
(out,), _ = lax.psum_p.abstract_eval(aval1, axes=('i',), axis_index_groups=None)
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
class FooTyRules:
|
||||
# handlers
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user