mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Make the vmap(jit) or vmap(wsc) with a concrete layout error more informative
PiperOrigin-RevId: 656176702
This commit is contained in:
parent
6f68887e0d
commit
2eb1888c98
@ -201,7 +201,6 @@ Parallel operators
|
||||
|
||||
all_gather
|
||||
all_to_all
|
||||
pdot
|
||||
psum
|
||||
psum_scatter
|
||||
pmax
|
||||
|
@ -1967,7 +1967,8 @@ def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type,
|
||||
# TODO(yashkatariya): Figure out layouts should change under vmap.
|
||||
if not (all(l is None for l in in_layouts) and
|
||||
all(l is None for l in out_layouts)):
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError(
|
||||
'Concrete layouts are not supported for vmap(jit).')
|
||||
|
||||
vals_out = pjit_p.bind(
|
||||
*vals_in,
|
||||
@ -2539,7 +2540,9 @@ def _sharding_constraint_batcher(
|
||||
|
||||
# TODO(yashkatariya): Figure out layouts should change under vmap.
|
||||
if layout is not None:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError(
|
||||
'Concrete layout is not supported for vmap(with_sharding_constraint). '
|
||||
f'Got layout {layout}')
|
||||
|
||||
y = sharding_constraint_p.bind(
|
||||
x,
|
||||
|
@ -1535,7 +1535,6 @@ tf_not_yet_impl = [
|
||||
"pgather",
|
||||
"reduce_scatter",
|
||||
"axis_index",
|
||||
"pdot",
|
||||
"all_gather",
|
||||
"lu_pivots_to_permutation",
|
||||
"xla_pmap",
|
||||
|
Loading…
x
Reference in New Issue
Block a user