11528 Commits

Author SHA1 Message Date
Ben West
02f6fcb9da Add beta function 2023-11-05 15:37:38 -08:00
Jake VanderPlas
96d9f89415 [random] better errors for unsupported operations on prng keys 2023-11-03 19:23:18 -07:00
jax authors
808289d52a Merge pull request #18373 from jakevdp:sharding-doc
PiperOrigin-RevId: 579224607
2023-11-03 10:23:44 -07:00
Jake VanderPlas
cd3ea05665 Ensure sharding-related array properties are documented 2023-11-03 09:56:33 -07:00
jax authors
db07f40233 Fall-back to original device/backend hashing if topology-desc is unavailable.
The original cache-key generation algorithm hashed devices and backend as
part of generating the key. The new algorithm relies on serialized
PjRtTopologyDescription instead. Not all backends support serialized
PjRtTopologyDescription. Fall back to the original device/backend hashing
if the needed backend does not support it.

Testing: unit testing + test workloads.
PiperOrigin-RevId: 579039803
2023-11-02 18:43:48 -07:00
Skye Wanderman-Milne
55e3072d2e Update versions and CHANGELOG after jax 0.4.20 release 2023-11-02 16:30:56 -07:00
jax authors
62741d9744 Reverts 81ac67f38164b7626d733d081a87ff49b235b9d0
PiperOrigin-RevId: 579010408
2023-11-02 16:17:29 -07:00
Parker Schuh
c8b7c1b80b Remove temporary flag for forcing arg tuplization of lowered functions.
PiperOrigin-RevId: 578910366
2023-11-02 10:53:16 -07:00
jax authors
1f6264896d Merge pull request #18354 from jakevdp:dep-opaque
PiperOrigin-RevId: 578904025
2023-11-02 10:37:47 -07:00
jax authors
6d5eaa6ec3 Merge pull request #18295 from gnecula:lax_multi
PiperOrigin-RevId: 578892592
2023-11-02 10:07:34 -07:00
Jake VanderPlas
0111dcbda3 Finish deprecation of allow_opaque_dtype 2023-11-02 09:51:06 -07:00
Etienne Pot
81ac67f381 Fix typing annotations for @jax.named_call
PiperOrigin-RevId: 578852649
2023-11-02 07:55:04 -07:00
George Necula
8feb413211 Add a lax.platform_dependent API for writing platform-dependent code.
In JAX the actual platform on which a computation is run is determined
very late, e.g., based on where the data is located. When using AOT
lowering or serialization, the computation may execute on a different
machine, or even on a platform that is not available at lowering time.
This means that it is not safe to write platform-dependent code using
Python conditionals, e.g., based on the current default JAX platform.
The proper way to do this is to introduce a primitive with
platform-specific lowering rules. This change introduces such a
primitive along with a user-facing API.

See more details in the docstring of lax.platform_dependent.
2023-11-02 14:31:38 +01:00
Reed Wanderman-Milne
d41078fb95 Properly pack and unpack int4 arrays on CPU in PJRT.
Transferring an array from host to device on CPU sometimes does a zero-copy implementation where no memory is actually moved. This is now never done with int4, since int4 arrays are stored in packed format on device and an unpacked format on host. Similarly, transferring an array from device to host on CPU used to always use a zero-copy implementation, but now it will unpack and copy for int4 arrays.

PiperOrigin-RevId: 578692796
2023-11-01 17:39:24 -07:00
jax authors
a009f8d6c1 Pass flags from kernel into HLO backend config.
PiperOrigin-RevId: 578390868
2023-10-31 21:32:19 -07:00
jax authors
9f28512c4b Merge pull request #18340 from jakevdp:keyarray-error
PiperOrigin-RevId: 578355093
2023-10-31 17:49:28 -07:00
Roy Frostig
16d082b002 [jex] replace extend.random.PRNGImpl with extend.random.define_prng_impl
Instead of exposing a constructor, only expose a function that returns an opaque
object representing the defined implementation. This result can still be passed
to `jax.random.key` and `wrap_key_data`.

PiperOrigin-RevId: 578349699
2023-10-31 17:21:54 -07:00
Jake VanderPlas
a4e6b4e943 [random] add more information to KeyArray deprecation error 2023-10-31 16:54:49 -07:00
Roy Frostig
ed9a4c2939 add jax.threefry_partitionable context manager 2023-10-31 13:45:55 -07:00
Roy Frostig
b22e75716f add threefry_partitionable config setting to thread-local JIT context 2023-10-31 13:45:49 -07:00
Yash Katariya
20255dce84 Delete cached_call_jaxpr_lowerings since a more general cached_primitive_lowerings is available
PiperOrigin-RevId: 577993595
2023-10-30 16:38:57 -07:00
Yash Katariya
85af862efd [Try again] For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the lowering time.
Reverts 4a5c6f82009dee9c30ca4a85359a702d745ed035

PiperOrigin-RevId: 577974380
2023-10-30 15:28:43 -07:00
Jake VanderPlas
fbacebc11e jnp.einsum: mention default value for optimize param 2023-10-30 09:22:37 -07:00
Sergei Lebedev
fd3a8b2cc6 Deprecated define_* and DEFINE_* methods on jax.config
These methods are internal to JAX. Yet, prior to this commit they were
effectively part of the public API, since users could (and some did!) invoke
them on `jax.config`.
2023-10-29 20:58:19 +00:00
Parker Schuh
19c65353d2 Do not init backends from topology construction, instead directly init the
plugin.

PiperOrigin-RevId: 577331743
2023-10-27 16:21:01 -07:00
Yash Katariya
8ee58117e2 Don't print all the devices in the mesh during ResourceEnv's repr. Just print the mesh shape.
PiperOrigin-RevId: 577305337
2023-10-27 14:25:34 -07:00
jax authors
69d0a404c2 Merge pull request #18305 from superbobry:deprecate-jax-config-submodule
PiperOrigin-RevId: 577299278
2023-10-27 14:04:34 -07:00
jax authors
9ba305cced Invalidate in-memory caches on XLA-AutoFDO profile version change.
When the value in --jax_xla_profile_version changes, all tracing
and compilation caches should be invalidated since the XLA programs
need to be recompiled with the new XLA-AutoFDO profile.

Testing:
. New unit test.
. Test workload with instrumentation to repeatedly change
  the profile version. Before/after comparison.
PiperOrigin-RevId: 577280639
2023-10-27 12:52:57 -07:00
Sergei Lebedev
c90f1f0c96 Deprecated accessing config via the jax.config submodule
The `config` object it re-exports is available on the top-level `jax` package,
i.e.

    from jax.config import config

can always safely be replaced with

    from jax import config

or just

    import jax
    jax.config
2023-10-27 20:08:10 +01:00
jax authors
11c4e2c820 [JAX] Add an option subset_by_index that allows computing a contiguous subset of eigenvalues from eigh.
PiperOrigin-RevId: 577222219
2023-10-27 09:29:33 -07:00
Yash Katariya
0d57330fe0 Add mitigation techniques in the error message when a barrier timeout occurs.
PiperOrigin-RevId: 577214081
2023-10-27 08:57:39 -07:00
jax authors
15e6d7a3ad Fix an error format issue in jax.random
PiperOrigin-RevId: 577042296
2023-10-26 18:14:55 -07:00
jax authors
38e28693a8 [Pallas] Add Slice to public interface
PiperOrigin-RevId: 576840688
2023-10-26 05:50:59 -07:00
Yash Katariya
4d15375596 AOT sharding mismatch error shouldn't have GSPMDSharding in it.
PiperOrigin-RevId: 576668290
2023-10-25 15:48:01 -07:00
David Majnemer
ba9fd7744e Remove bitwise conversions from _canonicalize_float_for_sort
This lets us compute entirely in the float domain.

PiperOrigin-RevId: 576613806
2023-10-25 12:43:55 -07:00
George Necula
edbe49fb2a Cleanup the handling of single- and multi-platform lowering in ModuleContext
Previously, we introduced support for multi-platform lowering, by
adding a new LoweringParameters object that can be used to specify
a cross-lowering platform or even multiple platforms. But we had
kept the ModuleContext.platform in place because some lowering rules
were still referencing it. Now we replace ModuleContext.platform with
ModuleContext.platforms, which removes the redundancy, simplifies
the code, and makes it clearer that the lowering rules should not
simply assume single-platform lowering.

PiperOrigin-RevId: 576575376
2023-10-25 10:40:41 -07:00
Adam Paszke
7325b753fd Improve the error message reported when Mosaic passes fail
Those failures should be treated as internal and reported as bugs.
As a next step, we should add a verifier pass that explicitly checks
that all the ops in the module are supported.

PiperOrigin-RevId: 576502802
2023-10-25 05:49:34 -07:00
Sharad Vikram
d488812d0c [Pallas TPU] Add support for hoisted scratch spaces
PiperOrigin-RevId: 576336673
2023-10-24 17:29:54 -07:00
jax authors
4897a5fb5a Merge pull request #18217 from gnecula:multi_call_tf
PiperOrigin-RevId: 576218473
2023-10-24 12:13:29 -07:00
Jake VanderPlas
3e9c50290f Allow array-like inputs to random.seed_impl 2023-10-24 11:23:49 -07:00
George Necula
db44249afc [call_tf] Fix call_tf lowering for multi-platform lowering
call_tf has per-platform lowering because the lowering
of the called TF function may depend on the platform. When
doing multi-platform lowering this means that we lower
call_tf several times and wrap the lowerings with a
conditional. This results in an assertion failure
in add_to_call_tf_concrete_function_list, because we
are attempting to add the same function multiple times.

Here we remove the assertion (afaik, it is Ok to add
multiple functions with the same name, because all
we care about is the index of the called function in
the list). We also reuse the existing function if
we are adding an identical one.

We add tests for call_tf with multi-platform lowering.
2023-10-24 18:57:58 +02:00
Yash Katariya
8b05b1623c Make the directories (and it's parents) specified in JAX_DUMP_IR_TO flag if they don't exist
PiperOrigin-RevId: 576151618
2023-10-24 08:39:51 -07:00
Chris Jones
b61af5a104 [pallas:gpu] Lower more complex primitives using JAX functions in terms of more basic primitives.
PiperOrigin-RevId: 575883386
2023-10-23 11:45:45 -07:00
jax authors
8fa287e4e7 Merge pull request #18236 from jakevdp:numpy-core
PiperOrigin-RevId: 575877503
2023-10-23 11:27:20 -07:00
jax authors
366dfc2fa4 Merge pull request #18234 from superbobry:f-repr-str
PiperOrigin-RevId: 575858521
2023-10-23 10:26:29 -07:00
Jake VanderPlas
ea9eb6d2b1 [CI] avoid referencing numpy.core to fix nightly CI 2023-10-23 09:11:33 -07:00
Peter Hawkins
caee898fd0 Fix jaxlib build failure after upstream MLIR Python binding changes.
https://github.com/llvm/llvm-project/pull/68853 changed the structure of
the upstream MLIR Python bindings, breaking the jaxlib build. Update our
build scripts to match.
2023-10-23 14:27:52 +00:00
Sergei Lebedev
f2ce5dbd01 MAINT Do not use str() and repr() in f-string replacement fields
`str()` is called by default by the formatting machinery, and `repr()` only
needs `!r`.
2023-10-23 15:12:04 +01:00
George Necula
9726e1b011 Adjust tolerance for backward compatibility test for Qr on TPU.
Fixes flaky test.

PiperOrigin-RevId: 575752219
2023-10-23 02:23:15 -07:00
George Necula
e89212c81a [export] Set the default export serialization version to 8.
This version has been supported by XlaCallModule since July 21, 2023 and we are now past the forward-compatibility window.

See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions

Reverts ae81ac9cc21696a22b973b1eae6ce222c7318ba7

PiperOrigin-RevId: 575382324
2023-10-20 21:00:55 -07:00