mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[jax2tf] Add a harness for broadcast_p and fix conversion.
This commit is contained in:
parent
d4b1215491
commit
07807d76ff
@ -1318,7 +1318,8 @@ tf_impl[lax.dot_general_p] = _dot_general
|
||||
|
||||
|
||||
def _broadcast(operand, *, sizes):
|
||||
return tf.broadcast_to(operand, sizes + tf.shape(operand))
|
||||
result_shape = tf.TensorShape(sizes).concatenate(operand.shape)
|
||||
return tf.broadcast_to(operand, result_shape)
|
||||
tf_impl[lax.broadcast_p] = _broadcast
|
||||
|
||||
|
||||
|
@ -356,6 +356,25 @@ lax_broadcast_in_dim = tuple( # Validate dtypes
|
||||
]
|
||||
)
|
||||
|
||||
def _make_broadcast_harness(name, *, dtype=np.float32, shape=(2,), sizes=()):
|
||||
return Harness(f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_sizes={sizes}",
|
||||
lambda operand: lax.broadcast_p.bind(operand, sizes=sizes),
|
||||
[RandArg(shape, dtype)],
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
sizes=sizes)
|
||||
|
||||
lax_broadcast = tuple( # Validate dtypes
|
||||
_make_broadcast_harness("dtypes", dtype=dtype)
|
||||
for dtype in jtu.dtypes.all
|
||||
) + tuple( # Validate sizes
|
||||
_make_broadcast_harness("sizes", sizes=sizes)
|
||||
for sizes in [
|
||||
(2,), # broadcast 1 dim
|
||||
(1, 2, 3), # broadcast n > 1 dims
|
||||
]
|
||||
)
|
||||
|
||||
lax_betainc = tuple(
|
||||
Harness(f"_{jtu.dtype_str(dtype)}",
|
||||
lax.betainc,
|
||||
|
@ -635,6 +635,10 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
def test_broadcast_in_dim(self, harness: primitive_harness.Harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_broadcast)
|
||||
def test_broadcast(self, harness: primitive_harness.Harness):
|
||||
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
|
||||
|
||||
@primitive_harness.parameterized(primitive_harness.lax_betainc)
|
||||
def test_betainc(self, harness: primitive_harness.Harness):
|
||||
dtype = harness.params["dtype"]
|
||||
|
@ -60,6 +60,7 @@ from jax._src.lax.lax import (
|
||||
bitwise_or,
|
||||
bitwise_xor,
|
||||
broadcast,
|
||||
broadcast_p,
|
||||
broadcast_in_dim,
|
||||
broadcast_in_dim_p,
|
||||
broadcast_shapes,
|
||||
|
Loading…
x
Reference in New Issue
Block a user