1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-18 12:56:07 +00:00

54 Commits

Author SHA1 Message Date
George Necula
1e813e1693 [better_errors] Continue adding debug info to Jaxprs (step 4)
This follows after , , , adding `debug_info` to more calls to `lu.wrap_init`.

As part of this I have changed the primitive `custom_transpose` to take the `transpose` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`.

These changes ensure that all the `lu.wrap_init` and `Jaxpr` are called with debug_info in the `api_test.py:CustomTransposeTest`.
2025-02-08 09:13:55 +02:00
Jake VanderPlas
5dc37d3f70 Remove internal uses of api_util.shaped_abstractify 2024-12-19 07:06:36 -08:00
Jake VanderPlas
de3191fab3 Cleanup: fix unused imports & mark exported names 2024-10-16 17:42:41 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Peter Hawkins
e869e5e0f8 Move contents of jax.api_util to jax._src.api_util and add a forwarding shim.
One of many changes to codify the set of exported symbols in the jax.* namespace.

PiperOrigin-RevId: 395484706
2021-09-08 09:00:56 -07:00
Adam Paszke
0c03e98046 Don't cast out_axis_resources to a tuple automatically
It's confusing and makes it impossible to specify non-trivial pytrees of
out_axis_resources for functions that return lists. Also extend the error
messages to be less confusing and hint at potential fixes.

PiperOrigin-RevId: 395246450
2021-09-07 07:54:07 -07:00
Sergei Lebedev
af41a959d3 Most of JAX now uses concrete types for things defined in jaxlib.xla_client
Note that a few call sites in the diff got a ``# type: ignore``, because
the latest jaxlib does not have up-to-date signatures for the correpsonding
callables.
2021-08-16 20:33:36 +01:00
Matthew Johnson
b9d72a480f improve concreteness error from arguments
also tweak some error message wording
2021-05-03 17:37:34 -07:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
jax authors
d2c53e0560 Merge pull request from shoyer:static-kwargs2
PiperOrigin-RevId: 366459406
2021-04-02 09:37:46 -07:00
Peter Hawkins
94e2c9958c Add an api_util hook facility, and add a hook to jax2tf.convert.
The hook API is intended to make it easier to monkey-patch particular APIs, but is an internal, unsupported API.

PiperOrigin-RevId: 366194658
2021-04-01 00:05:26 -07:00
Stephan Hoyer
acb0be9cb7 Add _python_jit_with_static_argnames. 2021-03-31 10:02:16 -07:00
Peter Hawkins
6ee6c59235 Move jax.tree_util implementation to jax._src.tree_util.
NFC intended.

PiperOrigin-RevId: 364857920
2021-03-24 12:00:38 -07:00
Matthew Johnson
5a97eab2a0 improve error messages for grad(..., has_aux=True)
fixes 
2021-02-18 09:46:16 -08:00
Roy Frostig
afecab9ad7 accept any arguments with shape/dtype attributes after make_jaxpr 2021-02-10 17:07:10 -08:00
Matthew Johnson
304685a152 allow vmapped function to accept kwargs
Arguments passed as keywords are always batched along their leading
axis. The in_tree specification must correspond to arguments passed
positionally.

This brings vmap in line with pmap. That is, pmap already followed this
convention for arguments passed via keywords. Consistency is good!

I had to adapt some utility functions so as not to change the error
messages raised. In particular, we have tests for vmap error messages
which report the in_axes and argument tree structure; naively including
keyword arguments changed those error messages. The error messages are
worth preserving. This change also brought the pmap error messages in
line with the vmap ones.

I also did some 80char wrapping of lines and docstring updating.

Fixes . Another user had the same issue and reported the same
expected behavior.
2021-01-12 20:13:23 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Jake VanderPlas
a0562dc9c9 api: handle numpy integers for static argnums 2020-12-14 14:52:51 -08:00
Peter Hawkins
81b6cd29ff [JAX] Move traceback_util.py into jax._src.
traceback_util is a JAX-private API.

PiperOrigin-RevId: 340659195
2020-11-04 09:02:59 -08:00
Matthew Johnson
1c4d873fa1 remove unused import 2020-10-27 21:40:41 -07:00
Jean-Baptiste Lespiau
cb48f42372 Raise an error on non-hashable static arguments for jax.jit and xla_computation.
Up to now, Jax was silently wrapping the object to ensure objects which are not hashable will be hashed using `id` and compared using `is`:

```
class WrapHashably(object):
  __slots__ = ["val"]
  def __init__(self, val):
    self.val = val
  def __hash__(self):
    return id(self.val)
  def __eq__(self, other):
    return self.val is other.val
```

This means that when providing different instances of objects that are non hashable, a recompilation was always occurring. This can be non-intuitive, for example with:

@partial(jax.jit, static_argnums=(1,))
def sum(a, b):
  return a+ b
sum(np.asarray([1,2,3]), np.asarray([4,5,6])
# The next line will recompile, because the 1-indexed argument is non
# hashable and thus compared by identity with different instances
sum(np.asarray([1,2,3]), np.asarray([4,5,6])

or more simply
np.pad(a, [2, 3], 'constant', constant_values=(4, 6))
          ^^^^^^
          non-hashable static argument.

The same problems can occur with any non-hashable types such as lists, dicts, etc. Even JAX itself was having some issues with this (which shows the behaviour was non-trivial to reason about).

If this commit breaks you, you usually have one of the following options:
- If specifying numpy array or jnp arrays arguments as static, you probably simply need to make them non static.
- When using non-hashable values, such as list, dicts or sets, you can simply use non-mutable versions, with tuples, frozendict, and frozenset.
- You can also change the way the function is defined, to capture these non-hashable arguments by closure, returning the jitted function.

PiperOrigin-RevId: 339351798
2020-10-27 16:12:24 -07:00
Roy Frostig
fa1d7ab5fa move wraps from api_util to util to avoid cyclic dependencies 2020-10-26 12:31:19 -07:00
Roy Frostig
5d50e19364 add path exclusion opt-in to filtered stack traces and use it throughout the codebase 2020-10-26 12:31:19 -07:00
Jean-Baptiste Lespiau
98ee69ba1c Emit a warning on non-hashble static arguments in jax.jit.
The message looks like, e.g.:

Static argument (index 1) of type <class 'numpy.ndarray'> for function f is non-hashable. As this can lead to unexpected cache-misses, it will raise an error in a near future.
2020-10-23 22:29:41 +02:00
Matthew Johnson
e88579f22b fix typo 2020-09-18 19:41:59 -07:00
Matthew Johnson
f172fb74e1 plumb donate_argnums into jax.xla_computation 2020-09-18 17:39:05 -07:00
Matthew Johnson
107689e91f
improve vmap axis spec structure mismatch errors ()
* improve vmap axis spec structure mismatch errors

fixes 

* deflake
2020-06-30 22:19:16 -07:00
Jake Vanderplas
2a10dbbf37
deflake remainder of jax () 2020-06-06 10:51:34 -07:00
Tom Hennigan
6124f703af
Add support for buffer donation in jit and pmap. ()
For a computation of the form:

    >>> f = lambda x: x ** 2
    >>> f = jax.jit(f)
    >>> while run:
    ...   x = f(x)

JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:

  1. Users at the limit of available device are constrained by the additional
     copy of their parameters and other state while they typically only require
     one copy. This typically frees 100M+ of device memory and is a critical
     optimization for larger models to match state of the art performance in
     other frameworks.

  2. This constant alloc/free of the input/output buffers can cause memory
     fragmentation on some platforms (although having a reusing allocator and
     limiting run-ahead may be a better solution for this problem).

We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:

    >>> f = lambda x: x ** 2
    >>> f = jit(f, donate_argnums=0)
    >>> while run:
    ...   x = f(x)

JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.

If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:

    >>> y = f(x)
    >>> jax.device_get(x)
    ...
    RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.

The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.

One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:

    >>> @partial(jit, donate_argnums=0)
    ... def move(x):
    ...   # Do something complex enough for JAX to just optimize it away.
    ...   return tree_map(lambda x: x + x - x, x)

    >>> def safe_eager_uniform(key, *a, **k):
    ...   assert hasattr(key, 'device_buffer'), "random must run eagerly"
    ...   key = move(key)
    ...   return jax.random.uniform(key, *a, **k)

This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 15:00:16 -07:00
Skye Wanderman-Milne
a5da921f4c
Move _flatten_axes to api_util.py ()
This is in preparation for using it in sharded_jit.py (since sharded_jit isn't included in api.py yet).
2020-05-11 11:04:57 -07:00
Matthew Johnson
7e480fa923 add custom_jvp / vjp, delete custom_transforms 2020-03-21 22:08:03 -07:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. ()
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
George Necula
528a69f32e Added some more documentation to the linear_util module
Also cleaned up the inconsistent way of importing the module.
Prefer importing with qualified name 'lu.transformation' rather
than just 'transformation'.
2020-01-05 16:40:26 +01:00
Matthew Johnson
979b38352f make vmap structured axes work for any pytree 2019-10-31 14:09:12 -07:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Dougal Maclaurin
c53c8bbb43 Some progress de-tupling ad.py 2019-08-21 07:01:07 -07:00
Dougal Maclaurin
6d71396d56 Start exploring jaxprs without tuples
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Peter Hawkins
476dc3db64 Python changes in preparation for adding a C++ implementation of the PyTree utilities. 2019-07-29 10:57:27 -04:00
Matthew Johnson
0546c94992 speed up pmap axis-size getting
Co-authored-by: Peter Hawkins <phawkins@google.com>
2019-07-25 12:41:31 -07:00
Matthew Johnson
b6031ffdd7 avoid packing leaf outputs for jit/pmap funs 2019-05-17 07:36:52 -07:00
Matthew Johnson
15a4554ffb flatten out pytrees in jit at the api.py level 2019-05-03 11:39:37 -07:00
Matthew Johnson
9c2e1c35b1 prevent jit from treating keyword args as static
fixes 
2019-04-10 22:09:14 -07:00
Matthew Johnson
902c149c47 add partial value lattice join, cond support
This change allows one side of a cond to have a different const-ness
from the other side, from the point-of-view of partial evaluation. In
other words, this now works as expected:

```python
lax.cond(x < 0, x, lambda x: 0., x, lambda x: x)  # relu
```

The partial evaluation logic works with tuples, so this works too:

```python
lax.cond(x < 0,
         x, lambda x: (x, x, 1, 1, 1),
         x, lambda x: (x, 1, x, 1, 2))
```

in that true_fun is resolved to something like `lambda x: (x, x, 1, *, 1)`
and false_fun is resolved to something like `lambda x: (x, 1, x, *, 2)`,
where `*` means unit and corresponds to a known constant that isn't
staged into the computation.

For forward-mode autodiff support, we'll need to add yet another lattice
join on the lattice of symbolic-zero-or-not.
2019-03-02 17:37:38 -08:00
Peter Hawkins
3e25d290be Set __wrapped__ attribute instead of using functools.wraps to fix Python 2.7 problem. 2019-02-14 11:00:40 -05:00
Peter Hawkins
33cd3d0299 Use functools.wraps as the basis for api_util.wraps.
Fixes API signatures in `jax.random` documentation (https://github.com/google/jax/issues/370).
2019-02-14 10:07:47 -05:00
Matthew Johnson
da2d185444 tweak 2019-01-28 09:19:06 -08:00
Matthew Johnson
945fa34e7e tweaks 2019-01-28 09:00:02 -08:00
Matthew Johnson
780106f892 moving pxla flattening/chunking to api.py, wip 2019-01-28 08:38:14 -08:00
Matthew Johnson
0f7c7c4eab generalize jacfwd and jacrev to handle pytrees 2019-01-06 12:49:41 -08:00