mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
add a jax2tf translation rule for the shaped-iota primitive
This allows for jax2tf conversion of the partitionable Threefry RNG.
This commit is contained in:
parent
a3483dbe32
commit
75af6b58d9
@ -1041,6 +1041,10 @@ iota_32x2_shape_p.def_impl(partial(xla.apply_primitive, iota_32x2_shape_p))
|
||||
def iota_32x2_shape_abstract_eval(*, shape):
|
||||
return (core.ShapedArray(shape, np.dtype('uint32')),) * 2
|
||||
|
||||
def bcast_iotas_to_reshaped_iota(add, mul, shape, iotas):
|
||||
strides = (*map(int, np.cumprod(shape[1:][::-1])[::-1]), 1)
|
||||
return reduce(add, [mul(s, i) for i, s in zip(iotas, strides)]) # type: ignore
|
||||
|
||||
def iota_32x2_shape_lowering(ctx, *, shape):
|
||||
def _add(x, y):
|
||||
return mlir.mhlo.AddOp(x, y).result
|
||||
@ -1051,17 +1055,13 @@ def iota_32x2_shape_lowering(ctx, *, shape):
|
||||
x_bcast = mlir.mhlo.BroadcastOp(x_const, mlir.dense_int_elements(shape))
|
||||
return mlir.mhlo.MulOp(x_bcast, y).result
|
||||
|
||||
def _sum(xs):
|
||||
return reduce(_add, xs)
|
||||
|
||||
assert len(shape) > 0
|
||||
aval_out, _ = ctx.avals_out
|
||||
aval_u64 = core.ShapedArray(shape, np.dtype('uint64'))
|
||||
iotas = [mlir.mhlo.IotaOp(mlir.aval_to_ir_type(aval_u64),
|
||||
mlir.i64_attr(dimension)).result
|
||||
mlir.i64_attr(dimension)).result
|
||||
for dimension in range(len(shape))]
|
||||
strides = (*map(int, np.cumprod(shape[1:][::-1])[::-1]), 1)
|
||||
counts = _sum(_mul(s, i) for i, s in zip(iotas, strides)) # type: ignore
|
||||
counts = bcast_iotas_to_reshaped_iota(_add, _mul, shape, iotas)
|
||||
shift = mlir.ir_constant(np.array(32, np.dtype('uint64')),
|
||||
canonicalize_types=False)
|
||||
shift = mlir.mhlo.BroadcastOp(shift, mlir.dense_int_elements(shape)).result
|
||||
@ -1069,7 +1069,7 @@ def iota_32x2_shape_lowering(ctx, *, shape):
|
||||
counts_lo = mlir.mhlo.ConvertOp(mlir.aval_to_ir_type(aval_out), counts).result
|
||||
counts_hi = mlir.mhlo.ConvertOp(mlir.aval_to_ir_type(aval_out),
|
||||
counts_shifted).result
|
||||
return (counts_hi, counts_lo)
|
||||
return counts_hi, counts_lo
|
||||
mlir.register_lowering(iota_32x2_shape_p, iota_32x2_shape_lowering)
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Experimental module transforms JAX functions to be executed by TensorFlow."""
|
||||
from functools import partial
|
||||
from functools import partial, reduce
|
||||
import contextlib
|
||||
import os
|
||||
import re
|
||||
@ -1214,7 +1214,6 @@ tf_not_yet_impl = [
|
||||
"pure_callback",
|
||||
"for",
|
||||
"inspect_sharding",
|
||||
"iota_32x2_shape",
|
||||
|
||||
# Not high priority?
|
||||
"after_all",
|
||||
@ -2444,6 +2443,20 @@ def _rng_uniform(minval: TfVal, maxval: TfVal, *, shape) -> TfVal:
|
||||
tf_impl[lax.rng_uniform_p] = _rng_uniform
|
||||
|
||||
|
||||
def _iota_32x2_shape(*, shape):
|
||||
def _add(x, y): return x + y
|
||||
def _mul(x, y): return x * y
|
||||
def _cast32(xs): return tf.dtypes.cast(xs, _to_tf_dtype(jnp.uint32))
|
||||
iotas = [_iota(dtype=jnp.uint64, shape=shape, dimension=dimension)
|
||||
for dimension in range(len(shape))]
|
||||
counts = prng.bcast_iotas_to_reshaped_iota(_add, _mul, shape, iotas)
|
||||
counts_lo = _cast32(counts)
|
||||
counts_hi = _cast32(tf.bitwise.right_shift(counts, 32))
|
||||
return counts_hi, counts_lo
|
||||
|
||||
tf_impl[prng.iota_32x2_shape_p] = _iota_32x2_shape
|
||||
|
||||
|
||||
def _gather_dimensions_proto(indices_shape, dimension_numbers):
|
||||
proto = xla_data_pb2.GatherDimensionNumbers()
|
||||
proto.offset_dims.extend(dimension_numbers.offset_dims)
|
||||
|
@ -130,9 +130,9 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
"broadcast", "broadcast_in_dim", "cbrt", "ceil", "clamp", "concatenate",
|
||||
"cos", "cosh", "complex", "conj", "convert_element_type", "cummax",
|
||||
"cummin", "device_put", "dynamic_slice", "dynamic_update_slice", "exp",
|
||||
"eq", "floor", "gather", "ge", "gt", "imag", "iota", "is_finite", "le",
|
||||
"logistic", "lt", "log", "mul", "ne", "neg", "not", "or", "pad",
|
||||
"population_count", "random_categorical", "random_uniform",
|
||||
"eq", "floor", "gather", "ge", "gt", "imag", "iota", "iota_32x2_shape",
|
||||
"is_finite", "le", "logistic", "lt", "log", "mul", "ne", "neg", "not",
|
||||
"or", "pad", "population_count", "random_categorical", "random_uniform",
|
||||
"random_randint", "reduce", "reduce_and", "reduce_prod", "reduce_or",
|
||||
"reduce_sum", "reduce_window_mul", "reduce_window_min",
|
||||
"reduce_window_max", "real", "reshape", "rev", "rsqrt", "select_n",
|
||||
|
@ -54,6 +54,7 @@ from jax import numpy as jnp
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import dispatch
|
||||
from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax._src.lax import windowed_reductions as lax_windowed_reductions
|
||||
@ -3241,3 +3242,16 @@ for algorithm in [lax.RandomAlgorithm.RNG_THREE_FRY,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
algorithm=algorithm)
|
||||
|
||||
def _make_iota_32x2_shape_harness(shape):
|
||||
shapestr = ','.join(str(dim) for dim in shape)
|
||||
define(
|
||||
prng.iota_32x2_shape_p,
|
||||
f"shape=({shapestr})",
|
||||
lambda shape: prng.iota_32x2_shape_p.bind(shape=shape),
|
||||
[StaticArg(shape)],
|
||||
dtype=jnp.uint32,
|
||||
shape=shape)
|
||||
|
||||
for shape in [(3,), (5, 7, 4), (100, 100)]:
|
||||
_make_iota_32x2_shape_harness(shape)
|
||||
|
Loading…
x
Reference in New Issue
Block a user