rocm_jax/jax/ad_util.py
Matthew Johnson 47df7b95c4
change the xla representation of JAX's unit (#2416)
* change the xla representation of JAX's unit

Previously the representation of JAX's unit value (a sentinel /
placeholder) was an empty tuple, but by changing the representation to
something else we can further reduce our dependence on runtime tuples.

This commit makes the representation fairly easy to change. There are
three functions in xla.py that define the representation. Here are
versions that would keep the old XLA representation as an empty tuple:

```
def _make_unit(c): return c.Tuple()
def _make_abstract_unit(_): return xc.Shape.tuple_shape(())
def _device_put_unit(_, device):
  return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device))
```

The new representation is as a trivial array. An alternative
representation would be nothing at all: we don't need to generate XLA
computations that have representations of JAX units. While that
alterntaive is probably the best choice, it seemed like it would require
a bit more refactoring/bookkeeping (e.g. to allow XLA computations to
have a smaller number of outputs than the corresponding JAX function),
and would also mean the XLA representation would be a step further
removed from the jaxpr representation. So I stuck with a trivial array
for now.

The mapping from JAX types to XLA types need not be invertible. However,
XLA translation rules currently don't take as arguments the
corresponding JAX types (abstract values), and there were a few cases
where we relied on checking whether an argument's XLA type was that of
an empty tuple so as to determine if we were effectively operating on a
JAX unit.

In particular, the AD-related primitive add_jaxvals_p could in principle
add two units, and get lowered to an XLA addition on the unit
representation. Previously, the translation rule for add_jaxvals_p
checked the XLA type so that adding two empty tuples didn't produce any
XLA operation; now it adds its inputs, and so if unit is represented as
a trivial array we could be inserting trivial scalar adds where we had
none before. However, if that case is ever possible, it doesn't come up
in our tests (which I checked by keeping the representation as an empty
tuple and then asserting an XLA tuple type is never seen by that
translation rule).

* add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00

64 lines
1.6 KiB
Python

# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .core import lattice_join, Primitive, Unit, unit, AbstractUnit
from .tree_util import register_pytree_node
from .util import safe_map
map = safe_map
jaxval_adders = {}
jaxval_adders[Unit] = lambda _, __: unit
def add_jaxvals(x, y):
return add_jaxvals_p.bind(x, y)
add_jaxvals_p = Primitive('add_any')
@add_jaxvals_p.def_impl
def add_impl(xs, ys):
return jaxval_adders[type(xs)](xs, ys)
@add_jaxvals_p.def_abstract_eval
def add_abstract(xs, ys):
return lattice_join(xs, ys)
jaxval_zeros_likers = {}
def zeros_like_aval(aval):
return aval_zeros_likers[type(aval)](aval)
aval_zeros_likers = {}
aval_zeros_likers[AbstractUnit] = lambda _: unit
def zeros_like_jaxval(val):
return zeros_like_p.bind(val)
zeros_like_p = Primitive('zeros_like')
@zeros_like_p.def_impl
def zeros_like_impl(example):
return jaxval_zeros_likers[type(example)](example)
zeros_like_p.def_abstract_eval(lambda x: x)
class Zero(object):
def __repr__(self):
return "Zero"
zero = Zero()
register_pytree_node(Zero, lambda z: ((), None), lambda _, xs: zero)