mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove the jax_array config option, which does nothing.
PiperOrigin-RevId: 548981491
This commit is contained in:
parent
8016fb3b66
commit
59509dc2b3
@ -60,8 +60,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
behavior (as documented) if the second and third arguments are
|
||||
callable, even if other operands are callable as well. See
|
||||
[#16413](https://github.com/google/jax/issues/16413).
|
||||
* The deprecated config option `jax_jit_pjit_api_merge`, which did nothing,
|
||||
has been removed.
|
||||
* The deprecated config options `jax_array` and `jax_jit_pjit_api_merge`,
|
||||
which did nothing, have been removed. These options have been true by
|
||||
default for many releases.
|
||||
|
||||
* New features
|
||||
* JAX now supports a configuration flag --jax_serialization_version
|
||||
|
@ -468,7 +468,7 @@ class Config:
|
||||
return (axis_env_state, mesh_context_manager, self.x64_enabled,
|
||||
self.jax_numpy_rank_promotion, self.jax_default_matmul_precision,
|
||||
self.jax_dynamic_shapes, self.jax_numpy_dtype_promotion,
|
||||
self.jax_default_device, self.jax_array,
|
||||
self.jax_default_device,
|
||||
self.jax_threefry_partitionable,
|
||||
self.jax_softmax_custom_jvp,
|
||||
# Technically this affects jaxpr->MHLO lowering, not tracing.
|
||||
@ -769,30 +769,6 @@ parallel_functions_output_gda = config.define_bool_state(
|
||||
default=False,
|
||||
help='If True, pjit will output GDAs.')
|
||||
|
||||
def _update_jax_array_global(val):
|
||||
if val is not None and not val:
|
||||
raise ValueError(
|
||||
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
|
||||
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
|
||||
' jax.config.jax_array.')
|
||||
|
||||
def _update_jax_array_thread_local(val):
|
||||
if val is not None and not val:
|
||||
raise ValueError(
|
||||
'jax.config.jax_array cannot be disabled after jax 0.4.7 release.'
|
||||
' Please downgrade to jax and jaxlib 0.4.6 if you want to disable'
|
||||
' jax.config.jax_array.')
|
||||
|
||||
jax_array = config.define_bool_state(
|
||||
name='jax_array',
|
||||
default=True,
|
||||
upgrade=True,
|
||||
update_global_hook = _update_jax_array_global,
|
||||
update_thread_local_hook = _update_jax_array_thread_local,
|
||||
help=('If True, new pjit behavior will be enabled and `jax.Array` will be '
|
||||
'used.'))
|
||||
|
||||
|
||||
pmap_shmap_merge = config.define_bool_state(
|
||||
name='jax_pmap_shmap_merge',
|
||||
default=False,
|
||||
@ -842,10 +818,9 @@ threefry_partitionable = config.define_bool_state(
|
||||
default=False,
|
||||
upgrade=True,
|
||||
help=('Enables internal threefry PRNG implementation changes that '
|
||||
'render it automatically partitionable in some cases. For use '
|
||||
'with pjit and/or jax_array=True. Without this flag, using the '
|
||||
'standard jax.random pseudo-random number generation may result '
|
||||
'in extraneous communication and/or redundant distributed '
|
||||
'render it automatically partitionable in some cases. Without this '
|
||||
'flag, using the standard jax.random pseudo-random number generation '
|
||||
'may result in extraneous communication and/or redundant distributed '
|
||||
'computation. With this flag, the communication overheads disappear '
|
||||
'in some cases.'),
|
||||
update_global_hook=lambda val: _update_global_jit_state(
|
||||
|
Loading…
x
Reference in New Issue
Block a user