Peter Hawkins
975dae34a4
Deprecate jax.numpy.trapz.
...
Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead.
Fixes https://github.com/google/jax/issues/17244
2023-08-25 09:04:13 -06:00
jax authors
a454081390
Merge pull request #17289 from google:nanobind
...
PiperOrigin-RevId: 560090600
2023-08-25 07:53:04 -07:00
Peter Hawkins
9a5b808853
Update nanobind version to 1.5.2.
2023-08-25 14:45:32 +00:00
Peter Hawkins
ac8ea86103
Fix accidental signature change to get_serialized_metadata() from nanobind PR.
...
pybind11 accepts either Python strings or bytes as a std::string argument, whereas nanobind accepts only strings. Change the argument to nb::bytes instead.
PiperOrigin-RevId: 560086072
2023-08-25 07:31:31 -07:00
jax authors
87cec1a6e8
Update XLA dependency to use revision
...
14c0e8565a
.
PiperOrigin-RevId: 560077922
2023-08-25 06:51:19 -07:00
Yash Katariya
eea2603363
Add a proper jax config for memories so that we can iteratively develop and enable it.
...
PiperOrigin-RevId: 559977015
2023-08-24 22:23:55 -07:00
Roy Frostig
009932760c
get aval directly via attribute in key array shard arg handler
...
No need to go through `core.get_aval` here.
PiperOrigin-RevId: 559945841
2023-08-24 19:47:35 -07:00
Peter Hawkins
70b7d50181
Switch jaxlib to use nanobind instead of pybind11.
...
nanobind has a number of advantages (https://nanobind.readthedocs.io/en/latest/why.html ), notably speed of compilation and dispatch, but the main reason to do this for these bindings is because nanobind can target the Python Stable ABI starting with Python 3.12. This means that we will not need to ship per-Python version CUDA plugins starting with Python 3.12.
PiperOrigin-RevId: 559898790
2023-08-24 16:07:56 -07:00
jax authors
ec95862592
Add an example of jax.tree_util.register_static to the jax 101 pytree docs.
...
PiperOrigin-RevId: 559896201
2023-08-24 15:58:31 -07:00
jax authors
d452eea9b6
Merge pull request #17276 from jakevdp:finfo-tests
...
PiperOrigin-RevId: 559878846
2023-08-24 14:56:53 -07:00
Roy Frostig
a71c0e6ecc
create jax.extend.random
as a copy of jax.prng
...
Co-authored-by: Jake Vanderplas <jakevdp@google.com>
PiperOrigin-RevId: 559874051
2023-08-24 14:41:56 -07:00
Jake VanderPlas
71275b5f89
Remove dtypes_test.py:testFinfo
...
Why? The definition of finfo has been moved to the ml_dtypes package, which has
more comprehensive unit tests than these.
2023-08-24 12:13:54 -07:00
Ruoxin Sang
48921a1b31
Use self.aval.str_short()
to represent array shape in the error message.
...
PiperOrigin-RevId: 559799200
2023-08-24 10:40:24 -07:00
Jake VanderPlas
0da3a7ffb5
jnp.einsum: lower to mixed-precision dot_general when possible.
...
This is a re-landing of https://github.com/google/jax/pull/16733 . The downstream issues should be fixed by https://github.com/google/jax/pull/17152 .
Reverts c6f40e202c7f5724b9be61afa33541a8f4abfdd0
PiperOrigin-RevId: 559794120
2023-08-24 10:31:39 -07:00
Richard Levasseur
f891cbf64b
Load Python rules from rules_python
...
PiperOrigin-RevId: 559789250
2023-08-24 10:22:57 -07:00
Jake VanderPlas
665b176c2c
remove deprecated jax.lax.prod function
...
PiperOrigin-RevId: 559787522
2023-08-24 10:13:59 -07:00
Ce Zheng
26643aa96e
[XLA] Delete _xla_host_transfer_original_type
and _xla_host_transfer_is_lower_bits
.
...
PiperOrigin-RevId: 559786239
2023-08-24 10:05:01 -07:00
Jake VanderPlas
368d3433a6
Add random benchmarks
...
The purpose of this is to measure the difference in dispatch seed between raw keys and new-style typed keys. The latter does not yet hit the C++ fast path, and so we expect it to incur a small additional overhead at dispatch time. Part of #9263 .
PiperOrigin-RevId: 559782442
2023-08-24 09:55:07 -07:00
jax authors
55079a3910
Merge pull request #17270 from Sai-Suraj-27:new_versions
...
PiperOrigin-RevId: 559782297
2023-08-24 09:45:13 -07:00
Sai-Suraj-27
b042b1e4c6
Updated mirrors-mypy and pre-commit hooks vesrion and checked all the files by running them.
...
Updated flake8 version in pre-commit configuration file.
2023-08-24 21:24:21 +05:30
jax authors
36cdafdcf4
Merge pull request #17271 from eltociear:eltociear-patch-1-1
...
PiperOrigin-RevId: 559766803
2023-08-24 08:49:58 -07:00
Ikko Eltociear Ashimine
2c9de9d9d3
Fix typo in shape_poly_test.py
...
overriden -> overridden
2023-08-25 00:10:13 +09:00
Christian Sigg
b4f7928a81
[NFC] Explicitly set dialects to usePropertiesForAttributes=0
in preparation for https://reviews.llvm.org/D158581 (flipping the default to 1
) to land.
...
This allows us to switch dialects to use properties one by one.
PiperOrigin-RevId: 559751065
2023-08-24 07:52:19 -07:00
jax authors
443f74bb65
Merge pull request #17261 from hawkinsp:in1d
...
PiperOrigin-RevId: 559748998
2023-08-24 07:42:21 -07:00
Adam Paszke
45428c1375
[Mosaic] Add an annotation explaining signedness semantics of tpu.pack_elements
...
PiperOrigin-RevId: 559747952
2023-08-24 07:32:59 -07:00
jax authors
3f60993b27
Update XLA dependency to use revision
...
116ceaeccd
.
PiperOrigin-RevId: 559729371
2023-08-24 06:03:10 -07:00
Ruoxin Sang
943bdd22b1
Enhance "Array has been deleted" error message with shape and type information.
...
PiperOrigin-RevId: 559673904
2023-08-24 01:06:09 -07:00
jax authors
8280bd457a
Merge pull request #17262 from jakevdp:fix-bernoulli-tolerance
...
PiperOrigin-RevId: 559611733
2023-08-23 19:15:25 -07:00
Jake VanderPlas
75d12a2e21
Fix tolerance on bernoulli test
2023-08-23 16:59:27 -07:00
Peter Hawkins
7c871916f7
Deprecate jax.numpy.in1d.
...
Issue https://github.com/google/jax/issues/17244
2023-08-23 17:36:14 -06:00
jax authors
f62b9eb3d1
Merge pull request #17253 from jakevdp:bernoulli
...
PiperOrigin-RevId: 559568126
2023-08-23 15:56:16 -07:00
Yash Katariya
6a63bc567d
Return the default memory kind for PmapSharding always. This is because PmapSharding can be given as an input to jit and that shouldn't cause any errors when we canonicalize by default.
...
PiperOrigin-RevId: 559544534
2023-08-23 14:31:57 -07:00
Yash Katariya
0501a15fd5
Print str_short
of the arg and remove printing the value of the arg.
...
PiperOrigin-RevId: 559524941
2023-08-23 13:31:35 -07:00
jax authors
f19e748303
Merge pull request #17016 from mattjj:royroyroy
...
PiperOrigin-RevId: 559524338
2023-08-23 13:22:37 -07:00
Jake VanderPlas
042111eb08
Add jax.scipy.special.bernoulli
2023-08-23 12:58:37 -07:00
jax authors
d1547ca45b
Ensure that CompileOptions serializes deterministically.
...
CompileOptions has two serialization mechanisms: Py pickle and
SerializeAsString. Neither mechanism serializes deterministically.
Deterministic serialization (also called idempotent serialization
or in-order serialization) ensures that a given structure
serializes to the same string repeatedly. Both these mechanisms
serialize by first generating the proto and then serializing it.
There are three points to note:
. Deterministic serialization will yield the same result
even if proto map fields are in a different order. Thus
map({"1": 1, "2": 2}) and map({"2": 2, "1": 1}) will
serialize the same.
. Deterministic serialization does not yield the same
result for repeated fields that are out of order. Thus,
for message Foo { repeated string s = 1; },
Foo{s: "1", s: "2"} will not result in the same
serialization as Foo{s: "2", s: "1"}.
. Deterministic serialization applies only in the context
of a given binary. It does not apply across releases.
Testing: the original serialization code with the new unit
test fails as expected while the revised code does not.
PiperOrigin-RevId: 559492626
2023-08-23 11:34:21 -07:00
Sharad Vikram
71e867392e
[Pallas] Add support for memory space annotations in Mosaic
...
PiperOrigin-RevId: 559483745
2023-08-23 11:07:22 -07:00
jax authors
86f5de855b
Update XLA dependency to use revision
...
a52ced2433
.
PiperOrigin-RevId: 559407111
2023-08-23 06:36:17 -07:00
Yash Katariya
aeb62cc006
Add TransferToMemoryKind
as a private API to allow device_put to transfer to different memories without specifying the sharding and allowing the SPMD partitioner to choose the sharding for the intermediate.
...
Exposing it as a public API can be done later.
PiperOrigin-RevId: 559314369
2023-08-22 22:11:38 -07:00
Sharad Vikram
bad217b2f8
[Pallas] Add support for casting int8->fp* in Mosaic lowering
...
PiperOrigin-RevId: 559313313
2023-08-22 22:02:25 -07:00
Sharad Vikram
f08df0f0c3
[Pallas] Add Mosaic lowering for pow2
...
PiperOrigin-RevId: 559291137
2023-08-22 19:42:17 -07:00
Parker Schuh
e58ddb7258
Add _manual_axes support to NamedSharding. This is needed because
...
custom_partitioning may produce manually sharded axes.
PiperOrigin-RevId: 559288864
2023-08-22 19:24:29 -07:00
jax authors
517e0a93ca
Merge pull request #17234 from google:pjrt_tpu
...
PiperOrigin-RevId: 559272340
2023-08-22 17:42:35 -07:00
jax authors
af42359433
Merge pull request #16419 from mattjj:pow-jvp
...
PiperOrigin-RevId: 559266945
2023-08-22 17:15:04 -07:00
Skye Wanderman-Milne
9d1cbc7d21
Default to PJRT TPU runtime instead of StreamExecutor on older jaxlibs.
...
I messed up the forwards compat in
3e50fea29e
. The
next jaxlib release won't need the env var at all, but jaxlib 0.4.14
and older still do.
2023-08-23 00:06:16 +00:00
Matthew Johnson
1f8fb2c8bd
change lowering rule to satisfy jax2tf
2023-08-22 16:48:11 -07:00
Peter Hawkins
abff9d2898
Remove jax.numpy.alltrue from type stub.
...
This function is already deprecated.
PiperOrigin-RevId: 559257301
2023-08-22 16:33:38 -07:00
George Necula
26f091e446
[callback] Disable stream_executor tests.
...
PiperOrigin-RevId: 559252832
2023-08-22 16:15:00 -07:00
jax authors
bf29c5e5f1
Merge pull request #17225 from jakevdp:prng-flag
...
PiperOrigin-RevId: 559243191
2023-08-22 15:41:10 -07:00
jax authors
c052bbfa68
Merge pull request #17206 from jakevdp:stats-sf
...
PiperOrigin-RevId: 559241134
2023-08-22 15:32:23 -07:00