Make the vmap(jit) or vmap(wsc) with a concrete layout error more informative

PiperOrigin-RevId: 656176702
This commit is contained in:
Yash Katariya 2024-07-25 18:31:50 -07:00 committed by jax authors
parent 6f68887e0d
commit 2eb1888c98
3 changed files with 5 additions and 4 deletions

View File

@ -201,7 +201,6 @@ Parallel operators
all_gather
all_to_all
pdot
psum
psum_scatter
pmax

View File

@ -1967,7 +1967,8 @@ def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type,
# TODO(yashkatariya): Figure out layouts should change under vmap.
if not (all(l is None for l in in_layouts) and
all(l is None for l in out_layouts)):
raise NotImplementedError
raise NotImplementedError(
'Concrete layouts are not supported for vmap(jit).')
vals_out = pjit_p.bind(
*vals_in,
@ -2539,7 +2540,9 @@ def _sharding_constraint_batcher(
# TODO(yashkatariya): Figure out layouts should change under vmap.
if layout is not None:
raise NotImplementedError
raise NotImplementedError(
'Concrete layout is not supported for vmap(with_sharding_constraint). '
f'Got layout {layout}')
y = sharding_constraint_p.bind(
x,

View File

@ -1535,7 +1535,6 @@ tf_not_yet_impl = [
"pgather",
"reduce_scatter",
"axis_index",
"pdot",
"all_gather",
"lu_pivots_to_permutation",
"xla_pmap",