334 Commits

Author SHA1 Message Date
Yash Katariya
c2d5527f72 [Jax cleanup]
* Remove lower_xla_callable and all related functions
* Remove pxla.device_put
* Remove dispatch.device_put_handlers

PiperOrigin-RevId: 517249345
2023-03-16 15:47:28 -07:00
Peter Hawkins
28e4038933 Mark jax.numpy.DeviceArray as deprecated. Use jax.Array instead.
PiperOrigin-RevId: 516835920
2023-03-15 08:50:00 -07:00
Jake VanderPlas
6dd0e0153a jnp.ndarray.at: deprecate passing additional arguments by position 2023-03-13 10:04:39 -07:00
Skye Wanderman-Milne
6560bf8c36 Update versions and changelog for jax + jaxlib 0.4.6 release 2023-03-09 14:50:42 -08:00
George Necula
961e09e614 [shape_poly, call_tf] Some improvements for call_tf in a shape polymorphic program
This is another attempt to land a rolled-back change https://github.com/google/jax/pull/14734 (cl/514070997).
See b/272154366 for more details.

The use case for call_tf with shape polymorphism is when we have a JAX program
that calls into TF function, and we want to serialize the JAX program with
some shapes unknown. Previously this use case did not work, except in the special
case when the output shape of the called TF function returns statically known
shapes.

The idea is that we allow the user of call_tf to specify the output shape.
This can be done even in presence of shape polymorphism, by writing the
output shape as an expression in terms of the input shapes. This is what
other JAX primitives do, e.g., concat, so we are simply enabling call_tf
to get the same behavior.

This change should be enough for old-style jax2tf, but will require more
work for native serialization.

We also removed some old code that was trying to workaround some limitations
in shape inference in TF. I think that those workarounds are ugly, and I am
prepared to give error messages rather than keep that code. So far no
tests fail.

PiperOrigin-RevId: 515137407
2023-03-08 14:10:08 -08:00
Ivy Zheng
46838a4116
fix typo 2023-03-05 13:14:57 -08:00
Ivy Zheng
dfe940de8d
add key path change to changelog 2023-03-03 18:05:37 -08:00
Skye Wanderman-Milne
ed2c5717c5 Bump version and changelog after jax 0.4.5 release 2023-03-02 16:08:34 -08:00
Jake VanderPlas
a283aa0cc3 Deprecate three jax.Array methods:
- jax.Array.broadcast: use lax.broadcast instead
- jax.Array.broadcast_in_dim: use lax.broadcast_in_dim instead
- jax.Array.split: use jnp.split instead
These are removed because they are not part of the np.ndarray API.
2023-02-23 16:15:09 -08:00
Jake VanderPlas
841bdcef5f DOC: add is_ready() to CHANGELOG 2023-02-23 11:56:48 -08:00
Yash Katariya
0ffdeb3de2 Rename jax.sharding.OpShardingSharding to jax.sharding.GSPMDSharding. jax.sharding.OpShardingSharding will be removed in 3 months from Feb 17, 2023.
PiperOrigin-RevId: 510556189
2023-02-17 17:11:06 -08:00
Yash Katariya
941722f7db Finish jax and jaxlib 0.4.4 release
PiperOrigin-RevId: 510234171
2023-02-16 13:54:56 -08:00
Peter Hawkins
b389eed8bf [JAX] Deprecate jax.experimental.maps.Mesh.
PiperOrigin-RevId: 509852142
2023-02-15 09:15:50 -08:00
Peter Hawkins
00d45feee6 Deprecate uses of jax.experimental.pjit.NamedSharding and jax.experimental.pjit.PartitionSpec.
Use the aliases under jax.sharding instead.

PiperOrigin-RevId: 509837529
2023-02-15 08:14:26 -08:00
Jake VanderPlas
dafb88a649 jax.numpy reductions: require initial to be a scalar
This follows the requirements of numpy's reduction API. Non-scalar initial values
can be implemented via .
2023-02-14 15:36:18 -08:00
Jake VanderPlas
7975192f92 Expose jax.typing & update docs 2023-02-13 15:53:08 -08:00
Yash Katariya
2fc64bee13 Change the axis_resources argument of with_sharding_constraint to shardings to match pjit and jit.
PiperOrigin-RevId: 509275107
2023-02-13 10:53:57 -08:00
Peter Hawkins
ec56d71d01 Drop support for NVIDIA Kepler series GPUs in jaxlib builds. 2023-02-10 14:15:15 -05:00
Yash Katariya
6ec9082cf5 Default jax_jit_pjit_api_merge to True. This means that the implementation of jit and pjit have been merged but they still remain separate APIs due to the semantic difference of how they behave under the Mesh context manager.
This changes the internals of JAX without affecting any public API.

Before, `jit` was a final style primitive. This means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).

Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.

PiperOrigin-RevId: 508143501
2023-02-08 11:55:48 -08:00
Skye Wanderman-Milne
21f12183bf Post 0.4.3 release updates 2023-02-08 10:08:59 -08:00
carlosgmartin
8251957025 Added scipy.stats.rankdata 2023-02-07 12:07:00 -05:00
jax authors
e4c18b89fe Merge pull request #14169 from tttc3:polar_unitary
PiperOrigin-RevId: 505143284
2023-01-27 09:57:29 -08:00
jax authors
5234091a43 Merge pull request #14108 from gnecula:poly_div
PiperOrigin-RevId: 505137089
2023-01-27 09:29:07 -08:00
tttc3
96707f09b1 Removed deprecated polar_unitary as per comment. 2023-01-27 07:24:55 +00:00
Skye Wanderman-Milne
b6a8aa6394 Update versions for jaxlib 0.4.2 release.
I also screwed up the CHANGELOG before (I shouldn't have added a
date), so I'm fixing the dates now.
2023-01-26 01:12:06 +00:00
George Necula
d25bcac93d [shape_poly] Add better support for division, and working with strides
Previously, division was only supported in certain situation, and this
led to errors, e.g., when using strides. Now we generalize the polynomials
to also include "floordiv(E, E)" and "mod(E, E)" as atoms, in addition
to dimension variables. A symbolic dimension is now a sum of products
of atoms. (We also changed the documentation to use symbolic dimension
instead of dimension polynomials).
2023-01-25 07:37:54 -08:00
Jake VanderPlas
a0eae5709f Raise an error when attempting to mutate Jaxpr objects 2023-01-23 09:37:58 -08:00
George Necula
8e931a82ff [shape_poly] Generalize binary operations with symbolic dimensions.
Previously binary operations involving symbolic dimensions would
work only when the other operand is convertible to a symbolic dimension,
e.g., an integer. This resulted in errors when trying "x.shape[0] * 3.5"
and the recourse was to ask the user to add an explicit
"jnp.array(x.shape[0])".

Now we allow binary operations with any operand and the
"jnp.array" is added automatically if the other operand is not
an integer or a symbolic dimension. This means that instead
of an error they may be an error downstream if one tries to use
the result as a dimension. There is one known case where
JAX works with static shapes and with the previous behavior,
but will fail now. When you operate on `np.ndarray` and
symbolic dimension, previously this was kept as a `np.ndarray`
but not it is turned into a JAX array. The following
program will now fail if `x.shape[0]` is a symbolic dimension.:

`jnp.ones(np.arange(5) * x.shape[0])`

Instead you should write

`jnp.ones([i * x.shape[0] for i in range(5)])`
2023-01-21 04:26:59 -08:00
Skye Wanderman-Milne
3f4bd5f449 Updates for jax + jaxlib 0.4.2 release 2023-01-20 19:04:46 +00:00
George Necula
ade5691630 [call_tf] Add has_side_effects parameter
The CallTfEffect was added recently as an internal workaround for
DCE removing instances of call_tf. Here we add a parameter to
`call_tf` to be able to declare if the called computation is
effectful and should not be removed by DCE.
2023-01-14 08:12:29 +01:00
Sharad Vikram
c9a57e1b44 Delete jax.experimental.callback
PiperOrigin-RevId: 501760507
2023-01-12 22:58:31 -08:00
Yash Katariya
835d0c979a Finish jax and jaxlib 0.4.2 release
PiperOrigin-RevId: 495068000
2022-12-13 10:51:13 -08:00
Yash Katariya
dc8ead04b4 Update CHANGELOG to indicate that 0.4.0 was yanked.
PiperOrigin-RevId: 494868884
2022-12-12 17:02:24 -08:00
Yash Katariya
0bdb7ec042 Finish jax and jaxlib release 0.4.0
PiperOrigin-RevId: 494833878
2022-12-12 14:43:35 -08:00
Yash Katariya
13c34f9dc5 Move with_sharding_constraint out of experimental into jax.lax namespace.
PiperOrigin-RevId: 494635809
2022-12-11 22:55:21 -08:00
Matthew Johnson
1185c895ca in jax.Array notebook, polish beginning and tweak title and some wording 2022-12-10 22:16:54 -08:00
Jake VanderPlas
09d1b6d8d5 Deprecate jnp.msort following deprecation of numpy.msort 2022-12-07 10:08:18 -08:00
Peter Hawkins
33a1b8866a Mark arguments to ufuncs as positional-only.
PiperOrigin-RevId: 493311821
2022-12-06 08:24:11 -08:00
Peter Hawkins
c2c3669c15 Remove long-deprecated method .block_host_until_ready().
PiperOrigin-RevId: 492571809
2022-12-02 15:18:11 -08:00
Peter Hawkins
f9b5312149 Do not mirror JAX config options back to ABSL flags.
Currently when JAX config values are configured via ABSL, we use the ABSL flags as a source of truth: if we read or write the JAX config option, we read or write the corresponding ABSL flag. This works but has the unfortunate downside that ABSL flags are relatively slow to read, which slows down JAX every time we read a configuration option.

However, there's fundamentally no reason we are mirroring the JAX configuration options back to ABSL in the first place. We can use ABSL flag parsing as a way only to populate the JAX configuration values. The downside is that if someone changes the ABSL flag values after parsing, that change will not be reflected in JAX's config values. JAX config changes after ABSL flags have been parsed must be made via the `jax.config.update()` API.

This gives a decent improvement on the device_put benchmark:

```
name        old cpu/op  new cpu/op  delta
device_put  79.5µs ± 6%  69.4µs ± 7%  -12.73%  (p=0.000 n=10+9)

name        old time/op             new time/op             delta
device_put  79.5µs ± 6%             69.4µs ± 7%  -12.73%         (p=0.000 n=10+9)
```

PiperOrigin-RevId: 492519085
2022-12-02 11:37:22 -08:00
Yash Katariya
934bc4e1b3 Move PartitionSpec and Mesh out of experimental and into the sharding namespace. The new API endpoint is jax.sharding.PartitionSpec and jax.sharding.Mesh.
PiperOrigin-RevId: 492358238
2022-12-01 19:28:32 -08:00
Skye Wanderman-Milne
82b442fa52 [docs] Replace one more jax_Array.html reference
I missed this in #13479, thanks @yashk2810 for flagging!
2022-12-01 20:50:55 +00:00
George Necula
4ca05f428f [call_tf] Use the same platform for TF lowering as the embedding JAX computation
This requires some changes for abstract evaluation, when
JAX does not use a specific platform.

Also attempt to fix the case when the TF lowering fails because the TF computation
uses a tf.Variable on another device as that used for lowering.

PiperOrigin-RevId: 492112847
2022-11-30 23:22:24 -08:00
Jake VanderPlas
cb62a31653 Drop support for Python 3.7 2022-11-29 15:01:47 -08:00
TJ
5fb0215d4d updated jaxlib CHANGELOG 2022-11-28 10:37:42 -08:00
Yash Katariya
9799d5b139 Add the jax.Array change to the changelog.
PiperOrigin-RevId: 488929264
2022-11-16 06:56:09 -08:00
Yash Katariya
8c42edfec1 Finish jax and jaxlib release 0.3.25. The next release will be 0.4.0 (since jax.Array will be enabled in that release)
PiperOrigin-RevId: 488672395
2022-11-15 09:02:53 -08:00
Peter Hawkins
ebd9840e1f Add several recent changes to the CHANGELOG.
PiperOrigin-RevId: 488362198
2022-11-14 07:39:13 -08:00
Sharad Vikram
e15619ceab Convert string axis name into tuple of strings in Mesh constructor
PiperOrigin-RevId: 487930412
2022-11-11 15:27:51 -08:00
Sharad Vikram
4bdfdd7363 Update changelog w/ info about deleting jax_experimental_name_stack 2022-11-11 14:02:30 -08:00