[jax2tf] Use the JAX impl rule for threefry instead of writing our own (#4204)

* performance is the same
This commit is contained in:
George Necula 2020-09-07 11:26:52 +03:00 committed by GitHub
parent 96278e67a2
commit c6e6ee2dcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 77 deletions

View File

@ -1141,68 +1141,18 @@ def _select_and_scatter_add(
scatter_fn)
tf_impl[lax.select_and_scatter_add_p] = _select_and_scatter_add
def _threefry2x32_jax_impl(*args: TfValOrUnit):
# We use the random._threefry2x32_lowering, but since add is not implemented
# for uint32, we cast to int32 and back.
args = tuple([tf.cast(a, tf.int32) for a in args])
res = _convert_jax_impl(
functools.partial(random._threefry2x32_lowering,
use_rolled_loops=False),
multiple_results=True)(*args)
res = tuple([tf.cast(r, tf.uint32) for r in res])
return res
tf_impl[jax.random.threefry2x32_p] = _threefry2x32_jax_impl
def uadd(a, *b):
"""Workaround to support + with uint32 (not supported in TF)."""
# Note: Tensorflow's add_n doesn't support broadcasting.
b = [tf.broadcast_to(b, tf.shape(a)) for b in b]
return tf.add_n([a] + b)
# TODO(necula): do not repeat the definition of threefry here. Note that on
# CPU we don't have a direct definition of the primitive; we expand it
# using xla.lower_fun. Could we do something similar here rather than
# repeating its definition?
def _threefry2x32(key1, key2, x1, x2):
"""Tensorflow implementation of the jax PRNG."""
def rotate_left(x, d):
"""Rotate left."""
return tf.bitwise.bitwise_or(
tf.bitwise.left_shift(x, np.uint32(d)),
tf.bitwise.right_shift(x, np.uint32(32 - d)))
def apply_round(v1, v2, rot):
v1 = uadd(v1, v2)
v2 = rotate_left(v2, rot)
v2 = tf.bitwise.bitwise_xor(v1, v2)
return v1, v2
rotations = [[13, 15, 26, 6], [17, 29, 16, 24]]
magic_number = tf.constant(np.uint32(0x1BD11BDA), dtype=tf.uint32)
key3 = tf.bitwise.bitwise_xor(key1,
tf.bitwise.bitwise_xor(key2, magic_number))
x1 = uadd(x1, key1)
x2 = uadd(x2, key2)
for r in rotations[0]:
x1, x2 = apply_round(x1, x2, r)
x1 = uadd(x1, key2)
x2 = uadd(x2, key3, np.uint32(1))
for r in rotations[1]:
x1, x2 = apply_round(x1, x2, r)
x1 = uadd(x1, key3)
x2 = uadd(x2, key1, np.uint32(2))
for r in rotations[0]:
x1, x2 = apply_round(x1, x2, r)
x1 = uadd(x1, key1)
x2 = uadd(x2, key2, np.uint32(3))
for r in rotations[1]:
x1, x2 = apply_round(x1, x2, r)
x1 = uadd(x1, key2)
x2 = uadd(x2, key3, np.uint32(4))
for r in rotations[0]:
x1, x2 = apply_round(x1, x2, r)
x1 = uadd(x1, key3)
x2 = uadd(x2, key1, np.uint32(5))
return x1, x2
tf_impl[jax.random.threefry2x32_p] = _threefry2x32
# Use the vmap implementation, otherwise on TPU the performance is really bad
# With use_vmap=True on, we get about the same performance for JAX and jax2tf.

View File

@ -23,6 +23,7 @@ from typing import Any, Callable, Dict, Iterable, Optional, NamedTuple, Sequence
from functools import partial
from absl import testing
import jax
from jax import config
from jax import dtypes
from jax import test_util as jtu
@ -843,3 +844,22 @@ lax_reduce_window = tuple(
for base_dilation in [(2, 1, 3, 2)]
for window_dilation in [(1, 2, 2, 1)]
)
random_gamma = tuple(
Harness(f"_shape={jtu.format_shape_dtype_string(shape, dtype)}",
jax.jit(jax.random.gamma),
[jax.random.PRNGKey(42), RandArg(shape, dtype)])
for shape in ((), (3,))
for dtype in (np.float32, np.float64)
)
random_split = tuple(
Harness(f"_i={key_i}",
jax.jit(lambda key: jax.random.split(key, 2)),
[key])
for key_i, key in enumerate([jax.random.PRNGKey(42),
np.array([0, 0], dtype=np.uint32),
np.array([0xFFFFFFFF, 0], dtype=np.uint32),
np.array([0, 0xFFFFFFFF], dtype=np.uint32),
np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32)])
)

View File

@ -612,23 +612,15 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
values = np.array([True, False, True], dtype=np.bool_)
self.ConvertAndCompare(f_jax, values)
def test_random_gamma(self):
f_jax = jax.jit(jax.random.gamma)
for alpha in [1.0,
np.array([1.0, 0.2, 1.2], np.float32),
np.array([1.0, 0.2, 1.2], np.float64)]:
for rng_key in [jax.random.PRNGKey(42)]:
self.ConvertAndCompare(f_jax, rng_key, alpha)
@primitive_harness.parameterized(primitive_harness.random_gamma)
def test_random_gamma(self, harness: primitive_harness.Harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
rtol=1e-5)
def test_prngsplit(self):
f_jax = jax.jit(lambda key: jax.random.split(key, 2))
for rng_key in [jax.random.PRNGKey(42),
np.array([0, 0], dtype=np.uint32),
np.array([0xFFFFFFFF, 0], dtype=np.uint32),
np.array([0, 0xFFFFFFFF], dtype=np.uint32),
np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32)
]:
self.ConvertAndCompare(f_jax, rng_key)
@primitive_harness.parameterized(primitive_harness.random_split,
one_containing="i=0")
def test_random_split(self, harness: primitive_harness.Harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
def test_zeros_like(self):
v = np.float32(2.)