19462 Commits

Author SHA1 Message Date
Sergei Lebedev
57e59eb6c3 Removed deprecated jax.config methods and jax.config.config
Reverts dcc65e621ea3a68fdc79fa9f2c995743a7b3faf7

PiperOrigin-RevId: 608676645
2024-02-20 11:25:16 -08:00
Peter Hawkins
f1ea67117e Split name_stack out of mlir.ModuleContext.
A unique name_stack is built for every equation, which means that we're constantly rebuilding ModuleContext objects, even though the lifetime of almost everything else (naturally) is the Module scope. Split name_stack into an object that is threaded separately, including as part of mlir.LoweringRuleContext.

PiperOrigin-RevId: 608594374
2024-02-20 07:17:23 -08:00
Peter Hawkins
2165611584 Fix code to populate defaults for boolean flags from environment variables.
PiperOrigin-RevId: 608574620
2024-02-20 05:57:23 -08:00
jax authors
c69a5daca3 Merge pull request #19871 from gnecula:poly_stride_sym
PiperOrigin-RevId: 608566786
2024-02-20 05:16:47 -08:00
Sergei Lebedev
37f313ab22 Fixed internal CI builds
* Added a noop config_tags_overrides parameter to jax_test()
* Updated BUILD files necessary to run Pallas tests via Bazel
* Changed PallasTest to skip "large" test cases

PiperOrigin-RevId: 608534008
2024-02-20 02:42:14 -08:00
jax authors
014d10a01f Update XLA dependency to use revision
5cf2e12e3a.

PiperOrigin-RevId: 608467246
2024-02-19 20:58:23 -08:00
Sergei Lebedev
22a8da5e76 Allow dynamic grid on GPU in interpret mode
PiperOrigin-RevId: 608348239
2024-02-19 08:59:57 -08:00
Sergei Lebedev
46ec581c55 Added a few missing compute capability checks to Pallas:GPU tests
PiperOrigin-RevId: 608348004
2024-02-19 08:51:03 -08:00
Thomas Köppe
dcc65e621e Reverts b506fee9e389391efb1336bc7575dba913e75cdf
PiperOrigin-RevId: 608319964
2024-02-19 06:23:00 -08:00
Sergei Lebedev
2101725d31 Generate per-backend tests in //tests/pallas
PiperOrigin-RevId: 608288828
2024-02-19 06:22:48 -08:00
Sergei Lebedev
07d6b81326 Fix an AttributeError when importing jax.experiemntal.pallas with CPU jaxlib
PiperOrigin-RevId: 608281921
2024-02-19 06:22:35 -08:00
jax authors
9721a1b6f0 Update XLA dependency to use revision
4f945362ca.

PiperOrigin-RevId: 608217823
2024-02-19 06:22:24 -08:00
jax authors
8b83c7b58a Update XLA dependency to use revision
d6cfa2451e.

PiperOrigin-RevId: 608061229
2024-02-19 06:22:12 -08:00
jax authors
4ddce023e2 Update XLA dependency to use revision
840f8e0405.

PiperOrigin-RevId: 607893507
2024-02-19 06:22:00 -08:00
Enrique Piqueras
3fcde4baec Fix pipeline AG test.
PiperOrigin-RevId: 607813218
2024-02-19 06:21:47 -08:00
Jake VanderPlas
a6732f93ef Remove unnecessary jax.config imports
PiperOrigin-RevId: 607806346
2024-02-19 06:21:33 -08:00
Sergei Lebedev
b506fee9e3 Removed deprecated jax.config methods and jax.config.config
Reverts eb0343683547b6e2d29245f3ab6c91037c0cff81

PiperOrigin-RevId: 607803834
2024-02-19 06:21:15 -08:00
Rebecca Chen
6b7dd6ff38 Internal change
PiperOrigin-RevId: 607803148
2024-02-19 14:09:25 +00:00
Anselm Levskaya
772743e6a4 Internal change
Reverts 330afdc8bebe900d999202c4d59613e99cadb0ad

PiperOrigin-RevId: 607783139
2024-02-19 14:03:09 +00:00
jax authors
aa13be1ea7 Internal change
PiperOrigin-RevId: 607767219
2024-02-19 14:01:44 +00:00
George Necula
30ddc400b8 [shape_poly] Fix handling of stride_in_dim with symbolic stride.
The fix is simple, just avoid using `int(stride)`.
While fixing this I discovered some issues with a test
being disabled and handling of division by 0 when
computing the bounds of floordiv.
2024-02-19 12:36:26 +01:00
jax authors
ceb198582b Merge pull request #19181 from jakevdp:scalar-conversion
PiperOrigin-RevId: 607763395
2024-02-16 12:14:13 -08:00
Jake VanderPlas
1fe46aa8be Error for deprecated scalar conversions of non-scalar arrays 2024-02-16 11:26:30 -08:00
jax authors
0c92f55048 Merge pull request #19842 from helpingstar:fix_typo_faq0
PiperOrigin-RevId: 607742453
2024-02-16 11:14:32 -08:00
jax authors
06809b8812 Merge pull request #19832 from jakevdp:tree-transpose
PiperOrigin-RevId: 607742432
2024-02-16 11:05:46 -08:00
jax authors
83a041ef52 Merge pull request #19845 from gnecula:poly_speed3
PiperOrigin-RevId: 607709100
2024-02-16 09:19:33 -08:00
jax authors
eb03436835 Reverts 318a19a89387caebd116168c4e47592e7d71ca65
PiperOrigin-RevId: 607708463
2024-02-16 09:11:05 -08:00
Tomás Longeri
14474acf76 [Mosaic] Fix mistake in error message
PiperOrigin-RevId: 607700109
2024-02-16 08:38:30 -08:00
George Necula
bb57fb71e2 [shape_poly] Performance improvements for symbolic dimension manipulations (step 3)
We make the following improvements:

  * Add a `linear_combination` function to use for computing
    linear combinations fo symbolic expressions. E.g, `a - b` used
    to involve 2 operations: "-1 * b" and "a + -1*b".
  * Change the representation of terms (_DimMon) from a dictionary
    mapping factors (_DimAtom) to exponents, into a sorted tuple of
    pairs (factor, exponent). This is worthwhile because in almost
    all cases a term contains a single factor. Everywhere we used
    `term.items()` now we use `term._factors`.
  * Make the computation of `._hash` lazy. Previously, we used dictionaries
    heavily for symbolic expressions and we always needed the hash value,
    now we use dictionaries less.
  * Replace `t.degree` with `t.is_constant`.
  * Add `__slots__` to the representation of symbolic expressions

Micro benchmark: `a * 2 - b * 2 - a * 3 + c * 4`

After: 12.51 μsec (mean 12.6 μsec ± 105.2 nsec, of 7 runs, 20000 loops each)
Before: 40.33 μsec (mean 40.5 μsec ± 247.6 nsec, of 7 runs, 5000 loops each)
2024-02-16 17:33:34 +01:00
Sergei Lebedev
31a4921b29 Do not skip Pallas/Triton tests when compiling via XLA
PiperOrigin-RevId: 607689347
2024-02-16 07:55:29 -08:00
Sergei Lebedev
318a19a893 Removed deprecated jax.config methods
PiperOrigin-RevId: 607675571
2024-02-16 06:49:13 -08:00
Peter Hawkins
67df647988 Reland https://github.com/google/jax/pull/10573.
The original PR was reverted because of downstream breakage.

Originally we used the `Var.count` attribute to ensure `Var` instances were printed consistently regardless of context, even though only their object id was load-bearing. That is, `Var.count` was only used for pretty printing. (#1949 added a total_ordering on `Var` for reasons out of scope of JAX's core code. I'm going to figure out if that's still needed... Haiku tests all seem to pass without it.)

But #8019 revised our pretty-printing so as not to use `Var.count`. Instead it chose how to pretty-print Var instances based on their order of appearance in a jaxpr. That meant `Var.count` really wasn't useful anymore.

So this PR removes `Var.count`. Since we no longer have `Var.count`, we also don't need core.gensym to take an optional sequence of jaxprs, since that was just used to set the starting count index for new `Var`s.

In fact, `Var.__repr__` and `JaxprEqn.__repr__` were made confusing after #8019, since they could print variable names totally different from the names that would appear when the same `JaxprEqn` or `Var` objects were printed as part of a jaxpr. That is, before this PR we might have a jaxpr which printed like:

```
import jax

def f(x):
  for _ in range(3):
    x = jax.numpy.sin(x)
  return x

jaxpr = jax.make_jaxpr(f)(3.)
print(jaxpr)
# { lambda ; a:f32[]. let
#     b:f32[] = sin a
#     c:f32[] = sin b
#     d:f32[] = sin c
#   in (d,) }

_, eqn, _ = jaxpr.jaxpr.eqns
print(eqn)
# a:f32[] = sin b
```

Notice the variable names in the equation pretty-print don't correspond to any in the jaxpr pretty-print!

So this PR changes `JaxprEqn.__repr__` and `Var.__repr__` to show `Var` object ids, and in general just do less formatting (which seems consistent with the spirit of `__repr__`):
```
JaxprEqn(invars=[Var(id=140202705341552):float32[]], outvars=[Var(id=140202705339584):float32[]], primitive=sin, params={}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7f837c73d770>, name_stack=NameStack(stack=())))
```

PiperOrigin-RevId: 607664497
2024-02-16 05:57:12 -08:00
Sergei Lebedev
02e23e841e Added a flag enabling Pallas kernel compilation via XLA:GPU
PiperOrigin-RevId: 607649047
2024-02-16 04:37:45 -08:00
Sergei Lebedev
8026d198b1 Use ir.FloatType instead of a Pallas-local shim
PiperOrigin-RevId: 607635063
2024-02-16 03:31:43 -08:00
helpingstar
4297e0a24c fix typo2 2024-02-16 20:14:26 +09:00
helpingstar
fbccdb29f8 fix typo 2024-02-16 17:30:04 +09:00
jax authors
86db2de522 Merge pull request #19793 from mattjj:remat-in-hlo-metadata
PiperOrigin-RevId: 607575175
2024-02-15 22:56:55 -08:00
jax authors
330afdc8be Merge pull request #19819 from mattjj:scan-dont-traverse-body-jaxpr-in-lowering
PiperOrigin-RevId: 607570860
2024-02-15 22:36:35 -08:00
Matthew Johnson
08dfb11da2 prototype 'remat' in hlo metadata 2024-02-15 22:27:42 -08:00
Matthew Johnson
5ead7a6ef2 [scan] don't traverse body jaxpr in lowering
The main changes are:
 1. simplify the scan impl that we trace through to get the lowering, and
 2. ensure that when tracing it to a jaxpr, we don't rebuild the scan body
    jaxpr we already have in hand.

The main motivation was (2), but (1) seems like a useful win too.

The way we achieve (2) is with a new trick: in our scan_impl function, which is
only ever traced to a jaxpr, instead of calling
`core.jaxpr_as_fun(jaxpr)(*args)` we call a new primitive
`eval_jaxpr_p.bind(*args, jaxpr=jaxpr)`. This new primitive only has a staging
rule defined for it (i.e. all we can do with it is stage it into a jaxpr), and
that rule just generates a call into the jaxpr of interest. Therefore we will
not traverse into the jaxpr just to rebuild it inline (as before.
2024-02-15 22:13:22 -08:00
jax authors
ff3247e713 Update XLA dependency to use revision
2856cb22ec.

PiperOrigin-RevId: 607560674
2024-02-15 21:48:46 -08:00
Yash Katariya
0b542ff585 Add a benchmark measuring device_put's speed for a 4GB input array
```
---------------------------------------------------------
Benchmark               Time             CPU   Iterations
---------------------------------------------------------
device_put_big        419 ms        0.363 ms           10
```

PiperOrigin-RevId: 607512568
2024-02-15 17:47:10 -08:00
jax authors
7e7094c82d [JAX] Add an option subset_by_index that allows computing a contiguous subset of singular components from svd.
PiperOrigin-RevId: 607493941
2024-02-15 16:33:09 -08:00
jax authors
0203d15485 Merge pull request #19837 from jakevdp:key-reuse-clerrs
PiperOrigin-RevId: 607489807
2024-02-15 16:18:59 -08:00
Jake VanderPlas
8eab599530 [key reuse] simplify key reuse logic through context-free jaxpr evaluation
The args_consumed and forwarded_inputs context is not actually needed, because it can be checked
afterward. The only reason for this was to have more granular errors, but arguably it's better
to error on jaxpr input.
2024-02-15 15:50:50 -08:00
Tomás Longeri
243e7edc56 [Mosaic] In apply_vector_layout.cc, check layout validity when reading the attribute
This allows us to rely on this throughout the code and replace some checks with TPU_ASSERT_*. They have the semantics of an assert and make it clearer that it is an unexpected internal error (instead of unimplemented or invalid user input that we should handle).

Note: the original error messages for some of these checks were using the wrong input names.
PiperOrigin-RevId: 607463728
2024-02-15 14:51:45 -08:00
Yash Katariya
f12550964d Update the cuda 12 dependencies since we upgraded to cuda 12.3
PiperOrigin-RevId: 607453817
2024-02-15 14:29:45 -08:00
Peter Hawkins
4834423e17 Micro-optimization: use .dtype rather than dtypes.result_type when forming MLIR ndarray constants.
We don't need .result_type: we are guaranteed the value is a numpy scalar or ndarray.

PiperOrigin-RevId: 607453728
2024-02-15 14:20:50 -08:00
Peter Hawkins
b5e4ba4900 Don't call inspect.signature() each time we trace a jit().
We can just call it once when jit itself is called.

While we're here, also don't recompute api_util.fun_sourceinfo.

PiperOrigin-RevId: 607443283
2024-02-15 13:49:27 -08:00
Yash Katariya
8888006a86 Partial rollback xla translation APIs that were removed in 0.4.24 release
PiperOrigin-RevId: 607437887
2024-02-15 13:32:33 -08:00