From acb56a2909bf3d4193703c8eef3efee2c975f42d Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 28 May 2024 16:15:30 +0300 Subject: [PATCH] [export] Fix calling under pmap of exported computation with polymorphic shapes If we call a computation with shape polymorphism under pmap we must refine the shapes before we compile. We follow the same pattern for `UnloadedPmapExecutable` as for `UnloadedMeshExecutable`: we store the `shape_poly_state` from the `LoweringResult` into the `compile_args` and we call `refine_polymorphic_shapes`. Without this fix we may end up trying to compile HLO with dynamic shapes. --- jax/_src/interpreters/pxla.py | 6 +++++- tests/export_test.py | 18 +++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index eee0d4f04..3fb6e41eb 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -793,7 +793,8 @@ def lower_parallel_callable( ordered_effects=ordered_effects, keepalive=lowering_result.keepalive, host_callbacks=lowering_result.host_callbacks, - jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info) + jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info, + shape_poly_state=lowering_result.shape_poly_state) def _pmap_unmap_shaped_array( @@ -906,7 +907,10 @@ class UnloadedPmapExecutable: host_callbacks: list[Any], keepalive: Any, jaxpr_debug_info: core.JaxprDebugInfo, + shape_poly_state: mlir.ShapePolyLoweringState | None = None, compiler_options=None): + if shape_poly_state is not None and shape_poly_state.uses_dim_vars: + hlo = mlir.refine_polymorphic_shapes(hlo) devices = pci.devices if devices is None: if shards.num_global_shards > xb.device_count(pci.backend): diff --git a/tests/export_test.py b/tests/export_test.py index 96dd1d6ac..bb025d6b7 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -848,6 +848,23 @@ class JaxExportTest(jtu.JaxTestCase): a = exp2.in_avals[0].shape[0] self.assertEqual(exp2.out_avals[0].shape, output_shape(a)) + def test_poly_call_pmap(self): + if len(jax.devices()) < 2: + self.skipTest("Need at least 2 devices") + def f(x): # x: f32[a, 4] + return x + jnp.arange(x.shape[0], dtype=x.dtype).reshape((x.shape[0], 1)) + + a, = export.symbolic_shape("a") + exp = export.export(f)( + jax.ShapeDtypeStruct((a, 4), np.float32)) + f_exp = export.call_exported(exp) + x_jit = np.arange(12, dtype=np.float32).reshape((3, 4)) + res_jit = jax.jit(f_exp)(x_jit) + self.assertAllClose(res_jit, f(x_jit)) + x_pmap = np.arange(24, dtype=np.float32).reshape((2, 3, 4)) + res_pmap = jax.pmap(f_exp)(x_pmap) + self.assertAllClose(res_pmap, jnp.stack([f(x) for x in x_pmap])) + def test_with_sharding(self): nr_devices = 2 if len(jax.devices()) < nr_devices: @@ -1204,7 +1221,6 @@ class JaxExportTest(jtu.JaxTestCase): g_rev = jax.grad(export.call(exp))(input) self.assertAllClose(g, g_rev) - def test_multi_platform(self): x = np.arange(8, dtype=np.float32) exp = get_exported(_testing_multi_platform_func,