mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #11387 from mattjj:djax-bint
PiperOrigin-RevId: 459430960
This commit is contained in:
commit
5270cb1c1f
@ -299,6 +299,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
if config.jax_dynamic_shapes:
|
||||
keep_unused = True
|
||||
has_outfeed = False
|
||||
donated_invars = [False] * len(fun.in_type)
|
||||
else:
|
||||
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
|
||||
jaxpr = apply_outfeed_rewriter(jaxpr)
|
||||
@ -318,8 +319,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
|
||||
device = _xla_callable_device(nreps, backend, device, arg_devices)
|
||||
backend = xb.get_device_backend(device) if device else xb.get_backend(backend)
|
||||
|
||||
if (config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr) and
|
||||
not _backend_supports_unbounded_dynamic_shapes(backend)):
|
||||
if config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr):
|
||||
jaxpr, consts = pe.pad_jaxpr(jaxpr, consts)
|
||||
|
||||
map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr)))
|
||||
@ -520,6 +520,7 @@ num_buffers_handlers[core.AbstractToken] = lambda _: 1
|
||||
num_buffers_handlers[core.ShapedArray] = lambda _: 1
|
||||
num_buffers_handlers[core.DShapedArray] = lambda _: 1
|
||||
num_buffers_handlers[core.ConcreteArray] = lambda _: 1
|
||||
num_buffers_handlers[core.AbstractBInt] = lambda _: 1
|
||||
|
||||
|
||||
def _input_handler(backend: Backend,
|
||||
@ -652,17 +653,22 @@ def dynamic_array_result_handler(sticky_device: Optional[Device],
|
||||
return partial(_dynamic_array_result_handler, sticky_device, aval)
|
||||
|
||||
def _dynamic_array_result_handler(sticky_device, aval, env, buf):
|
||||
if all(type(d) is int for d in aval.shape):
|
||||
del env
|
||||
return _maybe_create_array_from_da(buf, aval, sticky_device)
|
||||
else:
|
||||
assert env is not None
|
||||
in_env, out_env = env
|
||||
shape = [in_env[d.val] if type(d) is core.InDBIdx else
|
||||
out_env[d.val] if type(d) is core.OutDBIdx else d
|
||||
for d in aval.shape]
|
||||
in_env, out_env = env or (None, None)
|
||||
shape = [in_env[d.val] if type(d) is core.InDBIdx else
|
||||
out_env[d.val] if type(d) is core.OutDBIdx else d
|
||||
for d in aval.shape]
|
||||
if all(type(d) is int for d in shape):
|
||||
aval = core.ShapedArray(tuple(shape), aval.dtype)
|
||||
return _maybe_create_array_from_da(buf, aval, sticky_device)
|
||||
elif any(type(d) is core.BInt for d in shape):
|
||||
padded_shape = [d.bound if type(d) is core.BInt else d for d in shape]
|
||||
buf_aval = core.ShapedArray(tuple(padded_shape), aval.dtype, aval.weak_type)
|
||||
data = _maybe_create_array_from_da(buf, buf_aval, sticky_device)
|
||||
return core.PaddedArray(aval.update(shape=tuple(shape)), data)
|
||||
else:
|
||||
aval = core.ShapedArray(tuple(shape), aval.dtype)
|
||||
return _maybe_create_array_from_da(buf, aval, sticky_device)
|
||||
|
||||
|
||||
|
||||
result_handlers: Dict[
|
||||
@ -672,6 +678,8 @@ result_handlers[core.AbstractToken] = lambda _, __: lambda _, __: core.token
|
||||
result_handlers[core.ShapedArray] = array_result_handler
|
||||
result_handlers[core.DShapedArray] = dynamic_array_result_handler
|
||||
result_handlers[core.ConcreteArray] = array_result_handler
|
||||
result_handlers[core.AbstractBInt] = \
|
||||
lambda _, a: lambda _, b: core.BInt(int(b), a.bound)
|
||||
|
||||
|
||||
def needs_check_special():
|
||||
@ -1014,6 +1022,7 @@ device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Tuple[Any]]] =
|
||||
device_put_handlers.update((t, _device_put_array) for t in array_types)
|
||||
device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types)
|
||||
device_put_handlers[core.Token] = _device_put_token
|
||||
device_put_handlers[core.BInt] = lambda x, d: _device_put_scalar(x.val, d)
|
||||
|
||||
|
||||
def _device_put_device_array(x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray], device: Optional[Device]):
|
||||
@ -1021,6 +1030,7 @@ def _device_put_device_array(x: Union[device_array.DeviceArrayProtocol, device_a
|
||||
return (x.device_buffer,)
|
||||
for t in device_array.device_array_types:
|
||||
device_put_handlers[t] = _device_put_device_array
|
||||
device_put_handlers[core.PaddedArray] = lambda x, d: device_put(x._data, d)
|
||||
|
||||
def _copy_device_array_to_device(
|
||||
x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray],
|
||||
|
@ -104,8 +104,8 @@ class IreeBuffer(xla_client.DeviceArrayBase):
|
||||
return self # no async
|
||||
|
||||
# overrides repr on base class which expects _value and aval attributes
|
||||
def __repr__(self):
|
||||
return f'IreeBuffer({self.to_py()})'
|
||||
def __repr__(self): return f'IreeBuffer({self.to_py()})'
|
||||
_value = property(to_py)
|
||||
|
||||
class IreeExecutable:
|
||||
|
||||
|
@ -1440,6 +1440,7 @@ def unop(result_dtype, accepted_dtypes, name):
|
||||
weak_type_rule=weak_type_rule)
|
||||
batching.defvectorized(prim)
|
||||
masking.defvectorized(prim)
|
||||
pe.padding_rules[prim] = lambda _, __, x, **kw: [prim.bind(x, **kw)]
|
||||
return prim
|
||||
standard_unop = partial(unop, _identity)
|
||||
_attrgetter = lambda name: lambda x, **kwargs: getattr(x, name)
|
||||
@ -1515,6 +1516,7 @@ def naryop(result_dtype, accepted_dtypes, name):
|
||||
weak_type_rule=weak_type_rule)
|
||||
batching.defbroadcasting(prim)
|
||||
masking.defnaryop(prim)
|
||||
pe.padding_rules[prim] = lambda _, __, *xs, **kw: [prim.bind(*xs, **kw)]
|
||||
return prim
|
||||
standard_naryop = partial(naryop, _input_dtype)
|
||||
|
||||
@ -2080,7 +2082,6 @@ add_p: Primitive = standard_naryop([_num, _num], 'add')
|
||||
ad.primitive_jvps[add_p] = _add_jvp
|
||||
ad.primitive_transposes[add_p] = _add_transpose
|
||||
mlir.register_lowering(add_p, partial(_nary_lower_mhlo, mhlo.AddOp))
|
||||
pe.padding_rules[add_p] = lambda _, __, x, y: [add(x, y)]
|
||||
|
||||
def _sub_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
@ -2110,7 +2111,6 @@ sub_p = standard_naryop([_num, _num], 'sub')
|
||||
ad.primitive_jvps[sub_p] = _sub_jvp
|
||||
ad.primitive_transposes[sub_p] = _sub_transpose
|
||||
mlir.register_lowering(sub_p, partial(_nary_lower_mhlo, mhlo.SubOp))
|
||||
pe.padding_rules[sub_p] = lambda _, __, x, y: [sub(x, y)]
|
||||
|
||||
|
||||
def _mul_transpose(ct, x, y):
|
||||
@ -2137,7 +2137,6 @@ ad.defjvp(mul_p,
|
||||
lambda ydot, x, y: mul(x, ydot))
|
||||
ad.primitive_transposes[mul_p] = _mul_transpose
|
||||
mlir.register_lowering(mul_p, partial(_nary_lower_mhlo, mhlo.MulOp))
|
||||
pe.padding_rules[mul_p] = lambda _, __, x, y: [mul(x, y)]
|
||||
|
||||
def _div_transpose_rule(cotangent, x, y):
|
||||
assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y)
|
||||
@ -2174,7 +2173,6 @@ ad.defjvp2(max_p,
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
|
||||
mlir.register_lowering(max_p, partial(_nary_lower_mhlo, mlir.max_mhlo))
|
||||
pe.padding_rules[max_p] = lambda _, __, x, y: [max(x, y)]
|
||||
|
||||
min_p: core.Primitive = standard_naryop([_any, _any], 'min')
|
||||
ad.defjvp2(min_p,
|
||||
@ -2297,9 +2295,13 @@ def _convert_elt_type_pp_rule(eqn, context, settings):
|
||||
printed_params = {}
|
||||
if eqn.params['weak_type']:
|
||||
printed_params['weak_type'] = True
|
||||
return [pp.text(eqn.primitive.name),
|
||||
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
|
||||
pp.text(" ") + core.pp_vars(eqn.invars, context)]
|
||||
lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
|
||||
rhs = [pp.text(eqn.primitive.name),
|
||||
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
|
||||
pp.text(" ") + core.pp_vars(eqn.invars, context)]
|
||||
annotation = (source_info_util.summarize(eqn.source_info)
|
||||
if settings.source_info else None)
|
||||
return [lhs, pp.text(" = ", annotation=annotation), *rhs]
|
||||
|
||||
|
||||
convert_element_type_p = Primitive('convert_element_type')
|
||||
@ -2756,7 +2758,7 @@ def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape,
|
||||
assert isinstance(d, core.Tracer)
|
||||
new_shape.append(None)
|
||||
new_dyn_shape.append(d)
|
||||
return [broadcast_in_dim_p.bind(x, *new_dyn_shape, shape=new_shape,
|
||||
return [broadcast_in_dim_p.bind(x, *new_dyn_shape, shape=tuple(new_shape),
|
||||
broadcast_dimensions=broadcast_dimensions)]
|
||||
|
||||
def _broadcast_in_dim_jvp_rule(primals, tangents, *, shape, broadcast_dimensions):
|
||||
@ -2820,8 +2822,18 @@ def _broadcast_in_dim_pp_rule(eqn, context, settings):
|
||||
if settings.source_info else None)
|
||||
return [lhs, pp.text(" = ", annotation=annotation), *rhs]
|
||||
|
||||
def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions):
|
||||
if not any(isinstance(d, core.BInt) for d in shape):
|
||||
shape = _broadcast_in_dim_shape_rule( # error checking
|
||||
x, shape=shape, broadcast_dimensions=broadcast_dimensions)
|
||||
return core.ShapedArray(shape, x.dtype, x.weak_type, x.named_shape)
|
||||
# If any BInts in shape, produce a DShapedArray (even if x is a ShapedArray)
|
||||
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
|
||||
return core.DShapedArray(shape, x.dtype, x.weak_type)
|
||||
|
||||
broadcast_in_dim_p = standard_primitive(
|
||||
_broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
|
||||
broadcast_in_dim_p.def_abstract_eval(_broadcast_in_dim_abstract_eval)
|
||||
ad.primitive_jvps[broadcast_in_dim_p] = _broadcast_in_dim_jvp_rule
|
||||
ad.primitive_transposes[broadcast_in_dim_p] = _broadcast_in_dim_transpose_rule
|
||||
batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
|
||||
@ -3605,8 +3617,8 @@ def _reduce_sum_padding_rule(in_avals, out_avals, operand, *, axes):
|
||||
|
||||
def _replace_masked_values(x, val, padded_axes):
|
||||
if not padded_axes: return x
|
||||
masks = [broadcasted_iota(np.dtype('int32'), x.shape, i) < d
|
||||
for i, d in padded_axes]
|
||||
dtype = dtypes._scalar_type_to_dtype(int)
|
||||
masks = [broadcasted_iota(dtype, x.shape, i) < d for i, d in padded_axes]
|
||||
return select(_reduce(operator.and_, masks), x, full_like(x, val))
|
||||
|
||||
|
||||
@ -4388,7 +4400,10 @@ def _iota_abstract_eval(*, dtype, shape, dimension):
|
||||
if not 0 <= dimension < len(shape):
|
||||
raise ValueError("iota dimension must be between 0 and len(shape), got "
|
||||
f"dimension={dimension} for shape {shape}")
|
||||
return ShapedArray(shape, dtype)
|
||||
if not any(isinstance(d, core.BInt) for d in shape):
|
||||
return ShapedArray(shape, dtype)
|
||||
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
|
||||
return core.DShapedArray(shape, dtype, False)
|
||||
|
||||
iota_p = Primitive('iota')
|
||||
iota_p.def_impl(partial(xla.apply_primitive, iota_p))
|
||||
@ -4429,12 +4444,46 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension):
|
||||
mlir.i64_attr(dimension)).results
|
||||
mlir.register_lowering(iota_p, _iota_lower)
|
||||
|
||||
def _iota_pp_rule(eqn, context, settings):
|
||||
printed_params = {}
|
||||
if len(eqn.params['shape']) > 1:
|
||||
printed_params['dimension'] = eqn.params['dimension']
|
||||
lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
|
||||
rhs = [pp.text(eqn.primitive.name),
|
||||
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
|
||||
pp.text(" ") + core.pp_vars(eqn.invars, context)]
|
||||
annotation = (source_info_util.summarize(eqn.source_info)
|
||||
if settings.source_info else None)
|
||||
return [lhs, pp.text(" = ", annotation=annotation), *rhs]
|
||||
# core.pp_eqn_rules[iota_p] = _iota_pp_rule
|
||||
|
||||
def _iota_padding_rule(in_avals, out_avals, *dyn_shape, dtype, shape, dimension):
|
||||
out_aval, = out_avals
|
||||
new_shape = []
|
||||
new_dyn_shape = []
|
||||
for d in out_aval.shape:
|
||||
if type(d) is pe.BoundedAxisSize:
|
||||
new_shape.append(d.bound)
|
||||
elif type(d) is int:
|
||||
new_shape.append(d)
|
||||
else:
|
||||
assert isinstance(d, core.Tracer)
|
||||
new_shape.append(None)
|
||||
new_dyn_shape.append(d)
|
||||
return [iota_p.bind(*new_dyn_shape, shape=tuple(new_shape),
|
||||
dtype=dtype, dimension=dimension)]
|
||||
pe.padding_rules[iota_p] = _iota_padding_rule
|
||||
|
||||
|
||||
def make_bint(i, bd: int):
|
||||
return bint_p.bind(i, bd=bd)
|
||||
|
||||
bint_p = core.Primitive('bint')
|
||||
|
||||
@bint_p.def_impl
|
||||
def _bint_impl(i, *, bd):
|
||||
return core.BInt(i, bd)
|
||||
|
||||
@bint_p.def_abstract_eval
|
||||
def bint_abstract_eval(_, *, bd: int):
|
||||
return core.AbstractBInt(bound=bd)
|
||||
@ -4570,7 +4619,7 @@ def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False):
|
||||
if not len(obj): # pylint: disable=g-explicit-length-test
|
||||
return
|
||||
if (config.jax_dynamic_shapes and isinstance(obj, (tuple, list)) and
|
||||
any(isinstance(d, core.Tracer) for d in obj)):
|
||||
any(isinstance(d, (core.Tracer, core.BInt)) for d in obj)):
|
||||
return # TODO(mattjj): handle more checks in the dynamic shape case
|
||||
obj_arr = np.array(obj)
|
||||
if obj_arr.ndim != 1:
|
||||
|
@ -903,7 +903,7 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes):
|
||||
dynamic_slice_p = standard_primitive(
|
||||
_dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice',
|
||||
weak_type_rule=_argnum_weak_type(0))
|
||||
ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp # TODO
|
||||
ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp
|
||||
ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule
|
||||
batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule
|
||||
|
||||
|
@ -2101,8 +2101,12 @@ def arange(start: core.DimSize, stop: Optional[core.DimSize]=None,
|
||||
dtype = _jnp_dtype(dtype)
|
||||
if stop is None and step is None:
|
||||
if (jax.config.jax_dynamic_shapes and
|
||||
not isinstance(core.get_aval(start), core.AbstractBInt) and
|
||||
not isinstance(core.get_aval(start), core.ConcreteArray)):
|
||||
start = ceil(start).astype(int) # note using jnp here
|
||||
elif (isinstance(start, core.BInt) or isinstance(start, core.Tracer) and
|
||||
isinstance(core.get_aval(start), core.AbstractBInt)):
|
||||
pass
|
||||
else:
|
||||
start = require(start, msg("stop"))
|
||||
start = np.ceil(start).astype(int)
|
||||
|
236
jax/core.py
236
jax/core.py
@ -1104,22 +1104,9 @@ def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType:
|
||||
for v in jaxpr.invars]
|
||||
return tuple(out)
|
||||
|
||||
|
||||
class Bot(AbstractValue): pass
|
||||
|
||||
bot = Bot()
|
||||
|
||||
class AbstractBInt(AbstractValue):
|
||||
__slots__ = ['bound']
|
||||
bound: int
|
||||
def __init__(self, bound):
|
||||
self.bound = bound
|
||||
def str_short(self, short_dtypes=False) -> str:
|
||||
return f'bint{{≤{self.bound}}}[]'
|
||||
def __eq__(self, other):
|
||||
return type(other) is AbstractBInt and self.bound == other.bound
|
||||
def __hash__(self) -> int:
|
||||
return hash((type(self), self.bound))
|
||||
|
||||
def lattice_join(x: Optional[AbstractValue],
|
||||
y: Optional[AbstractValue]) -> AbstractValue:
|
||||
@ -1171,9 +1158,6 @@ def get_aval(x):
|
||||
return concrete_aval(x)
|
||||
|
||||
|
||||
pytype_aval_mappings: Dict[type, Callable[[Any], AbstractValue]] = {}
|
||||
|
||||
|
||||
def concretization_function_error(fun, suggest_astype=False):
|
||||
fname = getattr(fun, "__name__", fun)
|
||||
fname_context = f"The problem arose with the `{fname}` function. "
|
||||
@ -1204,7 +1188,7 @@ def _short_dtype_name(dtype):
|
||||
|
||||
class UnshapedArray(AbstractValue):
|
||||
__slots__ = ['dtype', 'weak_type']
|
||||
array_abstraction_level = 3
|
||||
array_abstraction_level = 4
|
||||
|
||||
def __init__(self, dtype, weak_type=False):
|
||||
self.dtype = np.dtype(dtype)
|
||||
@ -1269,77 +1253,9 @@ class UnshapedArray(AbstractValue):
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
# We have a convention of reusing AbsractValues as types, in particular reusing
|
||||
# ShapedArrays as types, even though we could make a distinction and use
|
||||
# abstract values during tracing only. This reuse becomes a bit more extreme
|
||||
# with DShapedArrays. A DShapedArray's shape attribute is a tuple which can
|
||||
# contain several different types: ints, other AbstractValues (specifically at
|
||||
# the input and output to pe.trace_to_jaxpr_dynamic), Tracers (while tracing),
|
||||
# or Vars (when used as jaxpr type annotations). We could reduce this
|
||||
# polymorphism if it seems cleaner, though it's kind of convenient!
|
||||
AxisSizeForTracing = Union[int, Tracer]
|
||||
AxisSizeForJaxprType = Union[int, Var]
|
||||
AxisSizeForJaxprTracingSpec = Union[int, AbstractValue]
|
||||
AxisSize = Union[AxisSizeForTracing, AxisSizeForJaxprType,
|
||||
AxisSizeForJaxprTracingSpec]
|
||||
|
||||
class DShapedArray(UnshapedArray):
|
||||
__slots__ = ['shape']
|
||||
shape: Tuple[AxisSize, ...] # noqa: F821
|
||||
array_abstraction_level: int = 2
|
||||
|
||||
def __init__(self, shape, dtype, weak_type):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.weak_type = weak_type
|
||||
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
size = property(lambda self: prod(self.shape))
|
||||
|
||||
def str_short(self, short_dtypes=False) -> str:
|
||||
del short_dtypes # ignored
|
||||
shape = f'{",".join(str(d) for d in self.shape)}' if self.shape else ''
|
||||
dtype = _short_dtype_name(self.dtype)
|
||||
return f'{dtype}[{shape}]'
|
||||
__str__ = __repr__ = str_short
|
||||
|
||||
def update(self, shape=None, dtype=None, weak_type=None):
|
||||
if shape is None:
|
||||
shape = self.shape
|
||||
if dtype is None:
|
||||
dtype = self.dtype
|
||||
if weak_type is None:
|
||||
weak_type = self.weak_type
|
||||
return DShapedArray(shape, dtype, weak_type)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other)
|
||||
and self.dtype == other.dtype and self.shape == other.shape
|
||||
and self.weak_type == other.weak_type)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.shape, self.dtype, self.weak_type))
|
||||
|
||||
def join(self, other):
|
||||
if (symbolic_equal_shape(self.shape, other.shape) and
|
||||
self.dtype == other.dtype):
|
||||
weak_type = self.weak_type and other.weak_type
|
||||
return self.update(weak_type=weak_type)
|
||||
elif self.dtype == other.dtype:
|
||||
return UnshapedArray(self.dtype)
|
||||
else:
|
||||
raise TypeError(self, other)
|
||||
|
||||
def at_least_vspace(self):
|
||||
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type)
|
||||
|
||||
del AxisSize, AxisSizeForTracing, AxisSizeForJaxprType, \
|
||||
AxisSizeForJaxprTracingSpec
|
||||
|
||||
class ShapedArray(UnshapedArray):
|
||||
__slots__ = ['shape', 'named_shape']
|
||||
array_abstraction_level = 1
|
||||
array_abstraction_level = 2
|
||||
|
||||
def __init__(self, shape, dtype, weak_type=False, named_shape=None):
|
||||
self.shape = canonicalize_shape(shape)
|
||||
@ -1415,6 +1331,7 @@ class ShapedArray(UnshapedArray):
|
||||
def _forward_to_value(self, fun, ignored_tracer, *args):
|
||||
return fun(self.val, *args)
|
||||
|
||||
|
||||
class ConcreteArray(ShapedArray):
|
||||
__slots__ = ['val']
|
||||
array_abstraction_level = 0
|
||||
@ -1477,6 +1394,135 @@ def primal_dtype_to_tangent_dtype(primal_dtype):
|
||||
else:
|
||||
return primal_dtype
|
||||
|
||||
|
||||
# Dynamic shape stuff below here! We keep the abstract values distinct just so
|
||||
# as not to interfere with any static shape machinery.
|
||||
|
||||
# We have a convention of reusing AbsractValues as types, even though we could
|
||||
# make a distinction and use abstract values during tracing only. This reuse
|
||||
# becomes a bit more extreme with DShapedArrays. A DShapedArray's shape
|
||||
# attribute is a tuple which can contain several different types: int, BInt,
|
||||
# Tracer (while tracing), Var (when used as jaxpr type annotations), or
|
||||
# DBIdx/InDBIdx/OutDBIdx (when used in InputType or OutputType). We could reduce
|
||||
# this polymorphism if it seems cleaner, though it's kind of convenient!
|
||||
AxisSize = Any
|
||||
|
||||
class DShapedArray(UnshapedArray):
|
||||
__slots__ = ['shape']
|
||||
shape: Tuple[AxisSize, ...] # noqa: F821
|
||||
array_abstraction_level: int = 3
|
||||
|
||||
def __init__(self, shape, dtype, weak_type):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.weak_type = weak_type
|
||||
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
size = property(lambda self: prod(self.shape))
|
||||
|
||||
def str_short(self, short_dtypes=False) -> str:
|
||||
del short_dtypes # ignored
|
||||
shape = f'{",".join(str(d) for d in self.shape)}' if self.shape else ''
|
||||
dtype = _short_dtype_name(self.dtype)
|
||||
return f'{dtype}[{shape}]'
|
||||
__str__ = __repr__ = str_short
|
||||
|
||||
def update(self, shape=None, dtype=None, weak_type=None):
|
||||
if shape is None:
|
||||
shape = self.shape
|
||||
if dtype is None:
|
||||
dtype = self.dtype
|
||||
if weak_type is None:
|
||||
weak_type = self.weak_type
|
||||
return DShapedArray(shape, dtype, weak_type)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other)
|
||||
and self.dtype == other.dtype and self.shape == other.shape
|
||||
and self.weak_type == other.weak_type)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.shape, self.dtype, self.weak_type))
|
||||
|
||||
def join(self, other):
|
||||
if (symbolic_equal_shape(self.shape, other.shape) and
|
||||
self.dtype == other.dtype):
|
||||
weak_type = self.weak_type and other.weak_type
|
||||
return self.update(weak_type=weak_type)
|
||||
elif self.dtype == other.dtype:
|
||||
return UnshapedArray(self.dtype)
|
||||
else:
|
||||
raise TypeError(self, other)
|
||||
|
||||
def at_least_vspace(self):
|
||||
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type)
|
||||
|
||||
class DConcreteArray(DShapedArray):
|
||||
__slots__ = ['val']
|
||||
array_abstraction_level = 1
|
||||
def __init__(self, shape, dtype, weak_type, val):
|
||||
super().__init__(shape, dtype, weak_type)
|
||||
self.val = val
|
||||
|
||||
|
||||
pytype_aval_mappings: Dict[type, Callable[[Any], AbstractValue]] = {}
|
||||
|
||||
|
||||
class AbstractBInt(AbstractValue):
|
||||
__slots__ = ['bound']
|
||||
bound: int
|
||||
def __init__(self, bound):
|
||||
self.bound = bound
|
||||
def str_short(self, short_dtypes=False) -> str:
|
||||
return f'bint{{≤{self.bound}}}[]'
|
||||
__repr__ = str_short
|
||||
def __eq__(self, other):
|
||||
return type(other) is AbstractBInt and self.bound == other.bound
|
||||
def __hash__(self) -> int:
|
||||
return hash((type(self), self.bound))
|
||||
|
||||
class BInt:
|
||||
val: Any # Union[int, Array]
|
||||
bound: int
|
||||
def __init__(self, val, bound):
|
||||
self.val = val
|
||||
self.bound = bound
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.val}{{≤{self.bound}}}'
|
||||
def __int__(self) -> int:
|
||||
return self.val
|
||||
def __eq__(self, other) -> bool:
|
||||
return (isinstance(other, BInt) and
|
||||
(self.val, self.bound) == (other.val, other.bound))
|
||||
def __hash__(self):
|
||||
return hash((self.val, self.bound))
|
||||
pytype_aval_mappings[BInt] = lambda x: AbstractBInt(x.bound)
|
||||
|
||||
|
||||
# DShapedArray w/ BInt in shapes => PaddedArray runtime representation
|
||||
class PaddedArray:
|
||||
_aval: DShapedArray
|
||||
_data: Any # standard array type
|
||||
def __init__(self, aval, data):
|
||||
padded_shape = tuple(d.bound if type(d) is BInt else d for d in aval.shape)
|
||||
assert data.shape == padded_shape
|
||||
self._aval = aval
|
||||
self._data = data
|
||||
shape = property(lambda self: self._aval.shape)
|
||||
dtype = property(lambda self: self._aval.dtype)
|
||||
def __repr__(self) -> str:
|
||||
dtypestr = _short_dtype_name(self._aval.dtype)
|
||||
shapestr = ','.join(map(str, self.shape))
|
||||
slices = tuple(slice(d.val) if type(d) is BInt else slice(None)
|
||||
for d in self.shape)
|
||||
data = self._data[slices]
|
||||
return f'{dtypestr}[{shapestr}] with value:\n{data}'
|
||||
pytype_aval_mappings[PaddedArray] = \
|
||||
lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type,
|
||||
x._data)
|
||||
|
||||
|
||||
class AbstractToken(AbstractValue):
|
||||
def join(self, other):
|
||||
if isinstance(other, AbstractToken):
|
||||
@ -1485,7 +1531,6 @@ class AbstractToken(AbstractValue):
|
||||
assert False, f"Cannot join {self} with {other}"
|
||||
def str_short(self, short_dtypes=False): return 'Tok'
|
||||
def at_least_vspace(self): return self
|
||||
|
||||
abstract_token: AbstractToken = AbstractToken()
|
||||
|
||||
# Concrete token object
|
||||
@ -1759,6 +1804,21 @@ def _invalid_shape_error(shape: Shape, context: str=""):
|
||||
"smaller subfunctions.")
|
||||
return TypeError(msg)
|
||||
|
||||
class BIntDimensionHandler(DimensionHandler):
|
||||
def symbolic_equal(self, d1, d2) -> bool:
|
||||
return isinstance(d2, BInt) and d1.val == d2.val and d1.bound == d2.bound
|
||||
def sum(self, *ds) -> BInt:
|
||||
if not all(isinstance(d, BInt) for d in ds):
|
||||
raise InconclusiveDimensionOperation
|
||||
if len({d.bound for d in ds}) != 1:
|
||||
raise InconclusiveDimensionOperation
|
||||
return BInt(sum(d.val for d in ds), ds[0].bound)
|
||||
def fail(self, *_): raise InconclusiveDimensionOperation
|
||||
great_equal = diff = divide_shape_sizes = stride = dilate = as_value = fail
|
||||
_SPECIAL_DIMENSION_HANDLERS[BInt] = BIntDimensionHandler()
|
||||
|
||||
|
||||
|
||||
# ------------------- Named shapes -------------------
|
||||
|
||||
|
||||
@ -2436,7 +2496,7 @@ def check_type(
|
||||
if isinstance(ty, DShapedArray):
|
||||
# Check all elements in the shape tuple are well-typed.
|
||||
for d in ty.shape:
|
||||
if isinstance(d, int):
|
||||
if isinstance(d, (int, BInt)):
|
||||
continue
|
||||
elif isinstance(d, Var):
|
||||
if d not in env:
|
||||
|
@ -122,15 +122,19 @@ def dtype_to_ir_type(dtype: Union[np.dtype, np.generic]) -> ir.Type:
|
||||
f"No dtype_to_ir_type handler for dtype: {dtype}") from err
|
||||
return ir_type_factory()
|
||||
|
||||
def _array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:
|
||||
def _array_ir_types(aval: Union[core.ShapedArray, core.DShapedArray]
|
||||
) -> Sequence[ir.Type]:
|
||||
return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),)
|
||||
|
||||
def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:
|
||||
shape = [d if type(d) is int else -1 for d in aval.shape]
|
||||
# in the MHLO builder, -1 indicates a '?' axis size
|
||||
shape = [d if type(d) is int else d.bound if type(d) is core.BInt else -1
|
||||
for d in aval.shape]
|
||||
return (ir.RankedTensorType.get(shape, dtype_to_ir_type(aval.dtype)),)
|
||||
|
||||
def _bint_ir_types(aval: core.AbstractBInt) -> Sequence[ir.Type]:
|
||||
return (ir.RankedTensorType.get((), dtype_to_ir_type(dtypes.dtype('int32'))),)
|
||||
dtype = dtypes._scalar_type_to_dtype(int)
|
||||
return (ir.RankedTensorType.get((), dtype_to_ir_type(dtype)),)
|
||||
|
||||
ir_type_handlers: Dict[Type[core.AbstractValue],
|
||||
Callable[[Any], Sequence[ir.Type]]] = {}
|
||||
|
@ -2414,7 +2414,7 @@ def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const]
|
||||
|
||||
def substitute(aval: AbstractValue) -> AbstractValue:
|
||||
if isinstance(aval, AbstractBInt):
|
||||
return ShapedArray((), np.dtype('int32'))
|
||||
return ShapedArray((), dtypes._scalar_type_to_dtype(int))
|
||||
elif isinstance(aval, DShapedArray):
|
||||
shape = [bounds.get(d, idxs.get(d, d)) for d in aval.shape] # type: ignore
|
||||
typ = ShapedArray if all(type(d) is int for d in shape) else DShapedArray
|
||||
|
@ -253,12 +253,15 @@ def _canonicalize_python_scalar_dtype(typ, x):
|
||||
|
||||
canonicalize_dtype_handlers: Dict[Any, Callable] = {}
|
||||
for t in device_array.device_array_types:
|
||||
canonicalize_dtype_handlers[t] = lambda x: x
|
||||
canonicalize_dtype_handlers[t] = identity
|
||||
canonicalize_dtype_handlers.update(
|
||||
(t, _canonicalize_ndarray_dtype) for t in array_types)
|
||||
canonicalize_dtype_handlers.update(
|
||||
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)
|
||||
canonicalize_dtype_handlers[core.Token] = lambda x: x
|
||||
canonicalize_dtype_handlers[core.Token] = identity
|
||||
canonicalize_dtype_handlers[core.PaddedArray] = identity
|
||||
canonicalize_dtype_handlers[core.BInt] = \
|
||||
lambda x: core.BInt(_canonicalize_python_scalar_dtype(int, x.val), x.bound)
|
||||
|
||||
def abstractify(x) -> core.AbstractValue:
|
||||
typ = type(x)
|
||||
@ -277,6 +280,8 @@ def _make_abstract_python_scalar(typ, val):
|
||||
pytype_aval_mappings: Dict[Any, Callable[[Any], core.AbstractValue]] = {}
|
||||
for t in device_array.device_array_types:
|
||||
pytype_aval_mappings[t] = operator.attrgetter('aval')
|
||||
pytype_aval_mappings[core.BInt] = lambda x: core.AbstractBInt(x.bound)
|
||||
pytype_aval_mappings[core.PaddedArray] = operator.attrgetter('_aval')
|
||||
pytype_aval_mappings[core.Token] = lambda _: core.abstract_token
|
||||
pytype_aval_mappings.update((t, make_shaped_array) for t in array_types)
|
||||
pytype_aval_mappings.update(
|
||||
|
@ -9142,6 +9142,7 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(y, jnp.arange(2 * 4), check_dtypes=False)
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
@jtu.skip_on_devices('iree') # TODO(mattjj): update getslice, no bints
|
||||
def test_slicing_basic(self):
|
||||
f = jax.jit(lambda x, n: jnp.sum(x[:n]))
|
||||
# TODO(mattjj): revise getslice, add typecheck rule for it, enable checks
|
||||
@ -9509,6 +9510,98 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
expected = grad(loss_ref)(params, batch1)
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def test_bint_basic(self):
|
||||
d = lax.make_bint(3, 5)
|
||||
self.assertEqual(str(d), '3{≤5}')
|
||||
|
||||
@jax.jit
|
||||
def f(d):
|
||||
jnp.sin(3.) # don't have an empty jaxpr
|
||||
return d
|
||||
f(d) # doesn't crash
|
||||
|
||||
def test_bint_broadcast(self):
|
||||
d = lax.make_bint(3, 5)
|
||||
|
||||
x = lax.broadcast_in_dim(0, (d,), ()) # doesn't crash
|
||||
self.assertIsInstance(x, core.PaddedArray)
|
||||
self.assertAllClose(x._data, np.zeros(5, dtype='int32'), check_dtypes=False)
|
||||
self.assertEqual(
|
||||
x._aval, core.DShapedArray((core.BInt(3, 5),), x._data.dtype, True))
|
||||
|
||||
def f(n):
|
||||
return jnp.zeros(n)
|
||||
x = jax.jit(f)(d)
|
||||
self.assertIsInstance(x, core.PaddedArray)
|
||||
self.assertAllClose(x._data, np.zeros(5, dtype='int32'), check_dtypes=False)
|
||||
self.assertEqual(
|
||||
x._aval, core.DShapedArray((core.BInt(3, 5),), x._data.dtype, False))
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)(d).jaxpr
|
||||
# { lambda ; a:bint{≤5}[]. let
|
||||
# b:f32[a] = broadcast_in_dim[...] 0.0 a
|
||||
# in (b,) }
|
||||
self.assertLen(jaxpr.invars, 1)
|
||||
a, = jaxpr.invars
|
||||
self.assertEqual(a.aval, core.AbstractBInt(5))
|
||||
self.assertLen(jaxpr.eqns, 1)
|
||||
eqn, = jaxpr.eqns
|
||||
self.assertLen(eqn.outvars, 1)
|
||||
b, = eqn.outvars
|
||||
self.assertEqual(b.aval.shape, (a,))
|
||||
|
||||
def test_bint_iota(self):
|
||||
def f(d):
|
||||
return jnp.arange(d, dtype='int32')
|
||||
|
||||
y = f(lax.make_bint(3, 5))
|
||||
self.assertIsInstance(y, core.PaddedArray)
|
||||
self.assertAllClose(y._data, np.arange(5), check_dtypes=False)
|
||||
|
||||
d = lax.make_bint(3, 5)
|
||||
y = jax.jit(f)(d)
|
||||
self.assertIsInstance(y, core.PaddedArray)
|
||||
self.assertAllClose(y._data, np.arange(5), check_dtypes=False)
|
||||
|
||||
def test_bint_compilation_cache(self):
|
||||
count = 0
|
||||
|
||||
@jax.jit
|
||||
def f(n):
|
||||
nonlocal count
|
||||
count += 1
|
||||
return jnp.zeros(n)
|
||||
f(lax.make_bint(3, 5))
|
||||
f(lax.make_bint(4, 5))
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
def test_bint_compilation_cache2(self):
|
||||
count = 0
|
||||
|
||||
@partial(jax.jit, abstracted_axes=('n',))
|
||||
def f(x):
|
||||
nonlocal count
|
||||
count += 1
|
||||
return x.sum()
|
||||
|
||||
d = lax.make_bint(3, 5)
|
||||
x = jnp.arange(d)
|
||||
y = f(x)
|
||||
self.assertEqual(y, 3)
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
d = lax.make_bint(4, 5)
|
||||
x = jnp.arange(d)
|
||||
y = f(x)
|
||||
self.assertEqual(y, 6)
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
d = lax.make_bint(4, 6)
|
||||
x = jnp.arange(d)
|
||||
y = f(x)
|
||||
self.assertEqual(y, 6)
|
||||
self.assertEqual(count, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user