1010 Commits

Author SHA1 Message Date
Peter Hawkins
926e42e025 [JAX] Delete ShardedDeviceArray.
Replace it with a temporary shim that is Any to type checkers and an uninstantiatable class at runtime.

PiperOrigin-RevId: 518074394
2023-03-20 14:24:09 -07:00
Anish Tondwalkar
143dfcd74b Eigh primitive is now a customcall
PiperOrigin-RevId: 518074163
2023-03-20 14:17:29 -07:00
Anish Tondwalkar
bf416a8b5c geqrf_p and householder_product_p directly call custom_calls
This replaces the xla_fallback path, which just used the Client HLO API to
generate custom_calls.

PiperOrigin-RevId: 518060025
2023-03-20 13:29:29 -07:00
George Necula
82b7c03d39 [jax2tf] Minor improvement in an error message 2023-03-18 11:00:46 +02:00
Blake Hechtman
1412eca9ea [LAX:RBG] Allow any type to RngBitGenerator. BF16 values are heavily quantized for long distributions which leads to failing the distribution test but in reality the distributions match.
PiperOrigin-RevId: 517586411
2023-03-17 22:39:43 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Peter Hawkins
01dcd3a3fc Relax argument type annotation for lax.dynamic_slice.
PiperOrigin-RevId: 516881433
2023-03-15 11:28:22 -07:00
Peter Hawkins
b6c1cd904c Relax the argument type annotation of dynamic_index_in_dim.
dynamic_index_in_dim accepts concrete scalars also.

PiperOrigin-RevId: 516537760
2023-03-14 08:55:01 -07:00
Yash Katariya
136749d955 Bump minimum jaxlib version to 0.4.6 which means xla_extension_version == 137 and mlir_api_version == 45
PiperOrigin-RevId: 516364523
2023-03-13 17:09:41 -07:00
Jake VanderPlas
760deb310e Remove leading underscores in jax._src.numpy.util 2023-03-13 12:18:36 -07:00
Peter Hawkins
1925aa1109 Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.

PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
jax authors
c27f79e41d Merge pull request #14849 from jakeh-gc:fix_implicit_rank_bcast
PiperOrigin-RevId: 515186872
2023-03-08 17:28:57 -08:00
Jake Hall
e0c2185c1d Fix implicit rank promotion. 2023-03-08 22:53:39 +00:00
Jake VanderPlas
c8c269f5f5 internal: avoid unused imports in lax_numpy 2023-03-08 10:29:04 -08:00
jax authors
cc694c66ce Merge pull request #14798 from nicholasjng:custom-linear-solve-batching-fix
PiperOrigin-RevId: 514873672
2023-03-07 16:39:18 -08:00
Matthew Johnson
b05975b964 add result info to mhlo, fixes #14780
incidentally fixes #14787
2023-03-06 21:21:26 -08:00
Peter Hawkins
eb286315fa Fix a TODO updating users of lax._check_user_dtype_supported to dtypes.check_user_dtype_supported.
PiperOrigin-RevId: 514435374
2023-03-06 09:33:01 -08:00
Nicholas Junge
89039ecc9c Fix custom_linear_solve batching rule in case of auxiliary arguments
Previously, batching in-/out axes of the wrong lengths were passed into
the batched jaxpr builders for the `matvec` and `solve_t` jaxprs. This
commit is a best-effort fix from debugging the axes designations in
the batched jaxpr constructions of these functions.
2023-03-06 17:03:48 +01:00
George Necula
c51537f827 [shape_poly] Add support for jnp.cum{sum,prod,max,min} with shape polymorphism
Unfortunately, on CPU and GPU where we use associative scan, we cannot
support shape polymorphism with native lowering.
2023-03-03 10:22:45 +01:00
Jake VanderPlas
853f65fd99 DOC: clarify variable names in scan doc
Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-03-02 10:25:54 -08:00
Anish Tondwalkar
3bad6fa223 [CHLO] Add erf_inv and lowering to mhlo
PiperOrigin-RevId: 513183138
2023-03-01 02:52:52 -08:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Peter Hawkins
148774587a Remove circular dependency between source_info_util and util.
Move util.new_name_stack into source_info_util. Replace uses of util.extend_name_stack with stack.extend().

PiperOrigin-RevId: 512685810
2023-02-27 11:41:46 -08:00
Jake VanderPlas
4918b9d1d0 DOC: improve lax.dot_general documentation 2023-02-27 09:46:04 -08:00
jax authors
8ebfb0be48 Merge pull request #14614 from sharadmv:ref
PiperOrigin-RevId: 512315462
2023-02-25 11:12:00 -08:00
Yash Katariya
d84ac2240c Remove use_stablehlo as minimum mlir_api_version >= 43
PiperOrigin-RevId: 512176274
2023-02-24 15:20:09 -08:00
Sharad Vikram
4960e656af Refactor Ref abstract type to contain other AbstractValues 2023-02-23 17:02:40 -08:00
Sharad Vikram
a6c4c87f3e Add JaxprInputEffect and refactor StateEffects to use it 2023-02-21 16:30:06 -08:00
Sharad Vikram
af2306c0a8 Refactor effects system to use effect types, not objects 2023-02-17 17:40:08 -08:00
jax authors
51182258bb Merge pull request #14529 from jakevdp:lax-bitcast-validation
PiperOrigin-RevId: 510410676
2023-02-17 06:01:15 -08:00
Peter Hawkins
54269c1145 Remove more exported names from jax.interpreters.xla.
None of these appear to have public users, and this module is not included in the deprecation policy.

Also:
* shorten a number of alias chains.
* move make_op_metadata() into its only caller in jax2tf
* delete the unused function dtype_to_primitive_type.
PiperOrigin-RevId: 510205315
2023-02-16 11:56:30 -08:00
Jake VanderPlas
11acec03c3 lax.bitcast_convert_type: better input validation 2023-02-16 10:56:06 -08:00
Jake VanderPlas
b18cbbe101 lax.bitcast_convert_type: support casting between types of different width 2023-02-16 08:21:18 -08:00
Peter Hawkins
cd0533cab0 Replace uses of jnp.ndarray with jax.Array inside JAX.
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
George Necula
582c042079 Implement lowering for convolutions with dynamic padding
PiperOrigin-RevId: 509451627
2023-02-14 00:55:45 -08:00
jax authors
9e01ee4d50 Merge pull request #14457 from mattjj:djax-bug-fix
PiperOrigin-RevId: 509377741
2023-02-13 17:28:37 -08:00
Matthew Johnson
96c558d5de fix minor broadcasting bug
Co-authored-by: Adam Paszke <apaszke@google.com>
2023-02-13 15:13:13 -08:00
Yash Katariya
d0eedf7e57 Plumb spmd_axis_name through batch_jaxpr2 and batch_jaxpr
PiperOrigin-RevId: 509341618
2023-02-13 14:58:20 -08:00
Roy Frostig
1c84e4a753 migrate internal dependencies from jax.interpreters.batching to jax._src.interpreters.batching
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.

PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
Matthew Johnson
a964dc3b9a simpler pretty-print for pjit, tweak custom pp rule signature 2023-02-09 12:45:51 -08:00
Peter Hawkins
cc8d7fae32 Move jax.interpreters.mlir to jax._src.interpreters.mlir.
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
Jake VanderPlas
3c6183498a lax.top_k: improve documentation and errors on invalid values 2023-02-08 11:07:56 -08:00
Peter Hawkins
98b75cf27b Prune accidental exports from jax.interpreters.pxla.
These imports do not appear to have users outside JAX itself.

PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
Roy Frostig
219723c738 migrate internal dependencies from jax.interpreters.ad to jax._src.interpreters.ad
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.

Includes some import fixups along the way.

PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Yash Katariya
8a69444ff9 Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
Peter Hawkins
def35b7e24 Remove scatter/gather dimension proto helpers.
These are unused since the MHLO switch.

PiperOrigin-RevId: 506969590
2023-02-03 12:40:31 -08:00
Jake VanderPlas
0b5443c6e8 Clean up: remove unused helper functions 2023-02-01 09:55:58 -08:00
Jake VanderPlas
671c72a782 Update signature of ad.defbilinear to simplify transpose rules 2023-01-31 09:07:39 -08:00