[jax2tf] Fix handling of float0

This commit is contained in:
George Necula 2021-04-07 11:24:31 +03:00
parent 42e01ee2fa
commit dce31e9631
5 changed files with 66 additions and 6 deletions

View File

@ -9,9 +9,13 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
-->
## jax 0.2.13 (unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.12...master).
* Bug fixes:
* The {func}`jax2tf.convert` now works in presence of gradients for functions
with integer inputs ({jax-issue}`#6360`).
## jax 0.2.12 (April 1 2021)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...master).
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...v0.2.12).
* New features
* New profiling APIs: {func}`jax.profiler.start_trace`,
{func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`

View File

@ -153,6 +153,23 @@ is attempted. The plan is to fix this. Note that if no gradients are requested,
the PreventGradient ops will be saved along with the converted code and will
give a nice error if differentiation of the converted code is attempted.
### Converting gradients for integer-argument functions
When JAX differentiates over functions with integer arguments, the gradients will
be zero-vectors with a special `float0` type (see PR 4039](https://github.com/google/jax/pull/4039)).
This type is translated to `bfloat16` when converting to TF. For example,
```python
def f_jax(x): # x: int32
return x * 2.
jax.grad(f_jax, allow_int=True)(2)
# returns a special `float0`: array((b'',), dtype=[('float0', 'V')])
jax2tf.convert(jax.grad(f_jax, allow_int=True))(2))
# returns a `bfloat16` zero: tf.Tensor(0, shape=(), dtype=bfloat16)
```
### TensorFlow XLA ops
For most JAX primitives there is a natural TF op that fits the needed semantics.

View File

@ -81,8 +81,9 @@ def _is_tfval(v: TfVal) -> bool:
def _safe_convert_to_tensor(val, dtype=None) -> TfVal:
dtype = dtype if dtype else (val.dtype if hasattr(val, "dtype") else None)
conversion_type = to_tf_dtype(dtype) if dtype else None
# We can convert directly, because all dtypes (even bfloat16) are the same
# in JAX and TF.
# The float0 type is not known to TF.
if dtype and dtype == dtypes.float0:
val = np.zeros(np.shape(val), conversion_type.as_numpy_dtype)
return tf.convert_to_tensor(val, dtype=conversion_type)
@ -774,9 +775,8 @@ class TensorFlowTrace(core.Trace):
def to_tf_dtype(jax_dtype):
if jax_dtype == dtypes.float0:
return tf.float32
else:
return tf.dtypes.as_dtype(jax_dtype)
jax_dtype = dtypes.bfloat16
return tf.dtypes.as_dtype(jax_dtype)
def to_jax_dtype(tf_dtype):
return tf_dtype.as_numpy_dtype

View File

@ -284,6 +284,43 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(4. * 4., y)
self.assertAllClose(3. * 4., tape.gradient(y, x))
def test_gradient_with_float0_intermediate(self):
# Gradient over integer-argument functions
def f(x, y): # x is an int, y is a float
return 2 * x + y
def g(x): # x: f32
return 2. * f(3 * x.astype("int32"), x * 4.)
x = np.float_(2.)
grad_g = jax.grad(g)
self.ConvertAndCompare(grad_g, x)
def test_gradient_with_float0_result(self):
# Gradient over integer-argument functions, with float0 result
def f(x, y): # x is an int, y is a float
return 2 * x + y
def g(x): # x: i32
return jnp.sum(2. * f(3 * x, 4. * x.astype("float32")))
grad_g = jax.grad(g, allow_int=True)
x = 2
d_dx_jax = grad_g(x)
d_dx_tf = jax2tf.convert(grad_g)(x)
self.assertEqual(d_dx_jax.dtype, dtypes.float0)
self.assertAllClose(jnp.zeros(np.shape(d_dx_jax), dtypes.bfloat16),
d_dx_tf.numpy())
shape = (3, 4)
x = np.ones(shape, dtype=np.int32)
d_dx_jax = grad_g(x)
d_dx_tf = jax2tf.convert(grad_g)(x)
self.assertEqual(d_dx_jax.dtype, dtypes.float0)
self.assertAllClose(jnp.zeros(np.shape(d_dx_jax), dtypes.bfloat16),
d_dx_tf.numpy())
def test_convert_argument_non_callable_error(self):
with self.assertRaisesRegex(TypeError, "Expected a callable value"):
jax2tf.convert(5.)

View File

@ -44,6 +44,8 @@ def _make_tf_args(args):
def _make_tf_input_signature(*tf_args) -> List[tf.TensorSpec]:
# tf_args can be PyTrees
def _make_one_arg_signature(tf_arg):
if np.isscalar(tf_arg):
tf_arg = np.array(tf_arg)
return tf.TensorSpec(np.shape(tf_arg), tf_arg.dtype)
return tf.nest.map_structure(_make_one_arg_signature, list(tf_args))