15817 Commits

Author SHA1 Message Date
Peter Hawkins
1d63d9b833 Include the device_kind in the compilation cache key.
PiperOrigin-RevId: 525726898
2023-04-20 06:16:45 -07:00
Parker Schuh
87c328864b Improve testing for custom_partitioning.
Add a test to demonstrate how to force XLA to choose
a different sharding.

Also it is possible to return the wrong
shape from a partition function. We should error in this case.

PiperOrigin-RevId: 525606690
2023-04-19 18:26:51 -07:00
jax authors
975e76ef76 Merge pull request #15664 from skye:tpu_install
PiperOrigin-RevId: 525605301
2023-04-19 18:18:32 -07:00
jax authors
db2cbd4ae8 Merge pull request #15665 from hawkinsp:sourceinfo
PiperOrigin-RevId: 525581713
2023-04-19 16:30:23 -07:00
Peter Hawkins
34fd4a1562 Add version guard to compilation cache test.
PiperOrigin-RevId: 525572568
2023-04-19 15:50:33 -07:00
Jake Vanderplas
fb5664d580 Copybara import of the project:
--
1f0eaa0059321f0b9301012d3bae7921056b5c9d by Jake VanderPlas <jakevdp@google.com>:

Test: fix TPU tolerance for Beta test
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/15674 from jakevdp:beta-tpu-test 1f0eaa0059321f0b9301012d3bae7921056b5c9d
PiperOrigin-RevId: 525568586
2023-04-19 15:35:51 -07:00
jax authors
1de4d14da8 Merge pull request #15656 from laqua-stack:add-special-gamma-fcn
PiperOrigin-RevId: 525566749
2023-04-19 15:28:36 -07:00
Yash Katariya
53e6382f4a Add arg_names to aval mismatch error raised during AOT compilation to raise better error messages
PiperOrigin-RevId: 525561905
2023-04-19 15:08:53 -07:00
jax authors
968dbaf8f3 Merge pull request #15673 from jakevdp:fix-i0e-test
PiperOrigin-RevId: 525556536
2023-04-19 14:48:29 -07:00
Jake VanderPlas
1b0106fd1e Make i0e gradient test more robust 2023-04-19 14:41:44 -07:00
Parker Schuh
4750ce7c87 Add an experimental API that allows compiling AOT for TPUs.
PiperOrigin-RevId: 525536075
2023-04-19 13:33:59 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
jax authors
54c9205493 Merge pull request #15666 from jakevdp:tracer-faq
PiperOrigin-RevId: 525525709
2023-04-19 12:52:40 -07:00
Yash Katariya
0a19638490 Plumb debug_info to meshExecutable as a optional arg to raise better error messages.
PiperOrigin-RevId: 525521694
2023-04-19 12:35:49 -07:00
Jake VanderPlas
a083ba7853 DOC: explicitly mention io_callback in FAQ 2023-04-19 12:30:53 -07:00
Peter Hawkins
a3b262c379 Use the traceback of the call site when assigning a source location to an inlined function.
Improves but does not completely fix https://github.com/google/jax/issues/15663 . The non-inlined case still has similar problems.
2023-04-19 13:56:53 -04:00
Skye Wanderman-Milne
b917a31f56 Update TPU install on main docs page 2023-04-19 17:52:16 +00:00
jax authors
a2fbd59e63 Merge pull request #15662 from nouiz:ci
PiperOrigin-RevId: 525473560
2023-04-19 09:44:43 -07:00
Frederic Bastien
bc0c25c4b5 pytest* just got removed. The CI don't need them anymore, so remove that requirement. 2023-04-19 08:26:04 -07:00
jax authors
c844464888 Merge pull request #15658 from jakevdp:fix-xlogy-grad
PiperOrigin-RevId: 525287070
2023-04-18 16:45:45 -07:00
Jake VanderPlas
dd023e266e jax.scipy.special: fix gradient for xlogy & xlog1py 2023-04-18 15:56:32 -07:00
jax authors
933d695170 Merge pull request #15610 from jakevdp:array-methods
PiperOrigin-RevId: 525238498
2023-04-18 13:38:21 -07:00
Jake VanderPlas
72bb8ab753 jax.Array: dynamically define abstract methods 2023-04-18 13:08:32 -07:00
laqua-stack
d742733bea feat (scipy.special): Add a xla version of scipy.special.gamma function
- Add gamma fcn api in scipy.special
- Add tests for this purpose
- Add function to the docs

Currently, there is no implementation of the gamma function in jax
but there is one in scipy.special. This breaks some higher level
jit-compilation like in the blackjax backend for pymc. This commit
adds the missing gamma function.

Resolves: #15409
2023-04-18 21:10:22 +02:00
jax authors
7cf86d1577 Merge pull request #15639 from jakevdp:checkify-dynamic
PiperOrigin-RevId: 525153864
2023-04-18 08:28:33 -07:00
Peter Hawkins
a377caec3a Import jax.experimental.compilation_cache.compilation_cache by default.
This is to fix users who were relying on this module being imported as part of 'import jax'.

PiperOrigin-RevId: 525151996
2023-04-18 08:19:45 -07:00
Jean-Baptiste Lespiau
6ca249da78 Improve the error message.
PiperOrigin-RevId: 525138471
2023-04-18 07:16:00 -07:00
Peter Hawkins
f8fe5d0542 Import jax.experimental.compilation_cache by default
PiperOrigin-RevId: 525033643
2023-04-17 21:28:35 -07:00
jax authors
98b6fe4676 Optimize sharding check inside jax_export._check_module
Formatting ops/attributes into str could be expensive. Instead, this uses a proper MLIR API to access `StringAttr` without printers.

PiperOrigin-RevId: 524999042
2023-04-17 17:54:26 -07:00
Yash Katariya
6218a6cf1d Don't loop to create replicated indices tuple consisting of slice(None). Use multiplication instead.
PiperOrigin-RevId: 524992811
2023-04-17 17:22:16 -07:00
jax authors
278d642161 Merge pull request #15643 from jakevdp:fix-apply
PiperOrigin-RevId: 524971702
2023-04-17 15:54:41 -07:00
Yash Katariya
75cf3d96d5 Try to preserve shardings with vmap(pjit) by converting the GSPMDShardings to original sharding type via the pxla.py helper
PiperOrigin-RevId: 524966654
2023-04-17 15:32:57 -07:00
Jake VanderPlas
01045b3b42 BUG: fix x.at[i].apply() with non-unit slice sizes 2023-04-17 15:27:03 -07:00
Lena Martens
1277f284ce Checkify: print more stack frames of the location where the error originated.
Can be customized with `jax_tracer_error_num_traceback_frames`.

PiperOrigin-RevId: 524956588
2023-04-17 14:51:57 -07:00
Yash Katariya
38c7939fc0 Cache the entire to_gspmd_sharding function to maximize cache hits even for GSPMDShardings
PiperOrigin-RevId: 524951951
2023-04-17 14:33:21 -07:00
Jake VanderPlas
8d1cf99825 checkify: dynamic_update_slice OOB index check 2023-04-17 13:43:26 -07:00
jax authors
cabf8b7302 Merge pull request #15636 from Vaishaal:idct
PiperOrigin-RevId: 524936117
2023-04-17 13:35:34 -07:00
jax authors
c4c256eef7 Merge pull request #15377 from jakevdp:gather-slice
PiperOrigin-RevId: 524920532
2023-04-17 12:37:51 -07:00
Vaishaal Shankar
add15aca25 implement idct and idctn + add function to scipy.rst 2023-04-17 12:12:51 -07:00
Parker Schuh
97c70f2171 initiate an experimental topologies module
Start off with two functions: one for retrieving the attached topology, and the other for producing a mesh from the topology (modeling how `mesh_utils` might be adapted).

Use as:
```
    topo = jax.topologies.get_attached_topology() // Discovers local devices.
    mesh = jax.topologies.make_mesh(topo, mesh_shape, axis_names) # see mesh_utils.create_device_mesh.
```

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 524909149
2023-04-17 11:54:33 -07:00
jax authors
b035a2b61b Merge pull request #15635 from jakevdp:mypy-ml-dtypes
PiperOrigin-RevId: 524907301
2023-04-17 11:47:24 -07:00
Jake VanderPlas
5310562250 mypy: type-check ml_dtypes 2023-04-17 11:35:59 -07:00
jax authors
be802788ac Merge pull request #15634 from jakevdp:fix-mypy
PiperOrigin-RevId: 524902802
2023-04-17 11:31:24 -07:00
Jake VanderPlas
4f0edc08a3 [typing] ignore zstandard in mypy 2023-04-17 11:22:08 -07:00
jax authors
9c92ec5d90 Merge pull request #15585 from SauravMaheshkar:main
PiperOrigin-RevId: 524896086
2023-04-17 11:08:46 -07:00
Jake VanderPlas
dca23d4d8f jax.numpy indexing: lower to dynamic_slice for more cases 2023-04-17 11:07:18 -07:00
Yash Katariya
bacb12de2b Don't explode into individual shards if an Array is fully replicated and addressable_data is called.
We can just extract 1 pjrt_buffer() and convert it to an Array with SingleDeviceSharding.

PiperOrigin-RevId: 524877300
2023-04-17 10:05:37 -07:00
jax authors
4d86c96b4e Merge pull request #15630 from nouiz:multi-node-nightly
PiperOrigin-RevId: 524871113
2023-04-17 09:42:49 -07:00
Peter Hawkins
112b317df3 Compress serialized executables in the compilation cache.
If the 'zstandard' package is installed, we use it. Otherwise we use 'zlib', which is always available.

PiperOrigin-RevId: 524862738
2023-04-17 09:09:38 -07:00
Peter Hawkins
017548c40b Move implementation of compilation cache out of jax/experimental and into jax/_src.
Use a Protocol instead of an abstract base class for the CacheInterface since it allows us to use one fewer file.

No functional change intended.

PiperOrigin-RevId: 524855263
2023-04-17 08:35:53 -07:00