Merge pull request #19591 from jakevdp:key-reuse-slice

PiperOrigin-RevId: 602868125
This commit is contained in:
jax authors 2024-01-30 16:07:43 -08:00
commit af2292aa4e
2 changed files with 12 additions and 14 deletions

View File

@ -193,17 +193,16 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> KeyReuseSignature
def _slice_signature(eqn, args_consumed):
del args_consumed # unused here
in_aval = eqn.invars[0].aval
if not jax.dtypes.issubdtype(in_aval.dtype, jax.dtypes.prng_key):
return KeyReuseSignatureWithForwards([], [], [Forward(0, 0)])
if any(core.is_symbolic_dim(s) for s in in_aval.shape):
return KeyReuseSignatureWithForwards([], [], [Forward(0, 0)])
start_indices = eqn.params['start_indices']
limit_indices = eqn.params['limit_indices']
strides = eqn.params['strides'] or (1,) * len(start_indices)
idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides))
if any(core.is_symbolic_dim(s) for s in in_aval.shape):
sink = True
else:
# TODO(jakevdp): should we avoid constructing the mask array if the input
# does not have a key dtype?
sink = np.zeros(in_aval.shape, dtype=bool)
sink[idx] = True
sink = np.zeros(in_aval.shape, dtype=bool)
sink[idx] = True
return KeyReuseSignatureWithForwards([Sink(0, sink)], [Source(0)])
key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature

View File

@ -164,17 +164,16 @@ def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> KeyReuseSignature
def _slice_signature(eqn, args_consumed):
del args_consumed # unused here
in_aval = eqn.invars[0].aval
if not jax.dtypes.issubdtype(in_aval.dtype, jax.dtypes.prng_key):
return KeyReuseSignature([Sink(0)], [Source(0)])
if any(core.is_symbolic_dim(s) for s in in_aval.shape):
return KeyReuseSignature([Sink(0)], [Source(0)])
start_indices = eqn.params['start_indices']
limit_indices = eqn.params['limit_indices']
strides = eqn.params['strides'] or (1,) * len(start_indices)
idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides))
if any(core.is_symbolic_dim(s) for s in in_aval.shape):
sink = True
else:
# TODO(jakevdp): should we avoid constructing the mask array if the input
# does not have a key dtype?
sink = np.zeros(in_aval.shape, dtype=bool)
sink[idx] = True
sink = np.zeros(in_aval.shape, dtype=bool)
sink[idx] = True
return KeyReuseSignature([Sink(0, sink)], [Source(0)])
key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature