mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #14788 from patrick-kidger:patch-1
PiperOrigin-RevId: 514437372
This commit is contained in:
commit
0d24d79453
@ -815,7 +815,7 @@ def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False)
|
||||
|
||||
### utilities for defining primitives' batching rules
|
||||
|
||||
BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]]
|
||||
BatchingRule = Callable[..., Tuple[Any, Union[None, int, Tuple[Union[None, int], ...]]]]
|
||||
primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
|
||||
axis_primitive_batchers: Dict[core.Primitive, Callable] = {}
|
||||
spmd_axis_primitive_batchers: Dict[core.Primitive, Callable] = {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user