shmap in_spec None shouldn't require hashability

Co-authored-by: Roy Frostig <frostig@google.com>
This commit is contained in:
Matthew Johnson 2024-09-12 23:03:06 +00:00
parent 255c30303d
commit 358f00d5e0
2 changed files with 14 additions and 1 deletions

View File

@ -166,7 +166,7 @@ def _shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs,
raise e('shard_map in_specs') from None
dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat)
if s is not None)
fun, args_flat = argnums_partial(fun, dyn_argnums, args_flat)
fun, args_flat = argnums_partial(fun, dyn_argnums, args_flat, False)
_check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, args_flat)
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))

View File

@ -2164,6 +2164,19 @@ class ShardMapTest(jtu.JaxTestCase):
with config.disable_vmap_shmap_error():
_ = jax.vmap(f, in_axes=(0, None), spmd_axis_name='i')(xs, y)
def test_in_spec_none_hashability(self):
mesh = jtu.create_mesh((2,), ('i',))
class A:
def __hash__(self):
raise Exception
@partial(shard_map, mesh=mesh, in_specs=(None,), out_specs=())
def f(a):
return ()
f(A()) # don't crash
class FunSpec(NamedTuple):
name: str