10 Commits

Author SHA1 Message Date
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
George Necula
fd2b9b4759 [call_tf] Add support for DLPack to avoid copying arrays in eager mode 2021-02-13 12:49:51 +02:00
Peter Hawkins
d001ac6b8a [JAX] Add support for retaining ownership of DLPack tensors.
Move dlpack.py contents under jax/_src/dlpack.py.

Add array interoperability test between JAX and TensorFlow using DLPack.

Fixes: https://github.com/google/jax/issues/4636
PiperOrigin-RevId: 338120910
2020-10-20 13:07:07 -07:00
Peter Hawkins
d55ea510e2
Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46. (#3046)
* Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46.

* Bump minimum jaxlib version to 0.1.47.
2020-05-11 17:43:55 -04:00
Peter Hawkins
b0b6cd8e39
Make dlpack code robust against upcoming XLA Python binding change. (#2876) 2020-04-28 18:44:00 -04:00
Peter Hawkins
bbf7a432e8
Update semantics of to_dlpack. (#2707)
to_dlpack now takes ownership of the original buffer, leaving it in an invalid state.
2020-04-13 21:57:54 -04:00
George Necula
428377afb3
Added type annotations and removed unused imports (#2472)
* Added type annotations and removed unused imports

* Adjusted type hints for pytype
2020-03-21 13:54:30 +01:00
Matthew Johnson
47df7b95c4
change the xla representation of JAX's unit (#2416)
* change the xla representation of JAX's unit

Previously the representation of JAX's unit value (a sentinel /
placeholder) was an empty tuple, but by changing the representation to
something else we can further reduce our dependence on runtime tuples.

This commit makes the representation fairly easy to change. There are
three functions in xla.py that define the representation. Here are
versions that would keep the old XLA representation as an empty tuple:

```
def _make_unit(c): return c.Tuple()
def _make_abstract_unit(_): return xc.Shape.tuple_shape(())
def _device_put_unit(_, device):
  return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device))
```

The new representation is as a trivial array. An alternative
representation would be nothing at all: we don't need to generate XLA
computations that have representations of JAX units. While that
alterntaive is probably the best choice, it seemed like it would require
a bit more refactoring/bookkeeping (e.g. to allow XLA computations to
have a smaller number of outputs than the corresponding JAX function),
and would also mean the XLA representation would be a step further
removed from the jaxpr representation. So I stuck with a trivial array
for now.

The mapping from JAX types to XLA types need not be invertible. However,
XLA translation rules currently don't take as arguments the
corresponding JAX types (abstract values), and there were a few cases
where we relied on checking whether an argument's XLA type was that of
an empty tuple so as to determine if we were effectively operating on a
JAX unit.

In particular, the AD-related primitive add_jaxvals_p could in principle
add two units, and get lowered to an XLA addition on the unit
representation. Previously, the translation rule for add_jaxvals_p
checked the XLA type so that adding two empty tuples didn't produce any
XLA operation; now it adds its inputs, and so if unit is represented as
a trivial array we could be inserting trivial scalar adds where we had
none before. However, if that case is ever possible, it doesn't come up
in our tests (which I checked by keeping the representation as an empty
tuple and then asserting an XLA tuple type is never seen by that
translation rule).

* add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
Peter Hawkins
511d33983a
Initial implementation of DLPack support. (#2123)
* Initial implementation of DLPack support.

Unfortunately there are still a few bugs in the jaxlib DLPack support, so this code won't be ready to use until jaxlib 0.1.39.

* Fix test failures.

* Update XLA.

Fix failing torch test.
2020-01-29 23:19:14 -05:00