jax2tf with native_serialization=False or with enable_xla=False have been deprecated since July 2024.
This change turns an attempt to use `native_serialization=False` or `enable_xla=False` into an error.
PiperOrigin-RevId: 689708392
Repeated string addition is apparently a bit of an anti-pattern. Not that it matters
much in this place, but why not do it properly.
PiperOrigin-RevId: 689416587
Consider the use case when we call_tf a restored saved model that
includes parameters (hence functions closing over tf.Variable), and then
we jax2tf.convert it with native serialization, under tf.function (or
for saving to saved model).
The lowering for call_tf in presence of functions with captured inputs
requires looking up the tf.Variable and reading its value. This fails
with an error that `v.numpy()` is not allowd in graph mode. The fix
is to use `tf.init_scope()` to lift out of graph building mode, so that
we can read the value of the variables.
`random.choice` uses `np.insert(arr.shape, new_shape)` which attempts
to coerce all the values in `new_shape` to constants when `arr.shape`
is constant. Replace use of `np.insert` with tuple slicing and
concatenation.
The case when the sampled axis has non-constant size and
`replace=False` is not supported, because `permutation` on
arrays with non-constant size is not supported.
Adds tests for many combinations of arguments for `random.choice`.
Improves a few error messages.
As an extra minor change, we now disallow specifying the predicate when uniform is
unset, as that implies that we're going to use two different mechanisms to select
a single thread.
PiperOrigin-RevId: 689289365
Originally proposed in #24021. Slightly rewritter to make testing with internal LLVM toolchains better.
Use CUDA driver API to query major and minor compute capabilities, thus arriving at a "base" SM string (e.g. `sm_90`).
Then use LLVM to see if we can "upgrade" the base SM string to one that enables architecture-specific capabilities (e.g. `sm_90a`).
Then use LLVM to map the SM string to a PTX ISA version that supports the SM.
Co-authored-by: Andrey Portnoy <aportnoy@nvidia.com>
PiperOrigin-RevId: 689286774
Backends often have custom effectful primitives, but their effects do not extend
beyond the scope of a single kernel, so we should remove them in core_map's abstract eval.
PiperOrigin-RevId: 688990275
It is not possible for primitives to return references so in order to support reshaping we need to use TransformRef. This CL introduces both a reshape memref transform and a function for the user to create transformed refs of that type.
PiperOrigin-RevId: 688966337
* Uninitialized values
* Custom ref aval construction
This will allow us to replace `run_scoped` with `run_state`, and allow us to change the memory space of initialized values.
Co-authored-by: Sharad Vikram <sharadmv@google.com>
PiperOrigin-RevId: 688965089