Jieying Luo
c7f60fa6eb
[PJRT C API] Implement framework side change for registering a custom call.
...
- Add a py extension to call the custom call C API.
- Change the implementation of register_custom_call_target to store handlers for the custom call targets and delays the registration until the handler for a xla platform is registered.
- Change register_plugin to load PJRT plugin when register_pluin is called (instead of when a client is created), and let it return the PJRT_Api* loaded.
- Delay calling discover_pjrt_plugins() and register_pjrt_plugin_factories_from_env() until the first time backends() is called.
PiperOrigin-RevId: 568265745
2023-09-25 10:52:29 -07:00
Peter Hawkins
625d2df735
Reverts d3f5e7f7956204ccccf4474423e2f189420e0f8e
...
PiperOrigin-RevId: 568249649
2023-09-25 09:59:54 -07:00
Peter Hawkins
69da839358
Remove test code that checks for the se_tpu runtime.
...
This runtime no longer exists.
PiperOrigin-RevId: 568242078
2023-09-25 09:30:07 -07:00
Jake VanderPlas
70e0098a87
[random] add itemsize property to custom PRNG
2023-09-25 08:52:26 -07:00
Peter Hawkins
5aaa15df84
Remove the skip_on_xla_cpu_mlir decorator.
...
We no longer test this variant in CI, so we don't need code to skip it.
PiperOrigin-RevId: 568219651
2023-09-25 08:04:56 -07:00
Peter Hawkins
d3f5e7f795
Remove code that skips array PRNG tests on CUDA.
...
https://github.com/google/jax/pull/13037 added this skip, but I have no idea why. The test seems to pass on GPU.
PiperOrigin-RevId: 568216977
2023-09-25 07:49:05 -07:00
jax authors
3a9289eac4
Merge pull request #17757 from gnecula:exp_dim_var
...
PiperOrigin-RevId: 568213625
2023-09-25 07:33:48 -07:00
George Necula
4d18595792
[export] Add mlir attributes to the platform index and dimension variables
...
In presence of multi-platform lowering or shape polymorphism we tag
the platform index and the dimension variable arguments with
MLIR attributions jax.platform_index and jax.dimension_variable.
So far nobody uses these attributes, hence this change should not
change any behavior. This is in preparation for implementing
exporting of effects.
2023-09-25 15:42:36 +02:00
Peter Hawkins
2fd6df45e4
Fix test failures under SciPy 1.11 for scipy.stats.mode.
2023-09-23 20:15:51 +00:00
jax authors
1466c3d2cb
Merge pull request #17746 from jakevdp:fix-typos
...
PiperOrigin-RevId: 567746141
2023-09-22 16:37:48 -07:00
Berkin Ilbeyi
c9b5996f59
[XLA] Initialize tuple shapes of async-done in dataflow analysis.
...
PiperOrigin-RevId: 567724401
2023-09-22 14:59:31 -07:00
Jake VanderPlas
4a5bd9e046
Fix typos across the package
2023-09-22 14:54:31 -07:00
Yash Katariya
8276038f63
Relax the memory alignment check between numpy array and jax array on CPU
...
PiperOrigin-RevId: 567722405
2023-09-22 14:49:00 -07:00
jax authors
6dcda12140
Merge pull request #17739 from jakevdp:valid-jaxtype
...
PiperOrigin-RevId: 567713101
2023-09-22 14:21:21 -07:00
jax authors
389c5a5cf5
Merge pull request #17740 from jakevdp:issubdtype-fix
...
PiperOrigin-RevId: 567708670
2023-09-22 14:01:01 -07:00
Jake VanderPlas
bfed3d862e
Improve behavior of core.valid_jaxtype
2023-09-22 13:46:09 -07:00
jax authors
681358fe02
Merge pull request #17425 from jon-chuang:jon-chuang/fix-deprecated-pallas-test
...
PiperOrigin-RevId: 567702337
2023-09-22 13:28:13 -07:00
Jake VanderPlas
8125e8bd03
issubdtype: fix corner cases with extended dtypes
2023-09-22 11:37:31 -07:00
Yash Katariya
426970591b
If an input to jnp.asarray
is a numpy array, then convert it to a jax.Array via device_put to avoid a copy.
...
Do a similar thing for jax.Array too if dtypes match.
Fixes https://github.com/google/jax/issues/17702
PiperOrigin-RevId: 567644997
2023-09-22 09:40:25 -07:00
Hyeontaek Lim
51589bbe70
Clarify that PjRtClient and PjRtDevice memory_spaces are not in particular order
...
PiperOrigin-RevId: 567630629
2023-09-22 08:37:07 -07:00
Yash Katariya
03877a9218
If a pmap out is replicated i.e. with out_axes=None
make jnp.copy's impl go via apply_primitive which will put it on a single device.
...
If we don't do that, then it hits an error mentioned in https://github.com/google/jax/issues/17690 .
Fixes https://github.com/google/jax/issues/17690
PiperOrigin-RevId: 567628026
2023-09-22 08:24:57 -07:00
George Necula
5b8f91fed7
[jax2tf] Fix higher-order differentiation.
...
We must ensure that we call jax2tf.convert recursively to ensure
that the proper tf.custom_gradient is used. This means that we can
reuse the conversion of the VJP function between native and graph
serialization.
2023-09-22 07:53:45 +02:00
Hyeontaek Lim
f0bde75dd3
[JAX] Export shard_map_test for testing on additional JAX backends
...
PiperOrigin-RevId: 567522898
2023-09-21 22:52:36 -07:00
Jon Chuang
1ecbf0e196
fix deprecated shuffle in pallas test
2023-09-21 23:06:08 -04:00
Jake VanderPlas
243a6a236c
dtypes.issubdtype: validate a when b is dtypes.extended
2023-09-21 15:53:05 -07:00
jax authors
0ce97c2f46
Merge pull request #17719 from jakevdp:dep-rng-key
...
PiperOrigin-RevId: 567436801
2023-09-21 15:30:18 -07:00
jax authors
256612bb80
Merge pull request #17720 from superbobry:tuple-list-comp
...
PiperOrigin-RevId: 567433086
2023-09-21 15:16:12 -07:00
Sergei Lebedev
df7f6a06c0
MAINT Use a generator expression in tuple([... for ... in ...])
...
In a few cases I also replaced tuple([*xs, *ys]) with (*xs, ys), because
tuple literals support unpacking as well.
2023-09-21 22:25:38 +01:00
Sharad Vikram
afb6691885
Disable msan/tsan for xmap_tests thanks to timeouts
...
PiperOrigin-RevId: 567412260
2023-09-21 14:24:08 -07:00
Jake VanderPlas
22818d664f
[random] deprecate named key creation functions
2023-09-21 13:57:49 -07:00
jax authors
8295007818
Merge pull request #17710 from jakevdp:issubdtype-generic
...
PiperOrigin-RevId: 567406904
2023-09-21 13:53:58 -07:00
Junwhan Ahn
6a551a1efa
Add memories_test.py to the list of exported tests
...
PiperOrigin-RevId: 567375604
2023-09-21 11:57:09 -07:00
Jake VanderPlas
4edb74ba7b
Fix some numpy 2.0 incompatibilities
2023-09-21 10:24:52 -07:00
Jake VanderPlas
4c48452652
dtypes: make issubdtype(dt, np.generic) work for custom types
2023-09-21 09:46:39 -07:00
Adrian Kuegel
d4965cd335
[XLA:GPU] Clean up Target util.
...
We have some differences between Triton codegen and other fusion codegen,
namely for Remainder/Fmod and Cbrt. Unify that.
- Remove two unused math functions.
- Add mapping from kRemainder to kFmod.
- Use kCbrt device function in elemental_ir_emitter.
PiperOrigin-RevId: 567274915
2023-09-21 05:12:06 -07:00
jax authors
68eddd16f3
Update the logic of PjRtArray::Reshard
after PjRtBuffer::CopyToMemorySpace
was introduced. Users should use PjRtBuffer::CopyToMemorySpace
instead of PjRtBuffer::CopyToDevice
when memories are supported, since the semantics of the latter one is to always copy to the default memory space of the device.
...
PiperOrigin-RevId: 567154400
2023-09-20 19:39:01 -07:00
Jake VanderPlas
95a209f28b
Tests: fix some failures for upstream numpy
2023-09-20 12:26:12 -07:00
Peter Hawkins
f52926e832
Fix test breakage in RNN test with old jaxlibs.
...
Remove some outdated version guards.
2023-09-20 11:50:04 -04:00
jax authors
7bc01d9472
Add memory kind check in PjRtArray::Create
.
...
PiperOrigin-RevId: 566851924
2023-09-19 22:58:37 -07:00
jax authors
33d862fb93
Merge pull request #17669 from andportnoy:aportnoy/xmap-test-use-float32
...
PiperOrigin-RevId: 566732138
2023-09-19 13:58:32 -07:00
jax authors
2332b07d3e
Merge pull request #17666 from jakevdp:jex-array-types
...
PiperOrigin-RevId: 566732103
2023-09-19 13:48:29 -07:00
jax authors
0da983d138
Merge pull request #17653 from andportnoy:aportnoy/rnn-plumb-precision-param
...
PiperOrigin-RevId: 566705284
2023-09-19 12:11:35 -07:00
Andrey Portnoy
fc1c31d958
Run LSTM test using FP32 math (as opposed to TF32)
...
1. Add (limited) precision specifier handling to LSTM
Enables differentiating between TF32 and FP32 math. TF32 math had insufficient
precision to reliably pass LSTM correctness tests on A100 and H100.
2. Run the test using FP32
TF32 precision is not sufficient for the test to pass reliably on Ampere+ GPUs
such as A100 and H100.
2023-09-19 14:45:14 -04:00
Jake VanderPlas
48087cbe8d
JEX: add jex.abstract_arrays.array_types
2023-09-19 11:37:05 -07:00
Andrey Portnoy
4eea05723b
Use float32 for testNestedMap and testPdotBatching in XMapTest
...
Multiple test cases were failing on Ampere+ due to use of TF32.
2023-09-19 14:35:23 -04:00
Jake VanderPlas
0dc2252f71
Better errors for array scalar/boolean conversion
2023-09-19 09:00:19 -07:00
jax authors
0f1813921e
Merge pull request #17651 from froystig:wrap-key-random
...
PiperOrigin-RevId: 566457857
2023-09-18 17:27:29 -07:00
Yash Katariya
05729513fb
Delete TransferPjRtBufferBetweenMemories
and replace it with CopyToMemorySpace
which is more robust and fully async and transfers between any memory space.
...
PiperOrigin-RevId: 566420233
2023-09-18 14:48:55 -07:00
Roy Frostig
2bf9322ccc
move wrap_key_data
to jax.random
...
This is a fine function for the public API, rather than `jax.extend`.
2023-09-18 14:38:22 -07:00
Adam Paszke
9fa96f8bdf
Raise SkipTest before running setUp
...
If `setUp` raises, then `tearDown` is never run. This, in conjunction with
`jtu.with_config` leads to leaked configs and causes downstreram test
failures due to NaN checking.
2023-09-18 15:28:00 +00:00