George Necula
f147e82fa7
[shape_poly] Add support for evaluating div/mod for DimExpr
...
We have added the ability to represent floordiv and mod to
DimExper. Here we add support for evaluating these dimensions
for the native lowering.
2023-02-03 17:44:26 +02:00
jax authors
b8d6efe22f
Merge pull request #14273 from mattjj:shard-map
...
PiperOrigin-RevId: 506820113
2023-02-02 23:25:39 -08:00
Matthew Johnson
ff1e9b3973
shard_map (shmap) prototype and JEP
...
Co-authored-by: Sharad Vikram <sharadmv@google.com>
Co-authored-by: Sholto Douglas <sholto@google.com>
2023-02-02 23:01:30 -08:00
jax authors
f18b91d928
Merge pull request #14276 from mattjj:core-type-annotation-tweaks
...
PiperOrigin-RevId: 506802661
2023-02-02 21:15:54 -08:00
Matthew Johnson
644d3b650f
minor tweaks to type annotations, specialize code on those types
...
I noticed some slightly-too-general type annotations in core.py. By tightening
them we could simplify the code too. (I think these were leftovers from
pre-omnistaging...)
2023-02-02 20:24:26 -08:00
jax authors
0f289ab0e3
Merge pull request #14174 from google:pjrt_test
...
PiperOrigin-RevId: 506751529
2023-02-02 16:23:26 -08:00
jax authors
30c9376f67
Merge pull request #14272 from jakevdp:conditions
...
PiperOrigin-RevId: 506737218
2023-02-02 15:22:24 -08:00
Jake VanderPlas
c4ec2996af
Sharp bits: mention alternatives to lax.cond
2023-02-02 13:19:26 -08:00
jax authors
5e5199567d
Merge pull request #14269 from hawkinsp:notimpl
...
PiperOrigin-RevId: 506697948
2023-02-02 12:55:47 -08:00
jax authors
6d2aea2d3a
Merge pull request #14270 from hawkinsp:device
...
PiperOrigin-RevId: 506697883
2023-02-02 12:48:09 -08:00
Peter Hawkins
b730ed4645
Remove placeholder functions for unimplemented NumPy functions.
...
These don't seem necessary now JAX has fairly complete coverage of the NumPy API. Also removes the accidental export of _NOT_IMPLEMENTED in several modules.
2023-02-02 13:00:18 -05:00
Peter Hawkins
74f1ab0503
Export Device as jax.Device.
...
Users are writing things like jax.lib.xla_client.Device in type annotations which is not a public API. Add a supported public name for the Device type.
2023-02-02 12:58:15 -05:00
Peter Hawkins
365262b77a
Reapply: move jax.interpreters.ad
to jax._src.interpreters.ad
...
Re-export roughly all of the same symbols via `jax.interpreters.ad` for now.
This version of the PR includes the names jax.interpreters.ad.source_info_util and jax.interpreters.ad.config, which the neural tangents is using.
PiperOrigin-RevId: 506642132
2023-02-02 09:29:05 -08:00
jax authors
795c14b388
Merge pull request #14252 from jakevdp:sparse-conv
...
PiperOrigin-RevId: 506641181
2023-02-02 09:21:26 -08:00
Peter Hawkins
04525e896e
Revert: move jax.interpreters.ad
to jax._src.interpreters.ad
...
Re-export roughly all of the same symbols via `jax.interpreters.ad` for now.
This change broke some tests.
PiperOrigin-RevId: 506606721
2023-02-02 06:52:47 -08:00
Yash Katariya
e5b2c5ea44
Remove the jit_pjit_api_merge disable for api_test now that it is passing
...
PiperOrigin-RevId: 506508148
2023-02-01 21:03:30 -08:00
jax authors
a79dea58eb
Merge pull request #14263 from mattjj:custom-jvp-nondiff-argnums-tracers
...
PiperOrigin-RevId: 506506008
2023-02-01 20:49:47 -08:00
jax authors
a7964a7773
Merge pull request #14262 from froystig:issue14249
...
PiperOrigin-RevId: 506503680
2023-02-01 20:35:44 -08:00
Matthew Johnson
cd615b6be8
skip custom_jvp/vjp tests which dont work with initial-style staging
...
These tests, involving nondiff_argnums and/or closing over tracers, happen to
work with final-style JIT but not our initial-style primitives. We shouldn't
support this behavior anyway; there are good alternatives.
2023-02-01 20:34:47 -08:00
Roy Frostig
26b75ff4ae
add "linear solve batching via jacrev
" test from github.com/google/jax/issues/14249
2023-02-01 20:01:53 -08:00
Roy Frostig
e199b35f4e
Revert "Merge pull request #14113 from botev:main"
...
This reverts commit 69d18cc7b58ae4ed82246605d66ed07a49fad676, reversing
changes made to 13e875f8b8d8dd9152045c7e3b5045a9bb0d7db0.
Reverting until we address https://github.com/google/jax/issues/14249
2023-02-01 19:50:27 -08:00
Roy Frostig
0e77af0a28
move jax.interpreters.ad
to jax._src.interpreters.ad
...
Re-export roughly all of the same symbols via `jax.interpreters.ad` for now.
PiperOrigin-RevId: 506490796
2023-02-01 19:46:47 -08:00
Jake VanderPlas
038798ed25
[sparse] add support for simple 1D convolutions
2023-02-01 18:53:49 -08:00
jax authors
4d56def91f
Merge pull request #14257 from jakevdp:sparse-rev
...
PiperOrigin-RevId: 506483272
2023-02-01 18:51:58 -08:00
Eugene Zhulenev
9d5132f1fb
[jax] Skip compilation cache test for older jaxlibs
...
PiperOrigin-RevId: 506460144
2023-02-01 16:53:19 -08:00
jax authors
7a5a63f2ad
Merge pull request #14250 from mattjj:checkify-retracing
...
PiperOrigin-RevId: 506458253
2023-02-01 16:44:56 -08:00
Jake VanderPlas
4fa80b44cd
[sparse] implement sparse rule for lax.rev
2023-02-01 15:43:47 -08:00
Yash Katariya
782a34f5e9
Add more logging to serialization code to figure out exactly where we are during async checkpointing.
...
PiperOrigin-RevId: 506438425
2023-02-01 15:24:46 -08:00
jax authors
06e3d8cada
Merge pull request #14251 from jakevdp:sparse-len
...
PiperOrigin-RevId: 506428591
2023-02-01 14:53:47 -08:00
jax authors
1a858fedd1
Merge pull request #14254 from jakevdp:cleanup-dtypes
...
PiperOrigin-RevId: 506428355
2023-02-01 14:46:24 -08:00
John QiangZhang
0cd3dee349
Consolidate the experimental_get_compiler_ir eager and tf function path in jax2tf.call_tf.
...
PiperOrigin-RevId: 506424270
2023-02-01 14:31:14 -08:00
Roy Frostig
c241ae60b1
add blank line, mainly to trigger/test source sync
...
PiperOrigin-RevId: 506414439
2023-02-01 13:56:29 -08:00
Peter Hawkins
c90a85403b
Merge pull request #14248 from jakevdp:dead-code
...
PiperOrigin-RevId: 506405131
2023-02-01 21:25:46 +00:00
Jake VanderPlas
72dfb23c90
Remove jax.dtypes._jax_types
2023-02-01 12:49:06 -08:00
Jake VanderPlas
27c068e7b7
[sparse] implement __len__ on sparse objects
2023-02-01 11:46:02 -08:00
Matthew Johnson
684846bd0f
checkify: cache jaxpr formation so we don't always retrace
2023-02-01 10:19:47 -08:00
Jake VanderPlas
0b5443c6e8
Clean up: remove unused helper functions
2023-02-01 09:55:58 -08:00
jax authors
fcb9dfb080
Merge pull request #14236 from rwitten:rwitten_debug_docs
...
PiperOrigin-RevId: 506154997
2023-01-31 16:52:16 -08:00
jax authors
b0202b6ae2
Merge pull request #14235 from jakevdp:take-doc
...
PiperOrigin-RevId: 506151220
2023-01-31 16:34:50 -08:00
Rafi Witten
278ff25ae1
Update docs that jax.debug is unsupported on Cloud TPUs
2023-02-01 00:12:51 +00:00
Jake VanderPlas
14a0fe08c8
DOC: improve documentation of OOB indices in jnp.take
2023-01-31 15:59:06 -08:00
jax authors
957adbd5ea
Merge pull request #14234 from jakevdp:fix-doc
...
PiperOrigin-RevId: 506134475
2023-01-31 15:30:01 -08:00
Jake VanderPlas
179f5ab200
Fix documentation broken by 8dc1dff6
2023-01-31 15:23:31 -08:00
Peter Hawkins
8dc1dff610
Remove device_count, local_device_count, process_index exports from xla_bridge.
...
These were accidental exports and have public equivalents under the top-level jax namespace. The deprecation policy does not apply to names under jax.lib, which is intended to be private.
PiperOrigin-RevId: 506088434
2023-01-31 13:01:19 -08:00
jax authors
aaae5ed32d
Merge pull request #14230 from jakevdp:unused-cache-info
...
PiperOrigin-RevId: 506077999
2023-01-31 12:19:51 -08:00
Jake VanderPlas
b679ef025f
Remove unused CacheInfo namedtuple
2023-01-31 11:36:43 -08:00
Yash Katariya
518bb56c6e
Add is_ready() method to PyArray
...
PiperOrigin-RevId: 506044282
2023-01-31 10:33:09 -08:00
jax authors
188d66208d
Merge pull request #14211 from jakevdp:simplify-conv
...
PiperOrigin-RevId: 506034339
2023-01-31 10:01:27 -08:00
Jake VanderPlas
671c72a782
Update signature of ad.defbilinear to simplify transpose rules
2023-01-31 09:07:39 -08:00
jax authors
b374080e85
Merge pull request #14219 from eltociear:patch-4
...
PiperOrigin-RevId: 506017241
2023-01-31 08:56:29 -08:00