215 Commits

Author SHA1 Message Date
Jake VanderPlas
40fe4b8797 Finalize deprecation of some symbols from jax.lib.xla_client 2024-12-23 10:14:16 -08:00
Jake VanderPlas
f858a71461 Finalize some deprecations in jax.core, jax.lib.xla_bridge, and jax.lib.xla_client. 2024-12-11 09:50:33 -08:00
Jake VanderPlas
e6d6c4ef8a Delete non-public API jax.lib.xla_bridge._backends
This is doubly non-public: nothing under `jax.lib` is public, and also the object itself has a preceding underscore. Therefore it is safe to remove (chex had referenced this previously, but that's now addressed in adaf1b2b75).

PiperOrigin-RevId: 704825268
2024-12-10 13:25:14 -08:00
Jake VanderPlas
85e2969aea Deprecate several private APIs in jax.lib 2024-11-20 08:48:26 -08:00
Peter Hawkins
e9c7ff0b7d Deprecate a number of APIs in jax.lib.xla_client.
(Technically these aren't public, so they don't need a deprecation period, but this is the polite thing to do.)

PiperOrigin-RevId: 684906277
2024-10-11 11:42:40 -07:00
Peter Hawkins
aa3254d723 Deprecate jax.lib.xla_client.PaddingType.
This type is unused by JAX, so there is no replacement.

(JAX does have an internal PaddingType enum in lax, but it is not present in any APIs, as best I can tell.)

PiperOrigin-RevId: 684451556
2024-10-10 08:22:20 -07:00
Peter Hawkins
94abaf430e Add lax.FftType.
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.

We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.

PiperOrigin-RevId: 684447186
2024-10-10 08:07:35 -07:00
Peter Hawkins
fc4f554e09 Delete jax.lib.xla_client.execute_with_python_values.
Nothing under jax.lib.xla_client is public, so there's no deprecation period required.

PiperOrigin-RevId: 681166972
2024-10-01 14:32:22 -07:00
Peter Hawkins
0e082f978b Deprecate jax.lib.xla_client.Device.
jax.Device is a longstanding public name for this class.

PiperOrigin-RevId: 679197718
2024-09-26 10:17:04 -07:00
Peter Hawkins
7b53c2f39d Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeError class.
Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API.

PiperOrigin-RevId: 679163106
2024-09-26 08:39:30 -07:00
Jake VanderPlas
a009e1cf50 deprecate jax.lib.xla_client.bfloat16 2024-08-06 11:22:27 -07:00
Jake VanderPlas
06f29bbb97 Deprecate jax.lib.xla_client._xla
This is an alias for jax.lib.xla_extension. Why the deprecation warning
for this when #22844 removed other APIs without any warning? This one
is relatively commonly used (I found a few dozen downstream references)
so I feld that a deprecation warning might be helpful.
2024-08-05 16:19:59 -07:00
Jake VanderPlas
3d857b02ac export jax.lib.xla_extension.HloModule
Followup to #22844, because the symbol is used downstream.

PiperOrigin-RevId: 659678623
2024-08-05 14:16:40 -07:00
Jake VanderPlas
521c94c6c6 Tighten the public API for jax.lib.xla_client & xla_extension 2024-08-03 05:26:22 -07:00
Jake VanderPlas
3fa86a9b32 remove jax.extend.backend.default_backend in favor of jax.backend
I added this two days ago before realizing there is already a canonical API
for this in the top-level namespace, so it should be safe to remove.
2024-08-02 07:07:29 -07:00
Jake VanderPlas
3551fcc077 Deprecate several APIs in jax.lib.xla_bridge
PiperOrigin-RevId: 658274719
2024-07-31 23:00:35 -07:00
Neil Girdhar
3c920c0120 Switch from flake8 to Ruff 2023-11-15 22:35:52 -05:00
Peter Hawkins
a259df0d76 Move compiler APIs out of dispatch.py and xla_bridge.py into a new jax._src.compiler module.
Refactoring only, no user-visible changes intended.

PiperOrigin-RevId: 557116160
2023-08-15 06:39:46 -07:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Peter Hawkins
8dc1dff610 Remove device_count, local_device_count, process_index exports from xla_bridge.
These were accidental exports and have public equivalents under the top-level jax namespace. The deprecation policy does not apply to names under jax.lib, which is intended to be private.

PiperOrigin-RevId: 506088434
2023-01-31 13:01:19 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
a9ae7ca71b Reexport jaxlib.__version as jax.lib.__version__.
PiperOrigin-RevId: 445186919
2022-04-28 10:25:06 -07:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Peter Hawkins
95f47074da Remove xla_bridge.{constant, register_constant_handler, _python_scalar_constant} from API.
An upcoming change will move and rename these functions, and it's not clear they should have been public in the first place.

PiperOrigin-RevId: 404051961
2021-10-18 13:56:58 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Jake VanderPlas
245581411e Add PEP484-compatible export for jax and its subpackages 2021-09-13 14:08:48 -07:00
Peter Hawkins
bf351d1e93 Drop support for the deprecated StreamExecutor CPU backend.
The TFRT backend is better and there's no reason to keep the StreamExecutor backend around any longer.

PiperOrigin-RevId: 395455049
2021-09-08 06:04:45 -07:00
Skye Wanderman-Milne
4e6aef581b Factor out uncached version of xla_bridge.get_backend
PiperOrigin-RevId: 391158558
2021-08-16 16:05:51 -07:00
Qiao Zhang
8d6ff968af Internal backend config changes.
PiperOrigin-RevId: 390233762
2021-08-11 15:30:26 -07:00
Ryan Sepassi
f685e43553 Add backend errors 2021-08-05 18:18:12 -07:00
yashkatariya
2f6f788916 Fix the timer->timer_secs typo 2021-08-02 18:27:27 -07:00
yashkatariya
277b9449a2 Resolve comments 2021-07-30 10:28:47 -07:00
yashkatariya
cdfc2dc8fe Use warnings instead of absl.logging.warning 2021-07-28 18:28:37 -07:00
yashkatariya
bf28b884b2 Log a warning after 60 secs to remind the user to run code on all hosts for Cloud TPU 1VM 2021-07-28 15:22:42 -07:00
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
Peter Hawkins
f33ce0d844 Warn if importing jaxlib on Mac ARM machines.
We can remove this warning when Mac ARM has CI testing.
2021-07-13 09:24:48 -04:00
Peter Hawkins
b393d9a8c1 Update jax version and changelog for 0.1.27.
Disable tfrt CPU backend on jaxlib 0.1.68 to work around https://github.com/google/jax/issues/7229.
2021-07-09 15:21:52 -04:00
Matthew Johnson
5e92faccbb tweak xla_bridge.py flags
* add environment variables for jax_disable_most_optimizations and
  jax_cpu_backend_variant
* comment on the default values in help strings
2021-07-02 10:31:28 -07:00
Jake VanderPlas
c8e571ad84 Allow suppression of GPU warning via jax_platform_name 2021-06-28 12:54:21 -07:00
Peter Hawkins
d658108d36 Fix type errors with current mypy and NumPy.
Enable type stubs for jaxlib.

Fix a nondeterminism problem in jax2tf tests.
2021-06-24 10:51:06 -04:00
Qiao Zhang
132a54228b Add flag to select tfrt backend for CPU. 2021-06-22 14:12:36 -07:00
George Necula
6a48c60a72 Rename master to main in embedded links.
Tried to avoid the change on external links to repos that
have not yet renamed master.
2021-06-18 10:00:01 +03:00
Peter Hawkins
5dc9df386c [JAX] Attach a priority to JAX backends. Use the backend with the highest priority when choosing a default backend.
PiperOrigin-RevId: 377351657
2021-06-03 12:48:24 -07:00
Peter Hawkins
b2c7ae728d [JAX] Catch all exceptions from backend initialization.
PiperOrigin-RevId: 377278098
2021-06-03 06:49:56 -07:00
Peter Hawkins
7db0c56a22 [JAX] Change how JAX manages XLA platforms.
* Combine the concepts of "platform" and "backend". The main upshot of this is that the tpu_driver backend requires users to write `jit(..., backend="tpu_driver")` if mixing CPU and TPU execution, however I doubt users are writing that because it didn't work to mix CPU and tpu_driver before.
* Initialize all platforms at startup, rather than lazily initializing platforms on demand. This makes it easy to do things like "list the available platforms".
* Don't use two levels of caching. Cache backends only in xla_bridge.py, not xla_client.py.

PiperOrigin-RevId: 376883261
2021-06-01 11:44:31 -07:00
Peter Hawkins
d481013f47 Add a CPU feature guard module to JAX.
To make sure that the CPU feature guard happens first, before any other code that may use instructions that do not exist, use a separate C extension module.

Fixes https://github.com/google/jax/issues/6671

PiperOrigin-RevId: 374683190
2021-05-19 10:58:35 -07:00
Jake VanderPlas
2868030160 Generalize constant handlers for multi-buffer objects
Co-authored-by: Matthew Johnson <mattjj@google.com>
2021-05-06 09:44:51 -07:00
Peter Hawkins
c983d3c660 Bundle libdevice.10.bc with jaxlib wheels.
libdevice.10.bc is a redistributable part of the CUDA SDK.

This avoids problems trying to locate a copy of libdevice inside the user's CUDA installation.
2021-04-29 10:26:03 -04:00
Skye Wanderman-Milne
9128ba0c74 Replace host_id with process_index terminology, take 2.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.

This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.

This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
2021-04-20 18:13:34 -07:00
Skye Wanderman-Milne
1614572eb9 Add optional distributed debugging logging.
This can be enabled by setting the environment variable
`JAX_DISTRIBUTED_DEBUG=1` (or other true-like values), the flag
`--jax_distributed_debug=1`, or `jax.config.distributed_debug =
True`. It's off by default.

This enables WARNING-level logging of each distributed computation
that's run and related debugging information. This is designed to help
with multi-process debugging, e.g. to identify mismatched pmaps across
processes. All debugging information is enclosed between
`DISTRIBUTED_DEBUG_BEGIN` and `DISTRIBUTED_DEBUG_END` to faciliate
grepping for this info.

Example output:

```
DISTRIBUTED_DEBUG_BEGIN
Initialized backend: tpu
  process_index: 0
  device_count: 8
  local_devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running pmapped function: <lambda>
  python function: <function PmapTest.testArgAllReduce.<locals>.<lambda> at 0x7f77924d6c80>
  devices: None
  abstract args: [ShapedArray(float32[2,2])]
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running xmapped function: <lambda>
  python function: <function XMapTest.testAxisSizes.<locals>.<lambda> at 0x7fb33d86e158>
  mesh: Mesh(array([TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
       TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)],
      dtype=object), ('x',))
  abstract args: []
DISTRIBUTED_DEBUG_END

DISTRIBUTED_DEBUG_BEGIN
Running pjit'd function: f
  python function: <function PJitTest.testShardingConstraintPyTree.<locals>.f at 0x7fad672b8b70>
  mesh: Mesh(array([[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)],
       [TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)]],
      dtype=object), ('x', 'y'))
  abstract args: [ShapedArray(int32[8,8]), ShapedArray(int32[8,8]), ShapedArray(int32[8,8])]
DISTRIBUTED_DEBUG_END
```
2021-04-20 13:34:45 -07:00