Merge pull request #14788 from patrick-kidger:patch-1

PiperOrigin-RevId: 514437372
This commit is contained in:
jax authors 2023-03-06 09:40:19 -08:00
commit 0d24d79453

View File

@ -815,7 +815,7 @@ def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False)
### utilities for defining primitives' batching rules
BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]]
BatchingRule = Callable[..., Tuple[Any, Union[None, int, Tuple[Union[None, int], ...]]]]
primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
axis_primitive_batchers: Dict[core.Primitive, Callable] = {}
spmd_axis_primitive_batchers: Dict[core.Primitive, Callable] = {}