Roy Frostig
f18bff5371
inline and remove scatter_mlir
rules
2023-05-17 20:07:59 -07:00
Roy Frostig
cc54b6e6ad
inline and remove select_mlir
rules
2023-05-17 20:07:59 -07:00
Roy Frostig
301d058b3d
inline and remove gather_mlir
rules
2023-05-17 20:07:59 -07:00
Roy Frostig
071c77e5bb
inline and remove transpose_mlir
rules
2023-05-17 20:07:59 -07:00
Roy Frostig
06132ac764
inline and remove broadcast_in_dim_mlir
rules
2023-05-17 20:07:59 -07:00
Roy Frostig
0ac792f4ed
inline and remove dynamic_update_slice_mlir
rules
2023-05-17 20:07:59 -07:00
Roy Frostig
2dbdf1a6c1
inline and remove dynamic_slice_mlir
rules
2023-05-17 20:07:59 -07:00
Roy Frostig
aed77c5031
inline and remove slice_mlir
rules
2023-05-17 20:07:58 -07:00
Roy Frostig
129a4a5f35
inline and remove empty_mlir
rules
2023-05-17 20:07:58 -07:00
Roy Frostig
180e26dafb
remove physical_avals
rule in favor of physical_element_aval
2023-05-17 20:07:58 -07:00
Jake VanderPlas
48abe7c684
PRNGKeyArray: add several missing attributes & methods
2023-05-17 14:47:22 -07:00
Jake VanderPlas
6ef4e5f01a
Custom PRNG: make KeyArray compatible with custom_jvp
2023-05-17 10:31:09 -07:00
Jake VanderPlas
0e483223c6
Custom PRNG: support lax.full() and related constructors
2023-05-17 09:04:50 -07:00
Peter Hawkins
eaf7eb2626
Break cycle between _src/core.py and _src/dtypes.py.
...
PiperOrigin-RevId: 532788430
2023-05-17 07:58:59 -07:00
Jake VanderPlas
b9aa236dac
Custom PRNG: support PRNGKeyArray.copy()
2023-05-12 15:50:22 -07:00
Jake VanderPlas
6ada8785aa
PRNGKeyArray: fix dynamic slice index dtype
2023-05-10 09:24:18 -07:00
jax authors
68ba54241c
Merge pull request #15929 from gnecula:fix_mlir_ir
...
PiperOrigin-RevId: 530675418
2023-05-09 12:02:35 -07:00
jax authors
cf4c1edafa
Merge pull request #15920 from froystig:issue15869
...
PiperOrigin-RevId: 530634021
2023-05-09 09:39:48 -07:00
George Necula
daf6a30f6e
Import "ir" directly rather than as "mlir.ir"
2023-05-09 17:55:13 +02:00
George Necula
de2a811fe9
[shape_poly] Improvements and more testing for shape polymorphism for random primitives
...
* added support for shape polymorphism for partitionable threefry and for
random_split.
* removed footgun that was ignoring the partitionable flag in presence of
shape polymorphism.
* Replicated the PRNG tests for threefry (partitionable and non-partitionable),
and unsafe_rbg.
* Added general support for overriding jax.config flags for PolyHarness
This fixes the known bug with random_gamma.
The known missing feature is shape polymorphism for RngBitGenerator.
https://github.com/openxla/stablehlo/issues/1344
2023-05-09 13:55:27 +02:00
Roy Frostig
051c5dda6e
delegate select
lowering to opaque dtype rule
...
... and implement it for PRNG key arrays
2023-05-08 19:02:42 -07:00
jax authors
99e7e8ee17
Merge pull request #15874 from jakevdp:keyarray-make-array
...
PiperOrigin-RevId: 529550502
2023-05-04 16:52:29 -07:00
Jake VanderPlas
4db717c52a
KeyArray: support make_array_from_* APIs
2023-05-04 16:32:49 -07:00
Jake VanderPlas
b031cc2660
jax2tf: better handling for opaque dtypes
2023-05-04 14:22:15 -07:00
George Necula
6dfd248e74
[shape_poly] Add support for shape polymorphism for prng GPU custom call
...
We are using the new support for dynamic shapes for hlo.CustomCallOp, where
we need to pass the output shapes as additional operands.
This allows us to enable multiple "random" tests that were previously disabled.
PiperOrigin-RevId: 528990469
2023-05-02 22:26:58 -07:00
Jake VanderPlas
979aa3235b
KeyArray: implement sharded & replicated device_put
2023-05-01 14:17:01 -07:00
Jake VanderPlas
054fca5cd4
KeyArray: define itemsize on opaque dtype
2023-04-27 15:59:57 -07:00
Jake VanderPlas
50405b1081
KeyArray: add size attribute
2023-04-27 14:06:55 -07:00
Yash Katariya
86c1f5bcee
Preserve the sharding type of physical sharding on logical sharding when .sharding
is accessed on a PRNGKeyArray
...
PiperOrigin-RevId: 527639257
2023-04-27 11:41:00 -07:00
Dinghua Li
7d6fb535a9
[shape_poly] Add support for shape polymorphism for _unsafe_rbg_split.
...
PiperOrigin-RevId: 527619524
2023-04-27 10:36:36 -07:00
Jake VanderPlas
fcffbac346
KeyArray: implement __eq__ and __ne__
2023-04-26 13:12:24 -07:00
Jake VanderPlas
a47a71ff80
KeyArray: better errors for operators
2023-04-26 11:34:07 -07:00
Jake VanderPlas
6e84ed2992
PRNGKeyArray: implement scatter/gather via .at()
2023-04-25 15:55:08 -07:00
Jake VanderPlas
6374a77176
KeyArray: remove _stackable registration mechanism
2023-04-24 15:06:22 -07:00
Jake VanderPlas
a5737f82af
custom prng: remove stackable override for jnp.concatenate
2023-04-24 12:26:58 -07:00
Jake VanderPlas
e50138608a
PRNGKeyArrayImpl: add aval property
...
This makes it more readily compatible with jax.numpy routines.
2023-04-24 08:59:14 -07:00
Jake VanderPlas
72bb8ab753
jax.Array: dynamically define abstract methods
2023-04-18 13:08:32 -07:00
Roy Frostig
bf55dc947d
jit
the threefry seed function
...
The Threefry PRNG's seeding function involves operations with small
constants, such as `lax.shift_right_logical(seed, 32)`. This causes to
host-to-device transfers of small scalars (e.g. `32`) every time that
one seeds outside of a `jit`. To avoid these transfers, and any
inflexibility under JAX's transfer guard, we `jit` the seeding
function.
This shifts costs around a bit. Whereas previously we were moving
scalars to device on every (eager) seed call, we are now tracing and
compiling the seed function. The latter will only happen once per
input shape.
2023-04-15 10:38:46 -07:00
Jake VanderPlas
5521423d92
Change np.prod->math.prod
...
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Roy Frostig
cd5e2380d8
make PRNGKeyArray
abstract, separate from implementation
...
We expose the `PRNGKeyArray` symbol publicly, at least for use in
annotations (especially by libraries). Separating interface from
implementation helps ensure no instantiations. Also, should anyone try
to inherit from the public type, they will not pick up all of the
magic behavior of the implementing class (e.g. presence in pytype-aval
mappings).
This reflects what we do with `jax.Array` as well.
Makes a few other annotation fixups in `jax._src.prng` along the way.
2023-04-07 17:47:03 -07:00
Peter Hawkins
b4402185db
Move PartitionSpec into its own file (jax/_src/partition_spec.py).
...
No functional changes intended.
A subsequent change will move ParsedPartitionSpec and array mapping utilities here also.
PiperOrigin-RevId: 522393166
2023-04-06 11:43:25 -07:00
Peter Hawkins
dfe95dcb4e
Split ShardingSpecs and most of the helpers for constructing them into a separate file (jax/_src/sharding_specs.py).
...
PiperOrigin-RevId: 522360232
2023-04-06 09:48:51 -07:00
Yash Katariya
728a5ed96a
[shard-map] fix eager shmap+prngs, revise phys aval/sharding logic
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-04-05 23:04:41 -07:00
Peter Hawkins
c1f65fc8b2
Avoid imports from the public jax.* namespace in more places internally.
...
This change is in preparation for more cycle breaking in the Bazel dependency graph.
PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
Peter Hawkins
abf1acf76c
Replace references to jax.interpreters with jax._src.interpreters in JAX core.
...
PiperOrigin-RevId: 520933067
2023-03-31 08:58:00 -07:00
Jake VanderPlas
87aec2433b
internal: refactor array methods into separate private submodule
2023-03-23 10:57:53 -07:00
Etienne Pot
4cb32ba46f
Fix isinstance(k, PRNGKeyArray) on PRNGKeyArray subclasses
...
PiperOrigin-RevId: 518803946
2023-03-23 02:32:06 -07:00
Yash Katariya
58fed7001a
Remove pxla.OutputType enum class now that the only output can be jax.Array
...
PiperOrigin-RevId: 517985356
2023-03-20 09:09:58 -07:00
Yash Katariya
c2d5527f72
[Jax cleanup]
...
* Remove lower_xla_callable and all related functions
* Remove pxla.device_put
* Remove dispatch.device_put_handlers
PiperOrigin-RevId: 517249345
2023-03-16 15:47:28 -07:00
Peter Hawkins
dea7450e4e
Remove references to jax.config.jax_array, which is always True at head.
...
PiperOrigin-RevId: 516970232
2023-03-15 17:09:11 -07:00