mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[shard-map] better errors for not-implemented-in-eager features
This commit is contained in:
parent
bdacc32bda
commit
9dabb6fa59
@ -2441,8 +2441,8 @@ def extend_axis_env(axis_name: AxisName, size: int, tag: Any):
|
||||
if f.name is not no_axis_name))
|
||||
|
||||
@contextmanager
|
||||
def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]]):
|
||||
frames = [AxisEnvFrame(axis_name, size, None) for axis_name, size in axes]
|
||||
def extend_axis_env_nd(axes: Iterable[Tuple[AxisName, int]], tag: Any = None):
|
||||
frames = [AxisEnvFrame(axis_name, size, tag) for axis_name, size in axes]
|
||||
ts = thread_local_state.trace_state
|
||||
ts.axis_env.extend(frames)
|
||||
jax_config.update_thread_local_jit_state(
|
||||
|
@ -601,13 +601,13 @@ class MapTrace(core.Trace):
|
||||
def process_call(self, call_primitive, fun, tracers, params):
|
||||
raise NotImplementedError
|
||||
|
||||
def process_map(self, call_primitive, fun, tracers, params):
|
||||
def process_map(self, map_primitive, fun, tracers, params):
|
||||
if params['devices'] is not None:
|
||||
raise ValueError("Nested pmap with explicit devices argument.")
|
||||
if not config.jax_disable_jit:
|
||||
bind = HashableFunction(
|
||||
lambda *args, **kwargs: call_primitive.bind(fun, *args, **kwargs),
|
||||
(call_primitive, fun))
|
||||
lambda *args, **kwargs: map_primitive.bind(fun, *args, **kwargs),
|
||||
(map_primitive, fun))
|
||||
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
|
||||
return self.process_primitive(fake_primitive, tracers, params)
|
||||
axis_name, in_axes, out_axes_thunk, axis_size = (params["axis_name"],
|
||||
@ -627,12 +627,11 @@ class MapTrace(core.Trace):
|
||||
for v, s, dst in zip(out, outaxes, out_axes_thunk()))
|
||||
return map(partial(MapTracer, self), out, outaxes)
|
||||
|
||||
def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *,
|
||||
symbolic_zeros):
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
|
||||
bind = HashableFunction(
|
||||
lambda *args, **kwargs: primitive.bind(
|
||||
lambda *args, **kwargs: prim.bind(
|
||||
fun, jvp, *args, symbolic_zeros=symbolic_zeros, **kwargs),
|
||||
(primitive, fun, jvp))
|
||||
(prim, fun, jvp, symbolic_zeros))
|
||||
fake_primitive = FakePrimitive(multiple_results=True, bind=bind)
|
||||
return self.process_primitive(fake_primitive, tracers, {})
|
||||
|
||||
|
@ -358,7 +358,7 @@ def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep,
|
||||
# Staging
|
||||
|
||||
def _shard_map_staging(
|
||||
trace: pe.DynamicJaxprTrace, prim: core.Primitive, fun: lu.WrappedFun,
|
||||
trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun,
|
||||
in_tracers: Sequence[pe.DynamicJaxprTracer], *, mesh: Mesh,
|
||||
in_names: Tuple[AxisNames, ...],
|
||||
out_names_thunk: Callable[[], Tuple[AxisNames, ...]],
|
||||
@ -366,9 +366,9 @@ def _shard_map_staging(
|
||||
) -> Sequence[pe.DynamicJaxprTracer]:
|
||||
in_avals = [t.aval for t in in_tracers]
|
||||
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
|
||||
main = trace.main
|
||||
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr, out_avals_, consts = pe.trace_to_subjaxpr_dynamic(
|
||||
fun, trace.main, in_avals_)
|
||||
jaxpr, out_avals_, consts = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
|
||||
_check_names(out_names_thunk(), out_avals_)
|
||||
if check_rep:
|
||||
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
|
||||
@ -510,7 +510,7 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
|
||||
args = map(partial(_unmatch_spec, mesh), in_names, args)
|
||||
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
|
||||
with core.new_base_main(ShardMapTrace, mesh=mesh, check=check_rep) as main:
|
||||
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
|
||||
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items(), main):
|
||||
t = main.with_cur_sublevel()
|
||||
in_tracers = map(partial(ShardMapTracer, t), in_rep, args)
|
||||
ans = fun.call_wrapped(*in_tracers)
|
||||
@ -595,7 +595,52 @@ class ShardMapTrace(core.Trace):
|
||||
return ShardMapTracer(self, out_rep, out_vals)
|
||||
|
||||
def process_call(self, call_primitive, fun, tracers, params):
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError(
|
||||
f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't "
|
||||
"yet supported. Put a `jax.jit` around the `shard_map`-decorated "
|
||||
"function, and open a feature request at "
|
||||
"https://github.com/google/jax/issues !")
|
||||
|
||||
def process_map(self, map_primitive, fun, tracers, params):
|
||||
raise NotImplementedError(
|
||||
"Eager evaluation of `pmap` inside a `shard_map` isn't yet supported."
|
||||
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
|
||||
"a feature request at https://github.com/google/jax/issues !")
|
||||
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
|
||||
raise NotImplementedError(
|
||||
"Eager evaluation of a `custom_jvp` inside a `shard_map` isn't yet "
|
||||
"supported. "
|
||||
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
|
||||
"a feature request at https://github.com/google/jax/issues !")
|
||||
|
||||
def post_process_custom_jvp_call(self, out_tracers, _):
|
||||
raise NotImplementedError(
|
||||
"Eager evaluation of a `custom_jvp` inside a `shard_map` isn't yet "
|
||||
"supported. "
|
||||
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
|
||||
"a feature request at https://github.com/google/jax/issues !")
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
|
||||
raise NotImplementedError(
|
||||
"Eager evaluation of a `custom_vjp` inside a `shard_map` isn't yet "
|
||||
"supported. "
|
||||
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
|
||||
"a feature request at https://github.com/google/jax/issues !")
|
||||
|
||||
def post_process_custom_vjp_call(self, out_tracers, _):
|
||||
raise NotImplementedError(
|
||||
"Eager evaluation of a `custom_vjp` inside a `shard_map` isn't yet "
|
||||
"supported. "
|
||||
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
|
||||
"a feature request at https://github.com/google/jax/issues !")
|
||||
|
||||
def process_axis_index(self, frame):
|
||||
raise NotImplementedError(
|
||||
"Eager evaluation of an `axis_index` inside a `shard_map` isn't yet "
|
||||
"supported. "
|
||||
"Put a `jax.jit` around the `shard_map`-decorated function, and open "
|
||||
"a feature request at https://github.com/google/jax/issues !")
|
||||
|
||||
|
||||
class ShardMapTracer(core.Tracer):
|
||||
|
@ -639,6 +639,51 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))),
|
||||
out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x)
|
||||
|
||||
def test_eager_notimplemented_error_message_custom_jvp(self):
|
||||
@jax.custom_jvp
|
||||
def foo(x):
|
||||
return 2. * x
|
||||
|
||||
@foo.defjvp
|
||||
def foo_jvp(primals, tangents):
|
||||
(x,), (x_dot,) = primals, tangents
|
||||
return foo(x), 2. * x_dot
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x'))
|
||||
x = jnp.arange(4.)
|
||||
with self.assertRaisesRegex(NotImplementedError, 'custom_jvp'):
|
||||
g(x)
|
||||
|
||||
def test_eager_notimplemented_error_message_custom_vjp(self):
|
||||
@jax.custom_vjp
|
||||
def foo(x):
|
||||
return 2. * x
|
||||
|
||||
def foo_fwd(x):
|
||||
return x, None
|
||||
|
||||
def foo_bwd(_, y_bar):
|
||||
return 2. * y_bar,
|
||||
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x'))
|
||||
x = jnp.arange(4.)
|
||||
with self.assertRaisesRegex(NotImplementedError, 'custom_vjp'):
|
||||
g(x)
|
||||
|
||||
def test_eager_notimplemented_error_message_axis_index(self):
|
||||
def foo(x):
|
||||
return x + jax.lax.axis_index('x')
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x'))
|
||||
x = jnp.arange(4.)
|
||||
with self.assertRaisesRegex(NotImplementedError, 'axis_index'):
|
||||
g(x)
|
||||
|
||||
|
||||
class FunSpec(NamedTuple):
|
||||
name: str
|
||||
|
Loading…
x
Reference in New Issue
Block a user