Remove the jax_array config option, which does nothing.

PiperOrigin-RevId: 548981491
This commit is contained in:
Peter Hawkins 2023-07-18 06:15:24 -07:00 committed by jax authors
parent 8016fb3b66
commit 59509dc2b3
2 changed files with 7 additions and 31 deletions

View File

@ -60,8 +60,9 @@ Remember to align the itemized text with the first line of an item within a list
behavior (as documented) if the second and third arguments are behavior (as documented) if the second and third arguments are
callable, even if other operands are callable as well. See callable, even if other operands are callable as well. See
[#16413](https://github.com/google/jax/issues/16413). [#16413](https://github.com/google/jax/issues/16413).
* The deprecated config option `jax_jit_pjit_api_merge`, which did nothing, * The deprecated config options `jax_array` and `jax_jit_pjit_api_merge`,
has been removed. which did nothing, have been removed. These options have been true by
default for many releases.
* New features * New features
* JAX now supports a configuration flag --jax_serialization_version * JAX now supports a configuration flag --jax_serialization_version

View File

@ -468,7 +468,7 @@ class Config:
return (axis_env_state, mesh_context_manager, self.x64_enabled, return (axis_env_state, mesh_context_manager, self.x64_enabled,
self.jax_numpy_rank_promotion, self.jax_default_matmul_precision, self.jax_numpy_rank_promotion, self.jax_default_matmul_precision,
self.jax_dynamic_shapes, self.jax_numpy_dtype_promotion, self.jax_dynamic_shapes, self.jax_numpy_dtype_promotion,
self.jax_default_device, self.jax_array, self.jax_default_device,
self.jax_threefry_partitionable, self.jax_threefry_partitionable,
self.jax_softmax_custom_jvp, self.jax_softmax_custom_jvp,
# Technically this affects jaxpr->MHLO lowering, not tracing. # Technically this affects jaxpr->MHLO lowering, not tracing.
@ -769,30 +769,6 @@ parallel_functions_output_gda = config.define_bool_state(
default=False, default=False,
help='If True, pjit will output GDAs.') help='If True, pjit will output GDAs.')
def _update_jax_array_global(val):
if val is not None and not val:
raise ValueError(
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
' jax.config.jax_array.')
def _update_jax_array_thread_local(val):
if val is not None and not val:
raise ValueError(
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
' jax.config.jax_array.')
jax_array = config.define_bool_state(
name='jax_array',
default=True,
upgrade=True,
update_global_hook = _update_jax_array_global,
update_thread_local_hook = _update_jax_array_thread_local,
help=('If True, new pjit behavior will be enabled and `jax.Array` will be '
'used.'))
pmap_shmap_merge = config.define_bool_state( pmap_shmap_merge = config.define_bool_state(
name='jax_pmap_shmap_merge', name='jax_pmap_shmap_merge',
default=False, default=False,
@ -842,10 +818,9 @@ threefry_partitionable = config.define_bool_state(
default=False, default=False,
upgrade=True, upgrade=True,
help=('Enables internal threefry PRNG implementation changes that ' help=('Enables internal threefry PRNG implementation changes that '
'render it automatically partitionable in some cases. For use ' 'render it automatically partitionable in some cases. Without this '
'with pjit and/or jax_array=True. Without this flag, using the ' 'flag, using the standard jax.random pseudo-random number generation '
'standard jax.random pseudo-random number generation may result ' 'may result in extraneous communication and/or redundant distributed '
'in extraneous communication and/or redundant distributed '
'computation. With this flag, the communication overheads disappear ' 'computation. With this flag, the communication overheads disappear '
'in some cases.'), 'in some cases.'),
update_global_hook=lambda val: _update_global_jit_state( update_global_hook=lambda val: _update_global_jit_state(