remove physical_avals rule in favor of physical_element_aval

This commit is contained in:
Roy Frostig 2023-05-10 19:13:29 -07:00
parent f3cecd07c7
commit 180e26dafb
8 changed files with 51 additions and 49 deletions

View File

@ -30,7 +30,7 @@ import types
from typing import (Any, Callable, ClassVar, DefaultDict, Dict, FrozenSet,
Generator, Generic, Hashable, Iterable, Iterator, List,
NamedTuple, Optional, Sequence, Set, Tuple, Type, TypeVar,
Union, cast)
Union, cast, overload)
import warnings
from weakref import ref
@ -1379,6 +1379,25 @@ def concrete_or_error(force: Any, val: Any, context=""):
def has_opaque_dtype(x: Any) -> bool:
return dtypes.is_opaque_dtype(get_aval(x).dtype)
@overload
def physical_aval(aval: ShapedArray) -> ShapedArray: ...
@overload
def physical_aval(aval: DShapedArray) -> DShapedArray: ...
@overload # TODO(frostig): remove this case
def physical_aval(aval: AbstractValue) -> AbstractValue: ...
def physical_aval(aval):
aval_dtype = getattr(aval, 'dtype', None)
if aval_dtype and dtypes.is_opaque_dtype(aval_dtype):
ctor = type(aval)
aval_shape = getattr(aval, 'shape', None)
assert aval_shape is not None, (ctor, aval)
elt_aval = aval_dtype._rules.physical_element_aval(aval_dtype)
assert type(elt_aval) is ShapedArray
return ctor((*aval_shape, *elt_aval.shape), elt_aval.dtype) # pytype: disable=wrong-arg-count
else:
return aval
def _short_dtype_name(dtype) -> str:
if type(dtype) in dtypes.opaque_dtypes:
return str(dtype)

View File

@ -143,9 +143,7 @@ def dtype_to_ir_type(dtype: Union[np.dtype, np.generic]) -> ir.Type:
def _array_ir_types(aval: Union[core.ShapedArray, core.DShapedArray]
) -> Sequence[ir.Type]:
if dtypes.is_opaque_dtype(aval.dtype):
phys_avals = aval.dtype._rules.physical_avals(aval)
return tuple(itertools.chain(*map(_array_ir_types, phys_avals)))
aval = core.physical_aval(aval) # type: ignore
if not core.is_constant_shape(aval.shape):
return _dynamic_array_ir_types(aval) # type: ignore
return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),)
@ -1243,10 +1241,7 @@ def multi_broadcast_in_dim(ctx: LoweringRuleContext,
return out
def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Value:
if dtypes.is_opaque_dtype(aval_out.dtype): # type: ignore
# TODO(frostig,mattjj,necula): asserts a single physical aval, and a
# particular reshape rule (reshape to the output physical aval's shape)
aval_out, = aval_out.dtype._rules.physical_avals(aval_out) # type: ignore
aval_out = core.physical_aval(aval_out)
if not core.is_constant_shape(aval_out.shape): # type: ignore
shape = eval_dynamic_shape(ctx, aval_out.shape) # type: ignore
return hlo.DynamicReshapeOp(
@ -1704,10 +1699,7 @@ def _layout_to_mlir_layout(minor_to_major: Optional[Sequence[int]]):
return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get())
def _aval_to_default_layouts(aval):
if dtypes.is_opaque_dtype(aval.dtype):
avals = aval.dtype._rules.physical_avals(aval)
else:
avals = [aval]
avals = [core.physical_aval(aval)]
# Row major order is default for `NumPy`.
return [list(range(aval.ndim - 1, -1, -1)) for aval in avals]

View File

@ -55,14 +55,9 @@ def identity(x): return x
_scalar_types = dtypes.python_scalar_dtypes.keys()
def _make_array_shape(aval: ShapedArray) -> Sequence[xc.Shape]:
def dt(aval):
return np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype
if dtypes.is_opaque_dtype(aval.dtype):
avals = aval.dtype._rules.physical_avals(aval)
else:
avals = [aval]
return tuple(xc.Shape.array_shape(dt(a), a.shape) for a in avals)
aval = core.physical_aval(aval)
dtype = np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype
return (xc.Shape.array_shape(dtype, aval.shape),)
def get_canonical_source_file(frame: source_info_util.Frame):
source_file = frame.file_name

View File

@ -1647,10 +1647,10 @@ def _pred_bcast_select_hlo(ctx,
assert x.type == y.type, (x.type, y.type)
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
pred_aval.shape, x_y_aval)
if dtypes.is_opaque_dtype(x_y_aval.dtype):
x_y_aval, = x_y_aval.dtype._rules.physical_avals(x_y_aval)
bcast_pred = mlir.broadcast_in_dim(ctx, pred, core.DShapedArray(x_y_aval.shape, np.dtype(np.bool_)),
broadcast_dimensions=list(range(len(pred_aval.shape))))
x_y_aval = core.physical_aval(x_y_aval)
bcast_pred = mlir.broadcast_in_dim(
ctx, pred, core.DShapedArray(x_y_aval.shape, np.dtype(np.bool_)),
broadcast_dimensions=list(range(len(pred_aval.shape))))
return hlo.SelectOp(bcast_pred, x, y).results
### fori_loop

View File

@ -4728,9 +4728,9 @@ mlir.register_lowering(empty_p, _empty_lower)
class BIntRules:
@staticmethod
def physical_avals(aval) -> Sequence[core.AbstractValue]:
dtype = dtypes._scalar_type_to_dtype(int)
return [core.ShapedArray(aval.shape, dtype)]
def physical_element_aval(dtype) -> core.ShapedArray:
int_dtype = dtypes._scalar_type_to_dtype(int)
return core.ShapedArray((), int_dtype)
@staticmethod
def result_handler(sticky_device, aval):
@ -4742,7 +4742,7 @@ class BIntRules:
@staticmethod
def global_sharded_result_handler(aval, out_sharding, committed,
is_out_sharding_from_xla):
phys_aval, = BIntRules.physical_avals(aval)
phys_aval = core.physical_aval(aval)
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
if not dispatch.is_single_device_sharding(out_sharding):

View File

@ -372,9 +372,9 @@ def seed_with_impl(impl: PRNGImpl, seed: Union[int, Array]) -> PRNGKeyArrayImpl:
def keys_shaped_array(impl, shape):
return core.ShapedArray(shape, KeyTy(impl))
# TODO(frostig): remove in favor of physical_aval call
def keys_aval_to_base_arr_aval(keys_aval):
shape = (*keys_aval.shape, *keys_aval.dtype.impl.key_shape)
return core.ShapedArray(shape, np.dtype('uint32'))
return core.physical_aval(keys_aval)
def base_arr_shape_to_keys_shape(impl, base_arr_shape):
base_ndim = len(impl.key_shape)
@ -419,10 +419,8 @@ class KeyTyRules:
return random_wrap(key_data, impl=dtype.impl)
@staticmethod
def physical_avals(aval) -> Sequence[core.AbstractValue]: # TODO(frostig): rename to `grounded_avals`
# TODO(frostig): dedup with `keys_aval_to_base_arr_aval``
return [core.ShapedArray((*aval.shape, *aval.dtype.impl.key_shape), # type: ignore
jnp.dtype('uint32'))]
def physical_element_aval(dtype) -> core.ShapedArray:
return core.ShapedArray(dtype.impl.key_shape, jnp.dtype('uint32'))
@staticmethod
def physical_const(val) -> Array:
@ -472,7 +470,7 @@ class KeyTyRules:
@staticmethod
def local_sharded_result_handler(aval, sharding, indices):
phys_aval, = KeyTyRules.physical_avals(aval)
phys_aval = core.physical_aval(aval)
key_shape = aval.dtype.impl.key_shape
phys_handler_maker = pxla.local_result_handlers[core.ShapedArray]
@ -499,7 +497,7 @@ class KeyTyRules:
@staticmethod
def global_sharded_result_handler(aval, out_sharding, committed,
is_out_sharding_from_xla):
phys_aval, = KeyTyRules.physical_avals(aval)
phys_aval = core.physical_aval(aval)
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
phys_sharding = make_key_array_phys_sharding(
@ -512,7 +510,7 @@ class KeyTyRules:
@staticmethod
def make_sharded_array(aval, sharding, arrays, committed):
phys_aval, = KeyTyRules.physical_avals(aval)
phys_aval = core.physical_aval(aval)
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
phys_arrays = [random_unwrap(arr) for arr in arrays]
@ -536,7 +534,7 @@ class KeyTyRules:
start_indices = (*start_indices, *trailing_zeros)
limit_indices = (*limit_indices, *key_shape)
strides = (*strides, *trailing_ones)
physical_aval_out, = KeyTyRules.physical_avals(aval_out)
physical_aval_out = core.physical_aval(aval_out)
return mlir.slice_op(ctx, x, physical_aval_out,
start_indices=start_indices, limit_indices=limit_indices, strides=strides)
@ -547,7 +545,7 @@ class KeyTyRules:
key_shape = aval_out.dtype.impl.key_shape
trailing_zeros = [mlir.ir_constant(np.array(0, dtype))] * len(key_shape)
start_indices = (*start_indices, *trailing_zeros)
physical_aval_out, = KeyTyRules.physical_avals(aval_out)
physical_aval_out = core.physical_aval(aval_out)
return mlir.dynamic_slice(ctx, physical_aval_out, x,
start_indices=start_indices)
@ -558,7 +556,7 @@ class KeyTyRules:
key_shape = aval_out.dtype.impl.key_shape
zeros = [mlir.ir_constant(np.array(0, dtype=dtype))] * len(key_shape)
start_indices = (*start_indices, *zeros)
physical_aval_out, = KeyTyRules.physical_avals(aval_out)
physical_aval_out = core.physical_aval(aval_out)
return mlir.dynamic_update_slice(ctx, physical_aval_out, x, update,
start_indices=start_indices)
@ -568,7 +566,7 @@ class KeyTyRules:
key_shape = aval_out.dtype.impl.key_shape
trailing_dims = [aval_out.ndim + i for i in range(len(key_shape))]
broadcast_dimensions = [*broadcast_dimensions, *trailing_dims]
physical_aval_out, = KeyTyRules.physical_avals(aval_out)
physical_aval_out = core.physical_aval(aval_out)
return mlir.broadcast_in_dim(ctx, x, physical_aval_out, broadcast_dimensions=broadcast_dimensions)
@staticmethod

View File

@ -990,12 +990,10 @@ def _jax_physical_aval(aval: core.ShapedArray) -> core.ShapedArray:
physical avals, but we don't support those here. Instead we assert
there is only one and return it.
"""
if dtypes.is_opaque_dtype(aval.dtype):
physical_aval, = aval.dtype._rules.physical_avals(aval)
assert (len(physical_aval.shape) >= len(aval.shape) and
physical_aval.shape[:len(aval.shape)] == aval.shape), (physical_aval, aval)
return physical_aval
return aval
physical_aval = core.physical_aval(aval)
assert (len(physical_aval.shape) >= len(aval.shape) and
physical_aval.shape[:len(aval.shape)] == aval.shape), (physical_aval, aval)
return physical_aval
def _jax_physical_dtype(dtype):
# assuming () is a fine stand-in shape

View File

@ -2838,8 +2838,8 @@ class FooTyRules:
# handlers
@staticmethod
def physical_avals(aval):
return [core.ShapedArray((*aval.shape, 2), jnp.dtype('uint32'))]
def physical_element_aval(dtype) -> core.ShapedArray:
return core.ShapedArray((2,), jnp.dtype('uint32'))
@staticmethod
def physical_op_sharding(aval, op_sharding_proto):