mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #19591 from jakevdp:key-reuse-slice
PiperOrigin-RevId: 602868125
This commit is contained in:
commit
af2292aa4e
@ -193,17 +193,16 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> KeyReuseSignature
|
||||
def _slice_signature(eqn, args_consumed):
|
||||
del args_consumed # unused here
|
||||
in_aval = eqn.invars[0].aval
|
||||
if not jax.dtypes.issubdtype(in_aval.dtype, jax.dtypes.prng_key):
|
||||
return KeyReuseSignatureWithForwards([], [], [Forward(0, 0)])
|
||||
if any(core.is_symbolic_dim(s) for s in in_aval.shape):
|
||||
return KeyReuseSignatureWithForwards([], [], [Forward(0, 0)])
|
||||
start_indices = eqn.params['start_indices']
|
||||
limit_indices = eqn.params['limit_indices']
|
||||
strides = eqn.params['strides'] or (1,) * len(start_indices)
|
||||
idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides))
|
||||
if any(core.is_symbolic_dim(s) for s in in_aval.shape):
|
||||
sink = True
|
||||
else:
|
||||
# TODO(jakevdp): should we avoid constructing the mask array if the input
|
||||
# does not have a key dtype?
|
||||
sink = np.zeros(in_aval.shape, dtype=bool)
|
||||
sink[idx] = True
|
||||
sink = np.zeros(in_aval.shape, dtype=bool)
|
||||
sink[idx] = True
|
||||
return KeyReuseSignatureWithForwards([Sink(0, sink)], [Source(0)])
|
||||
|
||||
key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature
|
||||
|
@ -164,17 +164,16 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> KeyReuseSignature
|
||||
def _slice_signature(eqn, args_consumed):
|
||||
del args_consumed # unused here
|
||||
in_aval = eqn.invars[0].aval
|
||||
if not jax.dtypes.issubdtype(in_aval.dtype, jax.dtypes.prng_key):
|
||||
return KeyReuseSignature([Sink(0)], [Source(0)])
|
||||
if any(core.is_symbolic_dim(s) for s in in_aval.shape):
|
||||
return KeyReuseSignature([Sink(0)], [Source(0)])
|
||||
start_indices = eqn.params['start_indices']
|
||||
limit_indices = eqn.params['limit_indices']
|
||||
strides = eqn.params['strides'] or (1,) * len(start_indices)
|
||||
idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides))
|
||||
if any(core.is_symbolic_dim(s) for s in in_aval.shape):
|
||||
sink = True
|
||||
else:
|
||||
# TODO(jakevdp): should we avoid constructing the mask array if the input
|
||||
# does not have a key dtype?
|
||||
sink = np.zeros(in_aval.shape, dtype=bool)
|
||||
sink[idx] = True
|
||||
sink = np.zeros(in_aval.shape, dtype=bool)
|
||||
sink[idx] = True
|
||||
return KeyReuseSignature([Sink(0, sink)], [Source(0)])
|
||||
|
||||
key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature
|
||||
|
Loading…
x
Reference in New Issue
Block a user