mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Enable debugging primitives in pjit
on CPU/GPU
PiperOrigin-RevId: 464208326
This commit is contained in:
parent
2109c6ec8c
commit
11b206a18a
@ -128,8 +128,7 @@ jax.grad(f)(1.)
|
||||
|
||||
#### Printing in other transformations
|
||||
|
||||
`jax.debug.print` also works in other transformations like `xmap` and `pjit`
|
||||
(but `pjit` only works on TPUs for now).
|
||||
`jax.debug.print` also works in other transformations like `xmap` and `pjit`.
|
||||
|
||||
### More control with `jax.debug.callback`
|
||||
|
||||
|
@ -1430,21 +1430,21 @@ def emit_python_callback(
|
||||
[xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals])
|
||||
operand_shapes = util.flatten(
|
||||
[xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals])
|
||||
if isinstance(ctx.module_context.axis_context,
|
||||
(SPMDAxisContext, ShardingContext)):
|
||||
# Apply maximal sharding so pjit only executes the callback on device 0.
|
||||
sharding = xc.OpSharding()
|
||||
sharding.type = xc.OpSharding.Type.MAXIMAL
|
||||
sharding.tile_assignment_dimensions = [1]
|
||||
sharding.tile_assignment_devices = [0]
|
||||
else:
|
||||
sharding = None
|
||||
if platform == "tpu":
|
||||
if result_avals:
|
||||
raise NotImplementedError(
|
||||
"Callback with return values not supported on TPU.")
|
||||
token = token or mhlo.CreateTokenOp(mhlo.TokenType.get()).result
|
||||
send_channels = []
|
||||
if isinstance(ctx.module_context.axis_context,
|
||||
(SPMDAxisContext, ShardingContext)):
|
||||
# Apply maximal sharding so pjit only executes the callback on device 0.
|
||||
sharding = xc.OpSharding()
|
||||
sharding.type = xc.OpSharding.Type.MAXIMAL
|
||||
sharding.tile_assignment_dimensions = [1]
|
||||
sharding.tile_assignment_devices = [0]
|
||||
else:
|
||||
sharding = None
|
||||
for operand, operand_aval in zip(operands, operand_avals):
|
||||
channel = ctx.module_context.new_channel()
|
||||
token = send_to_host(channel, token, operand, operand_aval,
|
||||
@ -1509,6 +1509,8 @@ def emit_python_callback(
|
||||
backend_config=ir.StringAttr.get(str(callback_descriptor)),
|
||||
operand_layouts=None,
|
||||
result_layouts=None)
|
||||
if sharding is not None:
|
||||
set_sharding(result, sharding)
|
||||
results = [
|
||||
mhlo.GetTupleElementOp(result, i32_attr(i)).result
|
||||
for i in range(len(result_types))
|
||||
|
@ -504,8 +504,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices(*disabled_backends)
|
||||
def test_unordered_print_with_pjit(self):
|
||||
if jax.default_backend() != "tpu":
|
||||
raise unittest.SkipTest("`pjit` doesn't work with CustomCall.")
|
||||
|
||||
def f(x):
|
||||
debug_print("{}", x, ordered=False)
|
||||
@ -532,8 +530,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices(*disabled_backends)
|
||||
def test_unordered_print_of_pjit_of_while(self):
|
||||
if jax.default_backend() != "tpu":
|
||||
raise unittest.SkipTest("`pjit` doesn't work with CustomCall.")
|
||||
|
||||
def f(x):
|
||||
def cond(carry):
|
||||
@ -560,8 +556,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices(*disabled_backends)
|
||||
def test_unordered_print_of_pjit_of_xmap(self):
|
||||
if jax.default_backend() != "tpu":
|
||||
raise unittest.SkipTest("`pjit` doesn't work with CustomCall.")
|
||||
|
||||
def f(x):
|
||||
def foo(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user