mirror of
https://github.com/ROCm/jax.git
synced 2025-04-23 22:26:04 +00:00

* change the xla representation of JAX's unit Previously the representation of JAX's unit value (a sentinel / placeholder) was an empty tuple, but by changing the representation to something else we can further reduce our dependence on runtime tuples. This commit makes the representation fairly easy to change. There are three functions in xla.py that define the representation. Here are versions that would keep the old XLA representation as an empty tuple: ``` def _make_unit(c): return c.Tuple() def _make_abstract_unit(_): return xc.Shape.tuple_shape(()) def _device_put_unit(_, device): return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device)) ``` The new representation is as a trivial array. An alternative representation would be nothing at all: we don't need to generate XLA computations that have representations of JAX units. While that alterntaive is probably the best choice, it seemed like it would require a bit more refactoring/bookkeeping (e.g. to allow XLA computations to have a smaller number of outputs than the corresponding JAX function), and would also mean the XLA representation would be a step further removed from the jaxpr representation. So I stuck with a trivial array for now. The mapping from JAX types to XLA types need not be invertible. However, XLA translation rules currently don't take as arguments the corresponding JAX types (abstract values), and there were a few cases where we relied on checking whether an argument's XLA type was that of an empty tuple so as to determine if we were effectively operating on a JAX unit. In particular, the AD-related primitive add_jaxvals_p could in principle add two units, and get lowered to an XLA addition on the unit representation. Previously, the translation rule for add_jaxvals_p checked the XLA type so that adding two empty tuples didn't produce any XLA operation; now it adds its inputs, and so if unit is represented as a trivial array we could be inserting trivial scalar adds where we had none before. However, if that case is ever possible, it doesn't come up in our tests (which I checked by keeping the representation as an empty tuple and then asserting an XLA tuple type is never seen by that translation rule). * add comment about JAX<->XLA array types assumption
64 lines
1.6 KiB
Python
64 lines
1.6 KiB
Python
# Copyright 2018 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
from .core import lattice_join, Primitive, Unit, unit, AbstractUnit
|
|
from .tree_util import register_pytree_node
|
|
from .util import safe_map
|
|
|
|
map = safe_map
|
|
|
|
jaxval_adders = {}
|
|
jaxval_adders[Unit] = lambda _, __: unit
|
|
|
|
def add_jaxvals(x, y):
|
|
return add_jaxvals_p.bind(x, y)
|
|
|
|
add_jaxvals_p = Primitive('add_any')
|
|
|
|
@add_jaxvals_p.def_impl
|
|
def add_impl(xs, ys):
|
|
return jaxval_adders[type(xs)](xs, ys)
|
|
|
|
@add_jaxvals_p.def_abstract_eval
|
|
def add_abstract(xs, ys):
|
|
return lattice_join(xs, ys)
|
|
|
|
jaxval_zeros_likers = {}
|
|
|
|
def zeros_like_aval(aval):
|
|
return aval_zeros_likers[type(aval)](aval)
|
|
|
|
aval_zeros_likers = {}
|
|
aval_zeros_likers[AbstractUnit] = lambda _: unit
|
|
|
|
def zeros_like_jaxval(val):
|
|
return zeros_like_p.bind(val)
|
|
|
|
zeros_like_p = Primitive('zeros_like')
|
|
|
|
@zeros_like_p.def_impl
|
|
def zeros_like_impl(example):
|
|
return jaxval_zeros_likers[type(example)](example)
|
|
|
|
zeros_like_p.def_abstract_eval(lambda x: x)
|
|
|
|
class Zero(object):
|
|
def __repr__(self):
|
|
return "Zero"
|
|
|
|
zero = Zero()
|
|
|
|
register_pytree_node(Zero, lambda z: ((), None), lambda _, xs: zero)
|