mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #8217 from LenaMartens:changelist/403115357
PiperOrigin-RevId: 403402974
This commit is contained in:
commit
8f1d7beace
@ -1979,6 +1979,19 @@ def _memcpy(axis, num, src, dst, offset):
|
||||
|
||||
masking.masking_rules[lax.concatenate_p] = _concat_masking_rule # type: ignore
|
||||
|
||||
def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm):
|
||||
"""Calls RBG in a loop and stacks the results."""
|
||||
key, = batched_args
|
||||
bd, = batch_dims
|
||||
if bd is batching.not_mapped:
|
||||
return lax.rng_bit_generator_p.bind(key, shape=shape, dtype=dtype,
|
||||
algorithm=algorithm), (None, None)
|
||||
key = batching.moveaxis(key, bd, 0)
|
||||
map_body = lambda k: lax.rng_bit_generator_p.bind(k, shape=shape, dtype=dtype, algorithm=algorithm)
|
||||
stacked_keys, stacked_bits = map(map_body, key)
|
||||
return (stacked_keys, stacked_bits), (0, 0)
|
||||
|
||||
batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule
|
||||
|
||||
def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
|
||||
"""Raises TypeError if (tree1, avals1) does not match (tree2, avals2).
|
||||
|
@ -6780,7 +6780,7 @@ def _rng_bit_generator_translation_rule(backend_is_gpu, c, key, *, shape, dtype,
|
||||
# need to convert u32[4] -> u64[2] here in the translation rule. However, we
|
||||
# also polymorphically allow a u64[2] for backward compatibility.
|
||||
assert ((key_shape == (4,) and key_dtype == dtypes.dtype('uint32')) or
|
||||
(key_shape == (2,) and key_dtype == dtypes.dtype('uint64')))
|
||||
(key_shape == (2,) and key_dtype == dtypes.dtype('uint64'))), (key_shape, key_dtype)
|
||||
xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape)
|
||||
if key_dtype == dtypes.dtype('uint32'):
|
||||
# TODO(mattjj): the BitcastConvertType segfaults on GPU
|
||||
@ -6827,25 +6827,6 @@ def _convert_2xU64_to_4xU32_without_bitcast(c, key):
|
||||
def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
|
||||
return [key.named_shape, key.named_shape]
|
||||
|
||||
def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm):
|
||||
"""Calls RBG in a loop and stacks the results."""
|
||||
key, = batched_args
|
||||
bd, = batch_dims
|
||||
if bd is batching.not_mapped:
|
||||
return rng_bit_generator_p.bind(key, shape=shape, dtype=dtype,
|
||||
algorithm=algorithm), (None, None)
|
||||
key = batching.moveaxis(key, bd, 0)
|
||||
out_keys = []
|
||||
out_bits = []
|
||||
for k in key:
|
||||
updated_key, bits = rng_bit_generator_p.bind(k, shape=shape, dtype=dtype,
|
||||
algorithm=algorithm)
|
||||
out_keys.append(reshape(updated_key, (1,)+updated_key.shape))
|
||||
out_bits.append(reshape(bits, (1,)+bits.shape))
|
||||
stacked_keys = concatenate(out_keys, 0)
|
||||
stacked_bits = concatenate(out_bits, 0)
|
||||
return (stacked_keys, stacked_bits), (0, 0)
|
||||
|
||||
rng_bit_generator_p = Primitive("rng_bit_generator")
|
||||
rng_bit_generator_p.multiple_results = True
|
||||
rng_bit_generator_p.def_impl(
|
||||
@ -6855,7 +6836,6 @@ rng_bit_generator_p.def_abstract_eval(
|
||||
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
|
||||
_rng_bit_generator_weak_type_rule,
|
||||
_rng_bit_generator_named_shape_rule))
|
||||
batching.primitive_batchers[rng_bit_generator_p] = _rng_bit_generator_batching_rule
|
||||
xla.translations[rng_bit_generator_p] = \
|
||||
partial(_rng_bit_generator_translation_rule, False)
|
||||
xla.backend_specific_translations['gpu'][rng_bit_generator_p] = \
|
||||
|
Loading…
x
Reference in New Issue
Block a user