82 Commits

Author SHA1 Message Date
Peter Hawkins
66293d8897 Remove code present to support jaxlib < 0.5.1.
The new minimum xla_extension_version is 317 and the new mlir_api_version is 58.
2025-02-26 07:40:40 -05:00
Dan Foreman-Mackey
83457c115a Always dispatch CPU executables synchronously when they include callbacks.
As discussed in https://github.com/jax-ml/jax/issues/25861 and https://github.com/jax-ml/jax/issues/24255, using host callbacks within an asynchronously-dispatched CPU executable can deadlock when the body of the callback itself asynchronously dispatches JAX CPU code. My rough understanding of the problem is that the XLA intra op thread pool gets filled up with callbacks waiting for their body to execute, but there aren't enough resources to schedule the inner computations.

There's probably a better way to fix this within XLA:CPU, but the temporary fix that I've come up with is to disable asynchronous dispatch on CPU when either:

1. Executing a program that includes any host callbacks, or
2. when running within the body of a callback.

It seems like both of these conditions are needed in general because I was able to find test cases that failed with just one or the other implemented.

This PR includes just the first change, and the second will be implemented in a follow-up.

PiperOrigin-RevId: 720777713
2025-01-28 18:23:35 -08:00
Peter Hawkins
8f2f4b45fb Annotate several tests as thread-unsafe.
PiperOrigin-RevId: 714117130
2025-01-10 11:24:39 -08:00
Peter Hawkins
c61b2f6b81 Make JAX test suite pass (at least most of the time) with multiple threads enabled.
Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile.

PiperOrigin-RevId: 714037277
2025-01-10 06:58:46 -08:00
Peter Hawkins
51b9fe3010 [JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
2025-01-08 06:37:44 -08:00
Yash Katariya
23eaf2160a Make inspect_array_sharding work without mesh context manager too.
PiperOrigin-RevId: 712702329
2025-01-06 17:32:15 -08:00
Jake VanderPlas
330606320a jax.debug.print: respect local np.printoptions 2025-01-02 16:10:54 -08:00
Yash Katariya
c191bbcdb1 Make debug.print work with static args. Fixes: https://github.com/google/jax/issues/23600
PiperOrigin-RevId: 676005582
2024-09-18 08:41:29 -07:00
Sergei Lebedev
91df9d1a17 Fixed validation in jax.debug.format
This commit ensures that no formatting is done during validation, because the
arguments could be abstract values.

Closes #23475.
2024-09-09 10:53:35 +01:00
Yash Katariya
e1b497078e Rename jtu.create_global_mesh to jtu.create_mesh and use jax.make_mesh inside jtu.create_mesh to get maximum test coverage of the new API.
PiperOrigin-RevId: 670744047
2024-09-03 16:23:07 -07:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
Jake VanderPlas
a861c55a28 test cleanup: use ExitStack to reduce test boilerplate 2024-06-06 14:18:27 -07:00
Christos Perivolaropoulos
9939cc9974 test_util.capture_stdout redirects using file descriptors rather than mocking the python interface.
PiperOrigin-RevId: 640183718
2024-06-04 09:41:47 -07:00
Sergei Lebedev
6e23c14f85 jax.debug.callback now passes arguments as jax.Arrays
Prior to this change the behavior in eager and under jax.jit was inconsistent

    >>> (lambda *args: jax.debug.callback(print, *args))([42])
    [42]
    >>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42])
    [array(42, dtype=int32)]

It was also inconsistent with other callback APIs, which cast the arguments
to jax.Arrays.

Closes #20627.

PiperOrigin-RevId: 626461904
2024-04-19 13:57:18 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jake VanderPlas
84e49bd6ce Remove internal references to deprecated jax.experimental.maps 2024-03-19 09:24:52 -07:00
Matthew Johnson
7608cce86f improve a debug.callback type error message for idiots
(i am the idiot)
2023-12-06 14:41:52 -08:00
Yash Katariya
ef20526a76 Return PositionalSharding if input's rank is >= 3 or a NamedSharding if a mesh is available via the context from inspect_array_sharding. Never return GSPMDSharding from inspect_array_sharding.
PiperOrigin-RevId: 573048344
2023-10-12 16:55:12 -07:00
Peter Hawkins
6be860bda8 Clean up some device opt-in/opt-outs in test suite.
Use allowlists rather than denylists in a few places.

PiperOrigin-RevId: 568968749
2023-09-27 14:56:00 -07:00
Peter Hawkins
bbfba9ace8 Remove code that disabled tests on "stream_executor" backends.
These tests work on both GPU and the current (non-stream_executor) TPU runtime, so the conditions aren't needed any more.

Tag a couple of tests as "multiaccelerator" since they appear to benefit from multiple devices.

PiperOrigin-RevId: 565367453
2023-09-14 07:52:43 -07:00
Yash Katariya
ae9d1498e5 Bump minimum jaxlib version to 0.4.11. xla_extension_version is 158 and mlir_api_version is 49. It will subsume https://github.com/google/jax/pull/16161#issuecomment-1564977332
PiperOrigin-RevId: 537047525
2023-06-01 09:42:55 -07:00
Parker Schuh
5f4408ded7 Convert inspect_sharding to register the handler directly in c++ so that it can
work across the c-api boundary.

PiperOrigin-RevId: 527322386
2023-04-26 11:22:28 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Skye Wanderman-Milne
ef5e4a4035 Remove 'pjrt_c_api_unimplemented' pytest mark.
Instead, we skip tests that the PJRT C API doesn't support. We had
this tag for feature development so it was easy to broadly disable,
but now we don't expect to need to do that.
2023-03-24 23:14:54 +00:00
Frederic Bastien
42e9753431 Fix inspect_array_sharding with grad. 2023-03-21 07:58:27 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Yash Katariya
418c2f9d2a Rename in_axis_resources and out_axis_resources with in_shardings and out_shardings. This is just a simple name replacement. It does not change any of the current pjit semantics and doesn't break any code.
This is a safe and trivial name replacement. It does not change any of the semantics. You can still pass in PatitionSpecs to in_shardings and out_shardings.

PiperOrigin-RevId: 510671300
2023-02-18 10:00:36 -08:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Skye Wanderman-Milne
f90b5eed52 Add pjrt_c_api_unimplemented pytest marker to skip unsupported tests.
Also adds `test_util.pytest_mark_if_available` helper function.
2023-01-12 22:17:23 +00:00
Yash Katariya
b1415bbb0c Skip the test for stream_executor like all the other tests in this file
PiperOrigin-RevId: 496107041
2022-12-17 10:47:32 -08:00
Yash Katariya
8520678249 Fix the failure caused by adding effects to call_tf primitive
PiperOrigin-RevId: 496037178
2022-12-16 23:01:43 -08:00
Yash Katariya
c42bad85ef Make MeshPspecSharding an alias for NamedSharding (it was the other way around before this CL).
PiperOrigin-RevId: 488473538
2022-11-14 14:44:00 -08:00
Sharad Vikram
3731e446c0 Set default layout for Python callback
PiperOrigin-RevId: 487388682
2022-11-09 17:18:49 -08:00
jax authors
93bcb599c2 Merge pull request #13065 from sharadmv:vis-colors
PiperOrigin-RevId: 485467801
2022-11-01 18:11:06 -07:00
Sharad Vikram
3bbd5f3028 Add colors to sharding visualization 2022-11-01 16:55:06 -07:00
Sharad Vikram
3e38675ac4 Update debugging_primitives_test to not use nontrivial floating point text comparisons
PiperOrigin-RevId: 484325096
2022-10-27 12:48:11 -07:00
Peter Hawkins
320d531521 Increase the minimum jaxlib version to 0.3.22.
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
Jingxin Ye
63964237b2 Skip two unit tests about custom sharding on libtpu
DETAILS:
Due to xc.register_custom_call_partitioner is not supported on libtpu, the following two tests are skipped:
tests/pjit_test.py::PJitTest::test_custom_partitioner
tests/debugging_primitives_test.py::InspectShardingTest::test_inspect_sharding_is_called_in_pjit
2022-10-25 20:55:15 +00:00
Jingxin Ye
59374c1cd8 skip some tests if runtime is stream_executor
DETAILS:
Run on CloudTPU v2-8 and found some tests in debugging_primitives_test
fail due to stream_executor runtime cannot support host callback.
Since host callback only support TFRT, so that skip all those types if
runtime type is stream_executor.

TESTED:
passed unit test on both TPU v2-8 and CPU.
2022-10-18 17:42:33 +00:00
Peter Hawkins
0d3277b5c3 Port more tests from jtu.cases_from_list to jtu.sample_product. 2022-10-11 21:06:08 +00:00
jax authors
96abd9ac75 Merge pull request #12540 from sharadmv:cond-lowering-fix
PiperOrigin-RevId: 477358889
2022-09-27 22:33:12 -07:00
Sharad Vikram
ddeaa8dbbc Fix lowering bug in effectful batched cond and add tests 2022-09-27 22:12:13 -07:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Sharad Vikram
805073f36a Add inspect_array_sharding, enabling looking at shardings in pjit-ted functions
PiperOrigin-RevId: 476237731
2022-09-22 17:36:56 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Sharad Vikram
0276a6e77c Add support for pmap sharding 2022-09-19 19:29:44 -07:00
Sharad Vikram
f825a3c8c0 Limit console width for visualize_sharding 2022-09-19 18:41:45 -07:00
jax authors
441f400358 Merge pull request #12386 from sharadmv:viz_sharding
PiperOrigin-RevId: 475387460
2022-09-19 14:36:21 -07:00