resolve opaque dtypes in MLIR callback lowering and in XLA shape translation

This commit is contained in:
Roy Frostig 2023-05-01 07:53:40 -07:00
parent 63d87c6c3d
commit 8d4d520933
3 changed files with 28 additions and 16 deletions

View File

@ -1671,9 +1671,13 @@ def _layout_to_mlir_layout(minor_to_major: Optional[Sequence[int]]):
layout = np.array(minor_to_major, dtype="int64")
return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get())
def _aval_to_default_layout(aval):
def _aval_to_default_layouts(aval):
if core.is_opaque_dtype(aval.dtype):
avals = aval.dtype._rules.physical_avals(aval)
else:
avals = [aval]
# Row major order is default for `NumPy`.
return list(range(aval.ndim - 1, -1, -1))
return [list(range(aval.ndim - 1, -1, -1)) for aval in avals]
def emit_python_callback(
ctx: LoweringRuleContext, callback, token: Optional[Any],
@ -1695,17 +1699,12 @@ def emit_python_callback(
[xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals])
# Handling layouts
if operand_layouts is None:
operand_layouts = map(_aval_to_default_layout, operand_avals)
operand_mlir_layouts = [
_layout_to_mlir_layout(_aval_to_default_layout(layout)) if layout is None
else _layout_to_mlir_layout(layout) for layout, aval
in zip(operand_layouts, operand_avals)]
operand_layouts = util.concatenate(
map(_aval_to_default_layouts, operand_avals))
operand_mlir_layouts = map(_layout_to_mlir_layout, operand_layouts)
if result_layouts is None:
result_layouts = map(_aval_to_default_layout, result_avals)
result_mlir_layouts = [
_layout_to_mlir_layout(_aval_to_default_layout(aval)) if layout is None
else _layout_to_mlir_layout(layout) for layout, aval
in zip(result_layouts, result_avals)]
result_layouts = util.concatenate(map(_aval_to_default_layouts, result_avals))
result_mlir_layouts = map(_layout_to_mlir_layout, result_layouts)
# First we apply checks to ensure output shapes and dtypes match the expected
# ones.

View File

@ -54,11 +54,15 @@ def identity(x): return x
_scalar_types = dtypes.python_scalar_dtypes.keys()
def _make_array_shape(a: ShapedArray) -> Sequence[xc.Shape]:
if a.dtype == dtypes.float0:
return (xc.Shape.array_shape(np.dtype('bool'), a.shape),)
def _make_array_shape(aval: ShapedArray) -> Sequence[xc.Shape]:
def dt(aval):
return np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype
if core.is_opaque_dtype(aval.dtype):
avals = aval.dtype._rules.physical_avals(aval)
else:
return (xc.Shape.array_shape(a.dtype, a.shape),)
avals = [aval]
return tuple(xc.Shape.array_shape(dt(a), a.shape) for a in avals)
def get_canonical_source_file(frame: source_info_util.Frame):
source_file = frame.file_name

View File

@ -1908,6 +1908,15 @@ class HostCallbackTapTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, r"Support for \*\*kwargs in ``id_tap``"):
hcb.id_tap(func, 1, y=2)
def test_tap_id_tap_random_key(self):
# See https://github.com/google/jax/issues/13949
with jax.enable_custom_prng():
@jax.jit
def f(x):
def tap(tap_x, _): pass
return hcb.id_tap(tap, x, result=x)
f(jax.random.PRNGKey(123))
def test_tap_odeint(self):
# TODO: find a smaller repro for bug #4015
# Seems to be xla_call(scan(xla_call)), all under grad.