14790 Commits

Author SHA1 Message Date
Yash Katariya
2fc64bee13 Change the axis_resources argument of with_sharding_constraint to shardings to match pjit and jit.
PiperOrigin-RevId: 509275107
2023-02-13 10:53:57 -08:00
jax authors
c49af18b9b Merge pull request #14365 from jakevdp:reducers-initial
PiperOrigin-RevId: 509253981
2023-02-13 09:43:46 -08:00
jax authors
83b7ba2ba0 Merge pull request #14444 from jakevdp:fix-csr-lowering
PiperOrigin-RevId: 509241948
2023-02-13 08:59:37 -08:00
Jake VanderPlas
58323d5b40 jax.numpy reductions: better validation of initial value 2023-02-13 08:43:25 -08:00
Jake VanderPlas
ddae1d00ea fix change to csr lowering rule 2023-02-13 08:39:05 -08:00
George Necula
2d47921a34 [shape_poly] Fix some tests for shape pol with native lowering
Native lowering requires that dimension variables be computable from the shapes of the arguments that are kept by lowering. We had some tests that were using dimension variables but were not using the actual inputs. Then lowering removes the inputs and it is not possible anymore to recover the values of the dimension variables at invocation time.

Here we primarily changed the tests to ensure they use not just the shape of the input but the actual value. In some cases we have disabled some testing, until
https://github.com/google/jax/issues/14437 is fixed

PiperOrigin-RevId: 509171805
2023-02-13 02:48:11 -08:00
George Necula
6b70728bc9 Fix sharding_test
PiperOrigin-RevId: 509048406
2023-02-12 10:26:17 -08:00
jax authors
002cdf688d Merge pull request #14432 from gnecula:call_tf_checks_2
PiperOrigin-RevId: 509043710
2023-02-12 09:34:50 -08:00
George Necula
f280d31066 [jax2tf] Minor fix: remove dedundant check 2023-02-12 16:02:53 +01:00
Yash Katariya
6caaffc20c Add in_shardings and out_shardings argument to pjit and jit to start deprecating in_axis_resources and out_axis_resources.
PiperOrigin-RevId: 508934327
2023-02-11 15:30:14 -08:00
Joan Puigcerver
bc1f5f1cbb Register missing standard primitives to shard_map.py.
PiperOrigin-RevId: 508920824
2023-02-11 12:58:22 -08:00
Joan Puigcerver
654c1d3b2b Fix _standard_rep_rule in shard_map.py when in_rep is empty.
set.intersection() with no arguments (in_rep is empty) raises an exception.

PiperOrigin-RevId: 508910287
2023-02-11 11:09:47 -08:00
jax authors
1089992017 Merge pull request #14424 from stellaraccident:devenh
PiperOrigin-RevId: 508902521
2023-02-11 09:50:50 -08:00
Peter Hawkins
612a940160 Minimize the set of names exported from jax.experimental.pjit.
PiperOrigin-RevId: 508889911
2023-02-11 07:37:32 -08:00
Yash Katariya
9316188b3a [Rollback] Convert _arrays to return PyArray instead of PyBuffer.
PiperOrigin-RevId: 508827908
2023-02-10 21:36:56 -08:00
Stella Laurenzo
c1e13bdf3f A few developer workflow enhancements for working with jaxlib.
It seems to me that jaxlib development must be mostly happening on CI, because some basics are pretty essential. Here are a few things I've been typing/carrying for a while in my flow:

* Add .bazelrc.user to .gitignore so it doesn't accidentally get checked in.
* Add configs for 'debug_symbols' and 'debug' that make some things minimally workable under a debugger (or to get backtraces, etc).
* Add `--force-reinstall` to the copy/paste command to update a built jaxlib wheel (without this, if you are iterating, it fairly quietly does nothing).
2023-02-10 21:03:21 -08:00
Peter Hawkins
61da781174 [JAX] Replace uses of jax.interpreters.xla.DeviceArray with jax.Array.
PiperOrigin-RevId: 508822404
2023-02-10 20:56:34 -08:00
jax authors
1bdcd5e138 Merge pull request #14415 from jakevdp:bcsr-matmul
PiperOrigin-RevId: 508785095
2023-02-10 16:55:05 -08:00
Roy Frostig
a262314934 prune unintended exports from jax.interpreters.batching
PiperOrigin-RevId: 508784928
2023-02-10 16:47:28 -08:00
Skye Wanderman-Milne
e54858522c Add back loading TPU plugin for older jaxlib versions.
This was removed in 668b82d529.

PiperOrigin-RevId: 508777939
2023-02-10 16:16:20 -08:00
jax authors
26ddf3b571 Merge pull request #14419 from jakevdp:spsolve-cpu-lowering
PiperOrigin-RevId: 508777573
2023-02-10 16:16:05 -08:00
Jake VanderPlas
de8a77a3eb [sparse] implement BCSR.__matmul__ 2023-02-10 16:11:57 -08:00
jax authors
fc507f2ebe Merge pull request #14418 from mattjj:vmap-spmd-axis-name-tuples
PiperOrigin-RevId: 508777043
2023-02-10 16:08:32 -08:00
Yash Katariya
0d07372995 Point to the exact primitive name nested under jit/pjit instead of mentioning all possible ones.
PiperOrigin-RevId: 508770290
2023-02-10 15:40:25 -08:00
Jake VanderPlas
552fc2c5a3 [sparse] add CPU lowering rule for sparse.linalg.spsolve 2023-02-10 15:35:42 -08:00
Parker Schuh
568a93bcd1 Convert _arrays to return PyArray instead of PyBuffer.
PiperOrigin-RevId: 508769390
2023-02-10 15:32:57 -08:00
Matthew Johnson
9538bc3e73 generalize vmap spmd_axis_name to accept tuples of axis names
This brings the argument more in line with what can appear as positional
arguments to the PartitionSpec constructor.
2023-02-10 15:25:23 -08:00
jax authors
d531ec9743 Merge pull request #14331 from canyon289:add_user_guide_sentence
PiperOrigin-RevId: 508760750
2023-02-10 14:53:17 -08:00
Peter Hawkins
2f80e46f64 [XLA:Python] Fix overly pessimistic handling of singleton dimensions in dlpack code.
Requires an accompanying jaxlib change.

Fixes https://github.com/google/jax/issues/14399

PiperOrigin-RevId: 508757315
2023-02-10 14:44:22 -08:00
jax authors
dc6bf9b725 Merge pull request #14408 from lucashofer:scipy_spence
PiperOrigin-RevId: 508756972
2023-02-10 14:36:15 -08:00
jax authors
4deb12ea1c Merge pull request #14411 from hawkinsp:kepler
PiperOrigin-RevId: 508755353
2023-02-10 14:28:32 -08:00
Ravin Kumar
18b251e251 Add notes and headeers to user guides 2023-02-10 14:17:15 -08:00
Yash Katariya
1526c3e20c Improve the error message which is raised from _get_and_check_device_assignment.
Before:

```
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU
```

After:

```
ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp)
```
PiperOrigin-RevId: 508746961
2023-02-10 13:54:15 -08:00
Lucas Hofer
4636276214 added scipy special spence
added dtype to arrays in the _spence_poly function
2023-02-10 20:33:47 +00:00
jax authors
57900d7ef2 Merge pull request #14364 from jakevdp:fix-tril-indices
PiperOrigin-RevId: 508723970
2023-02-10 12:25:06 -08:00
Peter Hawkins
6ee67639e2 Split PyTorch interoperability tests into their own test.
PiperOrigin-RevId: 508722180
2023-02-10 12:17:11 -08:00
jax authors
5da5967d08 Merge pull request #14395 from jakevdp:bcsr-dot-general
PiperOrigin-RevId: 508721790
2023-02-10 12:09:29 -08:00
Jake VanderPlas
ac647b9459 [sparse] implement autodiff rules for bcsr_dot_general 2023-02-10 12:00:30 -08:00
Peter Hawkins
ec56d71d01 Drop support for NVIDIA Kepler series GPUs in jaxlib builds. 2023-02-10 14:15:15 -05:00
jax authors
7a864d73bc Merge pull request #14394 from jakevdp:jax-array-methods
PiperOrigin-RevId: 508694486
2023-02-10 10:27:14 -08:00
George Necula
be21404085 [jax2tf] Add shard_map tests
Also fix tests to run on multiple devices in TF

PiperOrigin-RevId: 508691872
2023-02-10 10:18:19 -08:00
jax authors
d09f3c2ee4 Merge pull request #11727 from gnecula:call_tf_checks
PiperOrigin-RevId: 508685246
2023-02-10 09:51:35 -08:00
Jake VanderPlas
60256df668 [typing] define additional methods & properties on jax.Array
These are the methods that are only valid for actual materialized arrays (i.e. not Tracers)
In order to simplify the experience for users, we want to maintain only a single jax.Array
type, so we define all methods here and raise explicit errors on Tracer instances.
2023-02-10 09:42:32 -08:00
John QiangZhang
7659a3a271 Enable call_tf_native_lowering_test.
PiperOrigin-RevId: 508677359
2023-02-10 09:16:53 -08:00
jax authors
9f0783f35d Merge pull request #14403 from gnecula:reduce_precision
PiperOrigin-RevId: 508635187
2023-02-10 05:38:59 -08:00
jax authors
f070557260 Merge pull request #14400 from gnecula:native_bug1
PiperOrigin-RevId: 508635169
2023-02-10 05:30:24 -08:00
George Necula
30fda87142 [call_tf] Improve error reporting
Add more checks to catch early the cases when the called TF function
returns values that are not convertible to JAX values (arrays of
numeric values). All these cases were resulting in errors even before
but sometimes these errors were deep in the stack and harder to
diagnose.
2023-02-10 14:19:49 +01:00
George Necula
48c2538365 [jax2tf] Add support for reduce_precision 2023-02-10 13:29:46 +01:00
George Necula
ff6051fc31 [shape_poly] Better error message for functions that do not use input arguments
Also:
  * fixed some of the tests that were using the shape but not the value
  of the input arguments
  * fix importing of mlir.py due to recent move of interpreters.mlir to
  _src.interpreters.mlir
2023-02-10 10:59:46 +01:00
Peter Hawkins
54ff78dbde Deprecate jax.interpreters.xla.Device and jax.interpreters.xla.DeviceArray.
PiperOrigin-RevId: 508502470
2023-02-09 16:11:48 -08:00