mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[export] Fix calling under pmap of exported computation with polymorphic shapes
If we call a computation with shape polymorphism under pmap we must refine the shapes before we compile. We follow the same pattern for `UnloadedPmapExecutable` as for `UnloadedMeshExecutable`: we store the `shape_poly_state` from the `LoweringResult` into the `compile_args` and we call `refine_polymorphic_shapes`. Without this fix we may end up trying to compile HLO with dynamic shapes.
This commit is contained in:
parent
72b111afe6
commit
acb56a2909
@ -793,7 +793,8 @@ def lower_parallel_callable(
|
||||
ordered_effects=ordered_effects,
|
||||
keepalive=lowering_result.keepalive,
|
||||
host_callbacks=lowering_result.host_callbacks,
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info)
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
|
||||
shape_poly_state=lowering_result.shape_poly_state)
|
||||
|
||||
|
||||
def _pmap_unmap_shaped_array(
|
||||
@ -906,7 +907,10 @@ class UnloadedPmapExecutable:
|
||||
host_callbacks: list[Any],
|
||||
keepalive: Any,
|
||||
jaxpr_debug_info: core.JaxprDebugInfo,
|
||||
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
|
||||
compiler_options=None):
|
||||
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
|
||||
hlo = mlir.refine_polymorphic_shapes(hlo)
|
||||
devices = pci.devices
|
||||
if devices is None:
|
||||
if shards.num_global_shards > xb.device_count(pci.backend):
|
||||
|
@ -848,6 +848,23 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
a = exp2.in_avals[0].shape[0]
|
||||
self.assertEqual(exp2.out_avals[0].shape, output_shape(a))
|
||||
|
||||
def test_poly_call_pmap(self):
|
||||
if len(jax.devices()) < 2:
|
||||
self.skipTest("Need at least 2 devices")
|
||||
def f(x): # x: f32[a, 4]
|
||||
return x + jnp.arange(x.shape[0], dtype=x.dtype).reshape((x.shape[0], 1))
|
||||
|
||||
a, = export.symbolic_shape("a")
|
||||
exp = export.export(f)(
|
||||
jax.ShapeDtypeStruct((a, 4), np.float32))
|
||||
f_exp = export.call_exported(exp)
|
||||
x_jit = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
res_jit = jax.jit(f_exp)(x_jit)
|
||||
self.assertAllClose(res_jit, f(x_jit))
|
||||
x_pmap = np.arange(24, dtype=np.float32).reshape((2, 3, 4))
|
||||
res_pmap = jax.pmap(f_exp)(x_pmap)
|
||||
self.assertAllClose(res_pmap, jnp.stack([f(x) for x in x_pmap]))
|
||||
|
||||
def test_with_sharding(self):
|
||||
nr_devices = 2
|
||||
if len(jax.devices()) < nr_devices:
|
||||
@ -1204,7 +1221,6 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
g_rev = jax.grad(export.call(exp))(input)
|
||||
self.assertAllClose(g, g_rev)
|
||||
|
||||
|
||||
def test_multi_platform(self):
|
||||
x = np.arange(8, dtype=np.float32)
|
||||
exp = get_exported(_testing_multi_platform_func,
|
||||
|
Loading…
x
Reference in New Issue
Block a user