[jax2tf] Fix implementation and tests of select_and_scatter_add conversion.

This commit is contained in:
Benjamin Chetioui 2020-10-23 12:42:13 +02:00
parent fcaced32aa
commit dfce748274
5 changed files with 92 additions and 14 deletions

View File

@ -1470,18 +1470,20 @@ def _select_and_scatter(
tf_impl[lax.select_and_scatter_p] = _select_and_scatter
def _select_and_scatter_add(
operand, source, init_value, select_jaxpr, select_consts, scatter_jaxpr,
scatter_consts, window_dimensions, window_strides, padding):
del select_jaxpr, select_consts, scatter_jaxpr, scatter_consts
# TODO(phawkins): handle the select and scatter jaxprs correctly.
a = tf.constant(0, operand.dtype)
select_fn = _ge_fn.get_concrete_function(a, a)
scatter_fn = _add_fn.get_concrete_function(a, a)
return tfxla.select_and_scatter(operand, window_dimensions, window_strides,
padding, source, init_value, select_fn,
scatter_fn)
tf_impl[lax.select_and_scatter_add_p] = _select_and_scatter_add
@functools.partial(bool_to_int8, argnums=(0, 1))
def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
window_strides, padding, _in_avals, _out_aval):
init_value = tf.zeros((), operand.dtype)
select_fn = (tf.function(tf_impl[select_prim], autograph=False)
.get_concrete_function(init_value, init_value))
scatter_fn = _add_fn.get_concrete_function(init_value, init_value)
out = tfxla.select_and_scatter(operand, window_dimensions, window_strides,
padding, source, init_value, select_fn,
scatter_fn)
out.set_shape(_aval_to_tf_shape(_out_aval))
return out
tf_impl_with_avals[lax.select_and_scatter_add_p] = _select_and_scatter_add
def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval):
# We use the random._threefry2x32_lowering, but since add is not implemented

View File

@ -1,6 +1,6 @@
# Primitives with limited support
*Last generated on (YYYY-MM-DD): 2020-10-19*
*Last generated on (YYYY-MM-DD): 2020-10-23*
## Updating the documentation
@ -71,6 +71,7 @@ conversion to Tensorflow.
| scatter-mul | Missing TF support | Primitive is unimplemented in TF | complex64 | TPU |
| select_and_gather_add | Missing TF support | Primitive is unimplemented in TF | float32, float64 | TPU |
| select_and_gather_add | Missing TF support | Primitive is unimplemented in TF | float64 | CPU, GPU |
| select_and_scatter_add | Missing TF support | Primitive is unimplemented in TF | uint16, uint32, uint64 | CPU, GPU, TPU |
| sinh | Missing TF support | Primitive is unimplemented in TF | float16 | CPU, GPU, TPU |
| sort | Missing TF support | Primitive is unimplemented in TF | complex128, complex64 | CPU, GPU, TPU |
| sort | Missing TF support | Primitive is unimplemented in TF; only sorting on last dimension is supported for XlaSort | ALL | CPU, GPU, TPU |
@ -92,4 +93,4 @@ The conversion of the following JAX primitives is not yet implemented:
The following JAX primitives have a defined conversion but are known to be
missing tests:
`argmin`, `broadcast`, `clamp`, `complex`, `conj`, `custom_lin`, `device_put`, `integer_pow`, `rev`, `select_and_scatter`, `select_and_scatter_add`, `tie_in`
`argmin`, `broadcast`, `clamp`, `complex`, `conj`, `custom_lin`, `device_put`, `integer_pow`, `rev`, `select_and_scatter`, `tie_in`

View File

@ -164,6 +164,10 @@ def categorize(prim: core.Primitive, *args, **kwargs) \
"implementation")
tf_unimpl(np_dtype, additional_msg=additional_msg, devs=["CPU", "GPU"])
if prim is lax.select_and_scatter_add_p:
if np_dtype in [np.uint64, np.uint32, np.uint16]:
tf_unimpl(np_dtype)
if prim is lax.select_and_gather_add_p:
# TODO: the conversion is only supported for float16/float32 on CPU/GPU,
# and float16 on TPU. This is because we do not implement a precision

View File

@ -862,6 +862,68 @@ lax_shift_right_arithmetic = tuple(
for arg, dtype, shift_amount in shift_inputs
)
def _make_select_and_scatter_add_harness(
name, *, shape=(2, 4, 6), dtype=np.float32, select_prim=lax.ge_p,
window_dimensions=(2, 2, 2), window_strides=(1, 1, 1),
padding=((0, 0), (0, 0), (0, 0)), nb_inactive_dims=0):
ones = (1,) * len(shape)
cotangent_shape = jax.api.eval_shape(
lambda x: lax._select_and_gather_add(x, x, lax.ge_p, window_dimensions,
window_strides, padding, ones, ones),
np.ones(shape, dtype)).shape
return Harness(f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}",
lax._select_and_scatter_add,
[RandArg(cotangent_shape, dtype), RandArg(shape, dtype),
StaticArg(select_prim), StaticArg(window_dimensions),
StaticArg(window_strides), StaticArg(padding)],
shape=shape,
dtype=dtype,
select_prim=select_prim,
window_dimensions=window_dimensions,
window_strides=window_strides,
padding=padding,
# JAX can only run select_and_scatter_add on TPU when 2
# or more dimensions are inactive
run_on_tpu=(nb_inactive_dims >= 2))
lax_select_and_scatter_add = tuple( # Validate dtypes
_make_select_and_scatter_add_harness("dtypes", dtype=dtype)
for dtype in set(jtu.dtypes.all) - set([np.complex64, np.complex128])
) + tuple( # Validate different reduction primitives
_make_select_and_scatter_add_harness("select_prim", select_prim=select_prim)
for select_prim in [lax.le_p]
) + tuple( # Validate padding
_make_select_and_scatter_add_harness("padding", padding=padding)
for padding in [
# TODO(bchetioui): commented out the test based on
# https://github.com/google/jax/issues/4690
#((1, 2), (2, 3), (3, 4)) # non-zero padding
((1, 1), (1, 1), (1, 1)) # non-zero padding
]
) + tuple( # Validate window_dimensions
_make_select_and_scatter_add_harness("window_dimensions",
window_dimensions=window_dimensions)
for window_dimensions in [
(1, 2, 3) # uneven dimensions
]
) + tuple( # Validate window_strides
_make_select_and_scatter_add_harness("window_strides",
window_strides=window_strides)
for window_strides in [
(1, 2, 3) # smaller than/same as/bigger than corresponding window dimension
]
) + tuple( # Validate dtypes on TPU
_make_select_and_scatter_add_harness("tpu_dtypes", dtype=dtype,
nb_inactive_dims=nb_inactive_dims,
window_strides=window_strides,
window_dimensions=window_dimensions)
for dtype in set(jtu.dtypes.all) - set([np.bool_, np.complex64, np.complex128,
np.int8, np.uint8])
for window_strides, window_dimensions, nb_inactive_dims in [
((1, 2, 1), (1, 3, 1), 2)
]
)
lax_select_and_gather_add = tuple(
# Tests with 2d shapes (see tests.lax_autodiff_test.testReduceWindowGrad)
Harness(f"2d_shape={jtu.format_shape_dtype_string(shape, dtype)}_selectprim={select_prim}_windowdimensions={window_dimensions}_windowstrides={window_strides}_padding={padding}_basedilation={base_dilation}_windowdilation={window_dilation}",

View File

@ -262,6 +262,15 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
custom_assert=custom_assert,
always_custom_assert=True)
@primitive_harness.parameterized(primitive_harness.lax_select_and_scatter_add)
def test_select_and_scatter_add(self, harness: primitive_harness.Harness):
if jtu.device_under_test() == "tpu" and not harness.params["run_on_tpu"]:
raise unittest.SkipTest(
"TODO: select_and_scatter on JAX on TPU only works when the parameters "
"define 2 or more inactive dimensions"
)
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))
@primitive_harness.parameterized(primitive_harness.lax_select_and_gather_add)
@jtu.ignore_warning(category=UserWarning,
message="Using reduced precision for gradient.*")