321 Commits

Author SHA1 Message Date
Peter Hawkins
00d45feee6 Deprecate uses of jax.experimental.pjit.NamedSharding and jax.experimental.pjit.PartitionSpec.
Use the aliases under jax.sharding instead.

PiperOrigin-RevId: 509837529
2023-02-15 08:14:26 -08:00
Jake VanderPlas
dafb88a649 jax.numpy reductions: require initial to be a scalar
This follows the requirements of numpy's reduction API. Non-scalar initial values
can be implemented via .
2023-02-14 15:36:18 -08:00
Jake VanderPlas
7975192f92 Expose jax.typing & update docs 2023-02-13 15:53:08 -08:00
Yash Katariya
2fc64bee13 Change the axis_resources argument of with_sharding_constraint to shardings to match pjit and jit.
PiperOrigin-RevId: 509275107
2023-02-13 10:53:57 -08:00
Peter Hawkins
ec56d71d01 Drop support for NVIDIA Kepler series GPUs in jaxlib builds. 2023-02-10 14:15:15 -05:00
Yash Katariya
6ec9082cf5 Default jax_jit_pjit_api_merge to True. This means that the implementation of jit and pjit have been merged but they still remain separate APIs due to the semantic difference of how they behave under the Mesh context manager.
This changes the internals of JAX without affecting any public API.

Before, `jit` was a final style primitive. This means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).

Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.

PiperOrigin-RevId: 508143501
2023-02-08 11:55:48 -08:00
Skye Wanderman-Milne
21f12183bf Post 0.4.3 release updates 2023-02-08 10:08:59 -08:00
carlosgmartin
8251957025 Added scipy.stats.rankdata 2023-02-07 12:07:00 -05:00
jax authors
e4c18b89fe Merge pull request #14169 from tttc3:polar_unitary
PiperOrigin-RevId: 505143284
2023-01-27 09:57:29 -08:00
jax authors
5234091a43 Merge pull request #14108 from gnecula:poly_div
PiperOrigin-RevId: 505137089
2023-01-27 09:29:07 -08:00
tttc3
96707f09b1 Removed deprecated polar_unitary as per comment. 2023-01-27 07:24:55 +00:00
Skye Wanderman-Milne
b6a8aa6394 Update versions for jaxlib 0.4.2 release.
I also screwed up the CHANGELOG before (I shouldn't have added a
date), so I'm fixing the dates now.
2023-01-26 01:12:06 +00:00
George Necula
d25bcac93d [shape_poly] Add better support for division, and working with strides
Previously, division was only supported in certain situation, and this
led to errors, e.g., when using strides. Now we generalize the polynomials
to also include "floordiv(E, E)" and "mod(E, E)" as atoms, in addition
to dimension variables. A symbolic dimension is now a sum of products
of atoms. (We also changed the documentation to use symbolic dimension
instead of dimension polynomials).
2023-01-25 07:37:54 -08:00
Jake VanderPlas
a0eae5709f Raise an error when attempting to mutate Jaxpr objects 2023-01-23 09:37:58 -08:00
George Necula
8e931a82ff [shape_poly] Generalize binary operations with symbolic dimensions.
Previously binary operations involving symbolic dimensions would
work only when the other operand is convertible to a symbolic dimension,
e.g., an integer. This resulted in errors when trying "x.shape[0] * 3.5"
and the recourse was to ask the user to add an explicit
"jnp.array(x.shape[0])".

Now we allow binary operations with any operand and the
"jnp.array" is added automatically if the other operand is not
an integer or a symbolic dimension. This means that instead
of an error they may be an error downstream if one tries to use
the result as a dimension. There is one known case where
JAX works with static shapes and with the previous behavior,
but will fail now. When you operate on `np.ndarray` and
symbolic dimension, previously this was kept as a `np.ndarray`
but not it is turned into a JAX array. The following
program will now fail if `x.shape[0]` is a symbolic dimension.:

`jnp.ones(np.arange(5) * x.shape[0])`

Instead you should write

`jnp.ones([i * x.shape[0] for i in range(5)])`
2023-01-21 04:26:59 -08:00
Skye Wanderman-Milne
3f4bd5f449 Updates for jax + jaxlib 0.4.2 release 2023-01-20 19:04:46 +00:00
George Necula
ade5691630 [call_tf] Add has_side_effects parameter
The CallTfEffect was added recently as an internal workaround for
DCE removing instances of call_tf. Here we add a parameter to
`call_tf` to be able to declare if the called computation is
effectful and should not be removed by DCE.
2023-01-14 08:12:29 +01:00
Sharad Vikram
c9a57e1b44 Delete jax.experimental.callback
PiperOrigin-RevId: 501760507
2023-01-12 22:58:31 -08:00
Yash Katariya
835d0c979a Finish jax and jaxlib 0.4.2 release
PiperOrigin-RevId: 495068000
2022-12-13 10:51:13 -08:00
Yash Katariya
dc8ead04b4 Update CHANGELOG to indicate that 0.4.0 was yanked.
PiperOrigin-RevId: 494868884
2022-12-12 17:02:24 -08:00
Yash Katariya
0bdb7ec042 Finish jax and jaxlib release 0.4.0
PiperOrigin-RevId: 494833878
2022-12-12 14:43:35 -08:00
Yash Katariya
13c34f9dc5 Move with_sharding_constraint out of experimental into jax.lax namespace.
PiperOrigin-RevId: 494635809
2022-12-11 22:55:21 -08:00
Matthew Johnson
1185c895ca in jax.Array notebook, polish beginning and tweak title and some wording 2022-12-10 22:16:54 -08:00
Jake VanderPlas
09d1b6d8d5 Deprecate jnp.msort following deprecation of numpy.msort 2022-12-07 10:08:18 -08:00
Peter Hawkins
33a1b8866a Mark arguments to ufuncs as positional-only.
PiperOrigin-RevId: 493311821
2022-12-06 08:24:11 -08:00
Peter Hawkins
c2c3669c15 Remove long-deprecated method .block_host_until_ready().
PiperOrigin-RevId: 492571809
2022-12-02 15:18:11 -08:00
Peter Hawkins
f9b5312149 Do not mirror JAX config options back to ABSL flags.
Currently when JAX config values are configured via ABSL, we use the ABSL flags as a source of truth: if we read or write the JAX config option, we read or write the corresponding ABSL flag. This works but has the unfortunate downside that ABSL flags are relatively slow to read, which slows down JAX every time we read a configuration option.

However, there's fundamentally no reason we are mirroring the JAX configuration options back to ABSL in the first place. We can use ABSL flag parsing as a way only to populate the JAX configuration values. The downside is that if someone changes the ABSL flag values after parsing, that change will not be reflected in JAX's config values. JAX config changes after ABSL flags have been parsed must be made via the `jax.config.update()` API.

This gives a decent improvement on the device_put benchmark:

```
name        old cpu/op  new cpu/op  delta
device_put  79.5µs ± 6%  69.4µs ± 7%  -12.73%  (p=0.000 n=10+9)

name        old time/op             new time/op             delta
device_put  79.5µs ± 6%             69.4µs ± 7%  -12.73%         (p=0.000 n=10+9)
```

PiperOrigin-RevId: 492519085
2022-12-02 11:37:22 -08:00
Yash Katariya
934bc4e1b3 Move PartitionSpec and Mesh out of experimental and into the sharding namespace. The new API endpoint is jax.sharding.PartitionSpec and jax.sharding.Mesh.
PiperOrigin-RevId: 492358238
2022-12-01 19:28:32 -08:00
Skye Wanderman-Milne
82b442fa52 [docs] Replace one more jax_Array.html reference
I missed this in #13479, thanks @yashk2810 for flagging!
2022-12-01 20:50:55 +00:00
George Necula
4ca05f428f [call_tf] Use the same platform for TF lowering as the embedding JAX computation
This requires some changes for abstract evaluation, when
JAX does not use a specific platform.

Also attempt to fix the case when the TF lowering fails because the TF computation
uses a tf.Variable on another device as that used for lowering.

PiperOrigin-RevId: 492112847
2022-11-30 23:22:24 -08:00
Jake VanderPlas
cb62a31653 Drop support for Python 3.7 2022-11-29 15:01:47 -08:00
TJ
5fb0215d4d updated jaxlib CHANGELOG 2022-11-28 10:37:42 -08:00
Yash Katariya
9799d5b139 Add the jax.Array change to the changelog.
PiperOrigin-RevId: 488929264
2022-11-16 06:56:09 -08:00
Yash Katariya
8c42edfec1 Finish jax and jaxlib release 0.3.25. The next release will be 0.4.0 (since jax.Array will be enabled in that release)
PiperOrigin-RevId: 488672395
2022-11-15 09:02:53 -08:00
Peter Hawkins
ebd9840e1f Add several recent changes to the CHANGELOG.
PiperOrigin-RevId: 488362198
2022-11-14 07:39:13 -08:00
Sharad Vikram
e15619ceab Convert string axis name into tuple of strings in Mesh constructor
PiperOrigin-RevId: 487930412
2022-11-11 15:27:51 -08:00
Sharad Vikram
4bdfdd7363 Update changelog w/ info about deleting jax_experimental_name_stack 2022-11-11 14:02:30 -08:00
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Peter Hawkins
ab8cde9ed4 Add support for the hermitian option on jnp.linalg.pinv.
Improve the pinv implementation to avoid computing an unnecessary reduction: svd sorts its singular values so we don't need to use amax() to find the largest one.
Avoid explicitly forming the identity matrix in the pinv JVP.
2022-11-08 08:53:00 -05:00
Yash Katariya
1d48c93b0e Finish the release of jax and jaxlib 0.3.24
PiperOrigin-RevId: 486162090
2022-11-04 09:43:12 -07:00
Yash Katariya
cc5af7ed98 Rename ReshapeableDevicesSharding to PositionalSharding and add an alias NamedSharding for MeshPspecSharding.
`MeshPspecSharding` name will be replaced with `NamedSharding` in 3 months.

PiperOrigin-RevId: 485753078
2022-11-02 19:13:13 -07:00
jax authors
ef63f75e39 Merge pull request #13039 from skye:cache_compile_time_heuristic
PiperOrigin-RevId: 485644419
2022-11-02 11:13:52 -07:00
Skye Wanderman-Milne
cc5171034f Add new config jax_persistent_cache_min_compile_time_secs.
This replaces `jax_persistent_cache_min_instruction_count` introduced
in https://github.com/google/jax/pull/12798, since gating on the
compile time seems strictly better than gating on the instruction
count (except maybe that the instruction count is more deterministic,
but I don't think that's a big deal).

I defaulted to 1 second as the minimum threshold based on the same
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt) numbers from

name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344

Also adds new float config functionality.
2022-11-02 00:56:19 +00:00
Jake VanderPlas
2416d15435 Call _check_arraylike for jnp.linalg & jnp.fft functions 2022-10-31 09:19:53 -07:00
Peter Hawkins
bf21391248 [JAX] Change the default pmap() ordering to match the ordering of jax.devices() for single-process TPU jobs.
PiperOrigin-RevId: 484062717
2022-10-26 13:56:07 -07:00
Peter Hawkins
ce9e009c4c [JAX:CPU] Enable buffer donation on CPU.
Fix a bug in PJRT where if a buffer was not owned (e.g., it aliased a NumPy buffer) it could still be donated and that would lead to a use after free.

PiperOrigin-RevId: 484001545
2022-10-26 10:13:01 -07:00
Jake VanderPlas
2009e65a33 jnp.gradient: call check_arraylike on inputs & clean-up implementation 2022-10-24 15:27:33 -07:00
Jake VanderPlas
4aceb81570 Add docs & changelog for jax.scipy.stats.mode 2022-10-20 15:55:57 -07:00
Skye Wanderman-Milne
81eb3fca55 Add new config jax_persistent_cache_min_instruction_count.
This can be used to limit the number of entries written to the
persistent compilation cache.

I defaulted to setting 6 as the minimum threshold based on running the
flax wmt example
(https://github.com/google/flax/tree/main/examples/wmt) and logging
the instruction counts and complilation time:

name | instruction_count | compile_time_secs
---- | ----------------- | -----------------
`broadcast_in_dim` | 2 | 0.01633763313
`convert_element_type` | 2 | 0.01704716682
`reshape` | 2 | 0.01730203629
`_squareit` | 2 | 0.01730823517
`broadcast_in_dim` | 2 | 0.0182030201
`convert_element_type` | 2 | 0.01982188225
`concatenate` | 2 | 0.02102327347
`true_divide` | 2 | 0.02172231674
`broadcast_in_dim` | 2 | 0.02370619774
`broadcast_in_dim` | 2 | 0.02393102646
`broadcast_in_dim` | 2 | 0.02488565445
`broadcast_in_dim` | 2 | 0.03395628929
`broadcast_in_dim` | 2 | 0.03428125381
`broadcast_in_dim` | 2 | 0.0394551754
`shift_right_logical` | 2 | 0.06500506401
`<lambda>` | 3 | 0.01793265343
`_unstack` | 5 | 0.01975226402
`_reduce_sum` | 5 | 0.0210878849
`_reduce_sum` | 5 | 0.02416801453
`_multi_slice` | 9 | 0.09065580368
`_threefry_split` | 232 | 0.09037566185
`_threefry_split` | 232 | 0.09161829948
`<unnamed wrapped function>` | 2668 | 7.701903343
`<unnamed wrapped function>` | 3455 | 17.57672167
`<unnamed wrapped function>` | 46580 | 166.2570884
`init` | 60361 | 26.35722399
`<unnamed wrapped function>` | 78010 | 3.879326344

Also adds new int config functionality.

Fixes #12583
2022-10-20 00:17:24 +00:00
Peter Hawkins
9ab88071a7 Avoid loading scipy eagerly.
scipy accounts for around 400ms of the 900ms of JAX's import time. By
loading scipy lazily, we can improve the timing of `import jax` down to
about 500ms.
2022-10-12 19:51:09 +00:00