mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[jax2tf] Fix handling of float0
This commit is contained in:
parent
42e01ee2fa
commit
dce31e9631
@ -9,9 +9,13 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
-->
|
||||
|
||||
## jax 0.2.13 (unreleased)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.12...master).
|
||||
* Bug fixes:
|
||||
* The {func}`jax2tf.convert` now works in presence of gradients for functions
|
||||
with integer inputs ({jax-issue}`#6360`).
|
||||
|
||||
## jax 0.2.12 (April 1 2021)
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...master).
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...v0.2.12).
|
||||
* New features
|
||||
* New profiling APIs: {func}`jax.profiler.start_trace`,
|
||||
{func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`
|
||||
|
@ -153,6 +153,23 @@ is attempted. The plan is to fix this. Note that if no gradients are requested,
|
||||
the PreventGradient ops will be saved along with the converted code and will
|
||||
give a nice error if differentiation of the converted code is attempted.
|
||||
|
||||
### Converting gradients for integer-argument functions
|
||||
|
||||
When JAX differentiates over functions with integer arguments, the gradients will
|
||||
be zero-vectors with a special `float0` type (see PR 4039](https://github.com/google/jax/pull/4039)).
|
||||
This type is translated to `bfloat16` when converting to TF. For example,
|
||||
|
||||
```python
|
||||
def f_jax(x): # x: int32
|
||||
return x * 2.
|
||||
|
||||
jax.grad(f_jax, allow_int=True)(2)
|
||||
# returns a special `float0`: array((b'',), dtype=[('float0', 'V')])
|
||||
|
||||
jax2tf.convert(jax.grad(f_jax, allow_int=True))(2))
|
||||
# returns a `bfloat16` zero: tf.Tensor(0, shape=(), dtype=bfloat16)
|
||||
```
|
||||
|
||||
### TensorFlow XLA ops
|
||||
|
||||
For most JAX primitives there is a natural TF op that fits the needed semantics.
|
||||
|
@ -81,8 +81,9 @@ def _is_tfval(v: TfVal) -> bool:
|
||||
def _safe_convert_to_tensor(val, dtype=None) -> TfVal:
|
||||
dtype = dtype if dtype else (val.dtype if hasattr(val, "dtype") else None)
|
||||
conversion_type = to_tf_dtype(dtype) if dtype else None
|
||||
# We can convert directly, because all dtypes (even bfloat16) are the same
|
||||
# in JAX and TF.
|
||||
# The float0 type is not known to TF.
|
||||
if dtype and dtype == dtypes.float0:
|
||||
val = np.zeros(np.shape(val), conversion_type.as_numpy_dtype)
|
||||
return tf.convert_to_tensor(val, dtype=conversion_type)
|
||||
|
||||
|
||||
@ -774,9 +775,8 @@ class TensorFlowTrace(core.Trace):
|
||||
|
||||
def to_tf_dtype(jax_dtype):
|
||||
if jax_dtype == dtypes.float0:
|
||||
return tf.float32
|
||||
else:
|
||||
return tf.dtypes.as_dtype(jax_dtype)
|
||||
jax_dtype = dtypes.bfloat16
|
||||
return tf.dtypes.as_dtype(jax_dtype)
|
||||
|
||||
def to_jax_dtype(tf_dtype):
|
||||
return tf_dtype.as_numpy_dtype
|
||||
|
@ -284,6 +284,43 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(4. * 4., y)
|
||||
self.assertAllClose(3. * 4., tape.gradient(y, x))
|
||||
|
||||
def test_gradient_with_float0_intermediate(self):
|
||||
# Gradient over integer-argument functions
|
||||
def f(x, y): # x is an int, y is a float
|
||||
return 2 * x + y
|
||||
|
||||
def g(x): # x: f32
|
||||
return 2. * f(3 * x.astype("int32"), x * 4.)
|
||||
|
||||
x = np.float_(2.)
|
||||
grad_g = jax.grad(g)
|
||||
self.ConvertAndCompare(grad_g, x)
|
||||
|
||||
|
||||
def test_gradient_with_float0_result(self):
|
||||
# Gradient over integer-argument functions, with float0 result
|
||||
def f(x, y): # x is an int, y is a float
|
||||
return 2 * x + y
|
||||
|
||||
def g(x): # x: i32
|
||||
return jnp.sum(2. * f(3 * x, 4. * x.astype("float32")))
|
||||
|
||||
grad_g = jax.grad(g, allow_int=True)
|
||||
x = 2
|
||||
d_dx_jax = grad_g(x)
|
||||
d_dx_tf = jax2tf.convert(grad_g)(x)
|
||||
self.assertEqual(d_dx_jax.dtype, dtypes.float0)
|
||||
self.assertAllClose(jnp.zeros(np.shape(d_dx_jax), dtypes.bfloat16),
|
||||
d_dx_tf.numpy())
|
||||
|
||||
shape = (3, 4)
|
||||
x = np.ones(shape, dtype=np.int32)
|
||||
d_dx_jax = grad_g(x)
|
||||
d_dx_tf = jax2tf.convert(grad_g)(x)
|
||||
self.assertEqual(d_dx_jax.dtype, dtypes.float0)
|
||||
self.assertAllClose(jnp.zeros(np.shape(d_dx_jax), dtypes.bfloat16),
|
||||
d_dx_tf.numpy())
|
||||
|
||||
def test_convert_argument_non_callable_error(self):
|
||||
with self.assertRaisesRegex(TypeError, "Expected a callable value"):
|
||||
jax2tf.convert(5.)
|
||||
|
@ -44,6 +44,8 @@ def _make_tf_args(args):
|
||||
def _make_tf_input_signature(*tf_args) -> List[tf.TensorSpec]:
|
||||
# tf_args can be PyTrees
|
||||
def _make_one_arg_signature(tf_arg):
|
||||
if np.isscalar(tf_arg):
|
||||
tf_arg = np.array(tf_arg)
|
||||
return tf.TensorSpec(np.shape(tf_arg), tf_arg.dtype)
|
||||
|
||||
return tf.nest.map_structure(_make_one_arg_signature, list(tf_args))
|
||||
|
Loading…
x
Reference in New Issue
Block a user