Enable debugging primitives in pjit on CPU/GPU

PiperOrigin-RevId: 464208326
This commit is contained in:
Sharad Vikram 2022-07-29 20:10:01 -07:00 committed by jax authors
parent 2109c6ec8c
commit 11b206a18a
3 changed files with 12 additions and 17 deletions

View File

@ -128,8 +128,7 @@ jax.grad(f)(1.)
#### Printing in other transformations
`jax.debug.print` also works in other transformations like `xmap` and `pjit`
(but `pjit` only works on TPUs for now).
`jax.debug.print` also works in other transformations like `xmap` and `pjit`.
### More control with `jax.debug.callback`

View File

@ -1430,21 +1430,21 @@ def emit_python_callback(
[xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals])
operand_shapes = util.flatten(
[xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals])
if isinstance(ctx.module_context.axis_context,
(SPMDAxisContext, ShardingContext)):
# Apply maximal sharding so pjit only executes the callback on device 0.
sharding = xc.OpSharding()
sharding.type = xc.OpSharding.Type.MAXIMAL
sharding.tile_assignment_dimensions = [1]
sharding.tile_assignment_devices = [0]
else:
sharding = None
if platform == "tpu":
if result_avals:
raise NotImplementedError(
"Callback with return values not supported on TPU.")
token = token or mhlo.CreateTokenOp(mhlo.TokenType.get()).result
send_channels = []
if isinstance(ctx.module_context.axis_context,
(SPMDAxisContext, ShardingContext)):
# Apply maximal sharding so pjit only executes the callback on device 0.
sharding = xc.OpSharding()
sharding.type = xc.OpSharding.Type.MAXIMAL
sharding.tile_assignment_dimensions = [1]
sharding.tile_assignment_devices = [0]
else:
sharding = None
for operand, operand_aval in zip(operands, operand_avals):
channel = ctx.module_context.new_channel()
token = send_to_host(channel, token, operand, operand_aval,
@ -1509,6 +1509,8 @@ def emit_python_callback(
backend_config=ir.StringAttr.get(str(callback_descriptor)),
operand_layouts=None,
result_layouts=None)
if sharding is not None:
set_sharding(result, sharding)
results = [
mhlo.GetTupleElementOp(result, i32_attr(i)).result
for i in range(len(result_types))

View File

@ -504,8 +504,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
@jtu.skip_on_devices(*disabled_backends)
def test_unordered_print_with_pjit(self):
if jax.default_backend() != "tpu":
raise unittest.SkipTest("`pjit` doesn't work with CustomCall.")
def f(x):
debug_print("{}", x, ordered=False)
@ -532,8 +530,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
@jtu.skip_on_devices(*disabled_backends)
def test_unordered_print_of_pjit_of_while(self):
if jax.default_backend() != "tpu":
raise unittest.SkipTest("`pjit` doesn't work with CustomCall.")
def f(x):
def cond(carry):
@ -560,8 +556,6 @@ class DebugPrintParallelTest(jtu.JaxTestCase):
@jtu.skip_on_devices(*disabled_backends)
def test_unordered_print_of_pjit_of_xmap(self):
if jax.default_backend() != "tpu":
raise unittest.SkipTest("`pjit` doesn't work with CustomCall.")
def f(x):
def foo(x):