Merge pull request #11387 from mattjj:djax-bint

PiperOrigin-RevId: 459430960
This commit is contained in:
jax authors 2022-07-06 23:00:59 -07:00
commit 5270cb1c1f
10 changed files with 345 additions and 120 deletions

View File

@ -299,6 +299,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
if config.jax_dynamic_shapes:
keep_unused = True
has_outfeed = False
donated_invars = [False] * len(fun.in_type)
else:
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
jaxpr = apply_outfeed_rewriter(jaxpr)
@ -318,8 +319,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
device = _xla_callable_device(nreps, backend, device, arg_devices)
backend = xb.get_device_backend(device) if device else xb.get_backend(backend)
if (config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr) and
not _backend_supports_unbounded_dynamic_shapes(backend)):
if config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr):
jaxpr, consts = pe.pad_jaxpr(jaxpr, consts)
map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr)))
@ -520,6 +520,7 @@ num_buffers_handlers[core.AbstractToken] = lambda _: 1
num_buffers_handlers[core.ShapedArray] = lambda _: 1
num_buffers_handlers[core.DShapedArray] = lambda _: 1
num_buffers_handlers[core.ConcreteArray] = lambda _: 1
num_buffers_handlers[core.AbstractBInt] = lambda _: 1
def _input_handler(backend: Backend,
@ -652,17 +653,22 @@ def dynamic_array_result_handler(sticky_device: Optional[Device],
return partial(_dynamic_array_result_handler, sticky_device, aval)
def _dynamic_array_result_handler(sticky_device, aval, env, buf):
if all(type(d) is int for d in aval.shape):
del env
return _maybe_create_array_from_da(buf, aval, sticky_device)
else:
assert env is not None
in_env, out_env = env
shape = [in_env[d.val] if type(d) is core.InDBIdx else
out_env[d.val] if type(d) is core.OutDBIdx else d
for d in aval.shape]
in_env, out_env = env or (None, None)
shape = [in_env[d.val] if type(d) is core.InDBIdx else
out_env[d.val] if type(d) is core.OutDBIdx else d
for d in aval.shape]
if all(type(d) is int for d in shape):
aval = core.ShapedArray(tuple(shape), aval.dtype)
return _maybe_create_array_from_da(buf, aval, sticky_device)
elif any(type(d) is core.BInt for d in shape):
padded_shape = [d.bound if type(d) is core.BInt else d for d in shape]
buf_aval = core.ShapedArray(tuple(padded_shape), aval.dtype, aval.weak_type)
data = _maybe_create_array_from_da(buf, buf_aval, sticky_device)
return core.PaddedArray(aval.update(shape=tuple(shape)), data)
else:
aval = core.ShapedArray(tuple(shape), aval.dtype)
return _maybe_create_array_from_da(buf, aval, sticky_device)
result_handlers: Dict[
@ -672,6 +678,8 @@ result_handlers[core.AbstractToken] = lambda _, __: lambda _, __: core.token
result_handlers[core.ShapedArray] = array_result_handler
result_handlers[core.DShapedArray] = dynamic_array_result_handler
result_handlers[core.ConcreteArray] = array_result_handler
result_handlers[core.AbstractBInt] = \
lambda _, a: lambda _, b: core.BInt(int(b), a.bound)
def needs_check_special():
@ -1014,6 +1022,7 @@ device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Tuple[Any]]] =
device_put_handlers.update((t, _device_put_array) for t in array_types)
device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types)
device_put_handlers[core.Token] = _device_put_token
device_put_handlers[core.BInt] = lambda x, d: _device_put_scalar(x.val, d)
def _device_put_device_array(x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray], device: Optional[Device]):
@ -1021,6 +1030,7 @@ def _device_put_device_array(x: Union[device_array.DeviceArrayProtocol, device_a
return (x.device_buffer,)
for t in device_array.device_array_types:
device_put_handlers[t] = _device_put_device_array
device_put_handlers[core.PaddedArray] = lambda x, d: device_put(x._data, d)
def _copy_device_array_to_device(
x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray],

View File

@ -104,8 +104,8 @@ class IreeBuffer(xla_client.DeviceArrayBase):
return self # no async
# overrides repr on base class which expects _value and aval attributes
def __repr__(self):
return f'IreeBuffer({self.to_py()})'
def __repr__(self): return f'IreeBuffer({self.to_py()})'
_value = property(to_py)
class IreeExecutable:

View File

@ -1440,6 +1440,7 @@ def unop(result_dtype, accepted_dtypes, name):
weak_type_rule=weak_type_rule)
batching.defvectorized(prim)
masking.defvectorized(prim)
pe.padding_rules[prim] = lambda _, __, x, **kw: [prim.bind(x, **kw)]
return prim
standard_unop = partial(unop, _identity)
_attrgetter = lambda name: lambda x, **kwargs: getattr(x, name)
@ -1515,6 +1516,7 @@ def naryop(result_dtype, accepted_dtypes, name):
weak_type_rule=weak_type_rule)
batching.defbroadcasting(prim)
masking.defnaryop(prim)
pe.padding_rules[prim] = lambda _, __, *xs, **kw: [prim.bind(*xs, **kw)]
return prim
standard_naryop = partial(naryop, _input_dtype)
@ -2080,7 +2082,6 @@ add_p: Primitive = standard_naryop([_num, _num], 'add')
ad.primitive_jvps[add_p] = _add_jvp
ad.primitive_transposes[add_p] = _add_transpose
mlir.register_lowering(add_p, partial(_nary_lower_mhlo, mhlo.AddOp))
pe.padding_rules[add_p] = lambda _, __, x, y: [add(x, y)]
def _sub_jvp(primals, tangents):
x, y = primals
@ -2110,7 +2111,6 @@ sub_p = standard_naryop([_num, _num], 'sub')
ad.primitive_jvps[sub_p] = _sub_jvp
ad.primitive_transposes[sub_p] = _sub_transpose
mlir.register_lowering(sub_p, partial(_nary_lower_mhlo, mhlo.SubOp))
pe.padding_rules[sub_p] = lambda _, __, x, y: [sub(x, y)]
def _mul_transpose(ct, x, y):
@ -2137,7 +2137,6 @@ ad.defjvp(mul_p,
lambda ydot, x, y: mul(x, ydot))
ad.primitive_transposes[mul_p] = _mul_transpose
mlir.register_lowering(mul_p, partial(_nary_lower_mhlo, mhlo.MulOp))
pe.padding_rules[mul_p] = lambda _, __, x, y: [mul(x, y)]
def _div_transpose_rule(cotangent, x, y):
assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y)
@ -2174,7 +2173,6 @@ ad.defjvp2(max_p,
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
mlir.register_lowering(max_p, partial(_nary_lower_mhlo, mlir.max_mhlo))
pe.padding_rules[max_p] = lambda _, __, x, y: [max(x, y)]
min_p: core.Primitive = standard_naryop([_any, _any], 'min')
ad.defjvp2(min_p,
@ -2297,9 +2295,13 @@ def _convert_elt_type_pp_rule(eqn, context, settings):
printed_params = {}
if eqn.params['weak_type']:
printed_params['weak_type'] = True
return [pp.text(eqn.primitive.name),
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
pp.text(" ") + core.pp_vars(eqn.invars, context)]
lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
rhs = [pp.text(eqn.primitive.name),
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
pp.text(" ") + core.pp_vars(eqn.invars, context)]
annotation = (source_info_util.summarize(eqn.source_info)
if settings.source_info else None)
return [lhs, pp.text(" = ", annotation=annotation), *rhs]
convert_element_type_p = Primitive('convert_element_type')
@ -2756,7 +2758,7 @@ def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape,
assert isinstance(d, core.Tracer)
new_shape.append(None)
new_dyn_shape.append(d)
return [broadcast_in_dim_p.bind(x, *new_dyn_shape, shape=new_shape,
return [broadcast_in_dim_p.bind(x, *new_dyn_shape, shape=tuple(new_shape),
broadcast_dimensions=broadcast_dimensions)]
def _broadcast_in_dim_jvp_rule(primals, tangents, *, shape, broadcast_dimensions):
@ -2820,8 +2822,18 @@ def _broadcast_in_dim_pp_rule(eqn, context, settings):
if settings.source_info else None)
return [lhs, pp.text(" = ", annotation=annotation), *rhs]
def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions):
if not any(isinstance(d, core.BInt) for d in shape):
shape = _broadcast_in_dim_shape_rule( # error checking
x, shape=shape, broadcast_dimensions=broadcast_dimensions)
return core.ShapedArray(shape, x.dtype, x.weak_type, x.named_shape)
# If any BInts in shape, produce a DShapedArray (even if x is a ShapedArray)
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
return core.DShapedArray(shape, x.dtype, x.weak_type)
broadcast_in_dim_p = standard_primitive(
_broadcast_in_dim_shape_rule, _input_dtype, 'broadcast_in_dim')
broadcast_in_dim_p.def_abstract_eval(_broadcast_in_dim_abstract_eval)
ad.primitive_jvps[broadcast_in_dim_p] = _broadcast_in_dim_jvp_rule
ad.primitive_transposes[broadcast_in_dim_p] = _broadcast_in_dim_transpose_rule
batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
@ -3605,8 +3617,8 @@ def _reduce_sum_padding_rule(in_avals, out_avals, operand, *, axes):
def _replace_masked_values(x, val, padded_axes):
if not padded_axes: return x
masks = [broadcasted_iota(np.dtype('int32'), x.shape, i) < d
for i, d in padded_axes]
dtype = dtypes._scalar_type_to_dtype(int)
masks = [broadcasted_iota(dtype, x.shape, i) < d for i, d in padded_axes]
return select(_reduce(operator.and_, masks), x, full_like(x, val))
@ -4388,7 +4400,10 @@ def _iota_abstract_eval(*, dtype, shape, dimension):
if not 0 <= dimension < len(shape):
raise ValueError("iota dimension must be between 0 and len(shape), got "
f"dimension={dimension} for shape {shape}")
return ShapedArray(shape, dtype)
if not any(isinstance(d, core.BInt) for d in shape):
return ShapedArray(shape, dtype)
# TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code
return core.DShapedArray(shape, dtype, False)
iota_p = Primitive('iota')
iota_p.def_impl(partial(xla.apply_primitive, iota_p))
@ -4429,12 +4444,46 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension):
mlir.i64_attr(dimension)).results
mlir.register_lowering(iota_p, _iota_lower)
def _iota_pp_rule(eqn, context, settings):
printed_params = {}
if len(eqn.params['shape']) > 1:
printed_params['dimension'] = eqn.params['dimension']
lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
rhs = [pp.text(eqn.primitive.name),
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
pp.text(" ") + core.pp_vars(eqn.invars, context)]
annotation = (source_info_util.summarize(eqn.source_info)
if settings.source_info else None)
return [lhs, pp.text(" = ", annotation=annotation), *rhs]
# core.pp_eqn_rules[iota_p] = _iota_pp_rule
def _iota_padding_rule(in_avals, out_avals, *dyn_shape, dtype, shape, dimension):
out_aval, = out_avals
new_shape = []
new_dyn_shape = []
for d in out_aval.shape:
if type(d) is pe.BoundedAxisSize:
new_shape.append(d.bound)
elif type(d) is int:
new_shape.append(d)
else:
assert isinstance(d, core.Tracer)
new_shape.append(None)
new_dyn_shape.append(d)
return [iota_p.bind(*new_dyn_shape, shape=tuple(new_shape),
dtype=dtype, dimension=dimension)]
pe.padding_rules[iota_p] = _iota_padding_rule
def make_bint(i, bd: int):
return bint_p.bind(i, bd=bd)
bint_p = core.Primitive('bint')
@bint_p.def_impl
def _bint_impl(i, *, bd):
return core.BInt(i, bd)
@bint_p.def_abstract_eval
def bint_abstract_eval(_, *, bd: int):
return core.AbstractBInt(bound=bd)
@ -4570,7 +4619,7 @@ def _check_shapelike(fun_name, arg_name, obj, non_zero_shape=False):
if not len(obj): # pylint: disable=g-explicit-length-test
return
if (config.jax_dynamic_shapes and isinstance(obj, (tuple, list)) and
any(isinstance(d, core.Tracer) for d in obj)):
any(isinstance(d, (core.Tracer, core.BInt)) for d in obj)):
return # TODO(mattjj): handle more checks in the dynamic shape case
obj_arr = np.array(obj)
if obj_arr.ndim != 1:

View File

@ -903,7 +903,7 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes):
dynamic_slice_p = standard_primitive(
_dynamic_slice_shape_rule, _dynamic_slice_dtype_rule, 'dynamic_slice',
weak_type_rule=_argnum_weak_type(0))
ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp # TODO
ad.primitive_jvps[dynamic_slice_p] = _dynamic_slice_jvp
ad.primitive_transposes[dynamic_slice_p] = _dynamic_slice_transpose_rule
batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule

View File

@ -2101,8 +2101,12 @@ def arange(start: core.DimSize, stop: Optional[core.DimSize]=None,
dtype = _jnp_dtype(dtype)
if stop is None and step is None:
if (jax.config.jax_dynamic_shapes and
not isinstance(core.get_aval(start), core.AbstractBInt) and
not isinstance(core.get_aval(start), core.ConcreteArray)):
start = ceil(start).astype(int) # note using jnp here
elif (isinstance(start, core.BInt) or isinstance(start, core.Tracer) and
isinstance(core.get_aval(start), core.AbstractBInt)):
pass
else:
start = require(start, msg("stop"))
start = np.ceil(start).astype(int)

View File

@ -1104,22 +1104,9 @@ def _jaxpr_type_to_callable_annotation(jaxpr: Jaxpr) -> InputType:
for v in jaxpr.invars]
return tuple(out)
class Bot(AbstractValue): pass
bot = Bot()
class AbstractBInt(AbstractValue):
__slots__ = ['bound']
bound: int
def __init__(self, bound):
self.bound = bound
def str_short(self, short_dtypes=False) -> str:
return f'bint{{{self.bound}}}[]'
def __eq__(self, other):
return type(other) is AbstractBInt and self.bound == other.bound
def __hash__(self) -> int:
return hash((type(self), self.bound))
def lattice_join(x: Optional[AbstractValue],
y: Optional[AbstractValue]) -> AbstractValue:
@ -1171,9 +1158,6 @@ def get_aval(x):
return concrete_aval(x)
pytype_aval_mappings: Dict[type, Callable[[Any], AbstractValue]] = {}
def concretization_function_error(fun, suggest_astype=False):
fname = getattr(fun, "__name__", fun)
fname_context = f"The problem arose with the `{fname}` function. "
@ -1204,7 +1188,7 @@ def _short_dtype_name(dtype):
class UnshapedArray(AbstractValue):
__slots__ = ['dtype', 'weak_type']
array_abstraction_level = 3
array_abstraction_level = 4
def __init__(self, dtype, weak_type=False):
self.dtype = np.dtype(dtype)
@ -1269,77 +1253,9 @@ class UnshapedArray(AbstractValue):
raise TypeError(msg)
# We have a convention of reusing AbsractValues as types, in particular reusing
# ShapedArrays as types, even though we could make a distinction and use
# abstract values during tracing only. This reuse becomes a bit more extreme
# with DShapedArrays. A DShapedArray's shape attribute is a tuple which can
# contain several different types: ints, other AbstractValues (specifically at
# the input and output to pe.trace_to_jaxpr_dynamic), Tracers (while tracing),
# or Vars (when used as jaxpr type annotations). We could reduce this
# polymorphism if it seems cleaner, though it's kind of convenient!
AxisSizeForTracing = Union[int, Tracer]
AxisSizeForJaxprType = Union[int, Var]
AxisSizeForJaxprTracingSpec = Union[int, AbstractValue]
AxisSize = Union[AxisSizeForTracing, AxisSizeForJaxprType,
AxisSizeForJaxprTracingSpec]
class DShapedArray(UnshapedArray):
__slots__ = ['shape']
shape: Tuple[AxisSize, ...] # noqa: F821
array_abstraction_level: int = 2
def __init__(self, shape, dtype, weak_type):
self.shape = shape
self.dtype = dtype
self.weak_type = weak_type
ndim = property(lambda self: len(self.shape))
size = property(lambda self: prod(self.shape))
def str_short(self, short_dtypes=False) -> str:
del short_dtypes # ignored
shape = f'{",".join(str(d) for d in self.shape)}' if self.shape else ''
dtype = _short_dtype_name(self.dtype)
return f'{dtype}[{shape}]'
__str__ = __repr__ = str_short
def update(self, shape=None, dtype=None, weak_type=None):
if shape is None:
shape = self.shape
if dtype is None:
dtype = self.dtype
if weak_type is None:
weak_type = self.weak_type
return DShapedArray(shape, dtype, weak_type)
def __eq__(self, other):
return (type(self) is type(other)
and self.dtype == other.dtype and self.shape == other.shape
and self.weak_type == other.weak_type)
def __hash__(self):
return hash((self.shape, self.dtype, self.weak_type))
def join(self, other):
if (symbolic_equal_shape(self.shape, other.shape) and
self.dtype == other.dtype):
weak_type = self.weak_type and other.weak_type
return self.update(weak_type=weak_type)
elif self.dtype == other.dtype:
return UnshapedArray(self.dtype)
else:
raise TypeError(self, other)
def at_least_vspace(self):
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
del AxisSize, AxisSizeForTracing, AxisSizeForJaxprType, \
AxisSizeForJaxprTracingSpec
class ShapedArray(UnshapedArray):
__slots__ = ['shape', 'named_shape']
array_abstraction_level = 1
array_abstraction_level = 2
def __init__(self, shape, dtype, weak_type=False, named_shape=None):
self.shape = canonicalize_shape(shape)
@ -1415,6 +1331,7 @@ class ShapedArray(UnshapedArray):
def _forward_to_value(self, fun, ignored_tracer, *args):
return fun(self.val, *args)
class ConcreteArray(ShapedArray):
__slots__ = ['val']
array_abstraction_level = 0
@ -1477,6 +1394,135 @@ def primal_dtype_to_tangent_dtype(primal_dtype):
else:
return primal_dtype
# Dynamic shape stuff below here! We keep the abstract values distinct just so
# as not to interfere with any static shape machinery.
# We have a convention of reusing AbsractValues as types, even though we could
# make a distinction and use abstract values during tracing only. This reuse
# becomes a bit more extreme with DShapedArrays. A DShapedArray's shape
# attribute is a tuple which can contain several different types: int, BInt,
# Tracer (while tracing), Var (when used as jaxpr type annotations), or
# DBIdx/InDBIdx/OutDBIdx (when used in InputType or OutputType). We could reduce
# this polymorphism if it seems cleaner, though it's kind of convenient!
AxisSize = Any
class DShapedArray(UnshapedArray):
__slots__ = ['shape']
shape: Tuple[AxisSize, ...] # noqa: F821
array_abstraction_level: int = 3
def __init__(self, shape, dtype, weak_type):
self.shape = shape
self.dtype = dtype
self.weak_type = weak_type
ndim = property(lambda self: len(self.shape))
size = property(lambda self: prod(self.shape))
def str_short(self, short_dtypes=False) -> str:
del short_dtypes # ignored
shape = f'{",".join(str(d) for d in self.shape)}' if self.shape else ''
dtype = _short_dtype_name(self.dtype)
return f'{dtype}[{shape}]'
__str__ = __repr__ = str_short
def update(self, shape=None, dtype=None, weak_type=None):
if shape is None:
shape = self.shape
if dtype is None:
dtype = self.dtype
if weak_type is None:
weak_type = self.weak_type
return DShapedArray(shape, dtype, weak_type)
def __eq__(self, other):
return (type(self) is type(other)
and self.dtype == other.dtype and self.shape == other.shape
and self.weak_type == other.weak_type)
def __hash__(self):
return hash((self.shape, self.dtype, self.weak_type))
def join(self, other):
if (symbolic_equal_shape(self.shape, other.shape) and
self.dtype == other.dtype):
weak_type = self.weak_type and other.weak_type
return self.update(weak_type=weak_type)
elif self.dtype == other.dtype:
return UnshapedArray(self.dtype)
else:
raise TypeError(self, other)
def at_least_vspace(self):
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
class DConcreteArray(DShapedArray):
__slots__ = ['val']
array_abstraction_level = 1
def __init__(self, shape, dtype, weak_type, val):
super().__init__(shape, dtype, weak_type)
self.val = val
pytype_aval_mappings: Dict[type, Callable[[Any], AbstractValue]] = {}
class AbstractBInt(AbstractValue):
__slots__ = ['bound']
bound: int
def __init__(self, bound):
self.bound = bound
def str_short(self, short_dtypes=False) -> str:
return f'bint{{{self.bound}}}[]'
__repr__ = str_short
def __eq__(self, other):
return type(other) is AbstractBInt and self.bound == other.bound
def __hash__(self) -> int:
return hash((type(self), self.bound))
class BInt:
val: Any # Union[int, Array]
bound: int
def __init__(self, val, bound):
self.val = val
self.bound = bound
def __repr__(self) -> str:
return f'{self.val}{{{self.bound}}}'
def __int__(self) -> int:
return self.val
def __eq__(self, other) -> bool:
return (isinstance(other, BInt) and
(self.val, self.bound) == (other.val, other.bound))
def __hash__(self):
return hash((self.val, self.bound))
pytype_aval_mappings[BInt] = lambda x: AbstractBInt(x.bound)
# DShapedArray w/ BInt in shapes => PaddedArray runtime representation
class PaddedArray:
_aval: DShapedArray
_data: Any # standard array type
def __init__(self, aval, data):
padded_shape = tuple(d.bound if type(d) is BInt else d for d in aval.shape)
assert data.shape == padded_shape
self._aval = aval
self._data = data
shape = property(lambda self: self._aval.shape)
dtype = property(lambda self: self._aval.dtype)
def __repr__(self) -> str:
dtypestr = _short_dtype_name(self._aval.dtype)
shapestr = ','.join(map(str, self.shape))
slices = tuple(slice(d.val) if type(d) is BInt else slice(None)
for d in self.shape)
data = self._data[slices]
return f'{dtypestr}[{shapestr}] with value:\n{data}'
pytype_aval_mappings[PaddedArray] = \
lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type,
x._data)
class AbstractToken(AbstractValue):
def join(self, other):
if isinstance(other, AbstractToken):
@ -1485,7 +1531,6 @@ class AbstractToken(AbstractValue):
assert False, f"Cannot join {self} with {other}"
def str_short(self, short_dtypes=False): return 'Tok'
def at_least_vspace(self): return self
abstract_token: AbstractToken = AbstractToken()
# Concrete token object
@ -1759,6 +1804,21 @@ def _invalid_shape_error(shape: Shape, context: str=""):
"smaller subfunctions.")
return TypeError(msg)
class BIntDimensionHandler(DimensionHandler):
def symbolic_equal(self, d1, d2) -> bool:
return isinstance(d2, BInt) and d1.val == d2.val and d1.bound == d2.bound
def sum(self, *ds) -> BInt:
if not all(isinstance(d, BInt) for d in ds):
raise InconclusiveDimensionOperation
if len({d.bound for d in ds}) != 1:
raise InconclusiveDimensionOperation
return BInt(sum(d.val for d in ds), ds[0].bound)
def fail(self, *_): raise InconclusiveDimensionOperation
great_equal = diff = divide_shape_sizes = stride = dilate = as_value = fail
_SPECIAL_DIMENSION_HANDLERS[BInt] = BIntDimensionHandler()
# ------------------- Named shapes -------------------
@ -2436,7 +2496,7 @@ def check_type(
if isinstance(ty, DShapedArray):
# Check all elements in the shape tuple are well-typed.
for d in ty.shape:
if isinstance(d, int):
if isinstance(d, (int, BInt)):
continue
elif isinstance(d, Var):
if d not in env:

View File

@ -122,15 +122,19 @@ def dtype_to_ir_type(dtype: Union[np.dtype, np.generic]) -> ir.Type:
f"No dtype_to_ir_type handler for dtype: {dtype}") from err
return ir_type_factory()
def _array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:
def _array_ir_types(aval: Union[core.ShapedArray, core.DShapedArray]
) -> Sequence[ir.Type]:
return (ir.RankedTensorType.get(aval.shape, dtype_to_ir_type(aval.dtype)),)
def _dynamic_array_ir_types(aval: core.ShapedArray) -> Sequence[ir.Type]:
shape = [d if type(d) is int else -1 for d in aval.shape]
# in the MHLO builder, -1 indicates a '?' axis size
shape = [d if type(d) is int else d.bound if type(d) is core.BInt else -1
for d in aval.shape]
return (ir.RankedTensorType.get(shape, dtype_to_ir_type(aval.dtype)),)
def _bint_ir_types(aval: core.AbstractBInt) -> Sequence[ir.Type]:
return (ir.RankedTensorType.get((), dtype_to_ir_type(dtypes.dtype('int32'))),)
dtype = dtypes._scalar_type_to_dtype(int)
return (ir.RankedTensorType.get((), dtype_to_ir_type(dtype)),)
ir_type_handlers: Dict[Type[core.AbstractValue],
Callable[[Any], Sequence[ir.Type]]] = {}

View File

@ -2414,7 +2414,7 @@ def pad_jaxpr(jaxpr: Jaxpr, consts: Sequence[Const]
def substitute(aval: AbstractValue) -> AbstractValue:
if isinstance(aval, AbstractBInt):
return ShapedArray((), np.dtype('int32'))
return ShapedArray((), dtypes._scalar_type_to_dtype(int))
elif isinstance(aval, DShapedArray):
shape = [bounds.get(d, idxs.get(d, d)) for d in aval.shape] # type: ignore
typ = ShapedArray if all(type(d) is int for d in shape) else DShapedArray

View File

@ -253,12 +253,15 @@ def _canonicalize_python_scalar_dtype(typ, x):
canonicalize_dtype_handlers: Dict[Any, Callable] = {}
for t in device_array.device_array_types:
canonicalize_dtype_handlers[t] = lambda x: x
canonicalize_dtype_handlers[t] = identity
canonicalize_dtype_handlers.update(
(t, _canonicalize_ndarray_dtype) for t in array_types)
canonicalize_dtype_handlers.update(
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)
canonicalize_dtype_handlers[core.Token] = lambda x: x
canonicalize_dtype_handlers[core.Token] = identity
canonicalize_dtype_handlers[core.PaddedArray] = identity
canonicalize_dtype_handlers[core.BInt] = \
lambda x: core.BInt(_canonicalize_python_scalar_dtype(int, x.val), x.bound)
def abstractify(x) -> core.AbstractValue:
typ = type(x)
@ -277,6 +280,8 @@ def _make_abstract_python_scalar(typ, val):
pytype_aval_mappings: Dict[Any, Callable[[Any], core.AbstractValue]] = {}
for t in device_array.device_array_types:
pytype_aval_mappings[t] = operator.attrgetter('aval')
pytype_aval_mappings[core.BInt] = lambda x: core.AbstractBInt(x.bound)
pytype_aval_mappings[core.PaddedArray] = operator.attrgetter('_aval')
pytype_aval_mappings[core.Token] = lambda _: core.abstract_token
pytype_aval_mappings.update((t, make_shaped_array) for t in array_types)
pytype_aval_mappings.update(

View File

@ -9142,6 +9142,7 @@ class DynamicShapeTest(jtu.JaxTestCase):
self.assertAllClose(y, jnp.arange(2 * 4), check_dtypes=False)
self.assertEqual(count, 1)
@jtu.skip_on_devices('iree') # TODO(mattjj): update getslice, no bints
def test_slicing_basic(self):
f = jax.jit(lambda x, n: jnp.sum(x[:n]))
# TODO(mattjj): revise getslice, add typecheck rule for it, enable checks
@ -9509,6 +9510,98 @@ class DynamicShapeTest(jtu.JaxTestCase):
expected = grad(loss_ref)(params, batch1)
self.assertAllClose(ans, expected)
def test_bint_basic(self):
d = lax.make_bint(3, 5)
self.assertEqual(str(d), '3{≤5}')
@jax.jit
def f(d):
jnp.sin(3.) # don't have an empty jaxpr
return d
f(d) # doesn't crash
def test_bint_broadcast(self):
d = lax.make_bint(3, 5)
x = lax.broadcast_in_dim(0, (d,), ()) # doesn't crash
self.assertIsInstance(x, core.PaddedArray)
self.assertAllClose(x._data, np.zeros(5, dtype='int32'), check_dtypes=False)
self.assertEqual(
x._aval, core.DShapedArray((core.BInt(3, 5),), x._data.dtype, True))
def f(n):
return jnp.zeros(n)
x = jax.jit(f)(d)
self.assertIsInstance(x, core.PaddedArray)
self.assertAllClose(x._data, np.zeros(5, dtype='int32'), check_dtypes=False)
self.assertEqual(
x._aval, core.DShapedArray((core.BInt(3, 5),), x._data.dtype, False))
jaxpr = jax.make_jaxpr(f)(d).jaxpr
# { lambda ; a:bint{≤5}[]. let
# b:f32[a] = broadcast_in_dim[...] 0.0 a
# in (b,) }
self.assertLen(jaxpr.invars, 1)
a, = jaxpr.invars
self.assertEqual(a.aval, core.AbstractBInt(5))
self.assertLen(jaxpr.eqns, 1)
eqn, = jaxpr.eqns
self.assertLen(eqn.outvars, 1)
b, = eqn.outvars
self.assertEqual(b.aval.shape, (a,))
def test_bint_iota(self):
def f(d):
return jnp.arange(d, dtype='int32')
y = f(lax.make_bint(3, 5))
self.assertIsInstance(y, core.PaddedArray)
self.assertAllClose(y._data, np.arange(5), check_dtypes=False)
d = lax.make_bint(3, 5)
y = jax.jit(f)(d)
self.assertIsInstance(y, core.PaddedArray)
self.assertAllClose(y._data, np.arange(5), check_dtypes=False)
def test_bint_compilation_cache(self):
count = 0
@jax.jit
def f(n):
nonlocal count
count += 1
return jnp.zeros(n)
f(lax.make_bint(3, 5))
f(lax.make_bint(4, 5))
self.assertEqual(count, 1)
def test_bint_compilation_cache2(self):
count = 0
@partial(jax.jit, abstracted_axes=('n',))
def f(x):
nonlocal count
count += 1
return x.sum()
d = lax.make_bint(3, 5)
x = jnp.arange(d)
y = f(x)
self.assertEqual(y, 3)
self.assertEqual(count, 1)
d = lax.make_bint(4, 5)
x = jnp.arange(d)
y = f(x)
self.assertEqual(y, 6)
self.assertEqual(count, 1)
d = lax.make_bint(4, 6)
x = jnp.arange(d)
y = f(x)
self.assertEqual(y, 6)
self.assertEqual(count, 2)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())