Peter Hawkins
926e42e025
[JAX] Delete ShardedDeviceArray.
...
Replace it with a temporary shim that is Any to type checkers and an uninstantiatable class at runtime.
PiperOrigin-RevId: 518074394
2023-03-20 14:24:09 -07:00
Anish Tondwalkar
143dfcd74b
Eigh primitive is now a customcall
...
PiperOrigin-RevId: 518074163
2023-03-20 14:17:29 -07:00
Anish Tondwalkar
bf416a8b5c
geqrf_p and householder_product_p directly call custom_calls
...
This replaces the xla_fallback path, which just used the Client HLO API to
generate custom_calls.
PiperOrigin-RevId: 518060025
2023-03-20 13:29:29 -07:00
George Necula
82b7c03d39
[jax2tf] Minor improvement in an error message
2023-03-18 11:00:46 +02:00
Blake Hechtman
1412eca9ea
[LAX:RBG] Allow any type to RngBitGenerator. BF16 values are heavily quantized for long distributions which leads to failing the distribution test but in reality the distributions match.
...
PiperOrigin-RevId: 517586411
2023-03-17 22:39:43 -07:00
Peter Hawkins
dea7450e4e
Remove references to jax.config.jax_array, which is always True at head.
...
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00
Peter Hawkins
01dcd3a3fc
Relax argument type annotation for lax.dynamic_slice.
...
PiperOrigin-RevId: 516881433
2023-03-15 11:28:22 -07:00
Peter Hawkins
b6c1cd904c
Relax the argument type annotation of dynamic_index_in_dim.
...
dynamic_index_in_dim accepts concrete scalars also.
PiperOrigin-RevId: 516537760
2023-03-14 08:55:01 -07:00
Yash Katariya
136749d955
Bump minimum jaxlib version to 0.4.6 which means xla_extension_version == 137 and mlir_api_version == 45
...
PiperOrigin-RevId: 516364523
2023-03-13 17:09:41 -07:00
Jake VanderPlas
760deb310e
Remove leading underscores in jax._src.numpy.util
2023-03-13 12:18:36 -07:00
Peter Hawkins
1925aa1109
Split Sharding subclasses out of _src/sharding.py into _src/sharding_impls.py
...
By defining the Sharding base class in its own module, we can pull it out into a separate Bazel submodule, which will help pytype inference when defining Array.
PiperOrigin-RevId: 516223009
2023-03-13 08:50:18 -07:00
jax authors
c27f79e41d
Merge pull request #14849 from jakeh-gc:fix_implicit_rank_bcast
...
PiperOrigin-RevId: 515186872
2023-03-08 17:28:57 -08:00
Jake Hall
e0c2185c1d
Fix implicit rank promotion.
2023-03-08 22:53:39 +00:00
Jake VanderPlas
c8c269f5f5
internal: avoid unused imports in lax_numpy
2023-03-08 10:29:04 -08:00
jax authors
cc694c66ce
Merge pull request #14798 from nicholasjng:custom-linear-solve-batching-fix
...
PiperOrigin-RevId: 514873672
2023-03-07 16:39:18 -08:00
Matthew Johnson
b05975b964
add result info to mhlo, fixes #14780
...
incidentally fixes #14787
2023-03-06 21:21:26 -08:00
Peter Hawkins
eb286315fa
Fix a TODO updating users of lax._check_user_dtype_supported to dtypes.check_user_dtype_supported.
...
PiperOrigin-RevId: 514435374
2023-03-06 09:33:01 -08:00
Nicholas Junge
89039ecc9c
Fix custom_linear_solve
batching rule in case of auxiliary arguments
...
Previously, batching in-/out axes of the wrong lengths were passed into
the batched jaxpr builders for the `matvec` and `solve_t` jaxprs. This
commit is a best-effort fix from debugging the axes designations in
the batched jaxpr constructions of these functions.
2023-03-06 17:03:48 +01:00
George Necula
c51537f827
[shape_poly] Add support for jnp.cum{sum,prod,max,min} with shape polymorphism
...
Unfortunately, on CPU and GPU where we use associative scan, we cannot
support shape polymorphism with native lowering.
2023-03-03 10:22:45 +01:00
Jake VanderPlas
853f65fd99
DOC: clarify variable names in scan doc
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-03-02 10:25:54 -08:00
Anish Tondwalkar
3bad6fa223
[CHLO] Add erf_inv and lowering to mhlo
...
PiperOrigin-RevId: 513183138
2023-03-01 02:52:52 -08:00
Peter Hawkins
8fb1fd318d
Replace jax._src.util.prod with math.prod.
...
math.prod() was added in Python 3.8, so we can assume it is always present.
PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Peter Hawkins
f66f6ec98a
[JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
...
Limit jax._src.lib to shims around jaxlib and nothing else.
The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.
PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Peter Hawkins
148774587a
Remove circular dependency between source_info_util and util.
...
Move util.new_name_stack into source_info_util. Replace uses of util.extend_name_stack with stack.extend().
PiperOrigin-RevId: 512685810
2023-02-27 11:41:46 -08:00
Jake VanderPlas
4918b9d1d0
DOC: improve lax.dot_general documentation
2023-02-27 09:46:04 -08:00
jax authors
8ebfb0be48
Merge pull request #14614 from sharadmv:ref
...
PiperOrigin-RevId: 512315462
2023-02-25 11:12:00 -08:00
Yash Katariya
d84ac2240c
Remove use_stablehlo as minimum mlir_api_version >= 43
...
PiperOrigin-RevId: 512176274
2023-02-24 15:20:09 -08:00
Sharad Vikram
4960e656af
Refactor Ref
abstract type to contain other AbstractValue
s
2023-02-23 17:02:40 -08:00
Sharad Vikram
a6c4c87f3e
Add JaxprInputEffect
and refactor StateEffect
s to use it
2023-02-21 16:30:06 -08:00
Sharad Vikram
af2306c0a8
Refactor effects system to use effect types, not objects
2023-02-17 17:40:08 -08:00
jax authors
51182258bb
Merge pull request #14529 from jakevdp:lax-bitcast-validation
...
PiperOrigin-RevId: 510410676
2023-02-17 06:01:15 -08:00
Peter Hawkins
54269c1145
Remove more exported names from jax.interpreters.xla.
...
None of these appear to have public users, and this module is not included in the deprecation policy.
Also:
* shorten a number of alias chains.
* move make_op_metadata() into its only caller in jax2tf
* delete the unused function dtype_to_primitive_type.
PiperOrigin-RevId: 510205315
2023-02-16 11:56:30 -08:00
Jake VanderPlas
11acec03c3
lax.bitcast_convert_type: better input validation
2023-02-16 10:56:06 -08:00
Jake VanderPlas
b18cbbe101
lax.bitcast_convert_type: support casting between types of different width
2023-02-16 08:21:18 -08:00
Peter Hawkins
cd0533cab0
Replace uses of jnp.ndarray with jax.Array inside JAX.
...
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00
Roy Frostig
cb8dcce2fe
migrate more internal dependencies from jax.core
to jax._src.core
...
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
George Necula
582c042079
Implement lowering for convolutions with dynamic padding
...
PiperOrigin-RevId: 509451627
2023-02-14 00:55:45 -08:00
jax authors
9e01ee4d50
Merge pull request #14457 from mattjj:djax-bug-fix
...
PiperOrigin-RevId: 509377741
2023-02-13 17:28:37 -08:00
Matthew Johnson
96c558d5de
fix minor broadcasting bug
...
Co-authored-by: Adam Paszke <apaszke@google.com>
2023-02-13 15:13:13 -08:00
Yash Katariya
d0eedf7e57
Plumb spmd_axis_name through batch_jaxpr2 and batch_jaxpr
...
PiperOrigin-RevId: 509341618
2023-02-13 14:58:20 -08:00
Roy Frostig
1c84e4a753
migrate internal dependencies from jax.interpreters.batching
to jax._src.interpreters.batching
...
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.
PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
Matthew Johnson
a964dc3b9a
simpler pretty-print for pjit, tweak custom pp rule signature
2023-02-09 12:45:51 -08:00
Peter Hawkins
cc8d7fae32
Move jax.interpreters.mlir to jax._src.interpreters.mlir.
...
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.
PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
Jake VanderPlas
3c6183498a
lax.top_k: improve documentation and errors on invalid values
2023-02-08 11:07:56 -08:00
Peter Hawkins
98b75cf27b
Prune accidental exports from jax.interpreters.pxla.
...
These imports do not appear to have users outside JAX itself.
PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
Roy Frostig
219723c738
migrate internal dependencies from jax.interpreters.ad
to jax._src.interpreters.ad
...
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.
Includes some import fixups along the way.
PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Yash Katariya
8a69444ff9
Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
...
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
Peter Hawkins
def35b7e24
Remove scatter/gather dimension proto helpers.
...
These are unused since the MHLO switch.
PiperOrigin-RevId: 506969590
2023-02-03 12:40:31 -08:00
Jake VanderPlas
0b5443c6e8
Clean up: remove unused helper functions
2023-02-01 09:55:58 -08:00
Jake VanderPlas
671c72a782
Update signature of ad.defbilinear to simplify transpose rules
2023-01-31 09:07:39 -08:00