[jax2tf] Implement inequalities and friends for complex numbers.

This requires re-using JAX's lowering rule for comparisons of
complex numbers to use lexicographic comparison.
This commit is contained in:
George Necula 2021-06-04 11:02:50 +03:00
parent de9f55720d
commit d243258b86
9 changed files with 141 additions and 100 deletions

View File

@ -16,6 +16,9 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
tracebacks.
* A new traceback filtering mode using `__tracebackhide__` is now enabled by
default in sufficiently recent versions of IPython.
* The {func}`jax2tf.convert` supports shape polymorphism even when the
unknown dimensions are used in arithmetic operations, e.g., `jnp.reshape(-1)`
({jax-issue}`#6827`).
* Breaking changes:
@ -31,6 +34,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
* The {func}`jax2tf.convert` now converts `lax.dot_general` using the
`XlaDot` TensorFlow op, for better fidelity w.r.t. JAX numerical precision
({jax-issue}`#6717`).
* The {func}`jax2tf.convert` now has support for inequality comparisons and
min/max for complex numbers ({jax-issue}`#6892`).
## jaxlib 0.1.67 (unreleased)

View File

@ -2783,27 +2783,35 @@ def _broadcasting_select(c, which, x, y):
return xops.Select(which, x, y)
def _minmax_translation_rule(c, x, y, *, minmax=None, cmp=None):
def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x):
result_shape = broadcast_shapes(np.shape(x), np.shape(y))
x = _maybe_broadcast(result_shape, x)
y = _maybe_broadcast(result_shape, y)
rx = real(x)
ry = real(y)
pick_x = select(eq(rx, ry), lax_cmp_pick_x(imag(x), imag(y)),
lax_cmp_pick_x(rx, ry))
return select(pick_x, x, y)
def _minmax_translation_rule(c, x, y, *, op_minmax=None, lax_cmp_pick_x=None):
dtype = c.get_shape(x).numpy_dtype()
if dtypes.issubdtype(dtype, np.complexfloating):
rx = xops.Real(x)
ry = xops.Real(y)
return _broadcasting_select(
c, xops.Select(xops.Eq(rx, ry), cmp(xops.Imag(x), xops.Imag(y)),
cmp(rx, ry)),
x, y)
return minmax(x, y)
return xla.lower_fun(partial(_minmax_complex_lowering,
lax_cmp_pick_x=lax_cmp_pick_x),
multiple_results=False)(c, x, y)
else:
return op_minmax(x, y)
max_p: core.Primitive = standard_naryop(
[_any, _any], 'max', translation_rule=partial(
_minmax_translation_rule, minmax=xops.Max, cmp=xops.Gt))
_minmax_translation_rule, op_minmax=xops.Max, lax_cmp_pick_x=gt))
ad.defjvp2(max_p,
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
min_p: core.Primitive = standard_naryop(
[_any, _any], 'min', translation_rule=partial(
_minmax_translation_rule, minmax=xops.Min, cmp=xops.Lt))
_minmax_translation_rule, op_minmax=xops.Min, lax_cmp_pick_x=lt))
ad.defjvp2(min_p,
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))

View File

@ -1,10 +1,10 @@
# Primitives with limited JAX support
*Last generated on: 2021-05-17* (YYYY-MM-DD)
*Last generated on: 2021-06-04* (YYYY-MM-DD)
## Supported data types for primitives
We use a set of 2507 test harnesses to test
We use a set of 2570 test harnesses to test
the implementation of 121 numeric JAX primitives.
We consider a JAX primitive supported for a particular data
type if it is supported on at least one device type.
@ -60,7 +60,7 @@ be updated.
| broadcast_in_dim | 19 | all | |
| ceil | 4 | floating | bool, complex, integer |
| cholesky | 30 | inexact | bool, integer |
| clamp | 17 | floating, integer | bool, complex |
| clamp | 20 | all | |
| complex | 4 | float32, float64 | bfloat16, bool, complex, float16, integer |
| concatenate | 17 | all | |
| conj | 5 | complex, float32, float64 | bfloat16, bool, float16, integer |
@ -90,28 +90,28 @@ be updated.
| fft | 20 | complex, float32, float64 | bfloat16, bool, float16, integer |
| floor | 4 | floating | bool, complex, integer |
| gather | 37 | all | |
| ge | 15 | bool, floating, integer | complex |
| gt | 15 | bool, floating, integer | complex |
| ge | 17 | all | |
| gt | 17 | all | |
| igamma | 6 | floating | bool, complex, integer |
| igammac | 6 | floating | bool, complex, integer |
| imag | 2 | complex | bool, floating, integer |
| integer_pow | 108 | inexact, integer | bool |
| iota | 16 | inexact, integer | bool |
| is_finite | 4 | floating | bool, complex, integer |
| le | 15 | bool, floating, integer | complex |
| le | 17 | all | |
| lgamma | 4 | floating | bool, complex, integer |
| log | 6 | inexact | bool, integer |
| log1p | 6 | inexact | bool, integer |
| lt | 15 | bool, floating, integer | complex |
| lt | 17 | all | |
| lu | 18 | inexact | bool, integer |
| max | 29 | all | |
| min | 29 | all | |
| max | 33 | all | |
| min | 33 | all | |
| mul | 16 | inexact, integer | bool |
| ne | 17 | all | |
| neg | 14 | inexact, integer | bool |
| nextafter | 6 | floating | bool, complex, integer |
| or | 11 | bool, integer | inexact |
| pad | 90 | all | |
| pad | 120 | all | |
| population_count | 8 | integer | bool, inexact |
| pow | 10 | inexact | bool, integer |
| qr | 60 | inexact | bool, integer |
@ -144,7 +144,7 @@ be updated.
| shift_left | 10 | integer | bool, inexact |
| shift_right_arithmetic | 10 | integer | bool, inexact |
| shift_right_logical | 10 | integer | bool, inexact |
| sign | 14 | inexact, integer | bool |
| sign | 28 | inexact, integer | bool |
| sin | 6 | inexact | bool, integer |
| sinh | 6 | inexact | bool, integer |
| slice | 24 | all | |
@ -184,6 +184,7 @@ and search for "limitation".
| Affected primitive | Description of limitation | Affected dtypes | Affected devices |
| --- | --- | --- | --- |
|cholesky|unimplemented|float16|cpu, gpu|
|clamp|unimplemented|bool, complex|cpu, gpu, tpu|
|conv_general_dilated|preferred_element_type not implemented for integers|int16, int32, int8|gpu|
|conv_general_dilated|preferred_element_type=c128 not implemented|complex64|tpu|
|conv_general_dilated|preferred_element_type=f64 not implemented|bfloat16, float16, float32|tpu|

View File

@ -1,6 +1,6 @@
# Primitives with limited support for jax2tf
*Last generated on (YYYY-MM-DD): 2021-06-01*
*Last generated on (YYYY-MM-DD): 2021-06-04*
This document summarizes known limitations of the jax2tf conversion.
There are several kinds of limitations.
@ -34,10 +34,7 @@ On TPU only the "compiled" mode is relevant.
Our priority is to ensure same coverage and numerical behavior with JAX
in the "compiled" mode, **when using XLA to compile the converted program**.
We are pretty close to that goal. In addition to a few loose ends, there is a known
coverage problem due to JAX and XLA supporting inequality comparisons and min/max for
booleans and complex numbers. It is not clear that TensorFlow will be extended to
support these.
We are pretty close to that goal.
This table only shows errors for cases that are working in JAX (see [separate
list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) )
@ -67,6 +64,7 @@ More detailed information can be found in the
| cholesky | TF test skipped: Not implemented in JAX: unimplemented | float16 | cpu, gpu | compiled, eager, graph |
| cholesky | TF error: function not compilable | complex | cpu, gpu | compiled |
| cholesky | TF error: op not defined for dtype | complex | tpu | compiled, graph |
| clamp | TF test skipped: Not implemented in JAX: unimplemented | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
| clamp | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type not implemented for integers | int16, int32, int8 | gpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=c128 not implemented | complex64 | tpu | compiled, eager, graph |
@ -104,17 +102,16 @@ More detailed information can be found in the
| lt | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| lu | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
| lu | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
| max | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
| min | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
| max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| neg | TF error: op not defined for dtype | unsigned | cpu, gpu, tpu | compiled, eager, graph |
| nextafter | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
| qr | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
| qr | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph |
| reduce_max | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
| reduce_min | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
| reduce_window_max | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
| reduce_window_min | TF error: op not defined for dtype | uint64 | cpu, gpu | eager |
| reduce_window_min | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
| reduce_window_max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| reduce_window_min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| regularized_incomplete_beta | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
| rem | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
| round | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
@ -122,9 +119,9 @@ More detailed information can be found in the
| scatter_add | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_add | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
| scatter_max | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
| scatter_max | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
| scatter_max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_min | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
| scatter_min | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
| scatter_min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_mul | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_mul | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
| select_and_gather_add | TF error: jax2tf unimplemented for 64-bit inputs because the current implementation relies on packing two values into a single value. This can be fixed by using a variadic XlaReduceWindow, when available | float64 | cpu, gpu | compiled, eager, graph |

View File

@ -34,10 +34,7 @@ On TPU only the "compiled" mode is relevant.
Our priority is to ensure same coverage and numerical behavior with JAX
in the "compiled" mode, **when using XLA to compile the converted program**.
We are pretty close to that goal. In addition to a few loose ends, there is a known
coverage problem due to JAX and XLA supporting inequality comparisons and min/max for
booleans and complex numbers. It is not clear that TensorFlow will be extended to
support these.
We are pretty close to that goal.
This table only shows errors for cases that are working in JAX (see [separate
list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) )

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental module transforms JAX functions to be executed by TensorFlow."""
import functools
from functools import partial
import re
import string
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
@ -116,7 +116,7 @@ def _xla_disabled_error(primitive_name: str,
msg += f" {extra_msg}"
return NotImplementedError(msg)
@functools.partial(api_util.api_hook, tag="jax2tf_convert")
@partial(api_util.api_hook, tag="jax2tf_convert")
def convert(fun: Callable,
*,
polymorphic_shapes: Optional[Sequence[Any]] = None,
@ -293,8 +293,7 @@ def convert(fun: Callable,
out_with_avals = _interpret_fun(flat_fun, args_flat, args_avals_flat)
outs, out_avals = util.unzip2(out_with_avals)
return (tuple(outs),
functools.partial(
converted_grad_fn, _out_cts_avals=tuple(out_avals)))
partial(converted_grad_fn, _out_cts_avals=tuple(out_avals)))
out_flat = converted_fun_flat_with_custom_gradient(*args_flat)
else:
@ -828,7 +827,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
for unexpected in xla.call_translations: # Call primitives are inlined
if unexpected is pjit.pjit_p:
continue
tf_impl[unexpected] = functools.partial(_unexpected_primitive, unexpected)
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
# Primitives that are not yet implemented must be explicitly declared here.
tf_not_yet_impl = [
@ -1045,8 +1044,30 @@ def _rem(lhs, rhs):
tf_impl[lax.div_p] = _div
tf_impl[lax.rem_p] = _rem
tf_impl[lax.max_p] = tf.math.maximum
tf_impl[lax.min_p] = tf.math.minimum
def _minmax(x: TfVal, y: TfVal, *, is_min: bool,
_in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue,) -> TfVal:
# For complex numbers use lexicographic ordering, like JAX
if dtypes.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating):
return _convert_jax_impl(
partial(lax._minmax_complex_lowering,
lax_cmp_pick_x=lax.lt if is_min else lax.gt),
multiple_results=False)(x, y, _in_avals=_in_avals, _out_aval=_out_aval)
else:
return (tf.math.minimum if is_min else tf.math.maximum)(x, y)
def _minmax_scalar(x: TfVal, y: TfVal, *, is_min: bool) -> TfVal:
# For reducers we will need min/max for scalars only. In that case we
# can construct the AbstractValues outselves, even in the presence of
# shape polymorphism.
assert len(x.shape) == 0 and len(y.shape) == 0, f"x: {x.shape}, y: {y.shape}"
aval = core.ShapedArray((), _to_jax_dtype(x.dtype))
return _minmax(x, y, is_min=is_min,
_in_avals=[aval, aval], _out_aval=aval)
tf_impl_with_avals[lax.max_p] = partial(_minmax, is_min=False)
tf_impl_with_avals[lax.min_p] = partial(_minmax, is_min=True)
# Map from TF signed types to TF unsigned types.
_SIGNED_TO_UNSIGNED_TABLE = {
@ -1659,8 +1680,8 @@ def _argminmax(fn, operand, axes, index_dtype):
return tf.cast(result, _to_tf_dtype(index_dtype))
tf_impl[lax.argmin_p] = functools.partial(_argminmax, tf.math.argmin)
tf_impl[lax.argmax_p] = functools.partial(_argminmax, tf.math.argmax)
tf_impl[lax.argmin_p] = partial(_argminmax, tf.math.argmin)
tf_impl[lax.argmax_p] = partial(_argminmax, tf.math.argmax)
_add_fn = tf.function(_add, autograph=False)
_ge_fn = tf.function(tf.math.greater_equal, autograph=False)
@ -1947,21 +1968,18 @@ def _get_min_identity(tf_dtype):
# pylint: disable=protected-access
tf_impl_with_avals[lax.reduce_window_sum_p] = (
functools.partial(
_specialized_reduce_window, _add, lambda x: 0,
name="reduce_window_sum"))
partial(_specialized_reduce_window, _add, lambda x: 0,
name="reduce_window_sum"))
tf_impl_with_avals[lax.reduce_window_min_p] = (
functools.partial(
_specialized_reduce_window,
tf.math.minimum,
_get_min_identity,
name="reduce_window_min"))
partial(_specialized_reduce_window,
partial(_minmax_scalar, is_min=True),
_get_min_identity,
name="reduce_window_min"))
tf_impl_with_avals[lax.reduce_window_max_p] = (
functools.partial(
_specialized_reduce_window,
tf.math.maximum,
_get_max_identity,
name="reduce_window_max"))
partial(_specialized_reduce_window,
partial(_minmax_scalar, is_min=False),
_get_max_identity,
name="reduce_window_max"))
tf_impl_with_avals[lax.reduce_window_p] = _reduce_window
# pylint: enable=protected-access
@ -1970,11 +1988,11 @@ tf_impl_with_avals[lax.reduce_window_p] = _reduce_window
# O(n^2) on other backends. This may be implemented using associative_scan
# instead to favor different backends.
tf_impl_with_avals[lax_control_flow.cummin_p] = _convert_jax_impl(
functools.partial(lax_control_flow._cumred_tpu_translation_rule,
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_min),
multiple_results=False)
tf_impl_with_avals[lax_control_flow.cummax_p] = _convert_jax_impl(
functools.partial(lax_control_flow._cumred_tpu_translation_rule,
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_max),
multiple_results=False)
# TODO(bchetioui): cumsum and cumprod can be converted using pure TF ops for
@ -1983,11 +2001,11 @@ tf_impl_with_avals[lax_control_flow.cummax_p] = _convert_jax_impl(
# the operation. A non-XLA path can thus be defined for all dtypes, though the
# tests will crash.
tf_impl_with_avals[lax_control_flow.cumsum_p] = _convert_jax_impl(
functools.partial(lax_control_flow._cumred_tpu_translation_rule,
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_sum),
multiple_results=False)
tf_impl_with_avals[lax_control_flow.cumprod_p] = _convert_jax_impl(
functools.partial(lax_control_flow._cumred_tpu_translation_rule,
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_prod),
multiple_results=False)
@ -2001,7 +2019,7 @@ def _select_and_scatter(operand, source, init_value, select_jaxpr,
tf_impl[lax.select_and_scatter_p] = _select_and_scatter
@functools.partial(bool_to_int8, argnums=(0, 1))
@partial(bool_to_int8, argnums=(0, 1))
def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
window_strides, padding, _in_avals, _out_aval):
if not _enable_xla:
@ -2023,8 +2041,7 @@ tf_impl_with_avals[lax.select_and_scatter_add_p] = _select_and_scatter_add
def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval):
res = _convert_jax_impl(
functools.partial(
jax._src.random._threefry2x32_lowering, use_rolled_loops=False),
partial(jax._src.random._threefry2x32_lowering, use_rolled_loops=False),
multiple_results=True)(
*args, _in_avals=_in_avals, _out_aval=_out_aval)
return res
@ -2035,7 +2052,7 @@ tf_impl_with_avals[jax.random.threefry2x32_p] = _threefry2x32_jax_impl
# Use the vmap implementation, otherwise on TPU the performance is really bad
# With use_vmap=True on, we get about the same performance for JAX and jax2tf.
tf_impl_with_avals[random.random_gamma_p] = _convert_jax_impl(
functools.partial(jax._src.random._gamma_impl, use_vmap=True),
partial(jax._src.random._gamma_impl, use_vmap=True),
multiple_results=False)
@ -2049,7 +2066,7 @@ def _gather_dimensions_proto(indices_shape, dimension_numbers):
return proto
@functools.partial(bool_to_int8, argnums=0)
@partial(bool_to_int8, argnums=0)
def _gather(operand, start_indices, *, dimension_numbers, slice_sizes,
_in_avals, _out_aval):
"""Tensorflow implementation of gather."""
@ -2171,7 +2188,7 @@ def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr],
del linear
# tf.cond needs lambdas with no arguments.
branches_tf = [
functools.partial(_interpret_jaxpr, jaxpr, *operands)
partial(_interpret_jaxpr, jaxpr, *operands)
for jaxpr in branches
]
return tf.switch_case(index, branches_tf)
@ -2198,7 +2215,7 @@ def _while(*args: TfVal, cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr,
pred, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *args)
return pred
body_tf_func = functools.partial(_interpret_jaxpr, body_jaxpr, *body_consts)
body_tf_func = partial(_interpret_jaxpr, body_jaxpr, *body_consts)
return tf.while_loop(cond_tf_func, body_tf_func, init_carry)
@ -2586,7 +2603,7 @@ def _pjit(*args: TfVal,
_out_aval: core.ShapedArray) -> TfVal:
del donated_invars, name
# TODO: add `name` to the name stack
shard_value_for_mesh = functools.partial(_shard_value, resource_env.physical_mesh)
shard_value_for_mesh = partial(_shard_value, resource_env.physical_mesh)
# Apply sharding annotation to the arguments
sharded_args: Sequence[TfVal] = tuple(
map(shard_value_for_mesh, args, _in_avals, in_axis_resources))

View File

@ -882,7 +882,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
return [
missing_tf_kernel(
dtypes=[np.bool_, np.complex64, np.complex128]),
dtypes=[np.bool_]),
custom_numeric(
custom_assert=custom_assert,
description=(
@ -901,7 +901,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg)
return [
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
missing_tf_kernel(dtypes=[np.bool_]),
custom_numeric(
custom_assert=custom_assert,
description=(
@ -966,9 +966,8 @@ class Jax2TfLimitation(primitive_harness.Limitation):
@classmethod
def reduce_min(cls, harness: primitive_harness.Harness):
return [
missing_tf_kernel(dtypes=[np.complex64, np.complex128]),
]
return cls.reduce_max(harness)
@classmethod
def reduce_window_add(cls, harness):
@ -979,14 +978,14 @@ class Jax2TfLimitation(primitive_harness.Limitation):
def reduce_window_max(cls, harness):
assert "max" == harness.params["computation"].__name__
return [
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
missing_tf_kernel(dtypes=[np.bool_]),
]
@classmethod
def reduce_window_min(cls, harness):
assert "min" == harness.params["computation"].__name__
return [
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
missing_tf_kernel(dtypes=[np.bool_]),
]
@classmethod
@ -1045,13 +1044,13 @@ class Jax2TfLimitation(primitive_harness.Limitation):
@classmethod
def scatter_max(cls, harness):
return [
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
missing_tf_kernel(dtypes=[np.bool_]),
]
@classmethod
def scatter_min(cls, harness):
return [
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
missing_tf_kernel(dtypes=[np.bool_]),
]
@classmethod

View File

@ -718,32 +718,36 @@ for dtype in jtu.dtypes.all:
shape=shape,
dtype=dtype)
_LAX_COMPARATORS = (lax.eq_p, lax.ge_p, lax.gt_p, lax.le_p, lax.lt_p, lax.ne_p)
_LAX_COMPARATORS = dict(eq=jnp.equal, ne=jnp.not_equal,
ge=jnp.greater_equal, gt=jnp.greater,
le=jnp.less_equal, lt=jnp.less)
def _make_comparator_harness(name,
*,
dtype=np.float32,
op=lax.eq_p,
op_name="eq",
lhs_shape=(),
rhs_shape=()):
define(
op.name,
op_name,
f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}",
lambda *args: op.bind(*args),
lambda *args: op(*args),
[RandArg(lhs_shape, dtype),
RandArg(rhs_shape, dtype)],
op=op,
op_name=op_name,
lhs_shape=lhs_shape,
rhs_shape=rhs_shape,
dtype=dtype)
for op in _LAX_COMPARATORS:
for op_name, op in _LAX_COMPARATORS.items():
for dtype in (jtu.dtypes.all if op in [lax.eq_p, lax.ne_p] else
set(jtu.dtypes.all) - set(jtu.dtypes.complex)):
set(jtu.dtypes.all)):
# Validate dtypes
_make_comparator_harness("dtypes", dtype=dtype, op=op)
_make_comparator_harness("dtypes", dtype=dtype, op=op, op_name=op_name)
# Validate broadcasting behavior
for lhs_shape, rhs_shape in [
@ -751,7 +755,8 @@ for op in _LAX_COMPARATORS:
((1, 2), (3, 2)), # broadcast along specific axis
]:
_make_comparator_harness(
"broadcasting", lhs_shape=lhs_shape, rhs_shape=rhs_shape, op=op)
"broadcasting", lhs_shape=lhs_shape, rhs_shape=rhs_shape,
op=op, op_name=op_name)
for dtype in jtu.dtypes.all:
shape = (3, 4, 5)
@ -917,6 +922,7 @@ for prim in [lax.div_p, lax.rem_p]:
def _make_binary_elementwise_harnesses(prim,
dtypes,
default_dtype=np.float32,
broadcasting_dtypes=None,
jax_unimplemented=lambda **kwargs: []):
def _make(name, *, shapes=((20, 20), (20, 20)), dtype):
@ -931,15 +937,18 @@ def _make_binary_elementwise_harnesses(prim,
prim=prim,
dtype=dtype,
shapes=shapes)
return (tuple( # Validate dtypes
_make("dtypes", dtype=dtype)
for dtype in dtypes) + tuple( # Validate broadcasting
_make("broadcasting", dtype=default_dtype, shapes=shapes)
for shapes in [
broadcasting_dtypes = broadcasting_dtypes or (default_dtype,)
return (
# Validate dtypes
tuple(_make("dtypes", dtype=dtype) for dtype in dtypes) +
# Validate broadcasting
tuple(_make("broadcasting", dtype=dtype, shapes=shapes)
for shapes in [
((20, 20), (1, 20)), # broadcasting rhs
((1, 20), (20, 20)), # broadcasting lhs
]))
]
for dtype in broadcasting_dtypes)
)
_make_binary_elementwise_harnesses(
@ -1004,7 +1013,9 @@ _min_max_special_cases = tuple(
(np.array([-np.inf, -np.inf], dtype=dtype),
np.array([np.nan, np.nan], dtype=dtype))])
_make_binary_elementwise_harnesses(prim=lax.min_p, dtypes=jtu.dtypes.all)
_make_binary_elementwise_harnesses(
prim=lax.min_p, dtypes=jtu.dtypes.all,
broadcasting_dtypes=(np.float32, np.complex64, np.complex128))
# Validate special cases
for lhs, rhs in _min_max_special_cases:
define(
@ -1014,7 +1025,9 @@ for lhs, rhs in _min_max_special_cases:
prim=lax.min_p,
dtype=lhs.dtype)
_make_binary_elementwise_harnesses(prim=lax.max_p, dtypes=jtu.dtypes.all)
_make_binary_elementwise_harnesses(
prim=lax.max_p, dtypes=jtu.dtypes.all,
broadcasting_dtypes=(np.float32, np.complex64, np.complex128))
# Validate special cases
for lhs, rhs in _min_max_special_cases:
define(
@ -2336,10 +2349,15 @@ def _make_clamp_harness(name,
min_shape=min_arr.shape,
operand_shape=operand_shape,
max_shape=max_arr.shape,
dtype=dtype)
dtype=dtype,
jax_unimplemented=[
Limitation(
"unimplemented",
dtypes=[np.bool_, np.complex64, np.complex128])],
)
for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex + [np.bool_]):
for dtype in set(jtu.dtypes.all):
_make_clamp_harness("dtypes", dtype=dtype)
# Validate broadcasting of min/max arrays

View File

@ -99,8 +99,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
# If you want to run this test for only one harness, add parameter
# `one_containing="foo"` to parameterized below.
@primitive_harness.parameterized(
primitive_harness.all_harnesses, include_jax_unimpl=False,
)
primitive_harness.all_harnesses, include_jax_unimpl=False)
@jtu.ignore_warning(
category=UserWarning, message="Using reduced precision for gradient.*")
def test_prim(self, harness: primitive_harness.Harness):