249 Commits

Author SHA1 Message Date
Roy Frostig
f18bff5371 inline and remove scatter_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
cc54b6e6ad inline and remove select_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
301d058b3d inline and remove gather_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
071c77e5bb inline and remove transpose_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
06132ac764 inline and remove broadcast_in_dim_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
0ac792f4ed inline and remove dynamic_update_slice_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
2dbdf1a6c1 inline and remove dynamic_slice_mlir rules 2023-05-17 20:07:59 -07:00
Roy Frostig
aed77c5031 inline and remove slice_mlir rules 2023-05-17 20:07:58 -07:00
Roy Frostig
129a4a5f35 inline and remove empty_mlir rules 2023-05-17 20:07:58 -07:00
Roy Frostig
180e26dafb remove physical_avals rule in favor of physical_element_aval 2023-05-17 20:07:58 -07:00
Jake VanderPlas
48abe7c684 PRNGKeyArray: add several missing attributes & methods 2023-05-17 14:47:22 -07:00
Jake VanderPlas
6ef4e5f01a Custom PRNG: make KeyArray compatible with custom_jvp 2023-05-17 10:31:09 -07:00
Jake VanderPlas
0e483223c6 Custom PRNG: support lax.full() and related constructors 2023-05-17 09:04:50 -07:00
Peter Hawkins
eaf7eb2626 Break cycle between _src/core.py and _src/dtypes.py.
PiperOrigin-RevId: 532788430
2023-05-17 07:58:59 -07:00
Jake VanderPlas
b9aa236dac Custom PRNG: support PRNGKeyArray.copy() 2023-05-12 15:50:22 -07:00
Jake VanderPlas
6ada8785aa PRNGKeyArray: fix dynamic slice index dtype 2023-05-10 09:24:18 -07:00
jax authors
68ba54241c Merge pull request #15929 from gnecula:fix_mlir_ir
PiperOrigin-RevId: 530675418
2023-05-09 12:02:35 -07:00
jax authors
cf4c1edafa Merge pull request #15920 from froystig:issue15869
PiperOrigin-RevId: 530634021
2023-05-09 09:39:48 -07:00
George Necula
daf6a30f6e Import "ir" directly rather than as "mlir.ir" 2023-05-09 17:55:13 +02:00
George Necula
de2a811fe9 [shape_poly] Improvements and more testing for shape polymorphism for random primitives
* added support for shape polymorphism for partitionable threefry and for
    random_split.
  * removed footgun that was ignoring the partitionable flag in presence of
    shape polymorphism.
  * Replicated the PRNG tests for threefry (partitionable and non-partitionable),
    and unsafe_rbg.
  * Added general support for overriding jax.config flags for PolyHarness

This fixes the known bug with random_gamma.
The known missing feature is shape polymorphism for RngBitGenerator.
https://github.com/openxla/stablehlo/issues/1344
2023-05-09 13:55:27 +02:00
Roy Frostig
051c5dda6e delegate select lowering to opaque dtype rule
... and implement it for PRNG key arrays
2023-05-08 19:02:42 -07:00
jax authors
99e7e8ee17 Merge pull request #15874 from jakevdp:keyarray-make-array
PiperOrigin-RevId: 529550502
2023-05-04 16:52:29 -07:00
Jake VanderPlas
4db717c52a KeyArray: support make_array_from_* APIs 2023-05-04 16:32:49 -07:00
Jake VanderPlas
b031cc2660 jax2tf: better handling for opaque dtypes 2023-05-04 14:22:15 -07:00
George Necula
6dfd248e74 [shape_poly] Add support for shape polymorphism for prng GPU custom call
We are using the new support for dynamic shapes for hlo.CustomCallOp, where
we need to pass the output shapes as additional operands.

This allows us to enable multiple "random" tests that were previously disabled.

PiperOrigin-RevId: 528990469
2023-05-02 22:26:58 -07:00
Jake VanderPlas
979aa3235b KeyArray: implement sharded & replicated device_put 2023-05-01 14:17:01 -07:00
Jake VanderPlas
054fca5cd4 KeyArray: define itemsize on opaque dtype 2023-04-27 15:59:57 -07:00
Jake VanderPlas
50405b1081 KeyArray: add size attribute 2023-04-27 14:06:55 -07:00
Yash Katariya
86c1f5bcee Preserve the sharding type of physical sharding on logical sharding when .sharding is accessed on a PRNGKeyArray
PiperOrigin-RevId: 527639257
2023-04-27 11:41:00 -07:00
Dinghua Li
7d6fb535a9 [shape_poly] Add support for shape polymorphism for _unsafe_rbg_split.
PiperOrigin-RevId: 527619524
2023-04-27 10:36:36 -07:00
Jake VanderPlas
fcffbac346 KeyArray: implement __eq__ and __ne__ 2023-04-26 13:12:24 -07:00
Jake VanderPlas
a47a71ff80 KeyArray: better errors for operators 2023-04-26 11:34:07 -07:00
Jake VanderPlas
6e84ed2992 PRNGKeyArray: implement scatter/gather via .at() 2023-04-25 15:55:08 -07:00
Jake VanderPlas
6374a77176 KeyArray: remove _stackable registration mechanism 2023-04-24 15:06:22 -07:00
Jake VanderPlas
a5737f82af custom prng: remove stackable override for jnp.concatenate 2023-04-24 12:26:58 -07:00
Jake VanderPlas
e50138608a PRNGKeyArrayImpl: add aval property
This makes it more readily compatible with jax.numpy routines.
2023-04-24 08:59:14 -07:00
Jake VanderPlas
72bb8ab753 jax.Array: dynamically define abstract methods 2023-04-18 13:08:32 -07:00
Roy Frostig
bf55dc947d jit the threefry seed function
The Threefry PRNG's seeding function involves operations with small
constants, such as `lax.shift_right_logical(seed, 32)`. This causes to
host-to-device transfers of small scalars (e.g. `32`) every time that
one seeds outside of a `jit`. To avoid these transfers, and any
inflexibility under JAX's transfer guard, we `jit` the seeding
function.

This shifts costs around a bit. Whereas previously we were moving
scalars to device on every (eager) seed call, we are now tracing and
compiling the seed function. The latter will only happen once per
input shape.
2023-04-15 10:38:46 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Roy Frostig
cd5e2380d8 make PRNGKeyArray abstract, separate from implementation
We expose the `PRNGKeyArray` symbol publicly, at least for use in
annotations (especially by libraries). Separating interface from
implementation helps ensure no instantiations. Also, should anyone try
to inherit from the public type, they will not pick up all of the
magic behavior of the implementing class (e.g. presence in pytype-aval
mappings).

This reflects what we do with `jax.Array` as well.

Makes a few other annotation fixups in `jax._src.prng` along the way.
2023-04-07 17:47:03 -07:00
Peter Hawkins
b4402185db Move PartitionSpec into its own file (jax/_src/partition_spec.py).
No functional changes intended.

A subsequent change will move ParsedPartitionSpec and array mapping utilities here also.

PiperOrigin-RevId: 522393166
2023-04-06 11:43:25 -07:00
Peter Hawkins
dfe95dcb4e Split ShardingSpecs and most of the helpers for constructing them into a separate file (jax/_src/sharding_specs.py).
PiperOrigin-RevId: 522360232
2023-04-06 09:48:51 -07:00
Yash Katariya
728a5ed96a [shard-map] fix eager shmap+prngs, revise phys aval/sharding logic
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-04-05 23:04:41 -07:00
Peter Hawkins
c1f65fc8b2 Avoid imports from the public jax.* namespace in more places internally.
This change is in preparation for more cycle breaking in the Bazel dependency graph.

PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
Peter Hawkins
abf1acf76c Replace references to jax.interpreters with jax._src.interpreters in JAX core.
PiperOrigin-RevId: 520933067
2023-03-31 08:58:00 -07:00
Jake VanderPlas
87aec2433b internal: refactor array methods into separate private submodule 2023-03-23 10:57:53 -07:00
Etienne Pot
4cb32ba46f Fix isinstance(k, PRNGKeyArray) on PRNGKeyArray subclasses
PiperOrigin-RevId: 518803946
2023-03-23 02:32:06 -07:00
Yash Katariya
58fed7001a Remove pxla.OutputType enum class now that the only output can be jax.Array
PiperOrigin-RevId: 517985356
2023-03-20 09:09:58 -07:00
Yash Katariya
c2d5527f72 [Jax cleanup]
* Remove lower_xla_callable and all related functions
* Remove pxla.device_put
* Remove dispatch.device_put_handlers

PiperOrigin-RevId: 517249345
2023-03-16 15:47:28 -07:00
Peter Hawkins
dea7450e4e Remove references to jax.config.jax_array, which is always True at head.
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00