841 Commits

Author SHA1 Message Date
Matthew Willson
17d89ad166 Fix jax.device_put so it doesn't use tree_map for _check_sharding.
This causes it to unnecessarily attempt to unflatten the None return values from _check_sharding into the original tree structure, which is a problem for custom datatypes registered with jax.tree_util that don't accept None values in place of jax arrays.

PiperOrigin-RevId: 570189648
2023-10-02 15:01:03 -07:00
George Necula
552fef6fcd Introduce a LoweringParameters dataclass for easier plumbing
There are currently two parameters that are used to configure
lowering: lowering_platform (for cross-platform lowering), and
override_lowering_rules. Each of them are passed as separate arguments
through several layers of lowering internal functions. This is tedious,
and error prone. In fact, override_lowering_rules was not plumbed
in all places, and due to using default arguments in all places,
this leads to silent errors.

We foresee introducing other parameters for lowering: for multi-platform
lowering, for controlling the lowering of effects.

Here is pack all such parameters into a `mlir.LoweringParameters`
dataclass and we plumb that through.
2023-09-29 08:23:05 +03:00
Junwhan Ahn
8bfe3b92bc Roll back f92a70a41e
Reverts bb4382f0bce074ab081e1e02871e32ba331d1d46

PiperOrigin-RevId: 569292433
2023-09-28 14:32:23 -07:00
Junwhan Ahn
bb4382f0bc Destruct objects owned by WeakRefLRUCache::CacheEntry out of band using GlobalPyRefManager()
This assumes less about whether the thread that destructs `CacheEntry` has GIL or not, which is difficult to reason about due to the `xla::LRUCache`'s use of `std::shared_ptr<CacheEntry>`.

The following changes have been made in JAX to accommodate the behavior differences from direct destruction to GC:

* Since `PyLoadedExecutable`s cached in `WeakRefLRUCache` are now destructed out of band, `PyClient::LiveExecutables()` calls `GlobalPyRefManager()->CollectGarbage()` to make the returned information accurate and up to date.
* `test_jit_reference_dropping` has been updated to call `gc.collect()` before verifying the live executable counts since the destruction of executables owned by weak ref maps is now done out of band as part of `GlobalPyRefManager`'s GC.

PiperOrigin-RevId: 569062402
2023-09-27 22:15:22 -07:00
Peter Hawkins
210fab1aae Remove the "No GPU/TPU found" warning.
Instead, add a lightweight test for NVIDIA GPUs and Google TPUs. Warn
only if we suspect either is present but JAX is not using them.
2023-09-26 19:04:34 +00:00
Peter Hawkins
5aaa15df84 Remove the skip_on_xla_cpu_mlir decorator.
We no longer test this variant in CI, so we don't need code to skip it.

PiperOrigin-RevId: 568219651
2023-09-25 08:04:56 -07:00
Yash Katariya
8276038f63 Relax the memory alignment check between numpy array and jax array on CPU
PiperOrigin-RevId: 567722405
2023-09-22 14:49:00 -07:00
Jake VanderPlas
bfed3d862e Improve behavior of core.valid_jaxtype 2023-09-22 13:46:09 -07:00
Yash Katariya
426970591b If an input to jnp.asarray is a numpy array, then convert it to a jax.Array via device_put to avoid a copy.
Do a similar thing for jax.Array too if dtypes match.

Fixes https://github.com/google/jax/issues/17702

PiperOrigin-RevId: 567644997
2023-09-22 09:40:25 -07:00
Jake VanderPlas
0dc2252f71 Better errors for array scalar/boolean conversion 2023-09-19 09:00:19 -07:00
Parker Schuh
21389415cc Add support for float flags to compiler_options.
PiperOrigin-RevId: 565475731
2023-09-14 14:19:39 -07:00
jax authors
6b5af15eea Merge pull request #17593 from jakeh-gc:test_changes
PiperOrigin-RevId: 565428268
2023-09-14 11:30:55 -07:00
Yash Katariya
a2720ee2c3 Deprecate jax.experimental.pjit.with_sharding_constraint. Replacement is jax.lax.with_sharding_constraint which has been available since 1 year.
PiperOrigin-RevId: 565389746
2023-09-14 09:23:03 -07:00
Jake Hall
f59a4163fa Test changes for out-of-tree backend. 2023-09-14 12:18:37 +01:00
Yash Katariya
a36598b2a7 Set the jax_enable_memories flag to True.
If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes.

If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted.

PiperOrigin-RevId: 564457393
2023-09-11 11:55:09 -07:00
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
Yash Katariya
970f4c9d4d Remove trivial execution from jax since it leads to 100x slower dispatch time.
Trivial computations were added for a pre-omnistaging world. After omnistaging, JAX produces less trivial computations, so there is need for this to exist.

In the future, if we want to support forwarding of inputs to outputs, there would need to be a different way which the C++ dispatch path knows about.

```
jit_trivial_dispatch                                   246µs ± 3%                4µs ± 1%  -98.52%          (p=0.008 n=5+5)
jit_trivial                                            250µs ± 3%                5µs ± 1%  -98.19%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 560141018
2023-08-25 10:59:48 -07:00
Matthew Johnson
78199bbc32 [custom-jvp/vjp] for symbolic zeros, ensure rules can be run more than once
Co-authored-by: Roy Frostig <frostig@google.com>
2023-08-21 15:28:43 -07:00
Peter Hawkins
2c32660a8f Replace references to DeviceArray with Array.
A number of stale references are lurking in our documentation.
2023-08-18 17:46:00 -04:00
Yash Katariya
4ddf6a9a54 Bump minimum_jaxlib_version to 0.4.14. xla_extension_version is 174 and mlir_api_version is 54
PiperOrigin-RevId: 552816893
2023-08-01 08:53:28 -07:00
Yash Katariya
e4955ecd23 Fix resolve_argnums if inspect.signature fails. In this case, donate_argnames was None leading an error in assert_no_intersection.
PiperOrigin-RevId: 552554677
2023-07-31 12:10:15 -07:00
Yash Katariya
3929a63a74 Fix the bug where args_info didn't have correct donated bit for args when donate_argnums was set on jax.jit. Fixes https://github.com/google/jax/issues/16906
PiperOrigin-RevId: 552509081
2023-07-31 09:43:50 -07:00
Jake VanderPlas
561c9531ff Lower jax.numpy.dot to mixed-precision dot_general 2023-07-21 10:10:30 -07:00
Parker Schuh
dcad04d244 Add support for int fields to compiler_options.
PiperOrigin-RevId: 549790380
2023-07-20 17:37:19 -07:00
Jake VanderPlas
65751bb328 make jvp(asarray, (1.,), (2.,)) produce Arrays
fixes #15676

Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-20 09:21:55 -07:00
Peter Hawkins
7df3477926 [JAX] Use MLIR argument locations instead of a bespoke jax.arg_info attribute.
514dddbeba allowed for specifying argument Locations in the MLIR Python bindings. We should use them, in the form of a Name location, rather than making up our own attribute.

Example of new output:

```
In [1]: import jax
In [2]: ir = jax.jit(lambda x, y: x + y).lower(7, 3).compiler_ir()
In [3]: ir.operation.print(enable_debug_info=True)
#loc1 = loc("x")
#loc2 = loc("y")
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.sharding = "{replicated}"} loc("x"), %arg1: tensor<i32> {mhlo.sharding = "{replicated}"} loc("y")) -> (tensor<i32> {jax.result_info = ""}) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<i32> loc(#loc4)
    return %0 : tensor<i32> loc(#loc)
  } loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("<ipython-input-2-ef5a568a0c1c>":1:0)
#loc4 = loc("jit(<lambda>)/jit(main)/add"(#loc3))
```

Note debug information must be enabled.

PiperOrigin-RevId: 549325621
2023-07-19 08:39:16 -07:00
Jake VanderPlas
74159132b6 support np.array(x) where x is a custom pytree with __jax_array__ 2023-07-17 13:33:17 -07:00
Yash Katariya
89c78bf53f jax.jit now works correctly if both donate_argnums and donate_argnames are specified.
Update the docstring and changelog too to mention `donate_argnames`.

PiperOrigin-RevId: 548223395
2023-07-14 14:28:16 -07:00
jax authors
165a3db98f Merge pull request #16696 from jakevdp:custom-float-jit
PiperOrigin-RevId: 547643087
2023-07-12 17:14:27 -07:00
Yash Katariya
b337c26c72 Add donate_argnames to jax.jit. This works similarly to static_argnames.
Note that if donate_argnames is not None and donate_argnums is None, then JAX will infer donate_argnums from the names which will then we used to find the donation_vector. This is fine because currently, the same thing happens from static_argnums and static_argnames.

I'll fix the TODOs, etc in follow up CLs.

Fixes https://github.com/google/jax/issues/10539

PiperOrigin-RevId: 547612861
2023-07-12 15:09:57 -07:00
Jake VanderPlas
31c5044c1d Make jit work with custom float inputs 2023-07-12 13:06:03 -07:00
Roy Frostig
598e311191 document and test CustomVJPPrimal type as API symbol 2023-07-12 09:19:02 -07:00
Jake VanderPlas
a29d4bcd33 remove deprecation warning test in preparation for removing deprecated APIs
PiperOrigin-RevId: 547229078
2023-07-11 10:52:10 -07:00
Juliana Franco
f81a48a819 Makes it possible to lower primitives with user-defined lowering rules.
PiperOrigin-RevId: 547228102
2023-07-11 10:26:07 -07:00
jax authors
e894e4817a Remove deprecated compiler_ir from Compiled
PiperOrigin-RevId: 547211085
2023-07-11 09:24:48 -07:00
Roy Frostig
1ad0a11897 AOT: better error messages on call signature mismatch
Also update error example in AOT docs.
2023-07-10 22:10:50 -07:00
treyra
b0c309a25c Added test for vmap inconsistent sized arrays msg 2023-07-09 20:46:40 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
jax authors
f67acee129 Merge pull request #16430 from jakevdp:bool-error
PiperOrigin-RevId: 542951181
2023-06-23 14:00:12 -07:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Yash Katariya
fc0dcd15a2 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Roll forward] Update required Python version to 3.9

PiperOrigin-RevId: 542728213
2023-06-22 18:58:30 -07:00
Skye Wanderman-Milne
10424c5972 Update JAX's XlaExecutable.cost_analysis and related plumbing so it works on Cloud TPU
* Exposes LoadedExecutable.cost_analysis via pybind
* Updates XlaExecutable.cost_analysis to try
  LoadedExecutable.cost_analysis, then fallback to the client method.

PiperOrigin-RevId: 542671990
2023-06-22 14:43:00 -07:00
Jake VanderPlas
f1e603e4b3 errors: create TracerBoolConversionError for more targeted debugging tips 2023-06-21 01:41:45 -07:00
Jake VanderPlas
452a3b928b Errors: avoid printing tracer repr for concretization errors 2023-06-20 00:33:51 -07:00
Patrick Kidger
f2d64f6afb Added argument jax.linearize(..., has_aux=...) 2023-06-14 22:34:13 -07:00
jax authors
349938eb11 Merge pull request #15637 from nouiz:error
PiperOrigin-RevId: 538580301
2023-06-07 13:42:15 -07:00
Parker Schuh
47b8e55451 functools.partial should not unpack curried kwargs in custom_jvp.
PiperOrigin-RevId: 538274209
2023-06-06 13:23:13 -07:00
Yash Katariya
ae9d1498e5 Bump minimum jaxlib version to 0.4.11. xla_extension_version is 158 and mlir_api_version is 49. It will subsume https://github.com/google/jax/pull/16161#issuecomment-1564977332
PiperOrigin-RevId: 537047525
2023-06-01 09:42:55 -07:00
Yash Katariya
6a54ebd031 Fix the lu.clear_all_cache function by adding the memoized_fun to the global weakref set rather than the function local fun_caches weakrefDict.
PiperOrigin-RevId: 534971855
2023-05-24 13:58:51 -07:00
Frederic Bastien
d3216703bd Add a test 2023-05-19 12:24:18 -07:00