Remove the jax_array config option, which does nothing.

PiperOrigin-RevId: 548981491
This commit is contained in:
Peter Hawkins 2023-07-18 06:15:24 -07:00 committed by jax authors
parent 8016fb3b66
commit 59509dc2b3
2 changed files with 7 additions and 31 deletions

View File

@ -60,8 +60,9 @@ Remember to align the itemized text with the first line of an item within a list
behavior (as documented) if the second and third arguments are
callable, even if other operands are callable as well. See
[#16413](https://github.com/google/jax/issues/16413).
* The deprecated config option `jax_jit_pjit_api_merge`, which did nothing,
has been removed.
* The deprecated config options `jax_array` and `jax_jit_pjit_api_merge`,
which did nothing, have been removed. These options have been true by
default for many releases.
* New features
* JAX now supports a configuration flag --jax_serialization_version

View File

@ -468,7 +468,7 @@ class Config:
return (axis_env_state, mesh_context_manager, self.x64_enabled,
self.jax_numpy_rank_promotion, self.jax_default_matmul_precision,
self.jax_dynamic_shapes, self.jax_numpy_dtype_promotion,
self.jax_default_device, self.jax_array,
self.jax_default_device,
self.jax_threefry_partitionable,
self.jax_softmax_custom_jvp,
# Technically this affects jaxpr->MHLO lowering, not tracing.
@ -769,30 +769,6 @@ parallel_functions_output_gda = config.define_bool_state(
default=False,
help='If True, pjit will output GDAs.')
def _update_jax_array_global(val):
if val is not None and not val:
raise ValueError(
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
' jax.config.jax_array.')
def _update_jax_array_thread_local(val):
if val is not None and not val:
raise ValueError(
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
' jax.config.jax_array.')
jax_array = config.define_bool_state(
name='jax_array',
default=True,
upgrade=True,
update_global_hook = _update_jax_array_global,
update_thread_local_hook = _update_jax_array_thread_local,
help=('If True, new pjit behavior will be enabled and `jax.Array` will be '
'used.'))
pmap_shmap_merge = config.define_bool_state(
name='jax_pmap_shmap_merge',
default=False,
@ -842,10 +818,9 @@ threefry_partitionable = config.define_bool_state(
default=False,
upgrade=True,
help=('Enables internal threefry PRNG implementation changes that '
'render it automatically partitionable in some cases. For use '
'with pjit and/or jax_array=True. Without this flag, using the '
'standard jax.random pseudo-random number generation may result '
'in extraneous communication and/or redundant distributed '
'render it automatically partitionable in some cases. Without this '
'flag, using the standard jax.random pseudo-random number generation '
'may result in extraneous communication and/or redundant distributed '
'computation. With this flag, the communication overheads disappear '
'in some cases.'),
update_global_hook=lambda val: _update_global_jit_state(