mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
shmap in_spec None shouldn't require hashability
Co-authored-by: Roy Frostig <frostig@google.com>
This commit is contained in:
parent
255c30303d
commit
358f00d5e0
@ -166,7 +166,7 @@ def _shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs,
|
||||
raise e('shard_map in_specs') from None
|
||||
dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat)
|
||||
if s is not None)
|
||||
fun, args_flat = argnums_partial(fun, dyn_argnums, args_flat)
|
||||
fun, args_flat = argnums_partial(fun, dyn_argnums, args_flat, False)
|
||||
_check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, args_flat)
|
||||
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
|
||||
|
||||
|
@ -2164,6 +2164,19 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
with config.disable_vmap_shmap_error():
|
||||
_ = jax.vmap(f, in_axes=(0, None), spmd_axis_name='i')(xs, y)
|
||||
|
||||
def test_in_spec_none_hashability(self):
|
||||
mesh = jtu.create_mesh((2,), ('i',))
|
||||
|
||||
class A:
|
||||
def __hash__(self):
|
||||
raise Exception
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=(None,), out_specs=())
|
||||
def f(a):
|
||||
return ()
|
||||
|
||||
f(A()) # don't crash
|
||||
|
||||
|
||||
class FunSpec(NamedTuple):
|
||||
name: str
|
||||
|
Loading…
x
Reference in New Issue
Block a user