12939 Commits

Author SHA1 Message Date
jax authors
500da57e91 Merge pull request #21077 from merrymercy:patch-1
PiperOrigin-RevId: 631409738
2024-05-07 07:07:04 -07:00
Adam Paszke
326adc01a5 [Mosaic GPU] Adjust memref.expand_shape construction to pass in the new args
PiperOrigin-RevId: 631404097
2024-05-07 06:36:36 -07:00
jax authors
cb0c49850c Merge pull request #21081 from hawkinsp:sourcemap
PiperOrigin-RevId: 631236806
2024-05-06 17:33:12 -07:00
jax authors
4de346485d Fix that the insufficient output HBM buffer init would cause the <unk> token generated for quantized int8 model.
PiperOrigin-RevId: 631235764
2024-05-06 17:28:13 -07:00
jax authors
f6d88525a8 Merge pull request #20327 from selamw1:add_examples
PiperOrigin-RevId: 631186425
2024-05-06 14:30:06 -07:00
Selam Waktola
9caf59d68b improve documentation for ix_ 2024-05-06 13:43:55 -07:00
jax authors
3d3cb0bd2c Merge pull request #20842 from Micky774:array-api-default-promotion
PiperOrigin-RevId: 631168892
2024-05-06 13:39:56 -07:00
Peter Hawkins
d014f5dc5f Compute source maps when pretty-printing jaxprs.
This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose.

This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing.
The change also teaches the core jaxpr pretty printer to populate source map information on each equation.
2024-05-06 15:45:25 -04:00
Jake VanderPlas
4a363156b9 jnp.linalg tensorinv & tensorsolve: improve implementation & docs 2024-05-06 11:08:36 -07:00
jax authors
7e9ef1e4d2 Merge pull request #21078 from jakevdp:numpy-linalg-doc
PiperOrigin-RevId: 631112228
2024-05-06 10:44:42 -07:00
jax authors
fb65ba4adf Add a config for using Clang on Windows.
PiperOrigin-RevId: 631112031
2024-05-06 10:39:28 -07:00
Jake VanderPlas
40b2d4852e jnp.linalg: improve API documentation 2024-05-06 09:22:59 -07:00
Meekail Zain
34c5163fd2 Refactored common upcast for integral-type accumulators 2024-05-06 15:13:10 +00:00
Lianmin Zheng
0eed28a010
Fix a typo in jax.jit docstring 2024-05-06 04:59:23 -07:00
jax authors
7681493760 Don't create temp directory when module is getting imported.
PiperOrigin-RevId: 630958402
2024-05-06 00:58:45 -07:00
jax authors
1b804a7720 Merge pull request #21056 from mattjj:vmap-grad-remat-shmap-bug
PiperOrigin-RevId: 630555588
2024-05-03 19:06:46 -07:00
Matthew Johnson
7a87010f84 [shard_map] better fix for spmd_axis_name issues with shmap residuals
The fix in #21032 was not correct because it assumed that the set of all mesh
axis names appearing in in_specs was an upper bound on the set of mesh axes
over which residuals could be device-varying. But collectives can introduce
device variance! So it's not an upper bound.

We track device variance when check_rep=True, but often people set
check_rep=False (e.g. when using pallas_call in a shard_map). So relying on our
device variance tracking would be limiting. That may be a decent long term
solution, if we can make it easy to annotate pallas_calls with device variance
information. But it's not a great short term one to unblock things.

So instead I temporrarily went with context sensitivity: instead of making
residuals sharded over all mesh.axis_names (as we did before these patches), we
make them sharded over all mesh axis names _excluding_ any spmd_axis_names in
our dynamic context (by looking at the traces in our trace stack). It's illegal
to mention any spmd_axis_names in collectives (indeed anywhere in the body of
the function being vmapped), but I don't think we check it.

TODO(mattjj): add more testing (maybe in follow-ups)
2024-05-04 01:31:15 +00:00
Jake VanderPlas
e95173a4d3 Require arraylike input for several jax.numpy functions
PiperOrigin-RevId: 630532821
2024-05-03 16:55:10 -07:00
Jake VanderPlas
88318e60d2 jnp.delete: better docs 2024-05-03 14:41:06 -07:00
Jake VanderPlas
ff67e51e7e Remove last scipy imports 2024-05-03 10:20:05 -07:00
jax authors
c0cfc7ae9f Merge pull request #21032 from mattjj:vmap-grad-shmap-bug
PiperOrigin-RevId: 630375950
2024-05-03 06:47:22 -07:00
jax authors
d05d29d889 Merge pull request #21050 from rajasekharporeddy:test_branch3
PiperOrigin-RevId: 630339307
2024-05-03 03:24:22 -07:00
rajasekharporeddy
ccabdb29ea Fix typos in docs and an error message 2024-05-03 12:08:22 +05:30
jax authors
3952b7e78a Merge pull request #21042 from jakevdp:scipy-docs
PiperOrigin-RevId: 630212828
2024-05-02 16:36:46 -07:00
jax authors
653ee95193 Merge pull request #21043 from jakevdp:get-window
PiperOrigin-RevId: 630212815
2024-05-02 16:32:07 -07:00
Matthew Johnson
e4c76c97e2 [omnitracing] partially un-regress dispatch time 2024-05-02 22:18:36 +00:00
Matthew Johnson
b7b264571b fix vmap-of-grad-of-shmap axis name reuse bug
When we write `vmap(f, spmd_axis_name=A)`, we require that `f` does not mention
A in specs, like the `PartitionSpec` in a `with_sharding_constraint` or the
`in_specs`/`out_specs` of `shard_map`. Previously, shard_map autodiff violated
that requirement, since we gave residuals sharding over all mesh axes (i.e.
including axis name A present in the mesh). As a result, the vmap rule could
then insert a redundant appearance of A.

This commit fixes the problem by only sharding over mesh axes mentioned in
in_specs; residuals can at most be sharded over those mesh axes. Then the vmap
rule is free to introduce an occurrence of A in the specs.
2024-05-02 21:53:46 +00:00
Jake VanderPlas
27c0c41b44 Remove remaining top-level scipy imports 2024-05-02 14:37:44 -07:00
Sergei Lebedev
03b733bda7 Made has_side_effect= parameter of mlir.emit_python_callback keyword-only
This ensures that the call site always has parameter name and not just
a bare True/False argument.

PiperOrigin-RevId: 630166542
2024-05-02 13:44:54 -07:00
jax authors
187b2ac9a2 Merge pull request #21013 from Micky774:array-api-trim
PiperOrigin-RevId: 630146636
2024-05-02 12:38:04 -07:00
jax authors
d5983e13b2 Merge pull request #21038 from rajasekharporeddy:test_branch1
PiperOrigin-RevId: 630138895
2024-05-02 12:08:22 -07:00
jax authors
70f2ef211f Merge pull request #20971 from google:mutable-array-scan
PiperOrigin-RevId: 630130893
2024-05-02 11:40:54 -07:00
Dougal
e63b35d550 Add discharge rules for scan with mutable arrays. Move mutable array tests to separate file.
Co-authored-by: Matt Johnson <mattjj@google.com>
2024-05-02 14:36:16 -04:00
Meekail Zain
b88e2e808b Refactor array_api namespace, relying more directly on jax.numpy 2024-05-02 18:17:45 +00:00
jax authors
0335487c25 [Pallas TPU] Increase clarity of dot 2D shape enforcement error.
Currently only 2D shapes are supported in dot() lowering; the exception, however, gives a poor understanding of the problem.
The raised exception lists the associated shapes, but without knowing the 2D limitations, it provides little direction to the user on how to remedy the problem.

This change converts the raised exception to read something like:
`Exception: Only 2D tensors supported in dot; received: [ShapedArray(float32[128,128]), ShapedArray(float32[128])]`

rather than:
`Exception: [ShapedArray(float32[128,128]), ShapedArray(float32[128])]`

PiperOrigin-RevId: 630116468
2024-05-02 10:59:35 -07:00
Jake VanderPlas
18e4cfa911 DOC: Improve remaining jax.scipy docstrings 2024-05-02 10:29:00 -07:00
Adam Paszke
8692355220 [Mosaic] Add support for remote DMAs and semaphores in megacore mode
The change to tpu.td is not backwards compatible, but I made it so using the
newly added Mosaic stability layer. It's been a good exercise and it seems to
be working just fine.

Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 630060418
2024-05-02 07:43:36 -07:00
jax authors
582c56a707 Change determination of cloud TPU to check for TPU chips.
This is useful in the case of ahead of time compilation, when libtpu is present but there may not be any TPU chips, so we shouldn't attempt to initialize a TPU backend.

PiperOrigin-RevId: 630055511
2024-05-02 07:22:56 -07:00
George Necula
b40a31006c [export] Add backwards compatibility test for Pallas call on GPUs.
Note that this adds the minimum of safety net to protect against
non-backwards-compatible changes. We really should have more tests
that cover more of the Triton MLIR.

Also enable serialization of such calls.

PiperOrigin-RevId: 630033989
2024-05-02 05:38:33 -07:00
rajasekharporeddy
a0b93153ca Fix Typos and math rendering in jax.random docs 2024-05-02 17:43:37 +05:30
Mohammed Anany
2730cf38c1 Integrate Triton up to [8e0c7b42](8e0c7b425a)
PiperOrigin-RevId: 629987238
2024-05-02 01:42:43 -07:00
jax authors
d02ec9f90d Merge pull request #21018 from jakevdp:scipy-linalg-docs
PiperOrigin-RevId: 629893239
2024-05-01 17:57:29 -07:00
jax authors
0a3f8b71cf Merge pull request #20992 from jakevdp:scipy-signal-doc
PiperOrigin-RevId: 629890150
2024-05-01 17:42:44 -07:00
Jake VanderPlas
d51ccdf628 DOC: Improve docstrings for jax.scipy.linalg 2024-05-01 17:36:24 -07:00
jax authors
57bfe81260 Allow multiple indexers when doing discharge or swap in pallas
PiperOrigin-RevId: 629847808
2024-05-01 14:58:27 -07:00
Daniel Ng
77988ead94 Move dtype settings out of metadata field into the root of Tensorstore spec
Before, dtype used to be in the metadata field of tensorstore spec because of it was the legacy way to config the dtype.  This setting doesn't understand the "str" name, hence, there was special logic to translate bfloat for example.

This CL moves it out of the metadata field and put the dtype directly into the Tensorstore spec to eliminate special dtype translation logic.  This will also add support of other quantized types such as int4.

PiperOrigin-RevId: 629845048
2024-05-01 14:48:55 -07:00
Jake VanderPlas
18703d9385 DOC: Improve docstrings for jax.scipy.signal 2024-05-01 14:26:00 -07:00
jax authors
0ce05af8cb Merge pull request #20937 from jakevdp:fix-pure-callback
PiperOrigin-RevId: 629821826
2024-05-01 13:39:25 -07:00
jax authors
8626b9e171 Merge pull request #20847 from jakevdp:scipy-stats-doc
PiperOrigin-RevId: 629821754
2024-05-01 13:34:12 -07:00
jax authors
58b8e21127 [Pallas TPU] Print the exception when a lowering exception occurs.
Add the exception to the formatted string which is being re-raised, so that we present the problem more clearly.
This makes debugging significantly easier -- for exceptions like "Unimplemented primitive in Pallas TPU lowering: sign", such text currently does not appear in the error output.

PiperOrigin-RevId: 629767145
2024-05-01 10:44:01 -07:00