Add upgrade=True to jax_array flag so that its marked as transient flag which will eventually be set to True.

This commit is contained in:
Yash Katariya 2022-08-22 13:07:53 -07:00 committed by GitHub
parent 384776f0c9
commit 37089ec1b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -655,6 +655,7 @@ parallel_functions_output_gda = config.define_bool_state(
jax_array = config.define_bool_state(
name='jax_array',
default=False,
upgrade=True,
help=('If True, new pjit behavior will be enabled and `jax.Array` will be '
'used.'))