From 59509dc2b38cad37b211ad5fc65568cb7fcc7439 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 18 Jul 2023 06:15:24 -0700 Subject: [PATCH] Remove the jax_array config option, which does nothing. PiperOrigin-RevId: 548981491 --- CHANGELOG.md | 5 +++-- jax/_src/config.py | 33 ++++----------------------------- 2 files changed, 7 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d6733a41..1498ff306 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,8 +60,9 @@ Remember to align the itemized text with the first line of an item within a list behavior (as documented) if the second and third arguments are callable, even if other operands are callable as well. See [#16413](https://github.com/google/jax/issues/16413). - * The deprecated config option `jax_jit_pjit_api_merge`, which did nothing, - has been removed. + * The deprecated config options `jax_array` and `jax_jit_pjit_api_merge`, + which did nothing, have been removed. These options have been true by + default for many releases. * New features * JAX now supports a configuration flag --jax_serialization_version diff --git a/jax/_src/config.py b/jax/_src/config.py index ee1ce471c..300900e9d 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -468,7 +468,7 @@ class Config: return (axis_env_state, mesh_context_manager, self.x64_enabled, self.jax_numpy_rank_promotion, self.jax_default_matmul_precision, self.jax_dynamic_shapes, self.jax_numpy_dtype_promotion, - self.jax_default_device, self.jax_array, + self.jax_default_device, self.jax_threefry_partitionable, self.jax_softmax_custom_jvp, # Technically this affects jaxpr->MHLO lowering, not tracing. @@ -769,30 +769,6 @@ parallel_functions_output_gda = config.define_bool_state( default=False, help='If True, pjit will output GDAs.') -def _update_jax_array_global(val): - if val is not None and not val: - raise ValueError( - 'jax.config.jax_array cannot be disabled after jax 0.4.7 release.' - ' Please downgrade to jax and jaxlib 0.4.6 if you want to disable' - ' jax.config.jax_array.') - -def _update_jax_array_thread_local(val): - if val is not None and not val: - raise ValueError( - 'jax.config.jax_array cannot be disabled after jax 0.4.7 release.' - ' Please downgrade to jax and jaxlib 0.4.6 if you want to disable' - ' jax.config.jax_array.') - -jax_array = config.define_bool_state( - name='jax_array', - default=True, - upgrade=True, - update_global_hook = _update_jax_array_global, - update_thread_local_hook = _update_jax_array_thread_local, - help=('If True, new pjit behavior will be enabled and `jax.Array` will be ' - 'used.')) - - pmap_shmap_merge = config.define_bool_state( name='jax_pmap_shmap_merge', default=False, @@ -842,10 +818,9 @@ threefry_partitionable = config.define_bool_state( default=False, upgrade=True, help=('Enables internal threefry PRNG implementation changes that ' - 'render it automatically partitionable in some cases. For use ' - 'with pjit and/or jax_array=True. Without this flag, using the ' - 'standard jax.random pseudo-random number generation may result ' - 'in extraneous communication and/or redundant distributed ' + 'render it automatically partitionable in some cases. Without this ' + 'flag, using the standard jax.random pseudo-random number generation ' + 'may result in extraneous communication and/or redundant distributed ' 'computation. With this flag, the communication overheads disappear ' 'in some cases.'), update_global_hook=lambda val: _update_global_jit_state(