21 Commits

Author SHA1 Message Date
Sergei Lebedev
56745818a6 Added basic support for int2/uint2 dtypes to JAX
#21369

PiperOrigin-RevId: 649366888
2024-07-04 04:13:24 -07:00
Yash Katariya
b1f7627c71 [Rollback] Bumped the minimum ml_dtypes version to 0.4.0
Reverts e86c436e7f8e4e0546eff8bc2d3756a7c49dc83b

PiperOrigin-RevId: 642741832
2024-06-12 14:40:40 -07:00
Sergei Lebedev
0a694a1b42 Bumped the minimum ml_dtypes version to 0.4.0 2024-05-23 21:51:00 +01:00
Jake VanderPlas
9b46e2d6a3 Support float8 in reduce_min & reduce_max 2023-12-18 13:37:45 -08:00
Reed Wanderman-Milne
d41078fb95 Properly pack and unpack int4 arrays on CPU in PJRT.
Transferring an array from host to device on CPU sometimes does a zero-copy implementation where no memory is actually moved. This is now never done with int4, since int4 arrays are stored in packed format on device and an unpacked format on host. Similarly, transferring an array from device to host on CPU used to always use a zero-copy implementation, but now it will unpack and copy for int4 arrays.

PiperOrigin-RevId: 578692796
2023-11-01 17:39:24 -07:00
Sergei Lebedev
65d3058944 Migrate a subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

PiperOrigin-RevId: 571932143
2023-10-09 07:29:53 -07:00
Peter Hawkins
cf2d61c701 Move --jax_test_dut into the non-public test utilities.
This flag only exists for the use of JAX's own tests, and doesn't need to exist for most JAX users.

4f805c2d8f allows this move since none of the test comparison utilities now depend on the choice of backend. (That dependency was only an administrative dependency for external users of JAX, since the only public users of the test comparison utilites are the gradient utilities, which always override the default tolerance with their own tolerances.)

PiperOrigin-RevId: 563195002
2023-09-06 13:21:48 -07:00
Peter Hawkins
4f805c2d8f [JAX] Change jax.test_util utilities to have identical tolerances on all platforms.
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.

TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).

The net effect of this change is mostly to tighten up many test tolerances on TPU.

PiperOrigin-RevId: 562953488
2023-09-05 18:48:55 -07:00
Jake Hall
85f124c18d Add support for float8_e4m3fnuz and float8_e5m2fnuz. 2023-08-07 11:48:53 +01:00
Peter Hawkins
76cda0ae07 Update flags to use the ABSL typed flag API.
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.

For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.

Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.

This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.

PiperOrigin-RevId: 551604974
2023-07-27 12:15:58 -07:00
Jake VanderPlas
9962065deb Require ml_dtypes>=0.2 2023-07-07 12:07:44 -07:00
Jake VanderPlas
8d165193be Fix test_util for new float8 type 2023-06-08 00:30:45 -07:00
Jake VanderPlas
4063922b38 Clean up old jaxlib version guards.
minimum_jaxlib_version is now 0.4.11, which has xla_client._version=158
2023-06-06 01:17:30 -07:00
Peter Hawkins
c1f65fc8b2 Avoid imports from the public jax.* namespace in more places internally.
This change is in preparation for more cycle breaking in the Bazel dependency graph.

PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
jax authors
78599e65d1 Roll-back https://github.com/google/jax/pull/14144 due to downstream test failures
PiperOrigin-RevId: 504628432
2023-01-25 12:15:36 -08:00
jax authors
d14e144651 Use pareto optimal step size for computing numerical Jacobians in JAX. This allows us to tighten the tolerances in gradient unit testing significantly, especially for float64 and complex128.
PiperOrigin-RevId: 504579516
2023-01-25 09:12:52 -08:00
Qiao Zhang
d203926c16 Expose fp8 in jax dtypes and mlir builder.
PiperOrigin-RevId: 501980811
2023-01-13 18:12:12 -08:00
Jake VanderPlas
9e53de888a [x64] make chack_grads() more type-safe 2022-12-02 12:51:41 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
1246b6fc73 Separate jax.test_util implementations into public and private sources.
Eventually the private functionality will no longer be exported via the jax.test_util submodule.

PiperOrigin-RevId: 439415485
2022-04-04 14:43:39 -07:00