add a jax2tf translation rule for the shaped-iota primitive

This allows for jax2tf conversion of the partitionable Threefry RNG.
This commit is contained in:
Roy Frostig 2022-12-05 09:16:19 -08:00
parent a3483dbe32
commit 75af6b58d9
4 changed files with 39 additions and 12 deletions

View File

@ -1041,6 +1041,10 @@ iota_32x2_shape_p.def_impl(partial(xla.apply_primitive, iota_32x2_shape_p))
def iota_32x2_shape_abstract_eval(*, shape):
return (core.ShapedArray(shape, np.dtype('uint32')),) * 2
def bcast_iotas_to_reshaped_iota(add, mul, shape, iotas):
strides = (*map(int, np.cumprod(shape[1:][::-1])[::-1]), 1)
return reduce(add, [mul(s, i) for i, s in zip(iotas, strides)]) # type: ignore
def iota_32x2_shape_lowering(ctx, *, shape):
def _add(x, y):
return mlir.mhlo.AddOp(x, y).result
@ -1051,17 +1055,13 @@ def iota_32x2_shape_lowering(ctx, *, shape):
x_bcast = mlir.mhlo.BroadcastOp(x_const, mlir.dense_int_elements(shape))
return mlir.mhlo.MulOp(x_bcast, y).result
def _sum(xs):
return reduce(_add, xs)
assert len(shape) > 0
aval_out, _ = ctx.avals_out
aval_u64 = core.ShapedArray(shape, np.dtype('uint64'))
iotas = [mlir.mhlo.IotaOp(mlir.aval_to_ir_type(aval_u64),
mlir.i64_attr(dimension)).result
mlir.i64_attr(dimension)).result
for dimension in range(len(shape))]
strides = (*map(int, np.cumprod(shape[1:][::-1])[::-1]), 1)
counts = _sum(_mul(s, i) for i, s in zip(iotas, strides)) # type: ignore
counts = bcast_iotas_to_reshaped_iota(_add, _mul, shape, iotas)
shift = mlir.ir_constant(np.array(32, np.dtype('uint64')),
canonicalize_types=False)
shift = mlir.mhlo.BroadcastOp(shift, mlir.dense_int_elements(shape)).result
@ -1069,7 +1069,7 @@ def iota_32x2_shape_lowering(ctx, *, shape):
counts_lo = mlir.mhlo.ConvertOp(mlir.aval_to_ir_type(aval_out), counts).result
counts_hi = mlir.mhlo.ConvertOp(mlir.aval_to_ir_type(aval_out),
counts_shifted).result
return (counts_hi, counts_lo)
return counts_hi, counts_lo
mlir.register_lowering(iota_32x2_shape_p, iota_32x2_shape_lowering)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental module transforms JAX functions to be executed by TensorFlow."""
from functools import partial
from functools import partial, reduce
import contextlib
import os
import re
@ -1214,7 +1214,6 @@ tf_not_yet_impl = [
"pure_callback",
"for",
"inspect_sharding",
"iota_32x2_shape",
# Not high priority?
"after_all",
@ -2444,6 +2443,20 @@ def _rng_uniform(minval: TfVal, maxval: TfVal, *, shape) -> TfVal:
tf_impl[lax.rng_uniform_p] = _rng_uniform
def _iota_32x2_shape(*, shape):
def _add(x, y): return x + y
def _mul(x, y): return x * y
def _cast32(xs): return tf.dtypes.cast(xs, _to_tf_dtype(jnp.uint32))
iotas = [_iota(dtype=jnp.uint64, shape=shape, dimension=dimension)
for dimension in range(len(shape))]
counts = prng.bcast_iotas_to_reshaped_iota(_add, _mul, shape, iotas)
counts_lo = _cast32(counts)
counts_hi = _cast32(tf.bitwise.right_shift(counts, 32))
return counts_hi, counts_lo
tf_impl[prng.iota_32x2_shape_p] = _iota_32x2_shape
def _gather_dimensions_proto(indices_shape, dimension_numbers):
proto = xla_data_pb2.GatherDimensionNumbers()
proto.offset_dims.extend(dimension_numbers.offset_dims)

View File

@ -130,9 +130,9 @@ class Jax2TfLimitation(primitive_harness.Limitation):
"broadcast", "broadcast_in_dim", "cbrt", "ceil", "clamp", "concatenate",
"cos", "cosh", "complex", "conj", "convert_element_type", "cummax",
"cummin", "device_put", "dynamic_slice", "dynamic_update_slice", "exp",
"eq", "floor", "gather", "ge", "gt", "imag", "iota", "is_finite", "le",
"logistic", "lt", "log", "mul", "ne", "neg", "not", "or", "pad",
"population_count", "random_categorical", "random_uniform",
"eq", "floor", "gather", "ge", "gt", "imag", "iota", "iota_32x2_shape",
"is_finite", "le", "logistic", "lt", "log", "mul", "ne", "neg", "not",
"or", "pad", "population_count", "random_categorical", "random_uniform",
"random_randint", "reduce", "reduce_and", "reduce_prod", "reduce_or",
"reduce_sum", "reduce_window_mul", "reduce_window_min",
"reduce_window_max", "real", "reshape", "rev", "rsqrt", "select_n",

View File

@ -54,6 +54,7 @@ from jax import numpy as jnp
from jax._src import ad_util
from jax._src import dispatch
from jax._src import prng
from jax._src import test_util as jtu
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax import windowed_reductions as lax_windowed_reductions
@ -3241,3 +3242,16 @@ for algorithm in [lax.RandomAlgorithm.RNG_THREE_FRY,
shape=shape,
dtype=dtype,
algorithm=algorithm)
def _make_iota_32x2_shape_harness(shape):
shapestr = ','.join(str(dim) for dim in shape)
define(
prng.iota_32x2_shape_p,
f"shape=({shapestr})",
lambda shape: prng.iota_32x2_shape_p.bind(shape=shape),
[StaticArg(shape)],
dtype=jnp.uint32,
shape=shape)
for shape in [(3,), (5, 7, 4), (100, 100)]:
_make_iota_32x2_shape_harness(shape)