Merge pull request #8217 from LenaMartens:changelist/403115357

PiperOrigin-RevId: 403402974
This commit is contained in:
jax authors 2021-10-15 10:10:06 -07:00
commit 8f1d7beace
2 changed files with 14 additions and 21 deletions

View File

@ -1979,6 +1979,19 @@ def _memcpy(axis, num, src, dst, offset):
masking.masking_rules[lax.concatenate_p] = _concat_masking_rule # type: ignore
def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm):
"""Calls RBG in a loop and stacks the results."""
key, = batched_args
bd, = batch_dims
if bd is batching.not_mapped:
return lax.rng_bit_generator_p.bind(key, shape=shape, dtype=dtype,
algorithm=algorithm), (None, None)
key = batching.moveaxis(key, bd, 0)
map_body = lambda k: lax.rng_bit_generator_p.bind(k, shape=shape, dtype=dtype, algorithm=algorithm)
stacked_keys, stacked_bits = map(map_body, key)
return (stacked_keys, stacked_bits), (0, 0)
batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule
def _check_tree_and_avals(what, tree1, avals1, tree2, avals2):
"""Raises TypeError if (tree1, avals1) does not match (tree2, avals2).

View File

@ -6780,7 +6780,7 @@ def _rng_bit_generator_translation_rule(backend_is_gpu, c, key, *, shape, dtype,
# need to convert u32[4] -> u64[2] here in the translation rule. However, we
# also polymorphically allow a u64[2] for backward compatibility.
assert ((key_shape == (4,) and key_dtype == dtypes.dtype('uint32')) or
(key_shape == (2,) and key_dtype == dtypes.dtype('uint64')))
(key_shape == (2,) and key_dtype == dtypes.dtype('uint64'))), (key_shape, key_dtype)
xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape)
if key_dtype == dtypes.dtype('uint32'):
# TODO(mattjj): the BitcastConvertType segfaults on GPU
@ -6827,25 +6827,6 @@ def _convert_2xU64_to_4xU32_without_bitcast(c, key):
def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
return [key.named_shape, key.named_shape]
def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, algorithm):
"""Calls RBG in a loop and stacks the results."""
key, = batched_args
bd, = batch_dims
if bd is batching.not_mapped:
return rng_bit_generator_p.bind(key, shape=shape, dtype=dtype,
algorithm=algorithm), (None, None)
key = batching.moveaxis(key, bd, 0)
out_keys = []
out_bits = []
for k in key:
updated_key, bits = rng_bit_generator_p.bind(k, shape=shape, dtype=dtype,
algorithm=algorithm)
out_keys.append(reshape(updated_key, (1,)+updated_key.shape))
out_bits.append(reshape(bits, (1,)+bits.shape))
stacked_keys = concatenate(out_keys, 0)
stacked_bits = concatenate(out_bits, 0)
return (stacked_keys, stacked_bits), (0, 0)
rng_bit_generator_p = Primitive("rng_bit_generator")
rng_bit_generator_p.multiple_results = True
rng_bit_generator_p.def_impl(
@ -6855,7 +6836,6 @@ rng_bit_generator_p.def_abstract_eval(
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
_rng_bit_generator_weak_type_rule,
_rng_bit_generator_named_shape_rule))
batching.primitive_batchers[rng_bit_generator_p] = _rng_bit_generator_batching_rule
xla.translations[rng_bit_generator_p] = \
partial(_rng_bit_generator_translation_rule, False)
xla.backend_specific_translations['gpu'][rng_bit_generator_p] = \