17187 Commits

Author SHA1 Message Date
Peter Hawkins
975dae34a4 Deprecate jax.numpy.trapz.
Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead.

Fixes https://github.com/google/jax/issues/17244
2023-08-25 09:04:13 -06:00
jax authors
a454081390 Merge pull request #17289 from google:nanobind
PiperOrigin-RevId: 560090600
2023-08-25 07:53:04 -07:00
Peter Hawkins
9a5b808853 Update nanobind version to 1.5.2. 2023-08-25 14:45:32 +00:00
Peter Hawkins
ac8ea86103 Fix accidental signature change to get_serialized_metadata() from nanobind PR.
pybind11 accepts either Python strings or bytes as a std::string argument, whereas nanobind accepts only strings. Change the argument to nb::bytes instead.

PiperOrigin-RevId: 560086072
2023-08-25 07:31:31 -07:00
jax authors
87cec1a6e8 Update XLA dependency to use revision
14c0e8565a.

PiperOrigin-RevId: 560077922
2023-08-25 06:51:19 -07:00
Yash Katariya
eea2603363 Add a proper jax config for memories so that we can iteratively develop and enable it.
PiperOrigin-RevId: 559977015
2023-08-24 22:23:55 -07:00
Roy Frostig
009932760c get aval directly via attribute in key array shard arg handler
No need to go through `core.get_aval` here.

PiperOrigin-RevId: 559945841
2023-08-24 19:47:35 -07:00
Peter Hawkins
70b7d50181 Switch jaxlib to use nanobind instead of pybind11.
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.

PiperOrigin-RevId: 559898790
2023-08-24 16:07:56 -07:00
jax authors
ec95862592 Add an example of jax.tree_util.register_static to the jax 101 pytree docs.
PiperOrigin-RevId: 559896201
2023-08-24 15:58:31 -07:00
jax authors
d452eea9b6 Merge pull request #17276 from jakevdp:finfo-tests
PiperOrigin-RevId: 559878846
2023-08-24 14:56:53 -07:00
Roy Frostig
a71c0e6ecc create jax.extend.random as a copy of jax.prng
Co-authored-by: Jake Vanderplas <jakevdp@google.com>
PiperOrigin-RevId: 559874051
2023-08-24 14:41:56 -07:00
Jake VanderPlas
71275b5f89 Remove dtypes_test.py:testFinfo
Why? The definition of finfo has been moved to the ml_dtypes package, which has
more comprehensive unit tests than these.
2023-08-24 12:13:54 -07:00
Ruoxin Sang
48921a1b31 Use self.aval.str_short() to represent array shape in the error message.
PiperOrigin-RevId: 559799200
2023-08-24 10:40:24 -07:00
Jake VanderPlas
0da3a7ffb5 jnp.einsum: lower to mixed-precision dot_general when possible.
This is a re-landing of https://github.com/google/jax/pull/16733. The downstream issues should be fixed by https://github.com/google/jax/pull/17152.

Reverts c6f40e202c7f5724b9be61afa33541a8f4abfdd0

PiperOrigin-RevId: 559794120
2023-08-24 10:31:39 -07:00
Richard Levasseur
f891cbf64b Load Python rules from rules_python
PiperOrigin-RevId: 559789250
2023-08-24 10:22:57 -07:00
Jake VanderPlas
665b176c2c remove deprecated jax.lax.prod function
PiperOrigin-RevId: 559787522
2023-08-24 10:13:59 -07:00
Ce Zheng
26643aa96e [XLA] Delete _xla_host_transfer_original_type and _xla_host_transfer_is_lower_bits.
PiperOrigin-RevId: 559786239
2023-08-24 10:05:01 -07:00
Jake VanderPlas
368d3433a6 Add random benchmarks
The purpose of this is to measure the difference in dispatch seed between raw keys and new-style typed keys. The latter does not yet hit the C++ fast path, and so we expect it to incur a small additional overhead at dispatch time. Part of #9263.

PiperOrigin-RevId: 559782442
2023-08-24 09:55:07 -07:00
jax authors
55079a3910 Merge pull request #17270 from Sai-Suraj-27:new_versions
PiperOrigin-RevId: 559782297
2023-08-24 09:45:13 -07:00
Sai-Suraj-27
b042b1e4c6 Updated mirrors-mypy and pre-commit hooks vesrion and checked all the files by running them.
Updated flake8 version in pre-commit configuration file.
2023-08-24 21:24:21 +05:30
jax authors
36cdafdcf4 Merge pull request #17271 from eltociear:eltociear-patch-1-1
PiperOrigin-RevId: 559766803
2023-08-24 08:49:58 -07:00
Ikko Eltociear Ashimine
2c9de9d9d3
Fix typo in shape_poly_test.py
overriden -> overridden
2023-08-25 00:10:13 +09:00
Christian Sigg
b4f7928a81 [NFC] Explicitly set dialects to usePropertiesForAttributes=0 in preparation for https://reviews.llvm.org/D158581 (flipping the default to 1) to land.
This allows us to switch dialects to use properties one by one.

PiperOrigin-RevId: 559751065
2023-08-24 07:52:19 -07:00
jax authors
443f74bb65 Merge pull request #17261 from hawkinsp:in1d
PiperOrigin-RevId: 559748998
2023-08-24 07:42:21 -07:00
Adam Paszke
45428c1375 [Mosaic] Add an annotation explaining signedness semantics of tpu.pack_elements
PiperOrigin-RevId: 559747952
2023-08-24 07:32:59 -07:00
jax authors
3f60993b27 Update XLA dependency to use revision
116ceaeccd.

PiperOrigin-RevId: 559729371
2023-08-24 06:03:10 -07:00
Ruoxin Sang
943bdd22b1 Enhance "Array has been deleted" error message with shape and type information.
PiperOrigin-RevId: 559673904
2023-08-24 01:06:09 -07:00
jax authors
8280bd457a Merge pull request #17262 from jakevdp:fix-bernoulli-tolerance
PiperOrigin-RevId: 559611733
2023-08-23 19:15:25 -07:00
Jake VanderPlas
75d12a2e21 Fix tolerance on bernoulli test 2023-08-23 16:59:27 -07:00
Peter Hawkins
7c871916f7 Deprecate jax.numpy.in1d.
Issue https://github.com/google/jax/issues/17244
2023-08-23 17:36:14 -06:00
jax authors
f62b9eb3d1 Merge pull request #17253 from jakevdp:bernoulli
PiperOrigin-RevId: 559568126
2023-08-23 15:56:16 -07:00
Yash Katariya
6a63bc567d Return the default memory kind for PmapSharding always. This is because PmapSharding can be given as an input to jit and that shouldn't cause any errors when we canonicalize by default.
PiperOrigin-RevId: 559544534
2023-08-23 14:31:57 -07:00
Yash Katariya
0501a15fd5 Print str_short of the arg and remove printing the value of the arg.
PiperOrigin-RevId: 559524941
2023-08-23 13:31:35 -07:00
jax authors
f19e748303 Merge pull request #17016 from mattjj:royroyroy
PiperOrigin-RevId: 559524338
2023-08-23 13:22:37 -07:00
Jake VanderPlas
042111eb08 Add jax.scipy.special.bernoulli 2023-08-23 12:58:37 -07:00
jax authors
d1547ca45b Ensure that CompileOptions serializes deterministically.
CompileOptions has two serialization mechanisms: Py pickle and
SerializeAsString. Neither mechanism serializes deterministically.
Deterministic serialization (also called idempotent serialization
or in-order serialization) ensures that a given structure
serializes to the same string repeatedly. Both these mechanisms
serialize by first generating the proto and then serializing it.
There are three points to note:

. Deterministic serialization will yield the same result
  even if proto map fields are in a different order. Thus
  map({"1": 1, "2": 2}) and map({"2": 2, "1": 1}) will
  serialize the same.

. Deterministic serialization does not yield the same
  result for repeated fields that are out of order. Thus,
  for message Foo { repeated string s = 1; },
  Foo{s: "1", s: "2"} will not result in the same
  serialization as Foo{s: "2", s: "1"}.

. Deterministic serialization applies only in the context
  of a given binary. It does not apply across releases.

Testing: the original serialization code with the new unit
test fails as expected while the revised code does not.
PiperOrigin-RevId: 559492626
2023-08-23 11:34:21 -07:00
Sharad Vikram
71e867392e [Pallas] Add support for memory space annotations in Mosaic
PiperOrigin-RevId: 559483745
2023-08-23 11:07:22 -07:00
jax authors
86f5de855b Update XLA dependency to use revision
a52ced2433.

PiperOrigin-RevId: 559407111
2023-08-23 06:36:17 -07:00
Yash Katariya
aeb62cc006 Add TransferToMemoryKind as a private API to allow device_put to transfer to different memories without specifying the sharding and allowing the SPMD partitioner to choose the sharding for the intermediate.
Exposing it as a public API can be done later.

PiperOrigin-RevId: 559314369
2023-08-22 22:11:38 -07:00
Sharad Vikram
bad217b2f8 [Pallas] Add support for casting int8->fp* in Mosaic lowering
PiperOrigin-RevId: 559313313
2023-08-22 22:02:25 -07:00
Sharad Vikram
f08df0f0c3 [Pallas] Add Mosaic lowering for pow2
PiperOrigin-RevId: 559291137
2023-08-22 19:42:17 -07:00
Parker Schuh
e58ddb7258 Add _manual_axes support to NamedSharding. This is needed because
custom_partitioning may produce manually sharded axes.

PiperOrigin-RevId: 559288864
2023-08-22 19:24:29 -07:00
jax authors
517e0a93ca Merge pull request #17234 from google:pjrt_tpu
PiperOrigin-RevId: 559272340
2023-08-22 17:42:35 -07:00
jax authors
af42359433 Merge pull request #16419 from mattjj:pow-jvp
PiperOrigin-RevId: 559266945
2023-08-22 17:15:04 -07:00
Skye Wanderman-Milne
9d1cbc7d21 Default to PJRT TPU runtime instead of StreamExecutor on older jaxlibs.
I messed up the forwards compat in
3e50fea29e. The
next jaxlib release won't need the env var at all, but jaxlib 0.4.14
and older still do.
2023-08-23 00:06:16 +00:00
Matthew Johnson
1f8fb2c8bd change lowering rule to satisfy jax2tf 2023-08-22 16:48:11 -07:00
Peter Hawkins
abff9d2898 Remove jax.numpy.alltrue from type stub.
This function is already deprecated.

PiperOrigin-RevId: 559257301
2023-08-22 16:33:38 -07:00
George Necula
26f091e446 [callback] Disable stream_executor tests.
PiperOrigin-RevId: 559252832
2023-08-22 16:15:00 -07:00
jax authors
bf29c5e5f1 Merge pull request #17225 from jakevdp:prng-flag
PiperOrigin-RevId: 559243191
2023-08-22 15:41:10 -07:00
jax authors
c052bbfa68 Merge pull request #17206 from jakevdp:stats-sf
PiperOrigin-RevId: 559241134
2023-08-22 15:32:23 -07:00