14790 Commits

Author SHA1 Message Date
George Necula
f147e82fa7 [shape_poly] Add support for evaluating div/mod for DimExpr
We have added the ability to represent floordiv and mod to
DimExper. Here we add support for evaluating these dimensions
for the native lowering.
2023-02-03 17:44:26 +02:00
jax authors
b8d6efe22f Merge pull request #14273 from mattjj:shard-map
PiperOrigin-RevId: 506820113
2023-02-02 23:25:39 -08:00
Matthew Johnson
ff1e9b3973 shard_map (shmap) prototype and JEP
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00
jax authors
f18b91d928 Merge pull request #14276 from mattjj:core-type-annotation-tweaks
PiperOrigin-RevId: 506802661
2023-02-02 21:15:54 -08:00
Matthew Johnson
644d3b650f minor tweaks to type annotations, specialize code on those types
I noticed some slightly-too-general type annotations in core.py. By tightening
them we could simplify the code too. (I think these were leftovers from
pre-omnistaging...)
2023-02-02 20:24:26 -08:00
jax authors
0f289ab0e3 Merge pull request #14174 from google:pjrt_test
PiperOrigin-RevId: 506751529
2023-02-02 16:23:26 -08:00
jax authors
30c9376f67 Merge pull request #14272 from jakevdp:conditions
PiperOrigin-RevId: 506737218
2023-02-02 15:22:24 -08:00
Jake VanderPlas
c4ec2996af Sharp bits: mention alternatives to lax.cond 2023-02-02 13:19:26 -08:00
jax authors
5e5199567d Merge pull request #14269 from hawkinsp:notimpl
PiperOrigin-RevId: 506697948
2023-02-02 12:55:47 -08:00
jax authors
6d2aea2d3a Merge pull request #14270 from hawkinsp:device
PiperOrigin-RevId: 506697883
2023-02-02 12:48:09 -08:00
Peter Hawkins
b730ed4645 Remove placeholder functions for unimplemented NumPy functions.
These don't seem necessary now JAX has fairly complete coverage of the NumPy API. Also removes the accidental export of _NOT_IMPLEMENTED in several modules.
2023-02-02 13:00:18 -05:00
Peter Hawkins
74f1ab0503 Export Device as jax.Device.
Users are writing things like jax.lib.xla_client.Device in type annotations which is not a public API. Add a supported public name for the Device type.
2023-02-02 12:58:15 -05:00
Peter Hawkins
365262b77a Reapply: move jax.interpreters.ad to jax._src.interpreters.ad
Re-export roughly all of the same symbols via `jax.interpreters.ad` for now.

This version of the PR includes the names jax.interpreters.ad.source_info_util and jax.interpreters.ad.config, which the neural tangents is using.

PiperOrigin-RevId: 506642132
2023-02-02 09:29:05 -08:00
jax authors
795c14b388 Merge pull request #14252 from jakevdp:sparse-conv
PiperOrigin-RevId: 506641181
2023-02-02 09:21:26 -08:00
Peter Hawkins
04525e896e Revert: move jax.interpreters.ad to jax._src.interpreters.ad
Re-export roughly all of the same symbols via `jax.interpreters.ad` for now.

This change broke some tests.

PiperOrigin-RevId: 506606721
2023-02-02 06:52:47 -08:00
Yash Katariya
e5b2c5ea44 Remove the jit_pjit_api_merge disable for api_test now that it is passing
PiperOrigin-RevId: 506508148
2023-02-01 21:03:30 -08:00
jax authors
a79dea58eb Merge pull request #14263 from mattjj:custom-jvp-nondiff-argnums-tracers
PiperOrigin-RevId: 506506008
2023-02-01 20:49:47 -08:00
jax authors
a7964a7773 Merge pull request #14262 from froystig:issue14249
PiperOrigin-RevId: 506503680
2023-02-01 20:35:44 -08:00
Matthew Johnson
cd615b6be8 skip custom_jvp/vjp tests which dont work with initial-style staging
These tests, involving nondiff_argnums and/or closing over tracers, happen to
work with final-style JIT but not our initial-style primitives. We shouldn't
support this behavior anyway; there are good alternatives.
2023-02-01 20:34:47 -08:00
Roy Frostig
26b75ff4ae add "linear solve batching via jacrev" test from github.com/google/jax/issues/14249 2023-02-01 20:01:53 -08:00
Roy Frostig
e199b35f4e Revert "Merge pull request #14113 from botev:main"
This reverts commit 69d18cc7b58ae4ed82246605d66ed07a49fad676, reversing
changes made to 13e875f8b8d8dd9152045c7e3b5045a9bb0d7db0.

Reverting until we address https://github.com/google/jax/issues/14249
2023-02-01 19:50:27 -08:00
Roy Frostig
0e77af0a28 move jax.interpreters.ad to jax._src.interpreters.ad
Re-export roughly all of the same symbols via `jax.interpreters.ad` for now.

PiperOrigin-RevId: 506490796
2023-02-01 19:46:47 -08:00
Jake VanderPlas
038798ed25 [sparse] add support for simple 1D convolutions 2023-02-01 18:53:49 -08:00
jax authors
4d56def91f Merge pull request #14257 from jakevdp:sparse-rev
PiperOrigin-RevId: 506483272
2023-02-01 18:51:58 -08:00
Eugene Zhulenev
9d5132f1fb [jax] Skip compilation cache test for older jaxlibs
PiperOrigin-RevId: 506460144
2023-02-01 16:53:19 -08:00
jax authors
7a5a63f2ad Merge pull request #14250 from mattjj:checkify-retracing
PiperOrigin-RevId: 506458253
2023-02-01 16:44:56 -08:00
Jake VanderPlas
4fa80b44cd [sparse] implement sparse rule for lax.rev 2023-02-01 15:43:47 -08:00
Yash Katariya
782a34f5e9 Add more logging to serialization code to figure out exactly where we are during async checkpointing.
PiperOrigin-RevId: 506438425
2023-02-01 15:24:46 -08:00
jax authors
06e3d8cada Merge pull request #14251 from jakevdp:sparse-len
PiperOrigin-RevId: 506428591
2023-02-01 14:53:47 -08:00
jax authors
1a858fedd1 Merge pull request #14254 from jakevdp:cleanup-dtypes
PiperOrigin-RevId: 506428355
2023-02-01 14:46:24 -08:00
John QiangZhang
0cd3dee349 Consolidate the experimental_get_compiler_ir eager and tf function path in jax2tf.call_tf.
PiperOrigin-RevId: 506424270
2023-02-01 14:31:14 -08:00
Roy Frostig
c241ae60b1 add blank line, mainly to trigger/test source sync
PiperOrigin-RevId: 506414439
2023-02-01 13:56:29 -08:00
Peter Hawkins
c90a85403b Merge pull request #14248 from jakevdp:dead-code
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Jake VanderPlas
72dfb23c90 Remove jax.dtypes._jax_types 2023-02-01 12:49:06 -08:00
Jake VanderPlas
27c068e7b7 [sparse] implement __len__ on sparse objects 2023-02-01 11:46:02 -08:00
Matthew Johnson
684846bd0f checkify: cache jaxpr formation so we don't always retrace 2023-02-01 10:19:47 -08:00
Jake VanderPlas
0b5443c6e8 Clean up: remove unused helper functions 2023-02-01 09:55:58 -08:00
jax authors
fcb9dfb080 Merge pull request #14236 from rwitten:rwitten_debug_docs
PiperOrigin-RevId: 506154997
2023-01-31 16:52:16 -08:00
jax authors
b0202b6ae2 Merge pull request #14235 from jakevdp:take-doc
PiperOrigin-RevId: 506151220
2023-01-31 16:34:50 -08:00
Rafi Witten
278ff25ae1 Update docs that jax.debug is unsupported on Cloud TPUs 2023-02-01 00:12:51 +00:00
Jake VanderPlas
14a0fe08c8 DOC: improve documentation of OOB indices in jnp.take 2023-01-31 15:59:06 -08:00
jax authors
957adbd5ea Merge pull request #14234 from jakevdp:fix-doc
PiperOrigin-RevId: 506134475
2023-01-31 15:30:01 -08:00
Jake VanderPlas
179f5ab200 Fix documentation broken by 8dc1dff6 2023-01-31 15:23:31 -08:00
Peter Hawkins
8dc1dff610 Remove device_count, local_device_count, process_index exports from xla_bridge.
These were accidental exports and have public equivalents under the top-level jax namespace. The deprecation policy does not apply to names under jax.lib, which is intended to be private.

PiperOrigin-RevId: 506088434
2023-01-31 13:01:19 -08:00
jax authors
aaae5ed32d Merge pull request #14230 from jakevdp:unused-cache-info
PiperOrigin-RevId: 506077999
2023-01-31 12:19:51 -08:00
Jake VanderPlas
b679ef025f Remove unused CacheInfo namedtuple 2023-01-31 11:36:43 -08:00
Yash Katariya
518bb56c6e Add is_ready() method to PyArray
PiperOrigin-RevId: 506044282
2023-01-31 10:33:09 -08:00
jax authors
188d66208d Merge pull request #14211 from jakevdp:simplify-conv
PiperOrigin-RevId: 506034339
2023-01-31 10:01:27 -08:00
Jake VanderPlas
671c72a782 Update signature of ad.defbilinear to simplify transpose rules 2023-01-31 09:07:39 -08:00
jax authors
b374080e85 Merge pull request #14219 from eltociear:patch-4
PiperOrigin-RevId: 506017241
2023-01-31 08:56:29 -08:00