[export] Fix calling under pmap of exported computation with polymorphic shapes

If we call a computation with shape polymorphism under pmap
we must refine the shapes before we compile.
We follow the same pattern for `UnloadedPmapExecutable` as
for `UnloadedMeshExecutable`: we store the `shape_poly_state`
from the `LoweringResult` into the `compile_args` and we
call `refine_polymorphic_shapes`.

Without this fix we may end up trying to compile HLO with
dynamic shapes.
This commit is contained in:
George Necula 2024-05-28 16:15:30 +03:00
parent 72b111afe6
commit acb56a2909
2 changed files with 22 additions and 2 deletions

View File

@ -793,7 +793,8 @@ def lower_parallel_callable(
ordered_effects=ordered_effects,
keepalive=lowering_result.keepalive,
host_callbacks=lowering_result.host_callbacks,
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info)
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
shape_poly_state=lowering_result.shape_poly_state)
def _pmap_unmap_shaped_array(
@ -906,7 +907,10 @@ class UnloadedPmapExecutable:
host_callbacks: list[Any],
keepalive: Any,
jaxpr_debug_info: core.JaxprDebugInfo,
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
compiler_options=None):
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
hlo = mlir.refine_polymorphic_shapes(hlo)
devices = pci.devices
if devices is None:
if shards.num_global_shards > xb.device_count(pci.backend):

View File

@ -848,6 +848,23 @@ class JaxExportTest(jtu.JaxTestCase):
a = exp2.in_avals[0].shape[0]
self.assertEqual(exp2.out_avals[0].shape, output_shape(a))
def test_poly_call_pmap(self):
if len(jax.devices()) < 2:
self.skipTest("Need at least 2 devices")
def f(x): # x: f32[a, 4]
return x + jnp.arange(x.shape[0], dtype=x.dtype).reshape((x.shape[0], 1))
a, = export.symbolic_shape("a")
exp = export.export(f)(
jax.ShapeDtypeStruct((a, 4), np.float32))
f_exp = export.call_exported(exp)
x_jit = np.arange(12, dtype=np.float32).reshape((3, 4))
res_jit = jax.jit(f_exp)(x_jit)
self.assertAllClose(res_jit, f(x_jit))
x_pmap = np.arange(24, dtype=np.float32).reshape((2, 3, 4))
res_pmap = jax.pmap(f_exp)(x_pmap)
self.assertAllClose(res_pmap, jnp.stack([f(x) for x in x_pmap]))
def test_with_sharding(self):
nr_devices = 2
if len(jax.devices()) < nr_devices:
@ -1204,7 +1221,6 @@ class JaxExportTest(jtu.JaxTestCase):
g_rev = jax.grad(export.call(exp))(input)
self.assertAllClose(g, g_rev)
def test_multi_platform(self):
x = np.arange(8, dtype=np.float32)
exp = get_exported(_testing_multi_platform_func,