diff --git a/CHANGELOG.md b/CHANGELOG.md index 6978d8ca2..6378c8fc8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,9 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. tracebacks. * A new traceback filtering mode using `__tracebackhide__` is now enabled by default in sufficiently recent versions of IPython. + * The {func}`jax2tf.convert` supports shape polymorphism even when the + unknown dimensions are used in arithmetic operations, e.g., `jnp.reshape(-1)` + ({jax-issue}`#6827`). * Breaking changes: @@ -31,6 +34,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK. * The {func}`jax2tf.convert` now converts `lax.dot_general` using the `XlaDot` TensorFlow op, for better fidelity w.r.t. JAX numerical precision ({jax-issue}`#6717`). + * The {func}`jax2tf.convert` now has support for inequality comparisons and + min/max for complex numbers ({jax-issue}`#6892`). ## jaxlib 0.1.67 (unreleased) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 425f23c45..bb6dcf041 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2783,27 +2783,35 @@ def _broadcasting_select(c, which, x, y): return xops.Select(which, x, y) -def _minmax_translation_rule(c, x, y, *, minmax=None, cmp=None): +def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x): + result_shape = broadcast_shapes(np.shape(x), np.shape(y)) + x = _maybe_broadcast(result_shape, x) + y = _maybe_broadcast(result_shape, y) + rx = real(x) + ry = real(y) + pick_x = select(eq(rx, ry), lax_cmp_pick_x(imag(x), imag(y)), + lax_cmp_pick_x(rx, ry)) + return select(pick_x, x, y) + +def _minmax_translation_rule(c, x, y, *, op_minmax=None, lax_cmp_pick_x=None): dtype = c.get_shape(x).numpy_dtype() if dtypes.issubdtype(dtype, np.complexfloating): - rx = xops.Real(x) - ry = xops.Real(y) - return _broadcasting_select( - c, xops.Select(xops.Eq(rx, ry), cmp(xops.Imag(x), xops.Imag(y)), - cmp(rx, ry)), - x, y) - return minmax(x, y) + return xla.lower_fun(partial(_minmax_complex_lowering, + lax_cmp_pick_x=lax_cmp_pick_x), + multiple_results=False)(c, x, y) + else: + return op_minmax(x, y) max_p: core.Primitive = standard_naryop( [_any, _any], 'max', translation_rule=partial( - _minmax_translation_rule, minmax=xops.Max, cmp=xops.Gt)) + _minmax_translation_rule, op_minmax=xops.Max, lax_cmp_pick_x=gt)) ad.defjvp2(max_p, lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)), lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x))) min_p: core.Primitive = standard_naryop( [_any, _any], 'min', translation_rule=partial( - _minmax_translation_rule, minmax=xops.Min, cmp=xops.Lt)) + _minmax_translation_rule, op_minmax=xops.Min, lax_cmp_pick_x=lt)) ad.defjvp2(min_p, lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)), lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x))) diff --git a/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md b/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md index a5fae4665..32b6fb076 100644 --- a/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md +++ b/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md @@ -1,10 +1,10 @@ # Primitives with limited JAX support -*Last generated on: 2021-05-17* (YYYY-MM-DD) +*Last generated on: 2021-06-04* (YYYY-MM-DD) ## Supported data types for primitives -We use a set of 2507 test harnesses to test +We use a set of 2570 test harnesses to test the implementation of 121 numeric JAX primitives. We consider a JAX primitive supported for a particular data type if it is supported on at least one device type. @@ -60,7 +60,7 @@ be updated. | broadcast_in_dim | 19 | all | | | ceil | 4 | floating | bool, complex, integer | | cholesky | 30 | inexact | bool, integer | -| clamp | 17 | floating, integer | bool, complex | +| clamp | 20 | all | | | complex | 4 | float32, float64 | bfloat16, bool, complex, float16, integer | | concatenate | 17 | all | | | conj | 5 | complex, float32, float64 | bfloat16, bool, float16, integer | @@ -90,28 +90,28 @@ be updated. | fft | 20 | complex, float32, float64 | bfloat16, bool, float16, integer | | floor | 4 | floating | bool, complex, integer | | gather | 37 | all | | -| ge | 15 | bool, floating, integer | complex | -| gt | 15 | bool, floating, integer | complex | +| ge | 17 | all | | +| gt | 17 | all | | | igamma | 6 | floating | bool, complex, integer | | igammac | 6 | floating | bool, complex, integer | | imag | 2 | complex | bool, floating, integer | | integer_pow | 108 | inexact, integer | bool | | iota | 16 | inexact, integer | bool | | is_finite | 4 | floating | bool, complex, integer | -| le | 15 | bool, floating, integer | complex | +| le | 17 | all | | | lgamma | 4 | floating | bool, complex, integer | | log | 6 | inexact | bool, integer | | log1p | 6 | inexact | bool, integer | -| lt | 15 | bool, floating, integer | complex | +| lt | 17 | all | | | lu | 18 | inexact | bool, integer | -| max | 29 | all | | -| min | 29 | all | | +| max | 33 | all | | +| min | 33 | all | | | mul | 16 | inexact, integer | bool | | ne | 17 | all | | | neg | 14 | inexact, integer | bool | | nextafter | 6 | floating | bool, complex, integer | | or | 11 | bool, integer | inexact | -| pad | 90 | all | | +| pad | 120 | all | | | population_count | 8 | integer | bool, inexact | | pow | 10 | inexact | bool, integer | | qr | 60 | inexact | bool, integer | @@ -144,7 +144,7 @@ be updated. | shift_left | 10 | integer | bool, inexact | | shift_right_arithmetic | 10 | integer | bool, inexact | | shift_right_logical | 10 | integer | bool, inexact | -| sign | 14 | inexact, integer | bool | +| sign | 28 | inexact, integer | bool | | sin | 6 | inexact | bool, integer | | sinh | 6 | inexact | bool, integer | | slice | 24 | all | | @@ -184,6 +184,7 @@ and search for "limitation". | Affected primitive | Description of limitation | Affected dtypes | Affected devices | | --- | --- | --- | --- | |cholesky|unimplemented|float16|cpu, gpu| +|clamp|unimplemented|bool, complex|cpu, gpu, tpu| |conv_general_dilated|preferred_element_type not implemented for integers|int16, int32, int8|gpu| |conv_general_dilated|preferred_element_type=c128 not implemented|complex64|tpu| |conv_general_dilated|preferred_element_type=f64 not implemented|bfloat16, float16, float32|tpu| diff --git a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md index 2f55a0919..eccb1765e 100644 --- a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md +++ b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md @@ -1,6 +1,6 @@ # Primitives with limited support for jax2tf -*Last generated on (YYYY-MM-DD): 2021-06-01* +*Last generated on (YYYY-MM-DD): 2021-06-04* This document summarizes known limitations of the jax2tf conversion. There are several kinds of limitations. @@ -34,10 +34,7 @@ On TPU only the "compiled" mode is relevant. Our priority is to ensure same coverage and numerical behavior with JAX in the "compiled" mode, **when using XLA to compile the converted program**. -We are pretty close to that goal. In addition to a few loose ends, there is a known -coverage problem due to JAX and XLA supporting inequality comparisons and min/max for -booleans and complex numbers. It is not clear that TensorFlow will be extended to -support these. +We are pretty close to that goal. This table only shows errors for cases that are working in JAX (see [separate list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) @@ -67,6 +64,7 @@ More detailed information can be found in the | cholesky | TF test skipped: Not implemented in JAX: unimplemented | float16 | cpu, gpu | compiled, eager, graph | | cholesky | TF error: function not compilable | complex | cpu, gpu | compiled | | cholesky | TF error: op not defined for dtype | complex | tpu | compiled, graph | +| clamp | TF test skipped: Not implemented in JAX: unimplemented | bool, complex | cpu, gpu, tpu | compiled, eager, graph | | clamp | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph | | conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type not implemented for integers | int16, int32, int8 | gpu | compiled, eager, graph | | conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=c128 not implemented | complex64 | tpu | compiled, eager, graph | @@ -104,17 +102,16 @@ More detailed information can be found in the | lt | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph | | lu | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph | | lu | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph | -| max | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph | -| min | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph | +| max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph | +| min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph | | neg | TF error: op not defined for dtype | unsigned | cpu, gpu, tpu | compiled, eager, graph | | nextafter | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph | | qr | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph | | qr | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph | | reduce_max | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph | | reduce_min | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph | -| reduce_window_max | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph | -| reduce_window_min | TF error: op not defined for dtype | uint64 | cpu, gpu | eager | -| reduce_window_min | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph | +| reduce_window_max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph | +| reduce_window_min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph | | regularized_incomplete_beta | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph | | rem | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph | | round | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph | @@ -122,9 +119,9 @@ More detailed information can be found in the | scatter_add | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph | | scatter_add | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph | | scatter_max | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph | -| scatter_max | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph | +| scatter_max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph | | scatter_min | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph | -| scatter_min | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph | +| scatter_min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph | | scatter_mul | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph | | scatter_mul | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph | | select_and_gather_add | TF error: jax2tf unimplemented for 64-bit inputs because the current implementation relies on packing two values into a single value. This can be fixed by using a variadic XlaReduceWindow, when available | float64 | cpu, gpu | compiled, eager, graph | diff --git a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template index a9177bdb0..17753a9a6 100644 --- a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template +++ b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template @@ -34,10 +34,7 @@ On TPU only the "compiled" mode is relevant. Our priority is to ensure same coverage and numerical behavior with JAX in the "compiled" mode, **when using XLA to compile the converted program**. -We are pretty close to that goal. In addition to a few loose ends, there is a known -coverage problem due to JAX and XLA supporting inequality comparisons and min/max for -booleans and complex numbers. It is not clear that TensorFlow will be extended to -support these. +We are pretty close to that goal. This table only shows errors for cases that are working in JAX (see [separate list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 14d7d5ea2..ef8beb7fd 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Experimental module transforms JAX functions to be executed by TensorFlow.""" -import functools +from functools import partial import re import string from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union @@ -116,7 +116,7 @@ def _xla_disabled_error(primitive_name: str, msg += f" {extra_msg}" return NotImplementedError(msg) -@functools.partial(api_util.api_hook, tag="jax2tf_convert") +@partial(api_util.api_hook, tag="jax2tf_convert") def convert(fun: Callable, *, polymorphic_shapes: Optional[Sequence[Any]] = None, @@ -293,8 +293,7 @@ def convert(fun: Callable, out_with_avals = _interpret_fun(flat_fun, args_flat, args_avals_flat) outs, out_avals = util.unzip2(out_with_avals) return (tuple(outs), - functools.partial( - converted_grad_fn, _out_cts_avals=tuple(out_avals))) + partial(converted_grad_fn, _out_cts_avals=tuple(out_avals))) out_flat = converted_fun_flat_with_custom_gradient(*args_flat) else: @@ -828,7 +827,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): for unexpected in xla.call_translations: # Call primitives are inlined if unexpected is pjit.pjit_p: continue - tf_impl[unexpected] = functools.partial(_unexpected_primitive, unexpected) + tf_impl[unexpected] = partial(_unexpected_primitive, unexpected) # Primitives that are not yet implemented must be explicitly declared here. tf_not_yet_impl = [ @@ -1045,8 +1044,30 @@ def _rem(lhs, rhs): tf_impl[lax.div_p] = _div tf_impl[lax.rem_p] = _rem -tf_impl[lax.max_p] = tf.math.maximum -tf_impl[lax.min_p] = tf.math.minimum + +def _minmax(x: TfVal, y: TfVal, *, is_min: bool, + _in_avals: Sequence[core.AbstractValue], + _out_aval: core.AbstractValue,) -> TfVal: + # For complex numbers use lexicographic ordering, like JAX + if dtypes.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating): + return _convert_jax_impl( + partial(lax._minmax_complex_lowering, + lax_cmp_pick_x=lax.lt if is_min else lax.gt), + multiple_results=False)(x, y, _in_avals=_in_avals, _out_aval=_out_aval) + else: + return (tf.math.minimum if is_min else tf.math.maximum)(x, y) + +def _minmax_scalar(x: TfVal, y: TfVal, *, is_min: bool) -> TfVal: + # For reducers we will need min/max for scalars only. In that case we + # can construct the AbstractValues outselves, even in the presence of + # shape polymorphism. + assert len(x.shape) == 0 and len(y.shape) == 0, f"x: {x.shape}, y: {y.shape}" + aval = core.ShapedArray((), _to_jax_dtype(x.dtype)) + return _minmax(x, y, is_min=is_min, + _in_avals=[aval, aval], _out_aval=aval) + +tf_impl_with_avals[lax.max_p] = partial(_minmax, is_min=False) +tf_impl_with_avals[lax.min_p] = partial(_minmax, is_min=True) # Map from TF signed types to TF unsigned types. _SIGNED_TO_UNSIGNED_TABLE = { @@ -1659,8 +1680,8 @@ def _argminmax(fn, operand, axes, index_dtype): return tf.cast(result, _to_tf_dtype(index_dtype)) -tf_impl[lax.argmin_p] = functools.partial(_argminmax, tf.math.argmin) -tf_impl[lax.argmax_p] = functools.partial(_argminmax, tf.math.argmax) +tf_impl[lax.argmin_p] = partial(_argminmax, tf.math.argmin) +tf_impl[lax.argmax_p] = partial(_argminmax, tf.math.argmax) _add_fn = tf.function(_add, autograph=False) _ge_fn = tf.function(tf.math.greater_equal, autograph=False) @@ -1947,21 +1968,18 @@ def _get_min_identity(tf_dtype): # pylint: disable=protected-access tf_impl_with_avals[lax.reduce_window_sum_p] = ( - functools.partial( - _specialized_reduce_window, _add, lambda x: 0, - name="reduce_window_sum")) + partial(_specialized_reduce_window, _add, lambda x: 0, + name="reduce_window_sum")) tf_impl_with_avals[lax.reduce_window_min_p] = ( - functools.partial( - _specialized_reduce_window, - tf.math.minimum, - _get_min_identity, - name="reduce_window_min")) + partial(_specialized_reduce_window, + partial(_minmax_scalar, is_min=True), + _get_min_identity, + name="reduce_window_min")) tf_impl_with_avals[lax.reduce_window_max_p] = ( - functools.partial( - _specialized_reduce_window, - tf.math.maximum, - _get_max_identity, - name="reduce_window_max")) + partial(_specialized_reduce_window, + partial(_minmax_scalar, is_min=False), + _get_max_identity, + name="reduce_window_max")) tf_impl_with_avals[lax.reduce_window_p] = _reduce_window # pylint: enable=protected-access @@ -1970,11 +1988,11 @@ tf_impl_with_avals[lax.reduce_window_p] = _reduce_window # O(n^2) on other backends. This may be implemented using associative_scan # instead to favor different backends. tf_impl_with_avals[lax_control_flow.cummin_p] = _convert_jax_impl( - functools.partial(lax_control_flow._cumred_tpu_translation_rule, + partial(lax_control_flow._cumred_tpu_translation_rule, lax._reduce_window_min), multiple_results=False) tf_impl_with_avals[lax_control_flow.cummax_p] = _convert_jax_impl( - functools.partial(lax_control_flow._cumred_tpu_translation_rule, + partial(lax_control_flow._cumred_tpu_translation_rule, lax._reduce_window_max), multiple_results=False) # TODO(bchetioui): cumsum and cumprod can be converted using pure TF ops for @@ -1983,11 +2001,11 @@ tf_impl_with_avals[lax_control_flow.cummax_p] = _convert_jax_impl( # the operation. A non-XLA path can thus be defined for all dtypes, though the # tests will crash. tf_impl_with_avals[lax_control_flow.cumsum_p] = _convert_jax_impl( - functools.partial(lax_control_flow._cumred_tpu_translation_rule, + partial(lax_control_flow._cumred_tpu_translation_rule, lax._reduce_window_sum), multiple_results=False) tf_impl_with_avals[lax_control_flow.cumprod_p] = _convert_jax_impl( - functools.partial(lax_control_flow._cumred_tpu_translation_rule, + partial(lax_control_flow._cumred_tpu_translation_rule, lax._reduce_window_prod), multiple_results=False) @@ -2001,7 +2019,7 @@ def _select_and_scatter(operand, source, init_value, select_jaxpr, tf_impl[lax.select_and_scatter_p] = _select_and_scatter -@functools.partial(bool_to_int8, argnums=(0, 1)) +@partial(bool_to_int8, argnums=(0, 1)) def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions, window_strides, padding, _in_avals, _out_aval): if not _enable_xla: @@ -2023,8 +2041,7 @@ tf_impl_with_avals[lax.select_and_scatter_add_p] = _select_and_scatter_add def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval): res = _convert_jax_impl( - functools.partial( - jax._src.random._threefry2x32_lowering, use_rolled_loops=False), + partial(jax._src.random._threefry2x32_lowering, use_rolled_loops=False), multiple_results=True)( *args, _in_avals=_in_avals, _out_aval=_out_aval) return res @@ -2035,7 +2052,7 @@ tf_impl_with_avals[jax.random.threefry2x32_p] = _threefry2x32_jax_impl # Use the vmap implementation, otherwise on TPU the performance is really bad # With use_vmap=True on, we get about the same performance for JAX and jax2tf. tf_impl_with_avals[random.random_gamma_p] = _convert_jax_impl( - functools.partial(jax._src.random._gamma_impl, use_vmap=True), + partial(jax._src.random._gamma_impl, use_vmap=True), multiple_results=False) @@ -2049,7 +2066,7 @@ def _gather_dimensions_proto(indices_shape, dimension_numbers): return proto -@functools.partial(bool_to_int8, argnums=0) +@partial(bool_to_int8, argnums=0) def _gather(operand, start_indices, *, dimension_numbers, slice_sizes, _in_avals, _out_aval): """Tensorflow implementation of gather.""" @@ -2171,7 +2188,7 @@ def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr], del linear # tf.cond needs lambdas with no arguments. branches_tf = [ - functools.partial(_interpret_jaxpr, jaxpr, *operands) + partial(_interpret_jaxpr, jaxpr, *operands) for jaxpr in branches ] return tf.switch_case(index, branches_tf) @@ -2198,7 +2215,7 @@ def _while(*args: TfVal, cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr, pred, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *args) return pred - body_tf_func = functools.partial(_interpret_jaxpr, body_jaxpr, *body_consts) + body_tf_func = partial(_interpret_jaxpr, body_jaxpr, *body_consts) return tf.while_loop(cond_tf_func, body_tf_func, init_carry) @@ -2586,7 +2603,7 @@ def _pjit(*args: TfVal, _out_aval: core.ShapedArray) -> TfVal: del donated_invars, name # TODO: add `name` to the name stack - shard_value_for_mesh = functools.partial(_shard_value, resource_env.physical_mesh) + shard_value_for_mesh = partial(_shard_value, resource_env.physical_mesh) # Apply sharding annotation to the arguments sharded_args: Sequence[TfVal] = tuple( map(shard_value_for_mesh, args, _in_avals, in_axis_resources)) diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index fab1715e1..fa876c18c 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -882,7 +882,7 @@ class Jax2TfLimitation(primitive_harness.Limitation): return [ missing_tf_kernel( - dtypes=[np.bool_, np.complex64, np.complex128]), + dtypes=[np.bool_]), custom_numeric( custom_assert=custom_assert, description=( @@ -901,7 +901,7 @@ class Jax2TfLimitation(primitive_harness.Limitation): tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg) return [ - missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]), + missing_tf_kernel(dtypes=[np.bool_]), custom_numeric( custom_assert=custom_assert, description=( @@ -966,9 +966,8 @@ class Jax2TfLimitation(primitive_harness.Limitation): @classmethod def reduce_min(cls, harness: primitive_harness.Harness): - return [ - missing_tf_kernel(dtypes=[np.complex64, np.complex128]), - ] + return cls.reduce_max(harness) + @classmethod def reduce_window_add(cls, harness): @@ -979,14 +978,14 @@ class Jax2TfLimitation(primitive_harness.Limitation): def reduce_window_max(cls, harness): assert "max" == harness.params["computation"].__name__ return [ - missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]), + missing_tf_kernel(dtypes=[np.bool_]), ] @classmethod def reduce_window_min(cls, harness): assert "min" == harness.params["computation"].__name__ return [ - missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]), + missing_tf_kernel(dtypes=[np.bool_]), ] @classmethod @@ -1045,13 +1044,13 @@ class Jax2TfLimitation(primitive_harness.Limitation): @classmethod def scatter_max(cls, harness): return [ - missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]), + missing_tf_kernel(dtypes=[np.bool_]), ] @classmethod def scatter_min(cls, harness): return [ - missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]), + missing_tf_kernel(dtypes=[np.bool_]), ] @classmethod diff --git a/jax/experimental/jax2tf/tests/primitive_harness.py b/jax/experimental/jax2tf/tests/primitive_harness.py index ff3c26a6f..78318cfb7 100644 --- a/jax/experimental/jax2tf/tests/primitive_harness.py +++ b/jax/experimental/jax2tf/tests/primitive_harness.py @@ -718,32 +718,36 @@ for dtype in jtu.dtypes.all: shape=shape, dtype=dtype) -_LAX_COMPARATORS = (lax.eq_p, lax.ge_p, lax.gt_p, lax.le_p, lax.lt_p, lax.ne_p) +_LAX_COMPARATORS = dict(eq=jnp.equal, ne=jnp.not_equal, + ge=jnp.greater_equal, gt=jnp.greater, + le=jnp.less_equal, lt=jnp.less) def _make_comparator_harness(name, *, dtype=np.float32, op=lax.eq_p, + op_name="eq", lhs_shape=(), rhs_shape=()): define( - op.name, + op_name, f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}", - lambda *args: op.bind(*args), + lambda *args: op(*args), [RandArg(lhs_shape, dtype), RandArg(rhs_shape, dtype)], op=op, + op_name=op_name, lhs_shape=lhs_shape, rhs_shape=rhs_shape, dtype=dtype) -for op in _LAX_COMPARATORS: +for op_name, op in _LAX_COMPARATORS.items(): for dtype in (jtu.dtypes.all if op in [lax.eq_p, lax.ne_p] else - set(jtu.dtypes.all) - set(jtu.dtypes.complex)): + set(jtu.dtypes.all)): # Validate dtypes - _make_comparator_harness("dtypes", dtype=dtype, op=op) + _make_comparator_harness("dtypes", dtype=dtype, op=op, op_name=op_name) # Validate broadcasting behavior for lhs_shape, rhs_shape in [ @@ -751,7 +755,8 @@ for op in _LAX_COMPARATORS: ((1, 2), (3, 2)), # broadcast along specific axis ]: _make_comparator_harness( - "broadcasting", lhs_shape=lhs_shape, rhs_shape=rhs_shape, op=op) + "broadcasting", lhs_shape=lhs_shape, rhs_shape=rhs_shape, + op=op, op_name=op_name) for dtype in jtu.dtypes.all: shape = (3, 4, 5) @@ -917,6 +922,7 @@ for prim in [lax.div_p, lax.rem_p]: def _make_binary_elementwise_harnesses(prim, dtypes, default_dtype=np.float32, + broadcasting_dtypes=None, jax_unimplemented=lambda **kwargs: []): def _make(name, *, shapes=((20, 20), (20, 20)), dtype): @@ -931,15 +937,18 @@ def _make_binary_elementwise_harnesses(prim, prim=prim, dtype=dtype, shapes=shapes) - - return (tuple( # Validate dtypes - _make("dtypes", dtype=dtype) - for dtype in dtypes) + tuple( # Validate broadcasting - _make("broadcasting", dtype=default_dtype, shapes=shapes) - for shapes in [ + broadcasting_dtypes = broadcasting_dtypes or (default_dtype,) + return ( + # Validate dtypes + tuple(_make("dtypes", dtype=dtype) for dtype in dtypes) + + # Validate broadcasting + tuple(_make("broadcasting", dtype=dtype, shapes=shapes) + for shapes in [ ((20, 20), (1, 20)), # broadcasting rhs ((1, 20), (20, 20)), # broadcasting lhs - ])) + ] + for dtype in broadcasting_dtypes) + ) _make_binary_elementwise_harnesses( @@ -1004,7 +1013,9 @@ _min_max_special_cases = tuple( (np.array([-np.inf, -np.inf], dtype=dtype), np.array([np.nan, np.nan], dtype=dtype))]) -_make_binary_elementwise_harnesses(prim=lax.min_p, dtypes=jtu.dtypes.all) +_make_binary_elementwise_harnesses( + prim=lax.min_p, dtypes=jtu.dtypes.all, + broadcasting_dtypes=(np.float32, np.complex64, np.complex128)) # Validate special cases for lhs, rhs in _min_max_special_cases: define( @@ -1014,7 +1025,9 @@ for lhs, rhs in _min_max_special_cases: prim=lax.min_p, dtype=lhs.dtype) -_make_binary_elementwise_harnesses(prim=lax.max_p, dtypes=jtu.dtypes.all) +_make_binary_elementwise_harnesses( + prim=lax.max_p, dtypes=jtu.dtypes.all, + broadcasting_dtypes=(np.float32, np.complex64, np.complex128)) # Validate special cases for lhs, rhs in _min_max_special_cases: define( @@ -2336,10 +2349,15 @@ def _make_clamp_harness(name, min_shape=min_arr.shape, operand_shape=operand_shape, max_shape=max_arr.shape, - dtype=dtype) + dtype=dtype, + jax_unimplemented=[ + Limitation( + "unimplemented", + dtypes=[np.bool_, np.complex64, np.complex128])], + ) -for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex + [np.bool_]): +for dtype in set(jtu.dtypes.all): _make_clamp_harness("dtypes", dtype=dtype) # Validate broadcasting of min/max arrays diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index fa5352134..b38b3e99d 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -99,8 +99,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): # If you want to run this test for only one harness, add parameter # `one_containing="foo"` to parameterized below. @primitive_harness.parameterized( - primitive_harness.all_harnesses, include_jax_unimpl=False, - ) + primitive_harness.all_harnesses, include_jax_unimpl=False) @jtu.ignore_warning( category=UserWarning, message="Using reduced precision for gradient.*") def test_prim(self, harness: primitive_harness.Harness):