49 Commits

Author SHA1 Message Date
Peter Hawkins
c61b2f6b81 Make JAX test suite pass (at least most of the time) with multiple threads enabled.
Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile.

PiperOrigin-RevId: 714037277
2025-01-10 06:58:46 -08:00
Peter Hawkins
51b9fe3010 [JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
2025-01-08 06:37:44 -08:00
Peter Hawkins
62e66b684b Don't monkey-patch functions in test_utils to count events for tests.
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.

Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Ruturaj4
fb5c516405 [ROCM] test_computation_follows_data fix for rocm and cuda 2024-07-15 16:28:22 -05:00
Jake VanderPlas
33465274da fix some additional warnings related to #21834 2024-06-13 16:06:14 -07:00
Jake VanderPlas
a861c55a28 test cleanup: use ExitStack to reduce test boilerplate 2024-06-06 14:18:27 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Mark Sandler
2717daee8f Simplifies full to not instantiate intermediate array with default sharding, this significantly reduces overhead when creating sharded arrays in eager mode when using jnp.zeros_like(...)
PiperOrigin-RevId: 606765964
2024-02-13 15:32:44 -08:00
Jake VanderPlas
8476b44aac Test: move lax.full sharding tests to multi_device_test 2024-01-23 12:10:00 -08:00
Jake VanderPlas
97beb01c43 Deprecate the device() method of JAX arrays 2023-11-30 11:43:02 -08:00
Matthew Johnson
c9ab0bfd3c fix grad device_put src inference, and a small device_put bug
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-11-29 16:24:24 -08:00
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
Yash Katariya
970f4c9d4d Remove trivial execution from jax since it leads to 100x slower dispatch time.
Trivial computations were added for a pre-omnistaging world. After omnistaging, JAX produces less trivial computations, so there is need for this to exist.

In the future, if we want to support forwarding of inputs to outputs, there would need to be a different way which the C++ dispatch path knows about.

```
jit_trivial_dispatch                                   246µs ± 3%                4µs ± 1%  -98.52%          (p=0.008 n=5+5)
jit_trivial                                            250µs ± 3%                5µs ± 1%  -98.19%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 560141018
2023-08-25 10:59:48 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Yash Katariya
181355335c Remove references to jax.config.jax_jit_pjit_api_merge, which is always True at head.
PiperOrigin-RevId: 516998437
2023-03-15 20:07:20 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Yash Katariya
1526c3e20c Improve the error message which is raised from _get_and_check_device_assignment.
Before:

```
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU
```

After:

```
ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp)
```
PiperOrigin-RevId: 508746961
2023-02-10 13:54:15 -08:00
Yash Katariya
74601e59e1 Fix the error message of different devices when jit/pjit are merged
PiperOrigin-RevId: 500727596
2023-01-09 09:03:55 -08:00
Yash Katariya
cbf34cb609 Rename the concrete class Array to ArrayImpl
PiperOrigin-RevId: 477017236
2022-09-26 16:18:30 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Yash Katariya
980aa318fb Minimally support device argument on jit in the jax.Array path
This means that only a single device is allowed to flow through this path. This is a compromise i.e. it will support the existing codepaths but won't support sharded arrays to go through this path and encourage users to use other well supported techniques like using device_put explicitly instead of relying on `jit` to do that for you.

PiperOrigin-RevId: 473373822
2022-09-09 16:56:43 -07:00
jax authors
edfbbd7203 Merge pull request #12297 from mattjj:computation-follows-data-prng
PiperOrigin-RevId: 473092328
2022-09-08 14:57:31 -07:00
Matthew Johnson
47b2dfe92f add _device attribute to PRNGKeyArray so that computation follows key placement
unrelated: remove some redundant hasattr + try / except AttributeError
2022-09-08 14:30:18 -07:00
Yash Katariya
7fbf8ec669 Fix Forward. The fix is on the user's end. Original PR: https://github.com/google/jax/pull/12217
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 472999907
2022-09-08 08:49:40 -07:00
jax authors
14f1a345a1 roll back breakage
PiperOrigin-RevId: 472949225
2022-09-08 03:59:54 -07:00
Yash Katariya
b7e4e44cbf DCE jaxpr and trivial_jaxpr support for lower_sharding_computation
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 471274989
2022-09-06 14:09:10 -07:00
Yash Katariya
6340952e2a Make jit == pjit. This means that the lowering and execution paths of jit and pjit are merged.
A fallback to `lower_xla_callable` is taken when pmap appears in the jaxpr during the jit lowering path.

Added support for `keep_unused`, `committed` and `core.Token` to pxla.py.

PiperOrigin-RevId: 470896270
2022-08-29 22:03:21 -07:00
Yash Katariya
314cf8a439 Use .device() to get the device and platform from the device and fix TODO to point to github issue
PiperOrigin-RevId: 468769708
2022-08-19 13:14:13 -07:00
Yash Katariya
d77848bcc9 Enable jax_array on CPU for the entire JAX test suite!
PiperOrigin-RevId: 468726200
2022-08-19 10:04:35 -07:00
Reza Rahimi
a0d9d81f92 Update JAX to use new math libraries in ROCm-5.0. 2022-03-01 20:02:15 +00:00
Peter Hawkins
d262bae88b Split jax.interpreters.xla up into three pieces:
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).

The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)

PiperOrigin-RevId: 411565432
2021-11-22 08:22:43 -08:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Matthew Johnson
2b79264354 remove disable_omnistaging mechanism 2021-03-29 15:26:57 -07:00
Peter Hawkins
140c0acbbe Remove the JAX lazy sublanguage.
Back in the mists of time, before omnistaging landed in JAX, we used lazy
expressions to avoid materializing large constants inside `jit` computations.
Omnistaging, which means that computations that are in the dynamic scope of a
`jit` are staged into the `jit` computation, has subsumed most of the reasons
for laziness to exist, and this PR removes the laziness support for simplicity.

At the time of this PR, laziness is used only for broadcasts and transposes in
eager mode (i.e., outside a `jit`). This allows us to:
a) fuse together multiple broadcasts and transposes, and
b) if a lazy expression is lexically captured by a `jit` computation, we can
   avoid materializing it in its expanded form.

It is not clear that laziness has sufficient power to weight ratio to continue
to exist, and it is making other work on improving JAX dispatch times more
difficult. As a result, this PR removes laziness to unblock that work; if we
want laziness again we would want to reimplement it in C++ anyway.
2021-03-09 21:40:46 -05:00
Jake VanderPlas
6393349783 raise_to_shaped: preserve weak_type by default 2020-10-08 11:53:52 -07:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Peter Hawkins
b543652332
Replace np -> jnp, onp -> np in tests. (#2969) 2020-05-05 14:59:16 -04:00
George Necula
d315564ebf
Fixed a few more places where device commitment was lost. (#2913)
* trivial jit computations were forcing commitment to the default device
* a device_put with a device specification would not set the commitment
  if the data was already (uncommitted) on the specified device.
* added tests for the above
* once the above were fixed the LaztTest.test_zeros_ones_compilation
  stated to fail because the `sticky` parameter to lazy_force_computation
  was changing. Fixed this by removing stickyness from the compilation key.
* Expanded docstring for jax.device_put; expanded the
  device placement FAQ entry.
2020-05-04 11:30:28 +03:00
George Necula
ac023bf28f
Fixed a few places where device sticky-ness was lost. Added FAQ (#2882)
* Fixed a few places where device sitckyness was lost. Added FAQ for device
placement.

I have also added a new test (multi_device_test.test_computation_follows_data),
written more as part of the documentation. It is shorted than the
old test_computation_follows_data (which is still there, renamed
as test_computation_follows_data_old). I believe there is no
extra coverage in test_computation_follows_data_old w.r.t. all the
other tests we have.

* Fix mypy annotations and updates based on comments

* Undid some changes, will make another PR
2020-05-01 10:06:59 +03:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Matthew Johnson
ad9b6d4d94 implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(n)

But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(m)[:, None]

The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.

Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).

This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.

Incidentally fixes #1431

See https://github.com/google/jax/pull/1668 for more.
2020-01-07 20:48:26 -08:00
Matthew Johnson
a3eb2b1b96 improve computation-follows-data policy
fixes #1914 (see discussion there)

The new policy is that JAX's DeviceArrays, which are backed by device
memory but potentially on different devices (like distinct GPUs, or CPU
and GPU), can either be "stuck" to their device or not (i.e. "sticky" or
not). A DeviceArray result is stuck to its device if
  1. it was produced by a computation with an explicit user-provided
  device or backend label, i.e. a `jit` or `device_put` with an explicit
  device or backend argument, or
  2. it was produced by a computation that consumed as an argument a
  sticky DeviceArray value.
If a computation without an explicit device/backend label is applied to
all non-sticky arguments, the result is non-sticky. If a computation
without an explicit device/backend label is applied to any sticky
arguments, then if all the sticky device labels agree the result is
sticky on the same device, and otherwise an error is raised. (A
computation with an explicit device/backend label can consume any sticky
or non-sticky values without error, regardless of their devices.)

Implementation-wise, every DeviceArray has an attribute _device
(introduced in #1884, revised here) that set either to a value that
represents a Device instance (actually stored as a Device class / int id
pair), indicating that the DeviceArray is sticky on that device, or None
indicating that the DeviceArray is not sticky. The value of the _device
attribute for results of a computation is computed when the XLA
executable is compiled and stored in the result handler (which packages
up a raw device buffer into a DeviceArray).
2019-12-26 14:27:12 -08:00
Skye Wanderman-Milne
891aecb941
Add test utilities for counting compilations. (#1895)
Also uses the new utilities to check that pmap doesn't compile constant computations.
2019-12-19 11:19:58 -08:00
Matthew Johnson
0ed842ed11 add another multi-device test 2019-12-18 15:55:39 -08:00
Matthew Johnson
2a394ce31b move multi-device test into its own file 2019-12-18 14:40:20 -08:00