1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-21 06:16:06 +00:00

38 Commits

Author SHA1 Message Date
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Jake VanderPlas
33465274da fix some additional warnings related to 2024-06-13 16:06:14 -07:00
Jake VanderPlas
8b630452ae fix multi_backend_tests 2024-06-13 11:17:31 -07:00
Jake VanderPlas
3f210c63a0 avoid globally silencing the jit backend/device warning 2024-06-12 14:43:14 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jake VanderPlas
97beb01c43 Deprecate the device() method of JAX arrays 2023-11-30 11:43:02 -08:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to .
2023-10-12 17:32:15 +01:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Yash Katariya
181355335c Remove references to jax.config.jax_jit_pjit_api_merge, which is always True at head.
PiperOrigin-RevId: 516998437
2023-03-15 20:07:20 -07:00
Yash Katariya
53fceab17c pjit allows nesting of pjits where the outer backend is None while the inner backend is something other than device_under_test(). This is because the inner backend will take priority.
PiperOrigin-RevId: 502721834
2023-01-17 16:39:45 -08:00
Peter Hawkins
72f4f389be Migrate remaining tests from jtu.cases_from_list to jtu.sample_product.
Delete jtu.cases_from_list.
2022-10-12 15:20:53 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Yash Katariya
980aa318fb Minimally support device argument on jit in the jax.Array path
This means that only a single device is allowed to flow through this path. This is a compromise i.e. it will support the existing codepaths but won't support sharded arrays to go through this path and encourage users to use other well supported techniques like using device_put explicitly instead of relying on `jit` to do that for you.

PiperOrigin-RevId: 473373822
2022-09-09 16:56:43 -07:00
Yash Katariya
6340952e2a Make jit == pjit. This means that the lowering and execution paths of jit and pjit are merged.
A fallback to `lower_xla_callable` is taken when pmap appears in the jaxpr during the jit lowering path.

Added support for `keep_unused`, `committed` and `core.Token` to pxla.py.

PiperOrigin-RevId: 470896270
2022-08-29 22:03:21 -07:00
Yash Katariya
314cf8a439 Use .device() to get the device and platform from the device and fix TODO to point to github issue
PiperOrigin-RevId: 468769708
2022-08-19 13:14:13 -07:00
Yash Katariya
d77848bcc9 Enable jax_array on CPU for the entire JAX test suite!
PiperOrigin-RevId: 468726200
2022-08-19 10:04:35 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
2bd010ae88 Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.

Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.

In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.

PiperOrigin-RevId: 403607667
2021-10-16 07:53:24 -07:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
9f083d11da Use jax.* APIs rather than api.* names in tests.
Tests should use our own public APIs where they exist.
2021-09-13 16:01:32 -04:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Peter Hawkins
fffdb2daa8
Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… ()
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.

Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.

No functional changes intended.

* Fix a number of lax reference implementations to preserve types.
2020-06-01 17:19:23 -04:00
Peter Hawkins
b1bc841ae5
Replace np -> jnp, onp -> np in more places. ()
* Replace np -> jnp, onp -> np in more places.

Context: 

* Fix typo in random_test.py
2020-05-05 16:40:41 -04:00
Matthew Johnson
e06bde8cc0
revise xla.device_put device logic ()
* revise xla.device_put device logic, fixes 

* remove test of behavior we don't want

Previously, we were testing that for a DeviceArray x, writing
jax.device_put(x) would evaluate to a DeviceArray *on the default
device*. Instead, we should be happy with just returning the same
DeviceArray without any movement.
2020-04-30 17:21:10 -07:00
George Necula
8d4b6857ad
Fix typo in tests; caught on GPU and TPU () 2020-04-30 19:16:05 +03:00
George Necula
b39da1f842
Fix jit with device placement ()
In setups with multiple backends, a jit happens on the default
backend, unless we give a `backend` parameter. This is true
even if the inputs are committed to a device on the non-default
backend, or if we pass a `device` parameter to jit.
2020-04-30 10:16:14 +03:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. ()
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Peter Hawkins
7dbc8dc1bc
Minimal changes to make Jax pass a pytype check. () 2020-01-18 08:26:23 -05:00
Matthew Johnson
e51b6b34e1 fix test typo 2020-01-08 10:53:27 -08:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. ()
Remove six dependency.
2020-01-08 13:17:55 -05:00
Matthew Johnson
b04019ea74 fix test typos 2020-01-07 22:30:54 -08:00
Matthew Johnson
ad9b6d4d94 implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(n)

But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(m)[:, None]

The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.

Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).

This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.

Incidentally fixes 

See https://github.com/google/jax/pull/1668 for more.
2020-01-07 20:48:26 -08:00
Matthew Johnson
a3eb2b1b96 improve computation-follows-data policy
fixes  (see discussion there)

The new policy is that JAX's DeviceArrays, which are backed by device
memory but potentially on different devices (like distinct GPUs, or CPU
and GPU), can either be "stuck" to their device or not (i.e. "sticky" or
not). A DeviceArray result is stuck to its device if
  1. it was produced by a computation with an explicit user-provided
  device or backend label, i.e. a `jit` or `device_put` with an explicit
  device or backend argument, or
  2. it was produced by a computation that consumed as an argument a
  sticky DeviceArray value.
If a computation without an explicit device/backend label is applied to
all non-sticky arguments, the result is non-sticky. If a computation
without an explicit device/backend label is applied to any sticky
arguments, then if all the sticky device labels agree the result is
sticky on the same device, and otherwise an error is raised. (A
computation with an explicit device/backend label can consume any sticky
or non-sticky values without error, regardless of their devices.)

Implementation-wise, every DeviceArray has an attribute _device
(introduced in , revised here) that set either to a value that
represents a Device instance (actually stored as a Device class / int id
pair), indicating that the DeviceArray is sticky on that device, or None
indicating that the DeviceArray is not sticky. The value of the _device
attribute for results of a computation is computed when the XLA
executable is compiled and stored in the result handler (which packages
up a raw device buffer into a DeviceArray).
2019-12-26 14:27:12 -08:00
Matthew Johnson
286ec51f61 make op-by-op computation follow arg placement 2019-12-18 14:40:20 -08:00
Peter Hawkins
bbf8129aa6
Change test tolerance logic not to choose tolerance values based on f… ()
* Change test tolerance logic not to choose tolerance values based on flags (in particular, --jax_enable_x64).

We would like to move away from having global flags to enable 64-bit mode. We therefore need other methods to select test tolerances. Instead, use a per-type default tolerance, and allow tests to pass per-type dictionaries of tolerances as atol and rtol values. Fix up a number of tolerances to make tests pass.

* Fix test tolerances.

* Fix dtype canonicalization for test tolerances.

* Relax core test_vjp tolerance.
2019-11-16 13:51:42 -05:00
Anselm Levskaya
91a2311601 clean up multibackend tests 2019-08-23 23:42:08 -07:00
Anselm Levskaya
c839c6a602 Added basic behavior unit tests of multibackend jit. 2019-08-21 20:59:18 -07:00