mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Use the JAX impl rule for threefry instead of writing our own (#4204)
* performance is the same
This commit is contained in:
parent
96278e67a2
commit
c6e6ee2dcb
@ -1141,68 +1141,18 @@ def _select_and_scatter_add(
|
||||
scatter_fn)
|
||||
tf_impl[lax.select_and_scatter_add_p] = _select_and_scatter_add
|
||||
|
||||
def _threefry2x32_jax_impl(*args: TfValOrUnit):
|
||||
# We use the random._threefry2x32_lowering, but since add is not implemented
|
||||
# for uint32, we cast to int32 and back.
|
||||
args = tuple([tf.cast(a, tf.int32) for a in args])
|
||||
res = _convert_jax_impl(
|
||||
functools.partial(random._threefry2x32_lowering,
|
||||
use_rolled_loops=False),
|
||||
multiple_results=True)(*args)
|
||||
res = tuple([tf.cast(r, tf.uint32) for r in res])
|
||||
return res
|
||||
tf_impl[jax.random.threefry2x32_p] = _threefry2x32_jax_impl
|
||||
|
||||
def uadd(a, *b):
|
||||
"""Workaround to support + with uint32 (not supported in TF)."""
|
||||
# Note: Tensorflow's add_n doesn't support broadcasting.
|
||||
b = [tf.broadcast_to(b, tf.shape(a)) for b in b]
|
||||
return tf.add_n([a] + b)
|
||||
|
||||
# TODO(necula): do not repeat the definition of threefry here. Note that on
|
||||
# CPU we don't have a direct definition of the primitive; we expand it
|
||||
# using xla.lower_fun. Could we do something similar here rather than
|
||||
# repeating its definition?
|
||||
def _threefry2x32(key1, key2, x1, x2):
|
||||
"""Tensorflow implementation of the jax PRNG."""
|
||||
def rotate_left(x, d):
|
||||
"""Rotate left."""
|
||||
return tf.bitwise.bitwise_or(
|
||||
tf.bitwise.left_shift(x, np.uint32(d)),
|
||||
tf.bitwise.right_shift(x, np.uint32(32 - d)))
|
||||
|
||||
def apply_round(v1, v2, rot):
|
||||
v1 = uadd(v1, v2)
|
||||
v2 = rotate_left(v2, rot)
|
||||
v2 = tf.bitwise.bitwise_xor(v1, v2)
|
||||
return v1, v2
|
||||
|
||||
rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
|
||||
magic_number = tf.constant(np.uint32(0x1BD11BDA), dtype=tf.uint32)
|
||||
|
||||
key3 = tf.bitwise.bitwise_xor(key1,
|
||||
tf.bitwise.bitwise_xor(key2, magic_number))
|
||||
|
||||
x1 = uadd(x1, key1)
|
||||
x2 = uadd(x2, key2)
|
||||
|
||||
for r in rotations[0]:
|
||||
x1, x2 = apply_round(x1, x2, r)
|
||||
x1 = uadd(x1, key2)
|
||||
x2 = uadd(x2, key3, np.uint32(1))
|
||||
|
||||
for r in rotations[1]:
|
||||
x1, x2 = apply_round(x1, x2, r)
|
||||
x1 = uadd(x1, key3)
|
||||
x2 = uadd(x2, key1, np.uint32(2))
|
||||
|
||||
for r in rotations[0]:
|
||||
x1, x2 = apply_round(x1, x2, r)
|
||||
x1 = uadd(x1, key1)
|
||||
x2 = uadd(x2, key2, np.uint32(3))
|
||||
|
||||
for r in rotations[1]:
|
||||
x1, x2 = apply_round(x1, x2, r)
|
||||
x1 = uadd(x1, key2)
|
||||
x2 = uadd(x2, key3, np.uint32(4))
|
||||
|
||||
for r in rotations[0]:
|
||||
x1, x2 = apply_round(x1, x2, r)
|
||||
x1 = uadd(x1, key3)
|
||||
x2 = uadd(x2, key1, np.uint32(5))
|
||||
|
||||
return x1, x2
|
||||
|
||||
tf_impl[jax.random.threefry2x32_p] = _threefry2x32
|
||||
|
||||
# Use the vmap implementation, otherwise on TPU the performance is really bad
|
||||
# With use_vmap=True on, we get about the same performance for JAX and jax2tf.
|
||||
|
@ -23,6 +23,7 @@ from typing import Any, Callable, Dict, Iterable, Optional, NamedTuple, Sequence
|
||||
from functools import partial
|
||||
|
||||
from absl import testing
|
||||
import jax
|
||||
from jax import config
|
||||
from jax import dtypes
|
||||
from jax import test_util as jtu
|
||||
@ -843,3 +844,22 @@ lax_reduce_window = tuple(
|
||||
for base_dilation in [(2, 1, 3, 2)]
|
||||
for window_dilation in [(1, 2, 2, 1)]
|
||||
)
|
||||
|
||||
random_gamma = tuple(
|
||||
Harness(f"_shape={jtu.format_shape_dtype_string(shape, dtype)}",
|
||||
jax.jit(jax.random.gamma),
|
||||
[jax.random.PRNGKey(42), RandArg(shape, dtype)])
|
||||
for shape in ((), (3,))
|
||||
for dtype in (np.float32, np.float64)
|
||||
)
|
||||
|
||||
random_split = tuple(
|
||||
Harness(f"_i={key_i}",
|
||||
jax.jit(lambda key: jax.random.split(key, 2)),
|
||||
[key])
|
||||
for key_i, key in enumerate([jax.random.PRNGKey(42),
|
||||
np.array([0, 0], dtype=np.uint32),
|
||||
np.array([0xFFFFFFFF, 0], dtype=np.uint32),
|
||||
np.array([0, 0xFFFFFFFF], dtype=np.uint32),
|
||||
np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32)])
|
||||
)
|
||||
|
@ -612,23 +612,15 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
values = np.array([True, False, True], dtype=np.bool_)
|
||||
self.ConvertAndCompare(f_jax, values)
|
||||
|
||||
def test_random_gamma(self):
|
||||
f_jax = jax.jit(jax.random.gamma)
|
||||
for alpha in [1.0,
|
||||
np.array([1.0, 0.2, 1.2], np.float32),
|
||||
np.array([1.0, 0.2, 1.2], np.float64)]:
|
||||
for rng_key in [jax.random.PRNGKey(42)]:
|
||||
self.ConvertAndCompare(f_jax, rng_key, alpha)
|
||||
@primitive_harness.parameterized(primitive_harness.random_gamma)
|
||||
def test_random_gamma(self, harness: primitive_harness.Harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
|
||||
rtol=1e-5)
|
||||
|
||||
def test_prngsplit(self):
|
||||
f_jax = jax.jit(lambda key: jax.random.split(key, 2))
|
||||
for rng_key in [jax.random.PRNGKey(42),
|
||||
np.array([0, 0], dtype=np.uint32),
|
||||
np.array([0xFFFFFFFF, 0], dtype=np.uint32),
|
||||
np.array([0, 0xFFFFFFFF], dtype=np.uint32),
|
||||
np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32)
|
||||
]:
|
||||
self.ConvertAndCompare(f_jax, rng_key)
|
||||
@primitive_harness.parameterized(primitive_harness.random_split,
|
||||
one_containing="i=0")
|
||||
def test_random_split(self, harness: primitive_harness.Harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
def test_zeros_like(self):
|
||||
v = np.float32(2.)
|
||||
|
Loading…
x
Reference in New Issue
Block a user