[jax2tf] Add a harness for broadcast_p and fix conversion.

This commit is contained in:
Benjamin Chetioui 2020-11-11 16:41:03 +01:00
parent d4b1215491
commit 07807d76ff
4 changed files with 26 additions and 1 deletions

View File

@ -1318,7 +1318,8 @@ tf_impl[lax.dot_general_p] = _dot_general
def _broadcast(operand, *, sizes):
return tf.broadcast_to(operand, sizes + tf.shape(operand))
result_shape = tf.TensorShape(sizes).concatenate(operand.shape)
return tf.broadcast_to(operand, result_shape)
tf_impl[lax.broadcast_p] = _broadcast

View File

@ -356,6 +356,25 @@ lax_broadcast_in_dim = tuple( # Validate dtypes
]
)
def _make_broadcast_harness(name, *, dtype=np.float32, shape=(2,), sizes=()):
return Harness(f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_sizes={sizes}",
lambda operand: lax.broadcast_p.bind(operand, sizes=sizes),
[RandArg(shape, dtype)],
shape=shape,
dtype=dtype,
sizes=sizes)
lax_broadcast = tuple( # Validate dtypes
_make_broadcast_harness("dtypes", dtype=dtype)
for dtype in jtu.dtypes.all
) + tuple( # Validate sizes
_make_broadcast_harness("sizes", sizes=sizes)
for sizes in [
(2,), # broadcast 1 dim
(1, 2, 3), # broadcast n > 1 dims
]
)
lax_betainc = tuple(
Harness(f"_{jtu.dtype_str(dtype)}",
lax.betainc,

View File

@ -635,6 +635,10 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
def test_broadcast_in_dim(self, harness: primitive_harness.Harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_broadcast)
def test_broadcast(self, harness: primitive_harness.Harness):
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_betainc)
def test_betainc(self, harness: primitive_harness.Harness):
dtype = harness.params["dtype"]

View File

@ -60,6 +60,7 @@ from jax._src.lax.lax import (
bitwise_or,
bitwise_xor,
broadcast,
broadcast_p,
broadcast_in_dim,
broadcast_in_dim_p,
broadcast_shapes,