6660 Commits

Author SHA1 Message Date
Jieying Luo
c7f60fa6eb [PJRT C API] Implement framework side change for registering a custom call.
- Add a py extension to call the custom call C API.
- Change the implementation of register_custom_call_target to store handlers for the custom call targets and delays the registration until the handler for a xla platform is registered.
- Change register_plugin to load PJRT plugin when register_pluin is called (instead of when a client is created), and let it return the PJRT_Api* loaded.
- Delay calling discover_pjrt_plugins() and register_pjrt_plugin_factories_from_env() until the first time backends() is called.

PiperOrigin-RevId: 568265745
2023-09-25 10:52:29 -07:00
Peter Hawkins
625d2df735 Reverts d3f5e7f7956204ccccf4474423e2f189420e0f8e
PiperOrigin-RevId: 568249649
2023-09-25 09:59:54 -07:00
Peter Hawkins
69da839358 Remove test code that checks for the se_tpu runtime.
This runtime no longer exists.

PiperOrigin-RevId: 568242078
2023-09-25 09:30:07 -07:00
Jake VanderPlas
70e0098a87 [random] add itemsize property to custom PRNG 2023-09-25 08:52:26 -07:00
Peter Hawkins
5aaa15df84 Remove the skip_on_xla_cpu_mlir decorator.
We no longer test this variant in CI, so we don't need code to skip it.

PiperOrigin-RevId: 568219651
2023-09-25 08:04:56 -07:00
Peter Hawkins
d3f5e7f795 Remove code that skips array PRNG tests on CUDA.
https://github.com/google/jax/pull/13037 added this skip, but I have no idea why. The test seems to pass on GPU.

PiperOrigin-RevId: 568216977
2023-09-25 07:49:05 -07:00
jax authors
3a9289eac4 Merge pull request #17757 from gnecula:exp_dim_var
PiperOrigin-RevId: 568213625
2023-09-25 07:33:48 -07:00
George Necula
4d18595792 [export] Add mlir attributes to the platform index and dimension variables
In presence of multi-platform lowering or shape polymorphism we tag
the platform index and the dimension variable arguments with
MLIR attributions jax.platform_index and jax.dimension_variable.

So far nobody uses these attributes, hence this change should not
change any behavior. This is in preparation for implementing
exporting of effects.
2023-09-25 15:42:36 +02:00
Peter Hawkins
2fd6df45e4 Fix test failures under SciPy 1.11 for scipy.stats.mode. 2023-09-23 20:15:51 +00:00
jax authors
1466c3d2cb Merge pull request #17746 from jakevdp:fix-typos
PiperOrigin-RevId: 567746141
2023-09-22 16:37:48 -07:00
Berkin Ilbeyi
c9b5996f59 [XLA] Initialize tuple shapes of async-done in dataflow analysis.
PiperOrigin-RevId: 567724401
2023-09-22 14:59:31 -07:00
Jake VanderPlas
4a5bd9e046 Fix typos across the package 2023-09-22 14:54:31 -07:00
Yash Katariya
8276038f63 Relax the memory alignment check between numpy array and jax array on CPU
PiperOrigin-RevId: 567722405
2023-09-22 14:49:00 -07:00
jax authors
6dcda12140 Merge pull request #17739 from jakevdp:valid-jaxtype
PiperOrigin-RevId: 567713101
2023-09-22 14:21:21 -07:00
jax authors
389c5a5cf5 Merge pull request #17740 from jakevdp:issubdtype-fix
PiperOrigin-RevId: 567708670
2023-09-22 14:01:01 -07:00
Jake VanderPlas
bfed3d862e Improve behavior of core.valid_jaxtype 2023-09-22 13:46:09 -07:00
jax authors
681358fe02 Merge pull request #17425 from jon-chuang:jon-chuang/fix-deprecated-pallas-test
PiperOrigin-RevId: 567702337
2023-09-22 13:28:13 -07:00
Jake VanderPlas
8125e8bd03 issubdtype: fix corner cases with extended dtypes 2023-09-22 11:37:31 -07:00
Yash Katariya
426970591b If an input to jnp.asarray is a numpy array, then convert it to a jax.Array via device_put to avoid a copy.
Do a similar thing for jax.Array too if dtypes match.

Fixes https://github.com/google/jax/issues/17702

PiperOrigin-RevId: 567644997
2023-09-22 09:40:25 -07:00
Hyeontaek Lim
51589bbe70 Clarify that PjRtClient and PjRtDevice memory_spaces are not in particular order
PiperOrigin-RevId: 567630629
2023-09-22 08:37:07 -07:00
Yash Katariya
03877a9218 If a pmap out is replicated i.e. with out_axes=None make jnp.copy's impl go via apply_primitive which will put it on a single device.
If we don't do that, then it hits an error mentioned in https://github.com/google/jax/issues/17690.

Fixes https://github.com/google/jax/issues/17690

PiperOrigin-RevId: 567628026
2023-09-22 08:24:57 -07:00
George Necula
5b8f91fed7 [jax2tf] Fix higher-order differentiation.
We must ensure that we call jax2tf.convert recursively to ensure
that the proper tf.custom_gradient is used. This means that we can
reuse the conversion of the VJP function between native and graph
serialization.
2023-09-22 07:53:45 +02:00
Hyeontaek Lim
f0bde75dd3 [JAX] Export shard_map_test for testing on additional JAX backends
PiperOrigin-RevId: 567522898
2023-09-21 22:52:36 -07:00
Jon Chuang
1ecbf0e196 fix deprecated shuffle in pallas test 2023-09-21 23:06:08 -04:00
Jake VanderPlas
243a6a236c dtypes.issubdtype: validate a when b is dtypes.extended 2023-09-21 15:53:05 -07:00
jax authors
0ce97c2f46 Merge pull request #17719 from jakevdp:dep-rng-key
PiperOrigin-RevId: 567436801
2023-09-21 15:30:18 -07:00
jax authors
256612bb80 Merge pull request #17720 from superbobry:tuple-list-comp
PiperOrigin-RevId: 567433086
2023-09-21 15:16:12 -07:00
Sergei Lebedev
df7f6a06c0 MAINT Use a generator expression in tuple([... for ... in ...])
In a few cases I also replaced tuple([*xs, *ys]) with (*xs, ys), because
tuple literals support unpacking as well.
2023-09-21 22:25:38 +01:00
Sharad Vikram
afb6691885 Disable msan/tsan for xmap_tests thanks to timeouts
PiperOrigin-RevId: 567412260
2023-09-21 14:24:08 -07:00
Jake VanderPlas
22818d664f [random] deprecate named key creation functions 2023-09-21 13:57:49 -07:00
jax authors
8295007818 Merge pull request #17710 from jakevdp:issubdtype-generic
PiperOrigin-RevId: 567406904
2023-09-21 13:53:58 -07:00
Junwhan Ahn
6a551a1efa Add memories_test.py to the list of exported tests
PiperOrigin-RevId: 567375604
2023-09-21 11:57:09 -07:00
Jake VanderPlas
4edb74ba7b Fix some numpy 2.0 incompatibilities 2023-09-21 10:24:52 -07:00
Jake VanderPlas
4c48452652 dtypes: make issubdtype(dt, np.generic) work for custom types 2023-09-21 09:46:39 -07:00
Adrian Kuegel
d4965cd335 [XLA:GPU] Clean up Target util.
We have some differences between Triton codegen and other fusion codegen,
namely for Remainder/Fmod and Cbrt. Unify that.

- Remove two unused math functions.
- Add mapping from kRemainder to kFmod.
- Use kCbrt device function in elemental_ir_emitter.

PiperOrigin-RevId: 567274915
2023-09-21 05:12:06 -07:00
jax authors
68eddd16f3 Update the logic of PjRtArray::Reshard after PjRtBuffer::CopyToMemorySpace was introduced. Users should use PjRtBuffer::CopyToMemorySpace instead of PjRtBuffer::CopyToDevice when memories are supported, since the semantics of the latter one is to always copy to the default memory space of the device.
PiperOrigin-RevId: 567154400
2023-09-20 19:39:01 -07:00
Jake VanderPlas
95a209f28b Tests: fix some failures for upstream numpy 2023-09-20 12:26:12 -07:00
Peter Hawkins
f52926e832 Fix test breakage in RNN test with old jaxlibs.
Remove some outdated version guards.
2023-09-20 11:50:04 -04:00
jax authors
7bc01d9472 Add memory kind check in PjRtArray::Create.
PiperOrigin-RevId: 566851924
2023-09-19 22:58:37 -07:00
jax authors
33d862fb93 Merge pull request #17669 from andportnoy:aportnoy/xmap-test-use-float32
PiperOrigin-RevId: 566732138
2023-09-19 13:58:32 -07:00
jax authors
2332b07d3e Merge pull request #17666 from jakevdp:jex-array-types
PiperOrigin-RevId: 566732103
2023-09-19 13:48:29 -07:00
jax authors
0da983d138 Merge pull request #17653 from andportnoy:aportnoy/rnn-plumb-precision-param
PiperOrigin-RevId: 566705284
2023-09-19 12:11:35 -07:00
Andrey Portnoy
fc1c31d958 Run LSTM test using FP32 math (as opposed to TF32)
1. Add (limited) precision specifier handling to LSTM

Enables differentiating between TF32 and FP32 math. TF32 math had insufficient
precision to reliably pass LSTM correctness tests on A100 and H100.

2. Run the test using FP32

TF32 precision is not sufficient for the test to pass reliably on Ampere+ GPUs
such as A100 and H100.
2023-09-19 14:45:14 -04:00
Jake VanderPlas
48087cbe8d JEX: add jex.abstract_arrays.array_types 2023-09-19 11:37:05 -07:00
Andrey Portnoy
4eea05723b Use float32 for testNestedMap and testPdotBatching in XMapTest
Multiple test cases were failing on Ampere+ due to use of TF32.
2023-09-19 14:35:23 -04:00
Jake VanderPlas
0dc2252f71 Better errors for array scalar/boolean conversion 2023-09-19 09:00:19 -07:00
jax authors
0f1813921e Merge pull request #17651 from froystig:wrap-key-random
PiperOrigin-RevId: 566457857
2023-09-18 17:27:29 -07:00
Yash Katariya
05729513fb Delete TransferPjRtBufferBetweenMemories and replace it with CopyToMemorySpace which is more robust and fully async and transfers between any memory space.
PiperOrigin-RevId: 566420233
2023-09-18 14:48:55 -07:00
Roy Frostig
2bf9322ccc move wrap_key_data to jax.random
This is a fine function for the public API, rather than `jax.extend`.
2023-09-18 14:38:22 -07:00
Adam Paszke
9fa96f8bdf Raise SkipTest before running setUp
If `setUp` raises, then `tearDown` is never run. This, in conjunction with
`jtu.with_config` leads to leaked configs and causes downstreram test
failures due to NaN checking.
2023-09-18 15:28:00 +00:00