[shard-map] better errors for not-implemented-in-eager features

This commit is contained in:
Matthew Johnson 2023-04-08 21:12:40 -07:00
parent bdacc32bda
commit 9dabb6fa59
4 changed files with 103 additions and 14 deletions

View File

@ -2441,8 +2441,8 @@ def extend_axis_env(axis_name: AxisName, size: int, tag: Any):
if f.name is not no_axis_name))
@contextmanager
def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]]):
frames = [AxisEnvFrame(axis_name, size, None) for axis_name, size in axes]
def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]], tag: Any = None):
frames = [AxisEnvFrame(axis_name, size, tag) for axis_name, size in axes]
ts = thread_local_state.trace_state
ts.axis_env.extend(frames)
jax_config.update_thread_local_jit_state(

View File

@ -601,13 +601,13 @@ class MapTrace(core.Trace):
def process_call(self, call_primitive, fun, tracers, params):
raise NotImplementedError
def process_map(self, call_primitive, fun, tracers, params):
def process_map(self, map_primitive, fun, tracers, params):
if params['devices'] is not None:
raise ValueError("Nested pmap with explicit devices argument.")
if not config.jax_disable_jit:
bind = HashableFunction(
lambda *args, **kwargs: call_primitive.bind(fun, *args, **kwargs),
(call_primitive, fun))
lambda *args, **kwargs: map_primitive.bind(fun, *args, **kwargs),
(map_primitive, fun))
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
return self.process_primitive(fake_primitive, tracers, params)
axis_name, in_axes, out_axes_thunk, axis_size = (params["axis_name"],
@ -627,12 +627,11 @@ class MapTrace(core.Trace):
for v, s, dst in zip(out, outaxes, out_axes_thunk()))
return map(partial(MapTracer, self), out, outaxes)
def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *,
symbolic_zeros):
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
bind = HashableFunction(
lambda *args, **kwargs: primitive.bind(
lambda *args, **kwargs: prim.bind(
fun, jvp, *args, symbolic_zeros=symbolic_zeros, **kwargs),
(primitive, fun, jvp))
(prim, fun, jvp, symbolic_zeros))
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
return self.process_primitive(fake_primitive, tracers, {})

View File

@ -358,7 +358,7 @@ def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep,
# Staging
def _shard_map_staging(
trace: pe.DynamicJaxprTrace, prim: core.Primitive, fun: lu.WrappedFun,
trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun,
in_tracers: Sequence[pe.DynamicJaxprTracer], *, mesh: Mesh,
in_names: Tuple[AxisNames, ...],
out_names_thunk: Callable[[], Tuple[AxisNames, ...]],
@ -366,9 +366,9 @@ def _shard_map_staging(
) -> Sequence[pe.DynamicJaxprTracer]:
in_avals = [t.aval for t in in_tracers]
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
main = trace.main
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, out_avals_, consts = pe.trace_to_subjaxpr_dynamic(
fun, trace.main, in_avals_)
jaxpr, out_avals_, consts = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
_check_names(out_names_thunk(), out_avals_)
if check_rep:
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
@ -510,7 +510,7 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
args = map(partial(_unmatch_spec, mesh), in_names, args)
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
with core.new_base_main(ShardMapTrace, mesh=mesh, check=check_rep) as main:
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items(), main):
t = main.with_cur_sublevel()
in_tracers = map(partial(ShardMapTracer, t), in_rep, args)
ans = fun.call_wrapped(*in_tracers)
@ -595,7 +595,52 @@ class ShardMapTrace(core.Trace):
return ShardMapTracer(self, out_rep, out_vals)
def process_call(self, call_primitive, fun, tracers, params):
raise NotImplementedError
raise NotImplementedError(
f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't "
"yet supported. Put a `jax.jit` around the `shard_map`-decorated "
"function, and open a feature request at "
"https://github.com/google/jax/issues !")
def process_map(self, map_primitive, fun, tracers, params):
raise NotImplementedError(
"Eager evaluation of `pmap` inside a `shard_map` isn't yet supported."
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/google/jax/issues !")
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
raise NotImplementedError(
"Eager evaluation of a `custom_jvp` inside a `shard_map` isn't yet "
"supported. "
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/google/jax/issues !")
def post_process_custom_jvp_call(self, out_tracers, _):
raise NotImplementedError(
"Eager evaluation of a `custom_jvp` inside a `shard_map` isn't yet "
"supported. "
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/google/jax/issues !")
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
raise NotImplementedError(
"Eager evaluation of a `custom_vjp` inside a `shard_map` isn't yet "
"supported. "
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/google/jax/issues !")
def post_process_custom_vjp_call(self, out_tracers, _):
raise NotImplementedError(
"Eager evaluation of a `custom_vjp` inside a `shard_map` isn't yet "
"supported. "
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/google/jax/issues !")
def process_axis_index(self, frame):
raise NotImplementedError(
"Eager evaluation of an `axis_index` inside a `shard_map` isn't yet "
"supported. "
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
"a feature request at https://github.com/google/jax/issues !")
class ShardMapTracer(core.Tracer):

View File

@ -639,6 +639,51 @@ class ShardMapTest(jtu.JaxTestCase):
shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x)
def test_eager_notimplemented_error_message_custom_jvp(self):
@jax.custom_jvp
def foo(x):
return 2. * x
@foo.defjvp
def foo_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return foo(x), 2. * x_dot
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
with self.assertRaisesRegex(NotImplementedError, 'custom_jvp'):
g(x)
def test_eager_notimplemented_error_message_custom_vjp(self):
@jax.custom_vjp
def foo(x):
return 2. * x
def foo_fwd(x):
return x, None
def foo_bwd(_, y_bar):
return 2. * y_bar,
foo.defvjp(foo_fwd, foo_bwd)
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
with self.assertRaisesRegex(NotImplementedError, 'custom_vjp'):
g(x)
def test_eager_notimplemented_error_message_axis_index(self):
def foo(x):
return x + jax.lax.axis_index('x')
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
with self.assertRaisesRegex(NotImplementedError, 'axis_index'):
g(x)
class FunSpec(NamedTuple):
name: str