63 Commits

Author SHA1 Message Date
Peter Hawkins
7de9eb20df Reverts 525b646c0ebd5205f4fa0639c94adb2de47e1cf0
PiperOrigin-RevId: 707146329
2024-12-17 10:12:34 -08:00
Yash Katariya
944d822ce6 Add a no-op batching rule for optimization_barrier_p
PiperOrigin-RevId: 704507586
2024-12-09 19:21:07 -08:00
Peter Hawkins
525b646c0e Reverts 2075b091c4e83f0bdbd0d47812a72114fb8b937a
PiperOrigin-RevId: 698152759
2024-11-19 14:47:24 -08:00
Peter Hawkins
2c80d1af50 Add a new API jax.lax.split.
This API does not add expressive power, since it is already possible to split arrays by repeated slicing. Its purpose is to be a primitive that is the transpose of `lax.concatenate`, so that primitives like `jnp.unstack` can be differentiatied more efficiently.

Before:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[5,3] = pad[padding_config=((4, 0, 0), (0, 0, 0))] l 0.0
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          o:f32[5,3] = pad[padding_config=((3, 1, 0), (0, 0, 0))] n 0.0
          p:f32[5,3] = add_any m o
          q:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          r:f32[5,3] = pad[padding_config=((2, 2, 0), (0, 0, 0))] q 0.0
          s:f32[5,3] = add_any p r
          t:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          u:f32[5,3] = pad[padding_config=((1, 3, 0), (0, 0, 0))] t 0.0
          v:f32[5,3] = add_any s u
          w:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          x:f32[5,3] = pad[padding_config=((0, 4, 0), (0, 0, 0))] w 0.0
          y:f32[5,3] = add_any v x
        in (y,) }
    ] a b c d e
  in (f,) }
```

Note in particular the `pad` calls, which are the transpose of `slice`. Transposing the split has the effect of forming many dense intermediate cotangents.

After:
```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jnp.ones((3,))

In [3]: jax.jit(jax.linear_transpose(lambda xs: jnp.unstack(xs), jnp.ones((5, 3)))).trace((x,)*5).jaxpr
Out[3]:
{ lambda ; a:f32[3] b:f32[3] c:f32[3] d:f32[3] e:f32[3]. let
    f:f32[5,3] = pjit[
      name=unstack
      jaxpr={ lambda ; g:f32[3] h:f32[3] i:f32[3] j:f32[3] k:f32[3]. let
          l:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] k
          m:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] j
          n:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] i
          o:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] h
          p:f32[1,3] = broadcast_in_dim[
            broadcast_dimensions=(1,)
            shape=(1, 3)
            sharding=None
          ] g
          q:f32[5,3] = concatenate[dimension=0] p o n m l
        in (q,) }
    ] a b c d e
  in (f,) }
```
2024-11-19 15:25:47 -05:00
Peter Hawkins
94abaf430e Add lax.FftType.
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.

We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.

PiperOrigin-RevId: 684447186
2024-10-10 08:07:35 -07:00
Tom Natan
ed5ba633d4 Reverts 6cf09f8c24c67ff650b95d174501fff3cb59db0d
PiperOrigin-RevId: 682440543
2024-10-04 13:56:27 -07:00
Tom Natan
6cf09f8c24 Reverts eff00cc4499cfe3f3f24bafda6c1ecf908232ff3
PiperOrigin-RevId: 678756266
2024-09-25 10:33:53 -07:00
Tom Natan
eff00cc449 [JAX] add support for gather/scatter batching dims following the new attributes in stablehlo.
This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See https://github.com/openxla/stablehlo/pull/2259

PiperOrigin-RevId: 678649138
2024-09-25 04:53:11 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Jake VanderPlas
1b3da85758 Fix scatter batching rule for scatter_apply
The issue is that the batching rule assumes that each scatter variant
always has the same update_jaxpr. This is not true of scatter_apply, which
lowers to scatter with a custom update_jaxpr. To address this, we change
the batching rule such that it re-uses the input jaxpr rather than always
re-generating it.
2023-07-10 16:42:45 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
pizzud
22cbf95e07 lax_vmap_test: Extend timeout so that the TPU variant can run in ASAN.
Unfortunately we can't conditionally change the timeout, as size and timeout
are both non-configurable even if jax_test supported setting the size.

PiperOrigin-RevId: 514745247
2023-03-07 08:49:42 -08:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
pizzud
631e4ed7e0 lax_test: Create a separate module for lax-specific test utils in a new package.
These utils are currently shared with lax_vmap_test by importing lax_test as a
library, which is an odd thing to do.

The new package and the module within it are not built into the wheel, as these
are internal utilities for JAX's tests, not utilities for JAX users writing
their own tests.

Followup changes will add additional existing internal test utilities to this
package. This will allow removing sys.path manipulation from
deprecation_module_test and hopefully lazy_loader_test, as well as removing
the non-public test_util.py from _src to make it clearer that it should not be
used from outside JAX.

PiperOrigin-RevId: 510260230
2023-02-16 15:29:41 -08:00
Peter Hawkins
b80af85298 Reenable GPU int matmul test since the XLA bug is fixed.
PiperOrigin-RevId: 495117439
2022-12-13 13:50:42 -08:00
Jake VanderPlas
f09fd8a4e9 [x64] minor test-only updates for better type safety 2022-11-30 15:18:40 -08:00
Peter Hawkins
e42e52d4aa Rename test flag --num_generated_cases to --jax_num_generated_cases.
parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again.

It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change.

Fix many test cases that were shown to be broken with a larger number of test cases enabled.

PiperOrigin-RevId: 487406670
2022-11-09 18:58:05 -08:00
Peter Hawkins
2c946b3b56 Migrate api_test, lax_numpy_test, and lax_vmap_test to
jtu.sample_product.

Gives a ~2x improvement in pytest --collect-only timing for
lax_numpy_test.
2022-10-05 13:46:19 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
498fd2083e Merge pull request #12122 from hawkinsp:fft
PiperOrigin-RevId: 470294824
2022-08-26 11:32:07 -07:00
Jake VanderPlas
7d21d6ef9d [x64] make vmap_test compatible with strict dtype promotion 2022-06-17 15:55:08 -07:00
Peter Hawkins
ece9b999fb Fix batching rule for convolution for batch dimensions of size 0. 2022-06-01 14:18:16 -04:00
Matthew Johnson
39c2f8b051 fixup from 5415306: remove extraneous lines
also add test
2022-03-11 15:19:10 -08:00
Roy Frostig
64572795b7 remove _select_and_{gather,scatter}_add from public jax.lax module 2022-03-10 10:43:42 -08:00
jax authors
2a3f936ffa Merge pull request #9576 from nicholasjng:broadcast-validation
PiperOrigin-RevId: 432531230
2022-03-04 14:21:17 -08:00
Nicholas Junge
56546d3e73 Validate lax.broadcast_shape inputs before control flow execution
This commit addresses previously unvalidated inputs to `jax.lax.broadcast_shapes` by adding a small validation check before control flow execution. A legal input to `lax.broadcast_shapes` hereafter is defined as an input that
1) is a sequence (i.e., implements for..in iteration) of integers and
2) said integers are all non-negative.

In addition, two tests were added to `tests.lax_vmap_test` to check that proper errors are raised when attempting to use illegal inputs with `lax.broadcast_shapes`.
2022-03-04 19:27:52 +01:00
Peter Hawkins
ba4db33f48 Implement lax.reduce() batching rule case where batch dimensions differ between operands. 2021-11-03 09:42:01 -04:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Peter Hawkins
9f083d11da Use jax.* APIs rather than api.* names in tests.
Tests should use our own public APIs where they exist.
2021-09-13 16:01:32 -04:00
Stephan Hoyer
22943ef839 Add jit to lax.fft
The main motivation here is ensuring that FFTs are always marked in
profiler results, which is not necessarily the case where running on
TPUs.

I would jit decorate the user facing functions in jax.numpy.fft, but
these functions also accept parameters as lists, e.g., for axes, which
are mutable and hence not valid as direct input into jit decorated
functions. This might be worth doing, but would be a breaking change.
2021-08-30 09:28:35 -07:00
Yash Katariya
80f9942b66 Fix the import. tests.lax_test -> lax_test since they are in the same directory.
PiperOrigin-RevId: 388813151
2021-08-04 16:44:52 -07:00
David Majnemer
f0c30492dc Remove stale limitations
PiperOrigin-RevId: 383652551
2021-07-08 09:42:26 -07:00
Qiao Zhang
e8e6138c75 Add vmap rule for lax.clamp. 2021-06-21 14:59:34 -07:00
jax authors
6cc4bb0476 Merge pull request #6420 from apaszke:faster-tests
PiperOrigin-RevId: 368519866
2021-04-14 15:24:06 -07:00
Adam Paszke
e0357283a6 Speed up test generation 7x
There are a few test cases that generate millions of configurations,
only to have a handful of them selected by `cases_form_list`. I've
found all tests that spend over 100ms in case generation and
converted them to a new "test sampler" approach. The result: test
generation time drops from 15s to around 2s. Doesn't sound like much,
but I expect that we all run tests many times daily, so it seems like a
useful thing to have.

The rough idea is that the sampling generators get parameterized by a
sampler function that should be applied to the range of every `for` loop.
This allows us to sample runs of the generator through different
configurations by restricting each loop to a smaller subset. Right now
we always narrow it down to a single randomly selected instance. But,
we still retain the possibility of adding exhaustive testing in the
future, which can be achieved by passing in an identity sampling
function that wouldn't modify any loop ranges.
2021-04-14 15:58:05 +00:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Jake VanderPlas
93a1882e86 lax.reduce_precision: add basic batching & masking rule 2021-04-05 14:16:50 -07:00
Peter Hawkins
0dd1b5516d Implement lax.pad batching rule for batched padding values. 2021-02-17 13:30:20 -05:00
Peter Hawkins
48393c1def Relax some test tolerances of tests that fail on GPU in -x32 mode. 2021-02-16 22:14:43 -05:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Peter Hawkins
308e7f95b0 Fix batching rule for gather where the batch dimension has size 0. 2020-12-14 22:27:34 -05:00
Jake VanderPlas
4ac1c60179 Cleanup: remove rng_factory boilerplate in lax_vmap_test.py 2020-12-08 11:39:36 -08:00
Jake VanderPlas
4e738bf56b Cleanup: remove obsolete TODOs 2020-11-20 10:15:48 -08:00
Peter Hawkins
e863103b0e Add a reverse=... argument to lax.cumsum/cumprod/...
This allows us to lower to a more efficient implementation on TPU.
2020-10-16 10:09:11 -04:00
Peter Hawkins
2d74037d42 Change select_and_scatter_add's batching rule to avoid unrolling.
Add a test for select_and_scatter_add batching.
2020-10-02 12:02:34 -04:00