mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix host_callback for pjit which was using REPLICATED which was a CanonicalizedParsedPspec
PiperOrigin-RevId: 501713533
This commit is contained in:
parent
936247a7e5
commit
94f0ccc54a
@ -936,9 +936,6 @@ class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec):
|
||||
f"sync={self.sync})")
|
||||
|
||||
|
||||
REPLICATED = CanonicalizedParsedPartitionSpec(ParsedPartitionSpec(None, ()))
|
||||
|
||||
|
||||
def _prepare_axis_resources(axis_resources,
|
||||
arg_name,
|
||||
allow_unconstrained_dims=False):
|
||||
@ -1997,7 +1994,7 @@ def parse_flatten_op_sharding(op_sharding: xc.OpSharding,
|
||||
out.extend(parse_flatten_op_sharding(s, mesh))
|
||||
return out
|
||||
elif op_sharding.type == xc.OpSharding.Type.REPLICATED:
|
||||
return [REPLICATED]
|
||||
return [CanonicalizedParsedPartitionSpec(ParsedPartitionSpec(None, ()))]
|
||||
elif op_sharding.type == xc.OpSharding.Type.OTHER:
|
||||
mesh_shape = mesh.shape
|
||||
mesh_axis_order = unflatten_array(mesh.shape, op_sharding.tile_assignment_devices)
|
||||
|
@ -1716,9 +1716,9 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
jaxpr=_rewrite_closed_jaxpr(jaxpr, True, True),
|
||||
donated_invars=eqn.params["donated_invars"] + (False, False),
|
||||
in_shardings=(eqn.params["in_shardings"] +
|
||||
(pjit.REPLICATED, pjit.REPLICATED)),
|
||||
(pjit._UNSPECIFIED, pjit._UNSPECIFIED)),
|
||||
out_shardings=(eqn.params["out_shardings"] +
|
||||
(pjit.REPLICATED, pjit.REPLICATED)),
|
||||
(pjit._UNSPECIFIED, pjit._UNSPECIFIED)),
|
||||
)))
|
||||
elif eqn.primitive is ad_checkpoint.remat_p:
|
||||
jaxpr_ = cast(core.Jaxpr, eqn.params["jaxpr"])
|
||||
|
@ -883,6 +883,7 @@ jax_test(
|
||||
name = "host_callback_test",
|
||||
srcs = ["host_callback_test.py"],
|
||||
args = ["--jax_host_callback_outfeed=true"],
|
||||
enable_configs = ["cpu_jit_pjit_api_merge"],
|
||||
pjrt_c_api_bypass = True,
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
@ -899,6 +900,7 @@ jax_test(
|
||||
"gpu",
|
||||
"tpu", # On TPU we always use outfeed
|
||||
],
|
||||
enable_configs = ["cpu_jit_pjit_api_merge"],
|
||||
main = "host_callback_test.py",
|
||||
shard_count = {
|
||||
"gpu": 5,
|
||||
|
Loading…
x
Reference in New Issue
Block a user