978 Commits

Author SHA1 Message Date
Jake Hall
f59a4163fa Test changes for out-of-tree backend. 2023-09-14 12:18:37 +01:00
Yash Katariya
a36598b2a7 Set the jax_enable_memories flag to True.
If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes.

If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted.

PiperOrigin-RevId: 564457393
2023-09-11 11:55:09 -07:00
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
Yash Katariya
970f4c9d4d Remove trivial execution from jax since it leads to 100x slower dispatch time.
Trivial computations were added for a pre-omnistaging world. After omnistaging, JAX produces less trivial computations, so there is need for this to exist.

In the future, if we want to support forwarding of inputs to outputs, there would need to be a different way which the C++ dispatch path knows about.

```
jit_trivial_dispatch                                   246µs ± 3%                4µs ± 1%  -98.52%          (p=0.008 n=5+5)
jit_trivial                                            250µs ± 3%                5µs ± 1%  -98.19%          (p=0.008 n=5+5)
```

PiperOrigin-RevId: 560141018
2023-08-25 10:59:48 -07:00
Matthew Johnson
78199bbc32 [custom-jvp/vjp] for symbolic zeros, ensure rules can be run more than once
Co-authored-by: Roy Frostig <frostig@google.com>
2023-08-21 15:28:43 -07:00
Peter Hawkins
2c32660a8f Replace references to DeviceArray with Array.
A number of stale references are lurking in our documentation.
2023-08-18 17:46:00 -04:00
Yash Katariya
4ddf6a9a54 Bump minimum_jaxlib_version to 0.4.14. xla_extension_version is 174 and mlir_api_version is 54
PiperOrigin-RevId: 552816893
2023-08-01 08:53:28 -07:00
Yash Katariya
e4955ecd23 Fix resolve_argnums if inspect.signature fails. In this case, donate_argnames was None leading an error in assert_no_intersection.
PiperOrigin-RevId: 552554677
2023-07-31 12:10:15 -07:00
Yash Katariya
3929a63a74 Fix the bug where args_info didn't have correct donated bit for args when donate_argnums was set on jax.jit. Fixes https://github.com/google/jax/issues/16906
PiperOrigin-RevId: 552509081
2023-07-31 09:43:50 -07:00
Jake VanderPlas
561c9531ff Lower jax.numpy.dot to mixed-precision dot_general 2023-07-21 10:10:30 -07:00
Parker Schuh
dcad04d244 Add support for int fields to compiler_options.
PiperOrigin-RevId: 549790380
2023-07-20 17:37:19 -07:00
Jake VanderPlas
65751bb328 make jvp(asarray, (1.,), (2.,)) produce Arrays
fixes #15676

Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-20 09:21:55 -07:00
Peter Hawkins
7df3477926 [JAX] Use MLIR argument locations instead of a bespoke jax.arg_info attribute.
514dddbeba allowed for specifying argument Locations in the MLIR Python bindings. We should use them, in the form of a Name location, rather than making up our own attribute.

Example of new output:

```
In [1]: import jax
In [2]: ir = jax.jit(lambda x, y: x + y).lower(7, 3).compiler_ir()
In [3]: ir.operation.print(enable_debug_info=True)
#loc1 = loc("x")
#loc2 = loc("y")
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.sharding = "{replicated}"} loc("x"), %arg1: tensor<i32> {mhlo.sharding = "{replicated}"} loc("y")) -> (tensor<i32> {jax.result_info = ""}) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<i32> loc(#loc4)
    return %0 : tensor<i32> loc(#loc)
  } loc(#loc)
} loc(#loc)
#loc = loc(unknown)
#loc3 = loc("<ipython-input-2-ef5a568a0c1c>":1:0)
#loc4 = loc("jit(<lambda>)/jit(main)/add"(#loc3))
```

Note debug information must be enabled.

PiperOrigin-RevId: 549325621
2023-07-19 08:39:16 -07:00
Jake VanderPlas
74159132b6 support np.array(x) where x is a custom pytree with __jax_array__ 2023-07-17 13:33:17 -07:00
Yash Katariya
89c78bf53f jax.jit now works correctly if both donate_argnums and donate_argnames are specified.
Update the docstring and changelog too to mention `donate_argnames`.

PiperOrigin-RevId: 548223395
2023-07-14 14:28:16 -07:00
jax authors
165a3db98f Merge pull request #16696 from jakevdp:custom-float-jit
PiperOrigin-RevId: 547643087
2023-07-12 17:14:27 -07:00
Yash Katariya
b337c26c72 Add donate_argnames to jax.jit. This works similarly to static_argnames.
Note that if donate_argnames is not None and donate_argnums is None, then JAX will infer donate_argnums from the names which will then we used to find the donation_vector. This is fine because currently, the same thing happens from static_argnums and static_argnames.

I'll fix the TODOs, etc in follow up CLs.

Fixes https://github.com/google/jax/issues/10539

PiperOrigin-RevId: 547612861
2023-07-12 15:09:57 -07:00
Jake VanderPlas
31c5044c1d Make jit work with custom float inputs 2023-07-12 13:06:03 -07:00
Roy Frostig
598e311191 document and test CustomVJPPrimal type as API symbol 2023-07-12 09:19:02 -07:00
Jake VanderPlas
a29d4bcd33 remove deprecation warning test in preparation for removing deprecated APIs
PiperOrigin-RevId: 547229078
2023-07-11 10:52:10 -07:00
Juliana Franco
f81a48a819 Makes it possible to lower primitives with user-defined lowering rules.
PiperOrigin-RevId: 547228102
2023-07-11 10:26:07 -07:00
jax authors
e894e4817a Remove deprecated compiler_ir from Compiled
PiperOrigin-RevId: 547211085
2023-07-11 09:24:48 -07:00
Roy Frostig
1ad0a11897 AOT: better error messages on call signature mismatch
Also update error example in AOT docs.
2023-07-10 22:10:50 -07:00
treyra
b0c309a25c Added test for vmap inconsistent sized arrays msg 2023-07-09 20:46:40 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
jax authors
f67acee129 Merge pull request #16430 from jakevdp:bool-error
PiperOrigin-RevId: 542951181
2023-06-23 14:00:12 -07:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
Yash Katariya
fc0dcd15a2 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Roll forward] Update required Python version to 3.9

PiperOrigin-RevId: 542728213
2023-06-22 18:58:30 -07:00
Skye Wanderman-Milne
10424c5972 Update JAX's XlaExecutable.cost_analysis and related plumbing so it works on Cloud TPU
* Exposes LoadedExecutable.cost_analysis via pybind
* Updates XlaExecutable.cost_analysis to try
  LoadedExecutable.cost_analysis, then fallback to the client method.

PiperOrigin-RevId: 542671990
2023-06-22 14:43:00 -07:00
Jake VanderPlas
f1e603e4b3 errors: create TracerBoolConversionError for more targeted debugging tips 2023-06-21 01:41:45 -07:00
Jake VanderPlas
452a3b928b Errors: avoid printing tracer repr for concretization errors 2023-06-20 00:33:51 -07:00
Patrick Kidger
f2d64f6afb Added argument jax.linearize(..., has_aux=...) 2023-06-14 22:34:13 -07:00
jax authors
349938eb11 Merge pull request #15637 from nouiz:error
PiperOrigin-RevId: 538580301
2023-06-07 13:42:15 -07:00
Parker Schuh
47b8e55451 functools.partial should not unpack curried kwargs in custom_jvp.
PiperOrigin-RevId: 538274209
2023-06-06 13:23:13 -07:00
Yash Katariya
ae9d1498e5 Bump minimum jaxlib version to 0.4.11. xla_extension_version is 158 and mlir_api_version is 49. It will subsume https://github.com/google/jax/pull/16161#issuecomment-1564977332
PiperOrigin-RevId: 537047525
2023-06-01 09:42:55 -07:00
Yash Katariya
6a54ebd031 Fix the lu.clear_all_cache function by adding the memoized_fun to the global weakref set rather than the function local fun_caches weakrefDict.
PiperOrigin-RevId: 534971855
2023-05-24 13:58:51 -07:00
Frederic Bastien
d3216703bd Add a test 2023-05-19 12:24:18 -07:00
Patrick Kidger
edbf7a5a37 linear_transpose now performs DCE before transposition 2023-05-13 18:20:37 -07:00
Yash Katariya
6506ee2a40 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Rollback] Update required Python version to 3.9

PiperOrigin-RevId: 528905991
2023-05-02 15:33:29 -07:00
Jake VanderPlas
57af5360a1 Update required Python version to 3.9 2023-05-01 10:00:57 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Yash Katariya
53e6382f4a Add arg_names to aval mismatch error raised during AOT compilation to raise better error messages
PiperOrigin-RevId: 525561905
2023-04-19 15:08:53 -07:00
Yash Katariya
3e93833ed8 Remove in_parts, out_parts from jax.xla_computation since they were only used for sharded_jit and sharded_jit is long gone
Also remove instantiate_const_outputs since that is unused

PiperOrigin-RevId: 524113088
2023-04-13 15:05:21 -07:00
Peter Hawkins
2e524411db Add unregistered mhlo.num_replicas and mhlo.num_partitions attributes to HLO output.
These are to allow PJRT plugin developers an inline way to determine the number of replicas/partitions to which the module is targeted. There are no stability guarantees on these attributes at the moment.

PiperOrigin-RevId: 524013922
2023-04-13 08:55:44 -07:00
Matthew Johnson
03e72e3b77 Copybara import of the project:
--
75a7e7a07d58e14de73190d060414fd3a1ba3d52 by Matthew Johnson <mattjj@google.com>:

Handle jaxpr-round-tripping of custom jvp rules w/ sym zero

fixes #14833

Co-authored-by: Roy Frostig <frostig@google.com>
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/15426 from mattjj:custom-jvp-symbolic-zeros-3 75a7e7a07d58e14de73190d060414fd3a1ba3d52
PiperOrigin-RevId: 523817551
2023-04-12 15:11:55 -07:00
Matthew Johnson
c6fa6c1557 check that overflow is raised on builtin int overflow
Co-authored-by: Jake Vanderplas <jakevdp@google.com>
2023-04-11 15:42:31 -07:00
Tom Hennigan
d4ccab5a78 In test_concurrent_device_get_and_put make sure to init on main thread.
PiperOrigin-RevId: 523367633
2023-04-11 04:53:26 -07:00
jax authors
c42aae9fd7 Merge pull request #15221 from froystig:custom-vjp-symbolic-zeros2
PiperOrigin-RevId: 522823918
2023-04-08 09:49:45 -07:00
Matthew Johnson
ccb58783da sick eq store, and test
Co-authored-by: Roy Frostig <frostig@google.com>
2023-04-07 18:56:50 -07:00
Matthew Johnson
26562a4382 [JAX] Add jax.clear_caches, plumb a way to clear pmap caches
fixes #10828

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 522654093
2023-04-07 12:19:00 -07:00