166 Commits

Author SHA1 Message Date
Dan Foreman-Mackey
ca23b7495e Add dtype dispatching to FFI example. 2025-02-21 05:57:37 -05:00
rajasekharporeddy
180be99798 Fix typos 2025-02-18 17:01:29 +05:30
Dan Foreman-Mackey
28afd25259 Add FFI example demonstrating the use of XLA's FFI state.
Support for this was added in JAX v0.5.0.

PiperOrigin-RevId: 722597704
2025-02-03 04:06:10 -08:00
Sergei Lebedev
5cad02c1f9 [pjrt] Use the PjRtMemorySpace version of BufferFromHostLiteral
Buffers live in memory spaces and not on devices. The `PjRtDevice` version
of `BufferFromHostLiteral` is deprecated and will be removed once the migration
is complete.

PiperOrigin-RevId: 721860983
2025-01-31 12:22:23 -08:00
Dan Foreman-Mackey
e72c148457 Internal change.
PiperOrigin-RevId: 715008710
2025-01-13 09:57:02 -08:00
Dan Foreman-Mackey
62656b32db Add an example demonstrating input-output aliasing with the FFI. 2025-01-07 13:21:59 -05:00
Dan Foreman-Mackey
cb4d97aa1f Move jex.ffi to jax.ffi. 2024-12-29 13:06:19 +00:00
Peter Hawkins
62e66b684b Don't monkey-patch functions in test_utils to count events for tests.
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.

Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
Dan Foreman-Mackey
84a9cba85b Refactor FFI examples to consolidate several examples into one submodule. 2024-11-25 09:08:20 -05:00
Mason Chang
42fbd301fc Move JAX example to public XLA:CPU API
PiperOrigin-RevId: 698143471
2024-11-19 14:19:29 -08:00
jax authors
83700828c5 Merge pull request #23805 from dfm:ffi-examples-state
PiperOrigin-RevId: 696383873
2024-11-13 21:43:41 -08:00
Dan Foreman-Mackey
f08648366e Add an example FFI call to demonstrate the use of global state. 2024-11-12 08:36:12 -08:00
Dan Foreman-Mackey
f757054267 Update some outdated syntax in FFI tutorial. 2024-11-12 08:34:24 -08:00
Dan Foreman-Mackey
ce8dba98fb Move the CUDA end-to-end example to FFI examples workflow + hosted
runner.
2024-10-31 12:21:51 -04:00
Dan Foreman-Mackey
61701af4a2 Rename vmap methods for callbacks. 2024-10-21 15:03:04 -04:00
Dan Foreman-Mackey
0b651f0f45 Make ffi_call return a callable 2024-10-21 12:16:57 -04:00
Dan Foreman-Mackey
1d27d420ac Deprecate the vectorized argument to pure_callback and ffi_call. 2024-10-02 11:33:51 -04:00
Dan Foreman-Mackey
f60c5ccdee Add support for passing array attributes via ffi_call 2024-10-01 19:22:04 -04:00
jax authors
9d277e61ce Merge pull request #23409 from dfm:ffi-examples
PiperOrigin-RevId: 678690801
2024-09-25 07:23:26 -07:00
Dan Foreman-Mackey
e1a68eee5e Add FFI example project and test on CI.
This PR includes an end-to-end example project which demonstrates the
use of the FFI. This complements [the FFI
tutorial](https://jax.readthedocs.io/en/latest/ffi.html) by putting all
of the code in one place, as well as demonstrating how FFI extensions
can be packaged. Alongside the example project, I have also added a new
GitHub Actions workflow to test the example as part of CI.

For now, the tests only run on CPU, but once we have GPU runners for
GitHub Actions (soon!), I plan on migrating the custom call examples
from `docs/gpu_ops` and `docs/cuda_custom_call` into this test case.

Similarly, I wanted to start small and this example project only
includes exactly the same functions as the tutorial for now, but I think
this could be a good place to showcase more advanced examples (including
custom calls with state).
2024-09-24 17:23:13 -04:00
Peter Hawkins
6d1f51e63d Clean up BUILD files.
PiperOrigin-RevId: 667604964
2024-08-26 09:11:17 -07:00
Kyle Lucke
84d748f43c Stop using xla/statusor.h now that it just contains an alias for absl::Status.
In some situations, this meant also changing unrelated files to directly include tsl/platform/statusor.h to get the definitions for TF_ASSIGN_OR_RETURN, etc., where they were getting transitively included for free.

PiperOrigin-RevId: 645169743
2024-06-20 15:09:40 -07:00
Kyle Lucke
418b68828a Automated Code Change
PiperOrigin-RevId: 635818645
2024-05-21 08:40:34 -07:00
Peter Hawkins
24b47318bd Force float32 matmuls in examples_test.
This test started failing when we changed our CI to use L4 GPUs. Using
highest precision resolves the problem.
2024-05-10 19:30:02 +00:00
Peter Hawkins
89d25bb1a3 Reenable examples_test in Bazel build.
Fix bitrot.

This test was disabled years ago because it was slow, but it isn't any more.

PiperOrigin-RevId: 632138101
2024-05-09 07:10:07 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Peter Hawkins
1baed9b285 [PJRT:CPU] Replace references to pjrt/tfrt_cpu_pjrt_client with pjrt/cpu/cpu_client.h.
The two are aliases and the former is a forwarding header pointing to the latter.

Cleanup only, no functional changes.

PiperOrigin-RevId: 621341188
2024-04-02 17:20:16 -07:00
Roy Frostig
78fd4f1664 update top-level examples to use new-style typed keys 2024-03-07 12:40:10 -08:00
Oleg Shyshkov
fb80d2abcb [XLA][NFC] Make interface of module loaders consistent.
LoadModuleFromData has (data, format, config, ...) signature while FromFile has (path, config, format, ...). Change the latter so `format` becomes the second argument in both cases.

Since I'm touching this file:
* Use `std::string_view` and `absl::Status`
* Change `ovr_config` parameter to `const &`

PiperOrigin-RevId: 601304308
2024-01-24 19:16:43 -08:00
Peter Hawkins
70b7d50181 Switch jaxlib to use nanobind instead of pybind11.
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.

PiperOrigin-RevId: 559898790
2023-08-24 16:07:56 -07:00
David Cottrell
40d0d40b6c Fix for log-normal prior in example. 2023-08-09 13:38:37 +01:00
Peter Hawkins
76cda0ae07 Update flags to use the ABSL typed flag API.
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.

For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.

Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.

This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.

PiperOrigin-RevId: 551604974
2023-07-27 12:15:58 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Peter Hawkins
3bb7386149 [JAX] Improve handling of metadata in compilation cache.
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.

There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.

This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.

PiperOrigin-RevId: 525534901
2023-04-19 13:27:04 -07:00
Peter Hawkins
9f03bd51ca Make differential privacy example work again.
* Use dp_accounting import instead of the TensorFlow privacy package.
* Correct example dataset import.
* Fix use of list argument to jnp.linalg.norm()

Fixes https://github.com/google/jax/issues/15500

PiperOrigin-RevId: 523168395
2023-04-10 11:31:06 -07:00
Peter Hawkins
ab45383038 Fix build breakage from OpenXLA switch.
PiperOrigin-RevId: 516325478
2023-03-13 14:37:35 -07:00
jax authors
42ef649e65 Merge pull request #14475 from hawkinsp:openxla
PiperOrigin-RevId: 516316330
2023-03-13 14:04:41 -07:00
Peter Hawkins
172a831219 Switch JAX to use the OpenXLA repository. 2023-03-13 18:38:26 +00:00
Peter Hawkins
88379603e0 [PJRT] Delete the old :cpu_device target that uses StreamExecutor.
The TFRT CPU client is better in every way and the SE CPU client is unmaintained and has not been used by JAX in many months.

PiperOrigin-RevId: 489246256
2022-11-17 10:29:03 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jean-Baptiste Lespiau
2929d79d0b Use value instead of ValueOrDie.
PiperOrigin-RevId: 473011038
2022-09-08 09:35:52 -07:00
jax authors
eb7040d6d6 Merge pull request #11629 from froystig:rm-control-example
PiperOrigin-RevId: 463462047
2022-07-26 17:12:51 -07:00
jax authors
32e77772b9 Rename xla::PjRtExecutable to xla::PjRtLoadedExecutable
PiperOrigin-RevId: 463460929
2022-07-26 17:06:45 -07:00
Roy Frostig
eadc3466c0 remove control example 2022-07-26 15:37:46 -07:00
Matthew Johnson
e350894371 remove resnet50 example 2022-07-19 09:40:39 -07:00
jax authors
b5e6145a42 Merge pull request #11359 from hawkinsp:bazel
PiperOrigin-RevId: 459234031
2022-07-06 06:13:20 -07:00
Peter Hawkins
1c75eee1ff Document how to run tests using Bazel.
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
2022-07-06 08:30:35 -04:00
Jake VanderPlas
1a995a0c61 [x64] make examples/control_test compatible with strict dtype promotion 2022-06-16 16:20:54 -07:00
Galen Andrew
ef9036abf3 Migrate from deprecated tensorflow_privacy RDP accountant to differential_privacy.
PiperOrigin-RevId: 454724315
2022-06-13 16:26:51 -07:00
jvmncs
c128876b34
incorrect link in dpsgd example 2022-05-19 16:35:07 -04:00