mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
remove physical_avals
rule in favor of physical_element_aval
This commit is contained in:
parent
f3cecd07c7
commit
180e26dafb
@ -30,7 +30,7 @@ import types
|
||||
from typing import (Any, Callable, ClassVar, DefaultDict, Dict, FrozenSet,
|
||||
Generator, Generic, Hashable, Iterable, Iterator, List,
|
||||
NamedTuple, Optional, Sequence, Set, Tuple, Type, TypeVar,
|
||||
Union, cast)
|
||||
Union, cast, overload)
|
||||
import warnings
|
||||
from weakref import ref
|
||||
|
||||
@ -1379,6 +1379,25 @@ def concrete_or_error(force: Any, val: Any, context=""):
|
||||
def has_opaque_dtype(x: Any) -> bool:
|
||||
return dtypes.is_opaque_dtype(get_aval(x).dtype)
|
||||
|
||||
@overload
|
||||
def physical_aval(aval: ShapedArray) -> ShapedArray: ...
|
||||
@overload
|
||||
def physical_aval(aval: DShapedArray) -> DShapedArray: ...
|
||||
@overload # TODO(frostig): remove this case
|
||||
def physical_aval(aval: AbstractValue) -> AbstractValue: ...
|
||||
|
||||
def physical_aval(aval):
|
||||
aval_dtype = getattr(aval, 'dtype', None)
|
||||
if aval_dtype and dtypes.is_opaque_dtype(aval_dtype):
|
||||
ctor = type(aval)
|
||||
aval_shape = getattr(aval, 'shape', None)
|
||||
assert aval_shape is not None, (ctor, aval)
|
||||
elt_aval = aval_dtype._rules.physical_element_aval(aval_dtype)
|
||||
assert type(elt_aval) is ShapedArray
|
||||
return ctor((*aval_shape, *elt_aval.shape), elt_aval.dtype) # pytype: disable=wrong-arg-count
|
||||
else:
|
||||
return aval
|
||||
|
||||
def _short_dtype_name(dtype) -> str:
|
||||
if type(dtype) in dtypes.opaque_dtypes:
|
||||
return str(dtype)
|
||||
|
@ -143,9 +143,7 @@ def dtype_to_ir_type(dtype: Union[np.dtype, np.generic]) -> ir.Type:
|
||||
|
||||
def _array_ir_types(aval: Union[core.ShapedArray, core.DShapedArray]
|
||||
) -> Sequence[ir.Type]:
|
||||
if dtypes.is_opaque_dtype(aval.dtype):
|
||||
phys_avals = aval.dtype._rules.physical_avals(aval)
|
||||
return tuple(itertools.chain(*map(_array_ir_types, phys_avals)))
|
||||
aval = core.physical_aval(aval) # type: ignore
|
||||
if not core.is_constant_shape(aval.shape):
|
||||
return _dynamic_array_ir_types(aval) # type: ignore
|
||||
return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),)
|
||||
@ -1243,10 +1241,7 @@ def multi_broadcast_in_dim(ctx: LoweringRuleContext,
|
||||
return out
|
||||
|
||||
def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Value:
|
||||
if dtypes.is_opaque_dtype(aval_out.dtype): # type: ignore
|
||||
# TODO(frostig,mattjj,necula): asserts a single physical aval, and a
|
||||
# particular reshape rule (reshape to the output physical aval's shape)
|
||||
aval_out, = aval_out.dtype._rules.physical_avals(aval_out) # type: ignore
|
||||
aval_out = core.physical_aval(aval_out)
|
||||
if not core.is_constant_shape(aval_out.shape): # type: ignore
|
||||
shape = eval_dynamic_shape(ctx, aval_out.shape) # type: ignore
|
||||
return hlo.DynamicReshapeOp(
|
||||
@ -1704,10 +1699,7 @@ def _layout_to_mlir_layout(minor_to_major: Optional[Sequence[int]]):
|
||||
return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get())
|
||||
|
||||
def _aval_to_default_layouts(aval):
|
||||
if dtypes.is_opaque_dtype(aval.dtype):
|
||||
avals = aval.dtype._rules.physical_avals(aval)
|
||||
else:
|
||||
avals = [aval]
|
||||
avals = [core.physical_aval(aval)]
|
||||
# Row major order is default for `NumPy`.
|
||||
return [list(range(aval.ndim - 1, -1, -1)) for aval in avals]
|
||||
|
||||
|
@ -55,14 +55,9 @@ def identity(x): return x
|
||||
_scalar_types = dtypes.python_scalar_dtypes.keys()
|
||||
|
||||
def _make_array_shape(aval: ShapedArray) -> Sequence[xc.Shape]:
|
||||
def dt(aval):
|
||||
return np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype
|
||||
|
||||
if dtypes.is_opaque_dtype(aval.dtype):
|
||||
avals = aval.dtype._rules.physical_avals(aval)
|
||||
else:
|
||||
avals = [aval]
|
||||
return tuple(xc.Shape.array_shape(dt(a), a.shape) for a in avals)
|
||||
aval = core.physical_aval(aval)
|
||||
dtype = np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype
|
||||
return (xc.Shape.array_shape(dtype, aval.shape),)
|
||||
|
||||
def get_canonical_source_file(frame: source_info_util.Frame):
|
||||
source_file = frame.file_name
|
||||
|
@ -1647,10 +1647,10 @@ def _pred_bcast_select_hlo(ctx,
|
||||
assert x.type == y.type, (x.type, y.type)
|
||||
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
|
||||
pred_aval.shape, x_y_aval)
|
||||
if dtypes.is_opaque_dtype(x_y_aval.dtype):
|
||||
x_y_aval, = x_y_aval.dtype._rules.physical_avals(x_y_aval)
|
||||
bcast_pred = mlir.broadcast_in_dim(ctx, pred, core.DShapedArray(x_y_aval.shape, np.dtype(np.bool_)),
|
||||
broadcast_dimensions=list(range(len(pred_aval.shape))))
|
||||
x_y_aval = core.physical_aval(x_y_aval)
|
||||
bcast_pred = mlir.broadcast_in_dim(
|
||||
ctx, pred, core.DShapedArray(x_y_aval.shape, np.dtype(np.bool_)),
|
||||
broadcast_dimensions=list(range(len(pred_aval.shape))))
|
||||
return hlo.SelectOp(bcast_pred, x, y).results
|
||||
|
||||
### fori_loop
|
||||
|
@ -4728,9 +4728,9 @@ mlir.register_lowering(empty_p, _empty_lower)
|
||||
|
||||
class BIntRules:
|
||||
@staticmethod
|
||||
def physical_avals(aval) -> Sequence[core.AbstractValue]:
|
||||
dtype = dtypes._scalar_type_to_dtype(int)
|
||||
return [core.ShapedArray(aval.shape, dtype)]
|
||||
def physical_element_aval(dtype) -> core.ShapedArray:
|
||||
int_dtype = dtypes._scalar_type_to_dtype(int)
|
||||
return core.ShapedArray((), int_dtype)
|
||||
|
||||
@staticmethod
|
||||
def result_handler(sticky_device, aval):
|
||||
@ -4742,7 +4742,7 @@ class BIntRules:
|
||||
@staticmethod
|
||||
def global_sharded_result_handler(aval, out_sharding, committed,
|
||||
is_out_sharding_from_xla):
|
||||
phys_aval, = BIntRules.physical_avals(aval)
|
||||
phys_aval = core.physical_aval(aval)
|
||||
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
|
||||
|
||||
if not dispatch.is_single_device_sharding(out_sharding):
|
||||
|
@ -372,9 +372,9 @@ def seed_with_impl(impl: PRNGImpl, seed: Union[int, Array]) -> PRNGKeyArrayImpl:
|
||||
def keys_shaped_array(impl, shape):
|
||||
return core.ShapedArray(shape, KeyTy(impl))
|
||||
|
||||
# TODO(frostig): remove in favor of physical_aval call
|
||||
def keys_aval_to_base_arr_aval(keys_aval):
|
||||
shape = (*keys_aval.shape, *keys_aval.dtype.impl.key_shape)
|
||||
return core.ShapedArray(shape, np.dtype('uint32'))
|
||||
return core.physical_aval(keys_aval)
|
||||
|
||||
def base_arr_shape_to_keys_shape(impl, base_arr_shape):
|
||||
base_ndim = len(impl.key_shape)
|
||||
@ -419,10 +419,8 @@ class KeyTyRules:
|
||||
return random_wrap(key_data, impl=dtype.impl)
|
||||
|
||||
@staticmethod
|
||||
def physical_avals(aval) -> Sequence[core.AbstractValue]: # TODO(frostig): rename to `grounded_avals`
|
||||
# TODO(frostig): dedup with `keys_aval_to_base_arr_aval``
|
||||
return [core.ShapedArray((*aval.shape, *aval.dtype.impl.key_shape), # type: ignore
|
||||
jnp.dtype('uint32'))]
|
||||
def physical_element_aval(dtype) -> core.ShapedArray:
|
||||
return core.ShapedArray(dtype.impl.key_shape, jnp.dtype('uint32'))
|
||||
|
||||
@staticmethod
|
||||
def physical_const(val) -> Array:
|
||||
@ -472,7 +470,7 @@ class KeyTyRules:
|
||||
|
||||
@staticmethod
|
||||
def local_sharded_result_handler(aval, sharding, indices):
|
||||
phys_aval, = KeyTyRules.physical_avals(aval)
|
||||
phys_aval = core.physical_aval(aval)
|
||||
key_shape = aval.dtype.impl.key_shape
|
||||
phys_handler_maker = pxla.local_result_handlers[core.ShapedArray]
|
||||
|
||||
@ -499,7 +497,7 @@ class KeyTyRules:
|
||||
@staticmethod
|
||||
def global_sharded_result_handler(aval, out_sharding, committed,
|
||||
is_out_sharding_from_xla):
|
||||
phys_aval, = KeyTyRules.physical_avals(aval)
|
||||
phys_aval = core.physical_aval(aval)
|
||||
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
|
||||
|
||||
phys_sharding = make_key_array_phys_sharding(
|
||||
@ -512,7 +510,7 @@ class KeyTyRules:
|
||||
|
||||
@staticmethod
|
||||
def make_sharded_array(aval, sharding, arrays, committed):
|
||||
phys_aval, = KeyTyRules.physical_avals(aval)
|
||||
phys_aval = core.physical_aval(aval)
|
||||
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
|
||||
phys_arrays = [random_unwrap(arr) for arr in arrays]
|
||||
|
||||
@ -536,7 +534,7 @@ class KeyTyRules:
|
||||
start_indices = (*start_indices, *trailing_zeros)
|
||||
limit_indices = (*limit_indices, *key_shape)
|
||||
strides = (*strides, *trailing_ones)
|
||||
physical_aval_out, = KeyTyRules.physical_avals(aval_out)
|
||||
physical_aval_out = core.physical_aval(aval_out)
|
||||
return mlir.slice_op(ctx, x, physical_aval_out,
|
||||
start_indices=start_indices, limit_indices=limit_indices, strides=strides)
|
||||
|
||||
@ -547,7 +545,7 @@ class KeyTyRules:
|
||||
key_shape = aval_out.dtype.impl.key_shape
|
||||
trailing_zeros = [mlir.ir_constant(np.array(0, dtype))] * len(key_shape)
|
||||
start_indices = (*start_indices, *trailing_zeros)
|
||||
physical_aval_out, = KeyTyRules.physical_avals(aval_out)
|
||||
physical_aval_out = core.physical_aval(aval_out)
|
||||
return mlir.dynamic_slice(ctx, physical_aval_out, x,
|
||||
start_indices=start_indices)
|
||||
|
||||
@ -558,7 +556,7 @@ class KeyTyRules:
|
||||
key_shape = aval_out.dtype.impl.key_shape
|
||||
zeros = [mlir.ir_constant(np.array(0, dtype=dtype))] * len(key_shape)
|
||||
start_indices = (*start_indices, *zeros)
|
||||
physical_aval_out, = KeyTyRules.physical_avals(aval_out)
|
||||
physical_aval_out = core.physical_aval(aval_out)
|
||||
return mlir.dynamic_update_slice(ctx, physical_aval_out, x, update,
|
||||
start_indices=start_indices)
|
||||
|
||||
@ -568,7 +566,7 @@ class KeyTyRules:
|
||||
key_shape = aval_out.dtype.impl.key_shape
|
||||
trailing_dims = [aval_out.ndim + i for i in range(len(key_shape))]
|
||||
broadcast_dimensions = [*broadcast_dimensions, *trailing_dims]
|
||||
physical_aval_out, = KeyTyRules.physical_avals(aval_out)
|
||||
physical_aval_out = core.physical_aval(aval_out)
|
||||
return mlir.broadcast_in_dim(ctx, x, physical_aval_out, broadcast_dimensions=broadcast_dimensions)
|
||||
|
||||
@staticmethod
|
||||
|
@ -990,12 +990,10 @@ def _jax_physical_aval(aval: core.ShapedArray) -> core.ShapedArray:
|
||||
physical avals, but we don't support those here. Instead we assert
|
||||
there is only one and return it.
|
||||
"""
|
||||
if dtypes.is_opaque_dtype(aval.dtype):
|
||||
physical_aval, = aval.dtype._rules.physical_avals(aval)
|
||||
assert (len(physical_aval.shape) >= len(aval.shape) and
|
||||
physical_aval.shape[:len(aval.shape)] == aval.shape), (physical_aval, aval)
|
||||
return physical_aval
|
||||
return aval
|
||||
physical_aval = core.physical_aval(aval)
|
||||
assert (len(physical_aval.shape) >= len(aval.shape) and
|
||||
physical_aval.shape[:len(aval.shape)] == aval.shape), (physical_aval, aval)
|
||||
return physical_aval
|
||||
|
||||
def _jax_physical_dtype(dtype):
|
||||
# assuming () is a fine stand-in shape
|
||||
|
@ -2838,8 +2838,8 @@ class FooTyRules:
|
||||
# handlers
|
||||
|
||||
@staticmethod
|
||||
def physical_avals(aval):
|
||||
return [core.ShapedArray((*aval.shape, 2), jnp.dtype('uint32'))]
|
||||
def physical_element_aval(dtype) -> core.ShapedArray:
|
||||
return core.ShapedArray((2,), jnp.dtype('uint32'))
|
||||
|
||||
@staticmethod
|
||||
def physical_op_sharding(aval, op_sharding_proto):
|
||||
|
Loading…
x
Reference in New Issue
Block a user