173 Commits

Author SHA1 Message Date
Jake VanderPlas
18521fef08 Deprecate jax.tree_* aliases 2025-03-27 10:13:14 -07:00
Jake VanderPlas
91a07ea2e8 Clean up a number of finalized deprecations 2025-03-26 09:57:19 -07:00
Yash Katariya
3a26804c68 Rename get_ty to typeof which is an alias of get_aval
PiperOrigin-RevId: 735946640
2025-03-11 17:34:44 -07:00
Tom Hennigan
1becb57ac9 Add jax.copy_to_host_async(tree).
A relatively common pattern I've observed is the following:

```python
_, metrics = some_jax_function()

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

We are missing an opportunity here to more eagerly begin the h2d copy of
the metrics (e.g. overlap it with closing the "compute_metrics" context
manager etc. The intention of `jax.copy_to_host_async(x)` is to make it
simple to begin h2d transfers as early as possible. Adapting the above code:

```python
_, metrics = some_jax_function()

# Begin D2H copies as early as we can.
jax.copy_to_host_async(metrics)

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

PiperOrigin-RevId: 731626446
2025-02-27 01:22:15 -08:00
Yash Katariya
b35083331c Expose get_ty aka get_aval from jax namespace
PiperOrigin-RevId: 728490205
2025-02-18 21:22:19 -08:00
Yash Katariya
799eb98cac Add reshard API in experimental. Currently for sharding_in_types we have 2 APIs: mesh_cast and reshard. Both work in sharding_in_types mode and affect the sharding of the aval. Following are the semantics of both:
* `mesh_cast`: AxisTypes between src and dst mesh **must** differ. There should be **no "visible" data movement**. The shape of the aval doesn't change.

* `reshard`: Mesh should be the **same** between src and dst (same axis_names, axis_sizes and axis_types). **Data movement is allowed**. The shape of the aval doesn't change.

We might make `reshard` == `device_put`, hence the API is in experimental. This decision can be taken at a later point in time. The reason not to just give `device_put` this power is because `device_put` does a lot of stuff right now (and is going to get even more powers in the near future like cross-host transfers) and it's semantics would be very confusing if we keep piling sharding-in-types stuff on it.

PiperOrigin-RevId: 717588253
2025-01-20 11:39:25 -08:00
Yash Katariya
c7f8d17f5a Expose hidden_axes via jax namespace as public API. Also mention it as a workaround for primitives we don't support yet.
PiperOrigin-RevId: 716839003
2025-01-17 16:48:58 -08:00
Dan Foreman-Mackey
cb4d97aa1f Move jex.ffi to jax.ffi. 2024-12-29 13:06:19 +00:00
Jake VanderPlas
f401c97967 finalize deprecation of jax.clear_backends 2024-11-14 09:22:09 -08:00
Jake VanderPlas
2b9c73d10d Remove a number of expired deprecations.
These APIs were all removed 3 or more months ago, and the registrations
here cause them to raise informative AttributeErrors. Enough time has
passed now that we can remove these.
2024-10-31 15:40:54 -07:00
Jake VanderPlas
de3191fab3 Cleanup: fix unused imports & mark exported names 2024-10-16 17:42:41 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Yash Katariya
de9b98e0a8 Delete jax.xla_computation since it's been 3 months since it was deprecated.
PiperOrigin-RevId: 673938336
2024-09-12 11:47:38 -07:00
Yash Katariya
252caebce3 Create jax.make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], devices: Sequence[jax.Device] | None = None) API to make it easier to create a mesh and reduce a ton of boilerplate.
`jax.make_mesh` is the stable API endpoint of `mesh_utils` but without all the extra options. If you want those, you can still use the experimental endpoint in `mesh_utils`.

PiperOrigin-RevId: 670707995
2024-09-03 14:32:03 -07:00
Matthew Johnson
670a648b7b add experimental jax.no_tracing context manager 2024-08-23 21:21:55 +00:00
Jake VanderPlas
2c221f2d5a Register several jax.numpy argument name deprecations 2024-08-22 09:41:53 -07:00
Dan Foreman-Mackey
60bf5b7727 Add a jax.process_indices function.
The `jax.host_ids` function has be long deprecated, but the suggested alternative of `list(range(jax.process_count()))` relies on the current behavior that the list of process indices is always dense. In the future we may want to allow dynamic addition and removal of processes in which case `jax.process_count` and `jax.process_indices` would need to be updated, and it is useful for users to be able to use this forward-compatible interface.

PiperOrigin-RevId: 662142636
2024-08-12 10:30:41 -07:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
rdyro
c6d6207170 Unifying persistent cache messages
and moving them to WARNING logging when explain_cache_misses is true.
2024-07-16 00:47:53 +00:00
Matthew Johnson
bd166e1d99 add more info to xla_computation deprecation warning 2024-07-02 13:31:07 +00:00
Yash Katariya
e6f26ff256 Deprecate jax.xla_computation. Use JAX AOT APIs to get the equivalent of jax.xla_computation functionality.
PiperOrigin-RevId: 644107276
2024-06-17 13:02:35 -07:00
Jake VanderPlas
0a86e9a929 Deprecate hashing of tracers 2024-06-13 13:14:27 -07:00
Jake VanderPlas
aa1452375b Register beta args deprecation
PiperOrigin-RevId: 642427224
2024-06-11 16:19:14 -07:00
Dan Foreman-Mackey
1e206880d3 Move jax.ffi submodule to jax.extend.ffi 2024-05-31 12:34:59 -04:00
Dan Foreman-Mackey
88790711e8 Package XLA FFI headers with jaxlib wheel
The new "typed" API that XLA provides for foreign function calls is
header-only and packaging it as part of jaxlib could simplify the open
source workflow for building custom calls.

It's not completely obvious that we need to include this, because jaxlib
isn't strictly required as a _build_ dependency for FFI calls, although
it typically will be required as a _run time_ dependency. Also, it
probably wouldn't be too painful for external projects to use the
headers directly from the openxla/xla repo.

All that being said, I wanted to figure out how to do this, and it has
been requested a few times.
2024-05-22 12:28:38 -04:00
Jake VanderPlas
d33a5689de Refactor & test internal deprecation APIs
The names and APIs were previously too similar and therefore somewhat confusing; this will be more clear I think.

PiperOrigin-RevId: 635615163
2024-05-20 17:16:31 -07:00
Jake VanderPlas
4bac10e750 Finalize deprecation of the config module.
To configure JAX, use `import jax` and reference the config object via `jax.config`.

PiperOrigin-RevId: 635430169
2024-05-20 05:49:31 -07:00
Mark Sandler
8f045cafd2 Add jax.make_array_from_process_local_data to create a distributed tensor from host data and supporting scaffolding in sharding to be able to figure out dimensions of host data required.
PiperOrigin-RevId: 634205261
2024-05-15 22:06:45 -07:00
Jake VanderPlas
2daaf49541 Remove extraneous pure_callback_api wrapper 2024-04-25 10:21:49 -07:00
Matthew Johnson
8588d4b747 alias jax.sharding.NamedSharding -> jax.NamedSharding 2024-04-11 16:23:59 -07:00
Jake VanderPlas
8949a63ce1 [key reuse] rename flag to jax_debug_key_reuse 2024-03-22 05:37:30 -07:00
Jake VanderPlas
d8662886d7 Register maps module deprecation outside of module
PiperOrigin-RevId: 617194807
2024-03-19 09:21:48 -07:00
Yue Sheng
147c363ea6 Deprecate jax.clear_backends.
`jax.clear_backends` does not necessarily do what its name suggests and can lead to unexpected consequences, e.g., it will not destroy existing backends and release corresponding owned resources. Use `jax.clear_caches` if you only want to clean up compilation caches. For backward compatibilty or you really need to switch/reinitialize the default backend, use `jax.extend.backend.clear_backends`.

PiperOrigin-RevId: 616946337
2024-03-18 14:23:18 -07:00
Jake VanderPlas
236275ebe1 Deprecate jax.tree_map for jax v0.4.26
Reverts f4045dceb206be1ea10ee651ccc6151809f2d9f3

PiperOrigin-RevId: 611230367
2024-02-28 14:29:01 -08:00
Yash Katariya
f4045dceb2 Remove the deprecation of jax.tree_map for the release of 0.4.25
PiperOrigin-RevId: 610014256
2024-02-24 09:30:06 -08:00
Jake VanderPlas
e59a0506fe Deprecate jax.tree_map in favor of jax.tree.map 2024-02-22 11:35:39 -08:00
Jake VanderPlas
cf80f574b5 Register jax.config module deprecation
PiperOrigin-RevId: 609352291
2024-02-22 06:38:56 -08:00
Sergei Lebedev
57e59eb6c3 Removed deprecated jax.config methods and jax.config.config
Reverts dcc65e621ea3a68fdc79fa9f2c995743a7b3faf7

PiperOrigin-RevId: 608676645
2024-02-20 11:25:16 -08:00
Thomas Köppe
dcc65e621e Reverts b506fee9e389391efb1336bc7575dba913e75cdf
PiperOrigin-RevId: 608319964
2024-02-19 06:23:00 -08:00
Sergei Lebedev
b506fee9e3 Removed deprecated jax.config methods and jax.config.config
Reverts eb0343683547b6e2d29245f3ab6c91037c0cff81

PiperOrigin-RevId: 607803834
2024-02-19 06:21:15 -08:00
jax authors
eb03436835 Reverts 318a19a89387caebd116168c4e47592e7d71ca65
PiperOrigin-RevId: 607708463
2024-02-16 09:11:05 -08:00
Sergei Lebedev
318a19a893 Removed deprecated jax.config methods
PiperOrigin-RevId: 607675571
2024-02-16 06:49:13 -08:00
Jake VanderPlas
6934a4b76b Add jax.tree module with aliases of jax.tree_util 2024-02-12 13:07:59 -08:00
Jake VanderPlas
e356d76913 Remove a number of deprecated APIs
All of these were deprecated prior to the JAX 0.4.16 release, on Sept 18 2023.
As of Monday Dec 18, we have met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 591933493
2023-12-18 10:08:47 -08:00
Jake VanderPlas
a52d18781e Add experimental static key reuse checking 2023-12-11 12:03:48 -08:00
Roy Frostig
ed9a4c2939 add jax.threefry_partitionable context manager 2023-10-31 13:45:55 -07:00
Sergei Lebedev
f2ce5dbd01 MAINT Do not use str() and repr() in f-string replacement fields
`str()` is called by default by the formatting machinery, and `repr()` only
needs `!r`.
2023-10-23 15:12:04 +01:00
Jake VanderPlas
024b1f23d7 Remove deprecated submodule jax.abstract_arrays 2023-09-19 15:40:18 -07:00
Jake VanderPlas
1800015884 Import jax.version first 2023-09-12 12:27:20 -07:00
Jake VanderPlas
ca39457ea9 JEX: move jax.linear_util to jax.extend.linear_util 2023-08-30 18:32:12 -07:00