mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
resolve opaque dtypes in MLIR callback lowering and in XLA shape translation
This commit is contained in:
parent
63d87c6c3d
commit
8d4d520933
@ -1671,9 +1671,13 @@ def _layout_to_mlir_layout(minor_to_major: Optional[Sequence[int]]):
|
||||
layout = np.array(minor_to_major, dtype="int64")
|
||||
return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get())
|
||||
|
||||
def _aval_to_default_layout(aval):
|
||||
def _aval_to_default_layouts(aval):
|
||||
if core.is_opaque_dtype(aval.dtype):
|
||||
avals = aval.dtype._rules.physical_avals(aval)
|
||||
else:
|
||||
avals = [aval]
|
||||
# Row major order is default for `NumPy`.
|
||||
return list(range(aval.ndim - 1, -1, -1))
|
||||
return [list(range(aval.ndim - 1, -1, -1)) for aval in avals]
|
||||
|
||||
def emit_python_callback(
|
||||
ctx: LoweringRuleContext, callback, token: Optional[Any],
|
||||
@ -1695,17 +1699,12 @@ def emit_python_callback(
|
||||
[xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals])
|
||||
# Handling layouts
|
||||
if operand_layouts is None:
|
||||
operand_layouts = map(_aval_to_default_layout, operand_avals)
|
||||
operand_mlir_layouts = [
|
||||
_layout_to_mlir_layout(_aval_to_default_layout(layout)) if layout is None
|
||||
else _layout_to_mlir_layout(layout) for layout, aval
|
||||
in zip(operand_layouts, operand_avals)]
|
||||
operand_layouts = util.concatenate(
|
||||
map(_aval_to_default_layouts, operand_avals))
|
||||
operand_mlir_layouts = map(_layout_to_mlir_layout, operand_layouts)
|
||||
if result_layouts is None:
|
||||
result_layouts = map(_aval_to_default_layout, result_avals)
|
||||
result_mlir_layouts = [
|
||||
_layout_to_mlir_layout(_aval_to_default_layout(aval)) if layout is None
|
||||
else _layout_to_mlir_layout(layout) for layout, aval
|
||||
in zip(result_layouts, result_avals)]
|
||||
result_layouts = util.concatenate(map(_aval_to_default_layouts, result_avals))
|
||||
result_mlir_layouts = map(_layout_to_mlir_layout, result_layouts)
|
||||
|
||||
# First we apply checks to ensure output shapes and dtypes match the expected
|
||||
# ones.
|
||||
|
@ -54,11 +54,15 @@ def identity(x): return x
|
||||
|
||||
_scalar_types = dtypes.python_scalar_dtypes.keys()
|
||||
|
||||
def _make_array_shape(a: ShapedArray) -> Sequence[xc.Shape]:
|
||||
if a.dtype == dtypes.float0:
|
||||
return (xc.Shape.array_shape(np.dtype('bool'), a.shape),)
|
||||
def _make_array_shape(aval: ShapedArray) -> Sequence[xc.Shape]:
|
||||
def dt(aval):
|
||||
return np.dtype('bool') if aval.dtype == dtypes.float0 else aval.dtype
|
||||
|
||||
if core.is_opaque_dtype(aval.dtype):
|
||||
avals = aval.dtype._rules.physical_avals(aval)
|
||||
else:
|
||||
return (xc.Shape.array_shape(a.dtype, a.shape),)
|
||||
avals = [aval]
|
||||
return tuple(xc.Shape.array_shape(dt(a), a.shape) for a in avals)
|
||||
|
||||
def get_canonical_source_file(frame: source_info_util.Frame):
|
||||
source_file = frame.file_name
|
||||
|
@ -1908,6 +1908,15 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, r"Support for \*\*kwargs in ``id_tap``"):
|
||||
hcb.id_tap(func, 1, y=2)
|
||||
|
||||
def test_tap_id_tap_random_key(self):
|
||||
# See https://github.com/google/jax/issues/13949
|
||||
with jax.enable_custom_prng():
|
||||
@jax.jit
|
||||
def f(x):
|
||||
def tap(tap_x, _): pass
|
||||
return hcb.id_tap(tap, x, result=x)
|
||||
f(jax.random.PRNGKey(123))
|
||||
|
||||
def test_tap_odeint(self):
|
||||
# TODO: find a smaller repro for bug #4015
|
||||
# Seems to be xla_call(scan(xla_call)), all under grad.
|
||||
|
Loading…
x
Reference in New Issue
Block a user