From 9dabb6fa59e7bbbefa552bef944afb72c9f1d29d Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sat, 8 Apr 2023 21:12:40 -0700 Subject: [PATCH] [shard-map] better errors for not-implemented-in-eager features --- jax/_src/core.py | 4 +-- jax/_src/interpreters/pxla.py | 13 ++++----- jax/experimental/shard_map.py | 55 +++++++++++++++++++++++++++++++---- tests/shard_map_test.py | 45 ++++++++++++++++++++++++++++ 4 files changed, 103 insertions(+), 14 deletions(-) diff --git a/jax/_src/core.py b/jax/_src/core.py index bff50f793..8d4973c56 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -2441,8 +2441,8 @@ def extend_axis_env(axis_name: AxisName, size: int, tag: Any): if f.name is not no_axis_name)) @contextmanager -def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]]): - frames = [AxisEnvFrame(axis_name, size, None) for axis_name, size in axes] +def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]], tag: Any = None): + frames = [AxisEnvFrame(axis_name, size, tag) for axis_name, size in axes] ts = thread_local_state.trace_state ts.axis_env.extend(frames) jax_config.update_thread_local_jit_state( diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 366c8d06d..e8e81bc11 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -601,13 +601,13 @@ class MapTrace(core.Trace): def process_call(self, call_primitive, fun, tracers, params): raise NotImplementedError - def process_map(self, call_primitive, fun, tracers, params): + def process_map(self, map_primitive, fun, tracers, params): if params['devices'] is not None: raise ValueError("Nested pmap with explicit devices argument.") if not config.jax_disable_jit: bind = HashableFunction( - lambda *args, **kwargs: call_primitive.bind(fun, *args, **kwargs), - (call_primitive, fun)) + lambda *args, **kwargs: map_primitive.bind(fun, *args, **kwargs), + (map_primitive, fun)) fake_primitive = FakePrimitive(multiple_results=True, bind=bind) return self.process_primitive(fake_primitive, tracers, params) axis_name, in_axes, out_axes_thunk, axis_size = (params["axis_name"], @@ -627,12 +627,11 @@ class MapTrace(core.Trace): for v, s, dst in zip(out, outaxes, out_axes_thunk())) return map(partial(MapTracer, self), out, outaxes) - def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, - symbolic_zeros): + def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): bind = HashableFunction( - lambda *args, **kwargs: primitive.bind( + lambda *args, **kwargs: prim.bind( fun, jvp, *args, symbolic_zeros=symbolic_zeros, **kwargs), - (primitive, fun, jvp)) + (prim, fun, jvp, symbolic_zeros)) fake_primitive = FakePrimitive(multiple_results=True, bind=bind) return self.process_primitive(fake_primitive, tracers, {}) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 6ab93d5b7..81fdc7db3 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -358,7 +358,7 @@ def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep, # Staging def _shard_map_staging( - trace: pe.DynamicJaxprTrace, prim: core.Primitive, fun: lu.WrappedFun, + trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun, in_tracers: Sequence[pe.DynamicJaxprTracer], *, mesh: Mesh, in_names: Tuple[AxisNames, ...], out_names_thunk: Callable[[], Tuple[AxisNames, ...]], @@ -366,9 +366,9 @@ def _shard_map_staging( ) -> Sequence[pe.DynamicJaxprTracer]: in_avals = [t.aval for t in in_tracers] in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals) + main = trace.main with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()): - jaxpr, out_avals_, consts = pe.trace_to_subjaxpr_dynamic( - fun, trace.main, in_avals_) + jaxpr, out_avals_, consts = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_) _check_names(out_names_thunk(), out_avals_) if check_rep: in_rep = map(partial(_in_names_to_rep, mesh), in_names) @@ -510,7 +510,7 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, args = map(partial(_unmatch_spec, mesh), in_names, args) in_rep = map(partial(_in_names_to_rep, mesh), in_names) with core.new_base_main(ShardMapTrace, mesh=mesh, check=check_rep) as main: - with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()): + with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items(), main): t = main.with_cur_sublevel() in_tracers = map(partial(ShardMapTracer, t), in_rep, args) ans = fun.call_wrapped(*in_tracers) @@ -595,7 +595,52 @@ class ShardMapTrace(core.Trace): return ShardMapTracer(self, out_rep, out_vals) def process_call(self, call_primitive, fun, tracers, params): - raise NotImplementedError + raise NotImplementedError( + f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't " + "yet supported. Put a `jax.jit` around the `shard_map`-decorated " + "function, and open a feature request at " + "https://github.com/google/jax/issues !") + + def process_map(self, map_primitive, fun, tracers, params): + raise NotImplementedError( + "Eager evaluation of `pmap` inside a `shard_map` isn't yet supported." + "Put a `jax.jit` around the `shard_map`-decorated function, and open " + "a feature request at https://github.com/google/jax/issues !") + + def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): + raise NotImplementedError( + "Eager evaluation of a `custom_jvp` inside a `shard_map` isn't yet " + "supported. " + "Put a `jax.jit` around the `shard_map`-decorated function, and open " + "a feature request at https://github.com/google/jax/issues !") + + def post_process_custom_jvp_call(self, out_tracers, _): + raise NotImplementedError( + "Eager evaluation of a `custom_jvp` inside a `shard_map` isn't yet " + "supported. " + "Put a `jax.jit` around the `shard_map`-decorated function, and open " + "a feature request at https://github.com/google/jax/issues !") + + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): + raise NotImplementedError( + "Eager evaluation of a `custom_vjp` inside a `shard_map` isn't yet " + "supported. " + "Put a `jax.jit` around the `shard_map`-decorated function, and open " + "a feature request at https://github.com/google/jax/issues !") + + def post_process_custom_vjp_call(self, out_tracers, _): + raise NotImplementedError( + "Eager evaluation of a `custom_vjp` inside a `shard_map` isn't yet " + "supported. " + "Put a `jax.jit` around the `shard_map`-decorated function, and open " + "a feature request at https://github.com/google/jax/issues !") + + def process_axis_index(self, frame): + raise NotImplementedError( + "Eager evaluation of an `axis_index` inside a `shard_map` isn't yet " + "supported. " + "Put a `jax.jit` around the `shard_map`-decorated function, and open " + "a feature request at https://github.com/google/jax/issues !") class ShardMapTracer(core.Tracer): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 7ee0ac56b..ccdec63b8 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -639,6 +639,51 @@ class ShardMapTest(jtu.JaxTestCase): shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) + def test_eager_notimplemented_error_message_custom_jvp(self): + @jax.custom_jvp + def foo(x): + return 2. * x + + @foo.defjvp + def foo_jvp(primals, tangents): + (x,), (x_dot,) = primals, tangents + return foo(x), 2. * x_dot + + mesh = jtu.create_global_mesh((4,), ('x',)) + g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) + x = jnp.arange(4.) + with self.assertRaisesRegex(NotImplementedError, 'custom_jvp'): + g(x) + + def test_eager_notimplemented_error_message_custom_vjp(self): + @jax.custom_vjp + def foo(x): + return 2. * x + + def foo_fwd(x): + return x, None + + def foo_bwd(_, y_bar): + return 2. * y_bar, + + foo.defvjp(foo_fwd, foo_bwd) + + mesh = jtu.create_global_mesh((4,), ('x',)) + g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) + x = jnp.arange(4.) + with self.assertRaisesRegex(NotImplementedError, 'custom_vjp'): + g(x) + + def test_eager_notimplemented_error_message_axis_index(self): + def foo(x): + return x + jax.lax.axis_index('x') + + mesh = jtu.create_global_mesh((4,), ('x',)) + g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) + x = jnp.arange(4.) + with self.assertRaisesRegex(NotImplementedError, 'axis_index'): + g(x) + class FunSpec(NamedTuple): name: str