Fix host_callback for pjit which was using REPLICATED which was a CanonicalizedParsedPspec

PiperOrigin-RevId: 501713533
This commit is contained in:
Yash Katariya 2023-01-12 17:59:58 -08:00 committed by jax authors
parent 936247a7e5
commit 94f0ccc54a
3 changed files with 5 additions and 6 deletions

View File

@ -936,9 +936,6 @@ class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec):
f"sync={self.sync})")
REPLICATED = CanonicalizedParsedPartitionSpec(ParsedPartitionSpec(None, ()))
def _prepare_axis_resources(axis_resources,
arg_name,
allow_unconstrained_dims=False):
@ -1997,7 +1994,7 @@ def parse_flatten_op_sharding(op_sharding: xc.OpSharding,
out.extend(parse_flatten_op_sharding(s, mesh))
return out
elif op_sharding.type == xc.OpSharding.Type.REPLICATED:
return [REPLICATED]
return [CanonicalizedParsedPartitionSpec(ParsedPartitionSpec(None, ()))]
elif op_sharding.type == xc.OpSharding.Type.OTHER:
mesh_shape = mesh.shape
mesh_axis_order = unflatten_array(mesh.shape, op_sharding.tile_assignment_devices)

View File

@ -1716,9 +1716,9 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
jaxpr=_rewrite_closed_jaxpr(jaxpr, True, True),
donated_invars=eqn.params["donated_invars"] + (False, False),
in_shardings=(eqn.params["in_shardings"] +
(pjit.REPLICATED, pjit.REPLICATED)),
(pjit._UNSPECIFIED, pjit._UNSPECIFIED)),
out_shardings=(eqn.params["out_shardings"] +
(pjit.REPLICATED, pjit.REPLICATED)),
(pjit._UNSPECIFIED, pjit._UNSPECIFIED)),
)))
elif eqn.primitive is ad_checkpoint.remat_p:
jaxpr_ = cast(core.Jaxpr, eqn.params["jaxpr"])

View File

@ -883,6 +883,7 @@ jax_test(
name = "host_callback_test",
srcs = ["host_callback_test.py"],
args = ["--jax_host_callback_outfeed=true"],
enable_configs = ["cpu_jit_pjit_api_merge"],
pjrt_c_api_bypass = True,
deps = [
"//jax:experimental",
@ -899,6 +900,7 @@ jax_test(
"gpu",
"tpu", # On TPU we always use outfeed
],
enable_configs = ["cpu_jit_pjit_api_merge"],
main = "host_callback_test.py",
shard_count = {
"gpu": 5,