5161 Commits

Author SHA1 Message Date
Peter Hawkins
e558feaa5e Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
2024-01-15 02:13:40 -08:00
Sergei Lebedev
2d8a20c413 Do not load Triton bindings eagerly in jax/lib/__init__.py
Triton is only used by Pallas, so it makes sense to delay loading until Pallas
is imported.

PiperOrigin-RevId: 598131836
2024-01-13 03:01:02 -08:00
jax authors
b8b119d9b9 Cleanup deprecated compilation cache APIs.
Since the compilation cache is now initialized lazily,
existing APIs initialize_cache() and is_initialized()
are confusing. Deprecate these APIs.

Introduce a new API set_cache_dir() to explicitly set the
cache directory path in code.

Testing: revised unit tests, test workload.
PiperOrigin-RevId: 598073423
2024-01-12 22:44:48 -08:00
jax authors
7418f55987 Merge pull request #19007 from sshahrokhi:enhanced
PiperOrigin-RevId: 598002075
2024-01-12 17:12:16 -08:00
jax authors
1ab5a3579a Merge pull request #19348 from jakevdp:scipy-version
PiperOrigin-RevId: 597978339
2024-01-12 15:48:20 -08:00
Shiva Shahrokhi
65f3e4fffd making sure enhanced barrier only turns on when there is a supported TPU available. 2024-01-12 23:47:37 +00:00
Jake VanderPlas
1870eee062 Test: make scipy version parsing compatible with pre-releases 2024-01-12 14:35:28 -08:00
Jake VanderPlas
989618c5f7 [array api] add jax.numpy.concat 2024-01-12 13:12:09 -08:00
Mohammed Anany
31efc2cc6a Import openai/triton from GitHub.
PiperOrigin-RevId: 597848585
2024-01-12 08:37:54 -08:00
Adam Paszke
f625fb69da [Mosaic] Add support for tile-aligned dynamic offsets in loads, stores and ref slices
PiperOrigin-RevId: 597798116
2024-01-12 03:42:58 -08:00
jax authors
adbbe69cc2 Add option to share compiled module between hosts.
PiperOrigin-RevId: 597754861
2024-01-11 23:38:02 -08:00
George Necula
3b7917a56e [shape_poly] Improve and rename export.args_specs.
We rename it to `symbolic_args_specs` in line with the other
public APIs related to shape polymorphism. The function used to
be in _export.py for historical reasons, we now move it to
shape_poly.py but we export the `symbolci_args_specs` from
the public `jax.experimental.export`.

The improvement is that for the case when the `args` passed in
are TF arrays, we move the logic to extract the shapes and dtypes
from this function to the callers. This achieves a better
separation of the JAX and TF use cases.
2024-01-12 08:11:03 +02:00
jax authors
8a8cd6d01a Merge pull request #19321 from jakevdp:diagonal
PiperOrigin-RevId: 597660818
2024-01-11 14:55:50 -08:00
Sharad Vikram
5258f7cdcc Fix all_gather_test
PiperOrigin-RevId: 597657438
2024-01-11 14:38:43 -08:00
Sharad Vikram
548bdd02a8 Add verbose kwarg to assertArraysEqual
PiperOrigin-RevId: 597650736
2024-01-11 14:11:53 -08:00
Sergei Lebedev
5b7a0d9c91 Pallas now uses MLIR Python builders to lower to Triton IR
This allows us to drop a dependency on the Triton Python package in the future,
and delegate ->ptx compilation to XLA.

PiperOrigin-RevId: 597640756
2024-01-11 13:33:26 -08:00
Jake VanderPlas
b08a010949 [array API] add jnp.linalg.diagonal 2024-01-11 12:52:15 -08:00
Adam Paszke
ce00e10d9b [Pallas][Mosaic] Add support for nontrivial semaphore memrefs
The previous patch simply changed the type we use to represent semaphores,
but didn't actually add support for any more operations. With this one,
semaphore memrefs can be allocated and (dynamically) indexed.

PiperOrigin-RevId: 597538913
2024-01-11 06:33:49 -08:00
Jake VanderPlas
1a39d8fdb2 [array API] implement jnp.pow; alias for jnp.power 2024-01-10 14:59:46 -08:00
Jake VanderPlas
4e55086dfb array api: add jnp.bitwise_* aliases 2024-01-10 14:22:20 -08:00
Jake VanderPlas
9890b23b0a Add jnp.vecdot 2024-01-10 13:11:37 -08:00
jax authors
7ad53a6e67 Merge pull request #19276 from jakevdp:mypy-version
PiperOrigin-RevId: 597332904
2024-01-10 13:09:23 -08:00
Tomás Longeri
027c24e602 [Mosaic] Remove Python implementation of apply_vector_layout and infer_memref_layout.
PiperOrigin-RevId: 597332393
2024-01-10 13:00:21 -08:00
Eugene Zhulenev
ba4c2b1c7d [pjrt:cpu] Add CpuTopology to TfrtCpuClient and enable persistent compilation cache for cpu backend
PiperOrigin-RevId: 597327136
2024-01-10 12:40:57 -08:00
Jake VanderPlas
6569b320b2 CI: bump mypy to version 1.8.0 2024-01-10 10:20:55 -08:00
Sergei Lebedev
ba10775eda Added a compatibility overlay for Triton Python APIs
Follow up changes will gradually re-implement these APIs using the MLIR
builders added in google/jax#19159.

PiperOrigin-RevId: 597023799
2024-01-09 13:13:56 -08:00
Jake VanderPlas
be8183d746 pallas: improve indexing trace time 2024-01-09 11:32:00 -08:00
jax authors
e76f514b49 Merge pull request #19270 from jakevdp:permute-dims-sig
PiperOrigin-RevId: 596992011
2024-01-09 11:20:20 -08:00
jax authors
2356e57af4 [Mosaic] Add lowering rule for jax.lax.logical_not
jax.lax.logical_not(x) is lowered into: xor x, 0xffffffff

PiperOrigin-RevId: 596965921
2024-01-09 09:58:43 -08:00
Jake VanderPlas
707657e5b7 Adjust permute_dims signature to match NumPy
This really doesn't matter because it's a position-only argument, but this
change satisfies our tests and is easier than making the tests smarter.
2024-01-09 09:56:19 -08:00
jax authors
7ad7890068 Merge pull request #18953 from gnecula:poly_min_max
PiperOrigin-RevId: 596768330
2024-01-08 19:37:00 -08:00
jax authors
6a99e38a82 Merge pull request #19252 from jakevdp:fix-vecdot
PiperOrigin-RevId: 596762791
2024-01-08 18:54:07 -08:00
jax authors
9556b09349 Merge pull request #19253 from jakevdp:fix-bitwise-count
PiperOrigin-RevId: 596756207
2024-01-08 18:20:13 -08:00
Jake VanderPlas
f901bead48 jnp.linalg.vecdot: fix broadcasting & conjugation semantics 2024-01-08 18:03:44 -08:00
jax authors
856915f3c4 Merge pull request #19244 from jakevdp:permute-dims
PiperOrigin-RevId: 596741266
2024-01-08 17:02:46 -08:00
Jake VanderPlas
60665f90ec jnp.bitwise_count: promote input to numeric 2024-01-08 15:19:54 -08:00
jax authors
da96633f11 Correct the cache miss metric instrumentation due to the new min cache entry size flag
Since introduction of the min cache entry size check for compilation cache, the cache miss metric overcounts the skipped caches whose sizes are smaller than the min cache entry size. After moving the metric instrumentation to compilation_cache.put_executable_and_time, the cache miss metric will be incremented if both compile time and cache entry size are greater than the minimum thresholds.

PiperOrigin-RevId: 596696013
2024-01-08 14:03:33 -08:00
George Necula
6b7b3a3902 [shape_poly] Replace non_negative_dim with max_dim and min_dim.
Previously, we had `core.non_negative_dim` and we used it to
express `max(d, 0)`. This is needed in several places internally
to express index computations involving clamping (for numpy
indexing), or striding and dilation (which have a conditional
semantics). It seemed that this special case was sufficient,
and we expressed `max(a, b)` as `a + non_negative(b - a)` and
`min(a, b)` as `a - non_negative(a - b)`.

One drawback was that `non_negative` can be a surprising
construct when it appears in error messages. Also, users need
`max` and `min` computations with dimensions. It is clearer if
we use `max` and `min` directly instead of rewriting these to
use `non_negative`. The drawback is that we now have to duplicate
some internal logic to for `max` and `min`, but overall I feel
this is worth it for the better error messages we get.
2024-01-08 20:54:18 +02:00
jax authors
f33f0e4337 Merge pull request #19220 from jakevdp:dep-kind
PiperOrigin-RevId: 596636214
2024-01-08 10:32:34 -08:00
Jake VanderPlas
6ab9398574 Remove reference to np.ComplexWarning
This is not available in NumPy 2.0.
2024-01-08 09:53:06 -08:00
Jake VanderPlas
d673b9bf5c [array api] add jax.numpy.permute_dims function 2024-01-08 09:30:51 -08:00
Adam Paszke
2a60c184af Use grid_mapping instead of mosaic_grid_mapping to query scalar prefetch
mosaic_grid_mapping sometimes extend the scalar prefetch operands with arrays
describing the mesh, throwing off the math in the recent patch.

PiperOrigin-RevId: 596596640
2024-01-08 08:14:21 -08:00
George Necula
69788d18b6 [export] Refactor the imports for the public API of jax.experimental.export
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1: export.Exported = export.deserialized(ser)
export.call(exp1)
```

This change requires changing the type of
`jax.experimental.export.export` from a
module to a function. This confuses
pytype for the targets with strict type checking,
which is why I attempt to make this change
atomically throughout the internal code base.

In order to support backwards compatibility with
OSS packages, this change also includes explicit
JAX version checks in several OSS packages, and
also adds to the `export` function the attributes
that the old export module had.

PiperOrigin-RevId: 596563481
2024-01-08 05:29:56 -08:00
George Necula
45fa5021db [shape_poly] Minor improvements for handling of dilation in convolutions.
Use core.dilate_dim, it is clearer and especially when using
shape polymoprhism it will result in better error messages.

Replace the use of np.maximum with core.max_dim, because the
latter will result in fewer errors in presence of shape polymorphism
(the core.max_dim is deferring the actual computation to shape
refinement time).
2024-01-08 08:56:53 +02:00
Tomás Longeri
88542f0e56 [Mosaic] Run C++ passes from within custom_call_emitter.cc
PiperOrigin-RevId: 596464480
2024-01-07 20:25:45 -08:00
jax authors
16699e4e78 Merge pull request #19224 from jakevdp:batched-solve
PiperOrigin-RevId: 596313955
2024-01-06 21:23:13 -08:00
Jake VanderPlas
20a1972d27 jnp.linalg.solve: fully implement batched cases 2024-01-05 15:21:49 -08:00
Yash Katariya
b8098b1782 Remove indices and devices from shard_arg_handlers and shard_args.
This only affects python dispatch path. This has no impact on the speed of cpp dispatch (which is why benchmarks are **not** regressing).

If your code ends up taking the python dispatch, then something is going wrong anyways.

PiperOrigin-RevId: 596081987
2024-01-05 14:17:14 -08:00
jax authors
ed62f28164 Triton integrate 2023-12-18
PiperOrigin-RevId: 596038346
2024-01-05 11:09:08 -08:00
Jake VanderPlas
6278363e25 jnp.argsort/sort: explicitly deprecate the kind argument
This argument is a carry-over from NumPy, and has never had any effect (all jax.numpy
sorts were stable by default). Now that the new stable parameter is supported, it will
be clearer if we explicitly deprecate and eventually remove this argument.
2024-01-05 09:19:36 -08:00