15844 Commits

Author SHA1 Message Date
Jake VanderPlas
a5737f82af custom prng: remove stackable override for jnp.concatenate 2023-04-24 12:26:58 -07:00
jax authors
075bbe3203 Merge pull request #15706 from jakevdp:prngkey-asarray
PiperOrigin-RevId: 526667357
2023-04-24 09:31:52 -07:00
jax authors
25cd55fb24 Merge pull request #15704 from jakevdp:lax-numpy-builtins
PiperOrigin-RevId: 526661476
2023-04-24 09:18:14 -07:00
jax authors
1b2423a4a8 Merge pull request #15701 from jakevdp:promo-error
PiperOrigin-RevId: 526661327
2023-04-24 09:10:19 -07:00
Jake VanderPlas
7f2e724703 Make PRNGKeyArray compatible with jnp.array 2023-04-24 09:00:00 -07:00
Jake VanderPlas
e50138608a PRNGKeyArrayImpl: add aval property
This makes it more readily compatible with jax.numpy routines.
2023-04-24 08:59:14 -07:00
jax authors
035f585e43 Merge pull request #15707 from JiaYaobo:fix_random_docstring_math_domain
PiperOrigin-RevId: 526277662
2023-04-22 07:00:12 -07:00
Jake VanderPlas
39adec8eb5 internal: remove aliasing of builtins from lax_numpy 2023-04-22 06:55:25 -07:00
jiayaobo
30a4b7be04 fix pareto docstring x domain 2023-04-22 11:36:16 +08:00
jax authors
13fe3810d2 Merge pull request #15694 from mattjj:djax-reshape
PiperOrigin-RevId: 526194423
2023-04-21 19:42:27 -07:00
Matthew Johnson
84ae14e7d3 [djax] handle simple reshapes and size-0 checks
One of the main changes here is that we don't do division in handling
x.reshape(..., -1) unless we have to.
2023-04-21 19:20:48 -07:00
jax authors
520b751f42 Merge pull request #15702 from jakevdp:jnp-array-cleanup
PiperOrigin-RevId: 526140815
2023-04-21 14:45:15 -07:00
Peter Hawkins
23d1640eac Small cleanups to pxla.py.
Remove stale references to XlaComputation and code left over from handling both XlaComputations and ir.Modules.

No functional changes intended.

PiperOrigin-RevId: 526139679
2023-04-21 14:38:27 -07:00
jax authors
e5faab8783 Merge pull request #15700 from jakevdp:module-tests
PiperOrigin-RevId: 526137627
2023-04-21 14:31:11 -07:00
Jake VanderPlas
efdf3e0a51 jnp.array: internal cleanup related to DeviceArray removal 2023-04-21 13:43:33 -07:00
Jake VanderPlas
1c7f8efce6 Add test framework for module attribute 2023-04-21 13:20:16 -07:00
jax authors
da2057a544 Merge pull request #15699 from jakevdp:config-import
PiperOrigin-RevId: 526109752
2023-04-21 12:39:41 -07:00
Jake VanderPlas
5628676460 type promotion: better error for unrecognized types 2023-04-21 12:35:20 -07:00
jax authors
3010611101 Merge pull request #15698 from jakevdp:error-mod
PiperOrigin-RevId: 526101203
2023-04-21 12:05:42 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Jake VanderPlas
9b6714c9d7 jax.errors: set obj.__module__ to public module 2023-04-21 09:31:02 -07:00
James Keeling
5647d5db98 Fix typo: Arrgs -> Args
This was breaking docstring parsing in IDEs.

PiperOrigin-RevId: 526019643
2023-04-21 06:44:43 -07:00
Yash Katariya
3722d7066a Add jax_pmap_shmap_merge flag to begin the process of merging pmap and shard_map
After the changes in shard_map, there are 75 failures left to be resolved (not counting the EagerPmap tests).

TODO:
* Move shard_map to _src so that the circular import can be removed from api.py
PiperOrigin-RevId: 525930416
2023-04-20 21:22:48 -07:00
Yash Katariya
dcd9127598 Make CI nightly faster which tests numpy and scipy nightly builds.
PiperOrigin-RevId: 525817202
2023-04-20 12:22:32 -07:00
jax authors
19ae886bdd Merge pull request #15671 from jakevdp:delete-traced
PiperOrigin-RevId: 525806300
2023-04-20 11:42:12 -07:00
jax authors
59f33a4338 Expose JaxprDebugInfo so others can use it for pytyping.
PiperOrigin-RevId: 525749186
2023-04-20 08:09:26 -07:00
Peter Hawkins
1d63d9b833 Include the device_kind in the compilation cache key.
PiperOrigin-RevId: 525726898
2023-04-20 06:16:45 -07:00
Parker Schuh
87c328864b Improve testing for custom_partitioning.
Add a test to demonstrate how to force XLA to choose
a different sharding.

Also it is possible to return the wrong
shape from a partition function. We should error in this case.

PiperOrigin-RevId: 525606690
2023-04-19 18:26:51 -07:00
jax authors
975e76ef76 Merge pull request #15664 from skye:tpu_install
PiperOrigin-RevId: 525605301
2023-04-19 18:18:32 -07:00
jax authors
db2cbd4ae8 Merge pull request #15665 from hawkinsp:sourceinfo
PiperOrigin-RevId: 525581713
2023-04-19 16:30:23 -07:00
Peter Hawkins
34fd4a1562 Add version guard to compilation cache test.
PiperOrigin-RevId: 525572568
2023-04-19 15:50:33 -07:00
Jake Vanderplas
fb5664d580 Copybara import of the project:
--
1f0eaa0059321f0b9301012d3bae7921056b5c9d by Jake VanderPlas <jakevdp@google.com>:

Test: fix TPU tolerance for Beta test
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/15674 from jakevdp:beta-tpu-test 1f0eaa0059321f0b9301012d3bae7921056b5c9d
PiperOrigin-RevId: 525568586
2023-04-19 15:35:51 -07:00
jax authors
1de4d14da8 Merge pull request #15656 from laqua-stack:add-special-gamma-fcn
PiperOrigin-RevId: 525566749
2023-04-19 15:28:36 -07:00
Yash Katariya
53e6382f4a Add arg_names to aval mismatch error raised during AOT compilation to raise better error messages
PiperOrigin-RevId: 525561905
2023-04-19 15:08:53 -07:00
jax authors
968dbaf8f3 Merge pull request #15673 from jakevdp:fix-i0e-test
PiperOrigin-RevId: 525556536
2023-04-19 14:48:29 -07:00
Jake VanderPlas
1b0106fd1e Make i0e gradient test more robust 2023-04-19 14:41:44 -07:00
Parker Schuh
4750ce7c87 Add an experimental API that allows compiling AOT for TPUs.
PiperOrigin-RevId: 525536075
2023-04-19 13:33:59 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
jax authors
54c9205493 Merge pull request #15666 from jakevdp:tracer-faq
PiperOrigin-RevId: 525525709
2023-04-19 12:52:40 -07:00
Yash Katariya
0a19638490 Plumb debug_info to meshExecutable as a optional arg to raise better error messages.
PiperOrigin-RevId: 525521694
2023-04-19 12:35:49 -07:00
Jake VanderPlas
a86142b446 jnp.delete: add assume_unique_indices for JIT-compatibility 2023-04-19 12:33:59 -07:00
Jake VanderPlas
a083ba7853 DOC: explicitly mention io_callback in FAQ 2023-04-19 12:30:53 -07:00
Peter Hawkins
a3b262c379 Use the traceback of the call site when assigning a source location to an inlined function.
Improves but does not completely fix https://github.com/google/jax/issues/15663 . The non-inlined case still has similar problems.
2023-04-19 13:56:53 -04:00
Skye Wanderman-Milne
b917a31f56 Update TPU install on main docs page 2023-04-19 17:52:16 +00:00
jax authors
a2fbd59e63 Merge pull request #15662 from nouiz:ci
PiperOrigin-RevId: 525473560
2023-04-19 09:44:43 -07:00
Frederic Bastien
bc0c25c4b5 pytest* just got removed. The CI don't need them anymore, so remove that requirement. 2023-04-19 08:26:04 -07:00
jax authors
c844464888 Merge pull request #15658 from jakevdp:fix-xlogy-grad
PiperOrigin-RevId: 525287070
2023-04-18 16:45:45 -07:00
Jake VanderPlas
dd023e266e jax.scipy.special: fix gradient for xlogy & xlog1py 2023-04-18 15:56:32 -07:00
jax authors
933d695170 Merge pull request #15610 from jakevdp:array-methods
PiperOrigin-RevId: 525238498
2023-04-18 13:38:21 -07:00
Jake VanderPlas
72bb8ab753 jax.Array: dynamically define abstract methods 2023-04-18 13:08:32 -07:00