mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[jax2tf] Implement inequalities and friends for complex numbers.
This requires re-using JAX's lowering rule for comparisons of complex numbers to use lexicographic comparison.
This commit is contained in:
parent
de9f55720d
commit
d243258b86
@ -16,6 +16,9 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
tracebacks.
|
||||
* A new traceback filtering mode using `__tracebackhide__` is now enabled by
|
||||
default in sufficiently recent versions of IPython.
|
||||
* The {func}`jax2tf.convert` supports shape polymorphism even when the
|
||||
unknown dimensions are used in arithmetic operations, e.g., `jnp.reshape(-1)`
|
||||
({jax-issue}`#6827`).
|
||||
|
||||
* Breaking changes:
|
||||
|
||||
@ -31,6 +34,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* The {func}`jax2tf.convert` now converts `lax.dot_general` using the
|
||||
`XlaDot` TensorFlow op, for better fidelity w.r.t. JAX numerical precision
|
||||
({jax-issue}`#6717`).
|
||||
* The {func}`jax2tf.convert` now has support for inequality comparisons and
|
||||
min/max for complex numbers ({jax-issue}`#6892`).
|
||||
|
||||
## jaxlib 0.1.67 (unreleased)
|
||||
|
||||
|
@ -2783,27 +2783,35 @@ def _broadcasting_select(c, which, x, y):
|
||||
return xops.Select(which, x, y)
|
||||
|
||||
|
||||
def _minmax_translation_rule(c, x, y, *, minmax=None, cmp=None):
|
||||
def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x):
|
||||
result_shape = broadcast_shapes(np.shape(x), np.shape(y))
|
||||
x = _maybe_broadcast(result_shape, x)
|
||||
y = _maybe_broadcast(result_shape, y)
|
||||
rx = real(x)
|
||||
ry = real(y)
|
||||
pick_x = select(eq(rx, ry), lax_cmp_pick_x(imag(x), imag(y)),
|
||||
lax_cmp_pick_x(rx, ry))
|
||||
return select(pick_x, x, y)
|
||||
|
||||
def _minmax_translation_rule(c, x, y, *, op_minmax=None, lax_cmp_pick_x=None):
|
||||
dtype = c.get_shape(x).numpy_dtype()
|
||||
if dtypes.issubdtype(dtype, np.complexfloating):
|
||||
rx = xops.Real(x)
|
||||
ry = xops.Real(y)
|
||||
return _broadcasting_select(
|
||||
c, xops.Select(xops.Eq(rx, ry), cmp(xops.Imag(x), xops.Imag(y)),
|
||||
cmp(rx, ry)),
|
||||
x, y)
|
||||
return minmax(x, y)
|
||||
return xla.lower_fun(partial(_minmax_complex_lowering,
|
||||
lax_cmp_pick_x=lax_cmp_pick_x),
|
||||
multiple_results=False)(c, x, y)
|
||||
else:
|
||||
return op_minmax(x, y)
|
||||
|
||||
max_p: core.Primitive = standard_naryop(
|
||||
[_any, _any], 'max', translation_rule=partial(
|
||||
_minmax_translation_rule, minmax=xops.Max, cmp=xops.Gt))
|
||||
_minmax_translation_rule, op_minmax=xops.Max, lax_cmp_pick_x=gt))
|
||||
ad.defjvp2(max_p,
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
|
||||
|
||||
min_p: core.Primitive = standard_naryop(
|
||||
[_any, _any], 'min', translation_rule=partial(
|
||||
_minmax_translation_rule, minmax=xops.Min, cmp=xops.Lt))
|
||||
_minmax_translation_rule, op_minmax=xops.Min, lax_cmp_pick_x=lt))
|
||||
ad.defjvp2(min_p,
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
|
||||
|
@ -1,10 +1,10 @@
|
||||
# Primitives with limited JAX support
|
||||
|
||||
*Last generated on: 2021-05-17* (YYYY-MM-DD)
|
||||
*Last generated on: 2021-06-04* (YYYY-MM-DD)
|
||||
|
||||
## Supported data types for primitives
|
||||
|
||||
We use a set of 2507 test harnesses to test
|
||||
We use a set of 2570 test harnesses to test
|
||||
the implementation of 121 numeric JAX primitives.
|
||||
We consider a JAX primitive supported for a particular data
|
||||
type if it is supported on at least one device type.
|
||||
@ -60,7 +60,7 @@ be updated.
|
||||
| broadcast_in_dim | 19 | all | |
|
||||
| ceil | 4 | floating | bool, complex, integer |
|
||||
| cholesky | 30 | inexact | bool, integer |
|
||||
| clamp | 17 | floating, integer | bool, complex |
|
||||
| clamp | 20 | all | |
|
||||
| complex | 4 | float32, float64 | bfloat16, bool, complex, float16, integer |
|
||||
| concatenate | 17 | all | |
|
||||
| conj | 5 | complex, float32, float64 | bfloat16, bool, float16, integer |
|
||||
@ -90,28 +90,28 @@ be updated.
|
||||
| fft | 20 | complex, float32, float64 | bfloat16, bool, float16, integer |
|
||||
| floor | 4 | floating | bool, complex, integer |
|
||||
| gather | 37 | all | |
|
||||
| ge | 15 | bool, floating, integer | complex |
|
||||
| gt | 15 | bool, floating, integer | complex |
|
||||
| ge | 17 | all | |
|
||||
| gt | 17 | all | |
|
||||
| igamma | 6 | floating | bool, complex, integer |
|
||||
| igammac | 6 | floating | bool, complex, integer |
|
||||
| imag | 2 | complex | bool, floating, integer |
|
||||
| integer_pow | 108 | inexact, integer | bool |
|
||||
| iota | 16 | inexact, integer | bool |
|
||||
| is_finite | 4 | floating | bool, complex, integer |
|
||||
| le | 15 | bool, floating, integer | complex |
|
||||
| le | 17 | all | |
|
||||
| lgamma | 4 | floating | bool, complex, integer |
|
||||
| log | 6 | inexact | bool, integer |
|
||||
| log1p | 6 | inexact | bool, integer |
|
||||
| lt | 15 | bool, floating, integer | complex |
|
||||
| lt | 17 | all | |
|
||||
| lu | 18 | inexact | bool, integer |
|
||||
| max | 29 | all | |
|
||||
| min | 29 | all | |
|
||||
| max | 33 | all | |
|
||||
| min | 33 | all | |
|
||||
| mul | 16 | inexact, integer | bool |
|
||||
| ne | 17 | all | |
|
||||
| neg | 14 | inexact, integer | bool |
|
||||
| nextafter | 6 | floating | bool, complex, integer |
|
||||
| or | 11 | bool, integer | inexact |
|
||||
| pad | 90 | all | |
|
||||
| pad | 120 | all | |
|
||||
| population_count | 8 | integer | bool, inexact |
|
||||
| pow | 10 | inexact | bool, integer |
|
||||
| qr | 60 | inexact | bool, integer |
|
||||
@ -144,7 +144,7 @@ be updated.
|
||||
| shift_left | 10 | integer | bool, inexact |
|
||||
| shift_right_arithmetic | 10 | integer | bool, inexact |
|
||||
| shift_right_logical | 10 | integer | bool, inexact |
|
||||
| sign | 14 | inexact, integer | bool |
|
||||
| sign | 28 | inexact, integer | bool |
|
||||
| sin | 6 | inexact | bool, integer |
|
||||
| sinh | 6 | inexact | bool, integer |
|
||||
| slice | 24 | all | |
|
||||
@ -184,6 +184,7 @@ and search for "limitation".
|
||||
| Affected primitive | Description of limitation | Affected dtypes | Affected devices |
|
||||
| --- | --- | --- | --- |
|
||||
|cholesky|unimplemented|float16|cpu, gpu|
|
||||
|clamp|unimplemented|bool, complex|cpu, gpu, tpu|
|
||||
|conv_general_dilated|preferred_element_type not implemented for integers|int16, int32, int8|gpu|
|
||||
|conv_general_dilated|preferred_element_type=c128 not implemented|complex64|tpu|
|
||||
|conv_general_dilated|preferred_element_type=f64 not implemented|bfloat16, float16, float32|tpu|
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Primitives with limited support for jax2tf
|
||||
|
||||
*Last generated on (YYYY-MM-DD): 2021-06-01*
|
||||
*Last generated on (YYYY-MM-DD): 2021-06-04*
|
||||
|
||||
This document summarizes known limitations of the jax2tf conversion.
|
||||
There are several kinds of limitations.
|
||||
@ -34,10 +34,7 @@ On TPU only the "compiled" mode is relevant.
|
||||
|
||||
Our priority is to ensure same coverage and numerical behavior with JAX
|
||||
in the "compiled" mode, **when using XLA to compile the converted program**.
|
||||
We are pretty close to that goal. In addition to a few loose ends, there is a known
|
||||
coverage problem due to JAX and XLA supporting inequality comparisons and min/max for
|
||||
booleans and complex numbers. It is not clear that TensorFlow will be extended to
|
||||
support these.
|
||||
We are pretty close to that goal.
|
||||
|
||||
This table only shows errors for cases that are working in JAX (see [separate
|
||||
list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) )
|
||||
@ -67,6 +64,7 @@ More detailed information can be found in the
|
||||
| cholesky | TF test skipped: Not implemented in JAX: unimplemented | float16 | cpu, gpu | compiled, eager, graph |
|
||||
| cholesky | TF error: function not compilable | complex | cpu, gpu | compiled |
|
||||
| cholesky | TF error: op not defined for dtype | complex | tpu | compiled, graph |
|
||||
| clamp | TF test skipped: Not implemented in JAX: unimplemented | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| clamp | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type not implemented for integers | int16, int32, int8 | gpu | compiled, eager, graph |
|
||||
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=c128 not implemented | complex64 | tpu | compiled, eager, graph |
|
||||
@ -104,17 +102,16 @@ More detailed information can be found in the
|
||||
| lt | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| lu | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| lu | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
|
||||
| max | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| min | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| neg | TF error: op not defined for dtype | unsigned | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| nextafter | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| qr | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
|
||||
| qr | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph |
|
||||
| reduce_max | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| reduce_min | TF error: op not defined for dtype | complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| reduce_window_max | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| reduce_window_min | TF error: op not defined for dtype | uint64 | cpu, gpu | eager |
|
||||
| reduce_window_min | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| reduce_window_max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| reduce_window_min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| regularized_incomplete_beta | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| rem | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| round | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
|
||||
@ -122,9 +119,9 @@ More detailed information can be found in the
|
||||
| scatter_add | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| scatter_add | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
|
||||
| scatter_max | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
|
||||
| scatter_max | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| scatter_max | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| scatter_min | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
|
||||
| scatter_min | TF error: op not defined for dtype | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| scatter_min | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| scatter_mul | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
|
||||
| scatter_mul | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
|
||||
| select_and_gather_add | TF error: jax2tf unimplemented for 64-bit inputs because the current implementation relies on packing two values into a single value. This can be fixed by using a variadic XlaReduceWindow, when available | float64 | cpu, gpu | compiled, eager, graph |
|
||||
|
@ -34,10 +34,7 @@ On TPU only the "compiled" mode is relevant.
|
||||
|
||||
Our priority is to ensure same coverage and numerical behavior with JAX
|
||||
in the "compiled" mode, **when using XLA to compile the converted program**.
|
||||
We are pretty close to that goal. In addition to a few loose ends, there is a known
|
||||
coverage problem due to JAX and XLA supporting inequality comparisons and min/max for
|
||||
booleans and complex numbers. It is not clear that TensorFlow will be extended to
|
||||
support these.
|
||||
We are pretty close to that goal.
|
||||
|
||||
This table only shows errors for cases that are working in JAX (see [separate
|
||||
list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/master/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) )
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Experimental module transforms JAX functions to be executed by TensorFlow."""
|
||||
import functools
|
||||
from functools import partial
|
||||
import re
|
||||
import string
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
@ -116,7 +116,7 @@ def _xla_disabled_error(primitive_name: str,
|
||||
msg += f" {extra_msg}"
|
||||
return NotImplementedError(msg)
|
||||
|
||||
@functools.partial(api_util.api_hook, tag="jax2tf_convert")
|
||||
@partial(api_util.api_hook, tag="jax2tf_convert")
|
||||
def convert(fun: Callable,
|
||||
*,
|
||||
polymorphic_shapes: Optional[Sequence[Any]] = None,
|
||||
@ -293,8 +293,7 @@ def convert(fun: Callable,
|
||||
out_with_avals = _interpret_fun(flat_fun, args_flat, args_avals_flat)
|
||||
outs, out_avals = util.unzip2(out_with_avals)
|
||||
return (tuple(outs),
|
||||
functools.partial(
|
||||
converted_grad_fn, _out_cts_avals=tuple(out_avals)))
|
||||
partial(converted_grad_fn, _out_cts_avals=tuple(out_avals)))
|
||||
|
||||
out_flat = converted_fun_flat_with_custom_gradient(*args_flat)
|
||||
else:
|
||||
@ -828,7 +827,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
|
||||
for unexpected in xla.call_translations: # Call primitives are inlined
|
||||
if unexpected is pjit.pjit_p:
|
||||
continue
|
||||
tf_impl[unexpected] = functools.partial(_unexpected_primitive, unexpected)
|
||||
tf_impl[unexpected] = partial(_unexpected_primitive, unexpected)
|
||||
|
||||
# Primitives that are not yet implemented must be explicitly declared here.
|
||||
tf_not_yet_impl = [
|
||||
@ -1045,8 +1044,30 @@ def _rem(lhs, rhs):
|
||||
tf_impl[lax.div_p] = _div
|
||||
tf_impl[lax.rem_p] = _rem
|
||||
|
||||
tf_impl[lax.max_p] = tf.math.maximum
|
||||
tf_impl[lax.min_p] = tf.math.minimum
|
||||
|
||||
def _minmax(x: TfVal, y: TfVal, *, is_min: bool,
|
||||
_in_avals: Sequence[core.AbstractValue],
|
||||
_out_aval: core.AbstractValue,) -> TfVal:
|
||||
# For complex numbers use lexicographic ordering, like JAX
|
||||
if dtypes.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating):
|
||||
return _convert_jax_impl(
|
||||
partial(lax._minmax_complex_lowering,
|
||||
lax_cmp_pick_x=lax.lt if is_min else lax.gt),
|
||||
multiple_results=False)(x, y, _in_avals=_in_avals, _out_aval=_out_aval)
|
||||
else:
|
||||
return (tf.math.minimum if is_min else tf.math.maximum)(x, y)
|
||||
|
||||
def _minmax_scalar(x: TfVal, y: TfVal, *, is_min: bool) -> TfVal:
|
||||
# For reducers we will need min/max for scalars only. In that case we
|
||||
# can construct the AbstractValues outselves, even in the presence of
|
||||
# shape polymorphism.
|
||||
assert len(x.shape) == 0 and len(y.shape) == 0, f"x: {x.shape}, y: {y.shape}"
|
||||
aval = core.ShapedArray((), _to_jax_dtype(x.dtype))
|
||||
return _minmax(x, y, is_min=is_min,
|
||||
_in_avals=[aval, aval], _out_aval=aval)
|
||||
|
||||
tf_impl_with_avals[lax.max_p] = partial(_minmax, is_min=False)
|
||||
tf_impl_with_avals[lax.min_p] = partial(_minmax, is_min=True)
|
||||
|
||||
# Map from TF signed types to TF unsigned types.
|
||||
_SIGNED_TO_UNSIGNED_TABLE = {
|
||||
@ -1659,8 +1680,8 @@ def _argminmax(fn, operand, axes, index_dtype):
|
||||
return tf.cast(result, _to_tf_dtype(index_dtype))
|
||||
|
||||
|
||||
tf_impl[lax.argmin_p] = functools.partial(_argminmax, tf.math.argmin)
|
||||
tf_impl[lax.argmax_p] = functools.partial(_argminmax, tf.math.argmax)
|
||||
tf_impl[lax.argmin_p] = partial(_argminmax, tf.math.argmin)
|
||||
tf_impl[lax.argmax_p] = partial(_argminmax, tf.math.argmax)
|
||||
|
||||
_add_fn = tf.function(_add, autograph=False)
|
||||
_ge_fn = tf.function(tf.math.greater_equal, autograph=False)
|
||||
@ -1947,21 +1968,18 @@ def _get_min_identity(tf_dtype):
|
||||
|
||||
# pylint: disable=protected-access
|
||||
tf_impl_with_avals[lax.reduce_window_sum_p] = (
|
||||
functools.partial(
|
||||
_specialized_reduce_window, _add, lambda x: 0,
|
||||
name="reduce_window_sum"))
|
||||
partial(_specialized_reduce_window, _add, lambda x: 0,
|
||||
name="reduce_window_sum"))
|
||||
tf_impl_with_avals[lax.reduce_window_min_p] = (
|
||||
functools.partial(
|
||||
_specialized_reduce_window,
|
||||
tf.math.minimum,
|
||||
_get_min_identity,
|
||||
name="reduce_window_min"))
|
||||
partial(_specialized_reduce_window,
|
||||
partial(_minmax_scalar, is_min=True),
|
||||
_get_min_identity,
|
||||
name="reduce_window_min"))
|
||||
tf_impl_with_avals[lax.reduce_window_max_p] = (
|
||||
functools.partial(
|
||||
_specialized_reduce_window,
|
||||
tf.math.maximum,
|
||||
_get_max_identity,
|
||||
name="reduce_window_max"))
|
||||
partial(_specialized_reduce_window,
|
||||
partial(_minmax_scalar, is_min=False),
|
||||
_get_max_identity,
|
||||
name="reduce_window_max"))
|
||||
tf_impl_with_avals[lax.reduce_window_p] = _reduce_window
|
||||
# pylint: enable=protected-access
|
||||
|
||||
@ -1970,11 +1988,11 @@ tf_impl_with_avals[lax.reduce_window_p] = _reduce_window
|
||||
# O(n^2) on other backends. This may be implemented using associative_scan
|
||||
# instead to favor different backends.
|
||||
tf_impl_with_avals[lax_control_flow.cummin_p] = _convert_jax_impl(
|
||||
functools.partial(lax_control_flow._cumred_tpu_translation_rule,
|
||||
partial(lax_control_flow._cumred_tpu_translation_rule,
|
||||
lax._reduce_window_min),
|
||||
multiple_results=False)
|
||||
tf_impl_with_avals[lax_control_flow.cummax_p] = _convert_jax_impl(
|
||||
functools.partial(lax_control_flow._cumred_tpu_translation_rule,
|
||||
partial(lax_control_flow._cumred_tpu_translation_rule,
|
||||
lax._reduce_window_max),
|
||||
multiple_results=False)
|
||||
# TODO(bchetioui): cumsum and cumprod can be converted using pure TF ops for
|
||||
@ -1983,11 +2001,11 @@ tf_impl_with_avals[lax_control_flow.cummax_p] = _convert_jax_impl(
|
||||
# the operation. A non-XLA path can thus be defined for all dtypes, though the
|
||||
# tests will crash.
|
||||
tf_impl_with_avals[lax_control_flow.cumsum_p] = _convert_jax_impl(
|
||||
functools.partial(lax_control_flow._cumred_tpu_translation_rule,
|
||||
partial(lax_control_flow._cumred_tpu_translation_rule,
|
||||
lax._reduce_window_sum),
|
||||
multiple_results=False)
|
||||
tf_impl_with_avals[lax_control_flow.cumprod_p] = _convert_jax_impl(
|
||||
functools.partial(lax_control_flow._cumred_tpu_translation_rule,
|
||||
partial(lax_control_flow._cumred_tpu_translation_rule,
|
||||
lax._reduce_window_prod),
|
||||
multiple_results=False)
|
||||
|
||||
@ -2001,7 +2019,7 @@ def _select_and_scatter(operand, source, init_value, select_jaxpr,
|
||||
tf_impl[lax.select_and_scatter_p] = _select_and_scatter
|
||||
|
||||
|
||||
@functools.partial(bool_to_int8, argnums=(0, 1))
|
||||
@partial(bool_to_int8, argnums=(0, 1))
|
||||
def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
|
||||
window_strides, padding, _in_avals, _out_aval):
|
||||
if not _enable_xla:
|
||||
@ -2023,8 +2041,7 @@ tf_impl_with_avals[lax.select_and_scatter_add_p] = _select_and_scatter_add
|
||||
|
||||
def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval):
|
||||
res = _convert_jax_impl(
|
||||
functools.partial(
|
||||
jax._src.random._threefry2x32_lowering, use_rolled_loops=False),
|
||||
partial(jax._src.random._threefry2x32_lowering, use_rolled_loops=False),
|
||||
multiple_results=True)(
|
||||
*args, _in_avals=_in_avals, _out_aval=_out_aval)
|
||||
return res
|
||||
@ -2035,7 +2052,7 @@ tf_impl_with_avals[jax.random.threefry2x32_p] = _threefry2x32_jax_impl
|
||||
# Use the vmap implementation, otherwise on TPU the performance is really bad
|
||||
# With use_vmap=True on, we get about the same performance for JAX and jax2tf.
|
||||
tf_impl_with_avals[random.random_gamma_p] = _convert_jax_impl(
|
||||
functools.partial(jax._src.random._gamma_impl, use_vmap=True),
|
||||
partial(jax._src.random._gamma_impl, use_vmap=True),
|
||||
multiple_results=False)
|
||||
|
||||
|
||||
@ -2049,7 +2066,7 @@ def _gather_dimensions_proto(indices_shape, dimension_numbers):
|
||||
return proto
|
||||
|
||||
|
||||
@functools.partial(bool_to_int8, argnums=0)
|
||||
@partial(bool_to_int8, argnums=0)
|
||||
def _gather(operand, start_indices, *, dimension_numbers, slice_sizes,
|
||||
_in_avals, _out_aval):
|
||||
"""Tensorflow implementation of gather."""
|
||||
@ -2171,7 +2188,7 @@ def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr],
|
||||
del linear
|
||||
# tf.cond needs lambdas with no arguments.
|
||||
branches_tf = [
|
||||
functools.partial(_interpret_jaxpr, jaxpr, *operands)
|
||||
partial(_interpret_jaxpr, jaxpr, *operands)
|
||||
for jaxpr in branches
|
||||
]
|
||||
return tf.switch_case(index, branches_tf)
|
||||
@ -2198,7 +2215,7 @@ def _while(*args: TfVal, cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr,
|
||||
pred, = _interpret_jaxpr(cond_jaxpr, *cond_consts, *args)
|
||||
return pred
|
||||
|
||||
body_tf_func = functools.partial(_interpret_jaxpr, body_jaxpr, *body_consts)
|
||||
body_tf_func = partial(_interpret_jaxpr, body_jaxpr, *body_consts)
|
||||
return tf.while_loop(cond_tf_func, body_tf_func, init_carry)
|
||||
|
||||
|
||||
@ -2586,7 +2603,7 @@ def _pjit(*args: TfVal,
|
||||
_out_aval: core.ShapedArray) -> TfVal:
|
||||
del donated_invars, name
|
||||
# TODO: add `name` to the name stack
|
||||
shard_value_for_mesh = functools.partial(_shard_value, resource_env.physical_mesh)
|
||||
shard_value_for_mesh = partial(_shard_value, resource_env.physical_mesh)
|
||||
# Apply sharding annotation to the arguments
|
||||
sharded_args: Sequence[TfVal] = tuple(
|
||||
map(shard_value_for_mesh, args, _in_avals, in_axis_resources))
|
||||
|
@ -882,7 +882,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
|
||||
return [
|
||||
missing_tf_kernel(
|
||||
dtypes=[np.bool_, np.complex64, np.complex128]),
|
||||
dtypes=[np.bool_]),
|
||||
custom_numeric(
|
||||
custom_assert=custom_assert,
|
||||
description=(
|
||||
@ -901,7 +901,7 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
tst.assertAllClose(result_jax[~mask], result_tf[~mask], err_msg=err_msg)
|
||||
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
|
||||
missing_tf_kernel(dtypes=[np.bool_]),
|
||||
custom_numeric(
|
||||
custom_assert=custom_assert,
|
||||
description=(
|
||||
@ -966,9 +966,8 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
|
||||
@classmethod
|
||||
def reduce_min(cls, harness: primitive_harness.Harness):
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.complex64, np.complex128]),
|
||||
]
|
||||
return cls.reduce_max(harness)
|
||||
|
||||
|
||||
@classmethod
|
||||
def reduce_window_add(cls, harness):
|
||||
@ -979,14 +978,14 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
def reduce_window_max(cls, harness):
|
||||
assert "max" == harness.params["computation"].__name__
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
|
||||
missing_tf_kernel(dtypes=[np.bool_]),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def reduce_window_min(cls, harness):
|
||||
assert "min" == harness.params["computation"].__name__
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
|
||||
missing_tf_kernel(dtypes=[np.bool_]),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@ -1045,13 +1044,13 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
@classmethod
|
||||
def scatter_max(cls, harness):
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
|
||||
missing_tf_kernel(dtypes=[np.bool_]),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def scatter_min(cls, harness):
|
||||
return [
|
||||
missing_tf_kernel(dtypes=[np.bool_, np.complex64, np.complex128]),
|
||||
missing_tf_kernel(dtypes=[np.bool_]),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
@ -718,32 +718,36 @@ for dtype in jtu.dtypes.all:
|
||||
shape=shape,
|
||||
dtype=dtype)
|
||||
|
||||
_LAX_COMPARATORS = (lax.eq_p, lax.ge_p, lax.gt_p, lax.le_p, lax.lt_p, lax.ne_p)
|
||||
_LAX_COMPARATORS = dict(eq=jnp.equal, ne=jnp.not_equal,
|
||||
ge=jnp.greater_equal, gt=jnp.greater,
|
||||
le=jnp.less_equal, lt=jnp.less)
|
||||
|
||||
|
||||
def _make_comparator_harness(name,
|
||||
*,
|
||||
dtype=np.float32,
|
||||
op=lax.eq_p,
|
||||
op_name="eq",
|
||||
lhs_shape=(),
|
||||
rhs_shape=()):
|
||||
define(
|
||||
op.name,
|
||||
op_name,
|
||||
f"{name}_lhs={jtu.format_shape_dtype_string(lhs_shape, dtype)}_rhs={jtu.format_shape_dtype_string(rhs_shape, dtype)}",
|
||||
lambda *args: op.bind(*args),
|
||||
lambda *args: op(*args),
|
||||
[RandArg(lhs_shape, dtype),
|
||||
RandArg(rhs_shape, dtype)],
|
||||
op=op,
|
||||
op_name=op_name,
|
||||
lhs_shape=lhs_shape,
|
||||
rhs_shape=rhs_shape,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
for op in _LAX_COMPARATORS:
|
||||
for op_name, op in _LAX_COMPARATORS.items():
|
||||
for dtype in (jtu.dtypes.all if op in [lax.eq_p, lax.ne_p] else
|
||||
set(jtu.dtypes.all) - set(jtu.dtypes.complex)):
|
||||
set(jtu.dtypes.all)):
|
||||
# Validate dtypes
|
||||
_make_comparator_harness("dtypes", dtype=dtype, op=op)
|
||||
_make_comparator_harness("dtypes", dtype=dtype, op=op, op_name=op_name)
|
||||
|
||||
# Validate broadcasting behavior
|
||||
for lhs_shape, rhs_shape in [
|
||||
@ -751,7 +755,8 @@ for op in _LAX_COMPARATORS:
|
||||
((1, 2), (3, 2)), # broadcast along specific axis
|
||||
]:
|
||||
_make_comparator_harness(
|
||||
"broadcasting", lhs_shape=lhs_shape, rhs_shape=rhs_shape, op=op)
|
||||
"broadcasting", lhs_shape=lhs_shape, rhs_shape=rhs_shape,
|
||||
op=op, op_name=op_name)
|
||||
|
||||
for dtype in jtu.dtypes.all:
|
||||
shape = (3, 4, 5)
|
||||
@ -917,6 +922,7 @@ for prim in [lax.div_p, lax.rem_p]:
|
||||
def _make_binary_elementwise_harnesses(prim,
|
||||
dtypes,
|
||||
default_dtype=np.float32,
|
||||
broadcasting_dtypes=None,
|
||||
jax_unimplemented=lambda **kwargs: []):
|
||||
|
||||
def _make(name, *, shapes=((20, 20), (20, 20)), dtype):
|
||||
@ -931,15 +937,18 @@ def _make_binary_elementwise_harnesses(prim,
|
||||
prim=prim,
|
||||
dtype=dtype,
|
||||
shapes=shapes)
|
||||
|
||||
return (tuple( # Validate dtypes
|
||||
_make("dtypes", dtype=dtype)
|
||||
for dtype in dtypes) + tuple( # Validate broadcasting
|
||||
_make("broadcasting", dtype=default_dtype, shapes=shapes)
|
||||
for shapes in [
|
||||
broadcasting_dtypes = broadcasting_dtypes or (default_dtype,)
|
||||
return (
|
||||
# Validate dtypes
|
||||
tuple(_make("dtypes", dtype=dtype) for dtype in dtypes) +
|
||||
# Validate broadcasting
|
||||
tuple(_make("broadcasting", dtype=dtype, shapes=shapes)
|
||||
for shapes in [
|
||||
((20, 20), (1, 20)), # broadcasting rhs
|
||||
((1, 20), (20, 20)), # broadcasting lhs
|
||||
]))
|
||||
]
|
||||
for dtype in broadcasting_dtypes)
|
||||
)
|
||||
|
||||
|
||||
_make_binary_elementwise_harnesses(
|
||||
@ -1004,7 +1013,9 @@ _min_max_special_cases = tuple(
|
||||
(np.array([-np.inf, -np.inf], dtype=dtype),
|
||||
np.array([np.nan, np.nan], dtype=dtype))])
|
||||
|
||||
_make_binary_elementwise_harnesses(prim=lax.min_p, dtypes=jtu.dtypes.all)
|
||||
_make_binary_elementwise_harnesses(
|
||||
prim=lax.min_p, dtypes=jtu.dtypes.all,
|
||||
broadcasting_dtypes=(np.float32, np.complex64, np.complex128))
|
||||
# Validate special cases
|
||||
for lhs, rhs in _min_max_special_cases:
|
||||
define(
|
||||
@ -1014,7 +1025,9 @@ for lhs, rhs in _min_max_special_cases:
|
||||
prim=lax.min_p,
|
||||
dtype=lhs.dtype)
|
||||
|
||||
_make_binary_elementwise_harnesses(prim=lax.max_p, dtypes=jtu.dtypes.all)
|
||||
_make_binary_elementwise_harnesses(
|
||||
prim=lax.max_p, dtypes=jtu.dtypes.all,
|
||||
broadcasting_dtypes=(np.float32, np.complex64, np.complex128))
|
||||
# Validate special cases
|
||||
for lhs, rhs in _min_max_special_cases:
|
||||
define(
|
||||
@ -2336,10 +2349,15 @@ def _make_clamp_harness(name,
|
||||
min_shape=min_arr.shape,
|
||||
operand_shape=operand_shape,
|
||||
max_shape=max_arr.shape,
|
||||
dtype=dtype)
|
||||
dtype=dtype,
|
||||
jax_unimplemented=[
|
||||
Limitation(
|
||||
"unimplemented",
|
||||
dtypes=[np.bool_, np.complex64, np.complex128])],
|
||||
)
|
||||
|
||||
|
||||
for dtype in set(jtu.dtypes.all) - set(jtu.dtypes.complex + [np.bool_]):
|
||||
for dtype in set(jtu.dtypes.all):
|
||||
_make_clamp_harness("dtypes", dtype=dtype)
|
||||
|
||||
# Validate broadcasting of min/max arrays
|
||||
|
@ -99,8 +99,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
# If you want to run this test for only one harness, add parameter
|
||||
# `one_containing="foo"` to parameterized below.
|
||||
@primitive_harness.parameterized(
|
||||
primitive_harness.all_harnesses, include_jax_unimpl=False,
|
||||
)
|
||||
primitive_harness.all_harnesses, include_jax_unimpl=False)
|
||||
@jtu.ignore_warning(
|
||||
category=UserWarning, message="Using reduced precision for gradient.*")
|
||||
def test_prim(self, harness: primitive_harness.Harness):
|
||||
|
Loading…
x
Reference in New Issue
Block a user