mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Use omnistaging env var even when not using absl flags for config. (#4152)
This commit is contained in:
parent
1d93991003
commit
f0fb7d0925
@ -42,10 +42,8 @@ class Config:
|
||||
self.meta = {}
|
||||
self.FLAGS = NameSpace(self.read)
|
||||
self.use_absl = False
|
||||
self.omnistaging_enabled = False
|
||||
|
||||
self.omnistaging_enabled = False
|
||||
self.omnistaging_enablers = []
|
||||
self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', False)
|
||||
self._omnistaging_enablers = []
|
||||
|
||||
def update(self, name, val):
|
||||
if self.use_absl:
|
||||
@ -119,10 +117,16 @@ class Config:
|
||||
if FLAGS.jax_omnistaging:
|
||||
self.enable_omnistaging()
|
||||
|
||||
def register_omnistaging_enabler(self, enabler):
|
||||
if not self.omnistaging_enabled:
|
||||
self._omnistaging_enablers.append(enabler)
|
||||
else:
|
||||
enabler()
|
||||
|
||||
# TODO(mattjj): remove this when omnistaging fully lands
|
||||
def enable_omnistaging(self):
|
||||
if not self.omnistaging_enabled:
|
||||
for enabler in self.omnistaging_enablers:
|
||||
for enabler in self._omnistaging_enablers:
|
||||
enabler()
|
||||
self.omnistaging_enabled = True
|
||||
|
||||
|
@ -1431,7 +1431,7 @@ def pp_kv_pairs(kv_pairs):
|
||||
axis_frame = None
|
||||
|
||||
# TODO(mattjj): remove when omnistaging fully lands
|
||||
@config.omnistaging_enablers.append
|
||||
@config.register_omnistaging_enabler
|
||||
@no_type_check
|
||||
def omnistaging_enabler() -> None:
|
||||
global thread_local_state, call_bind, find_top_trace, initial_style_staging, \
|
||||
|
@ -603,7 +603,7 @@ batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
|
||||
|
||||
|
||||
# TODO(mattjj): remove when omnistaging fully lands
|
||||
@config.omnistaging_enablers.append
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
global _initial_style_jaxpr
|
||||
|
||||
|
@ -681,7 +681,7 @@ def defvjp2(prim, *vjps):
|
||||
|
||||
|
||||
# TODO(mattjj): remove when omnistaging fully lands
|
||||
@config.omnistaging_enablers.append
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
global jvp_jaxpr
|
||||
|
||||
|
@ -426,7 +426,7 @@ def _merge_bdims(x, y):
|
||||
return x # arbitrary
|
||||
|
||||
|
||||
@config.omnistaging_enablers.append
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
global batch_jaxpr
|
||||
|
||||
|
@ -1080,7 +1080,7 @@ def fun_sourceinfo(fun):
|
||||
|
||||
# TODO(mattjj): remove when omnistaging fully lands
|
||||
|
||||
@config.omnistaging_enablers.append
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
global trace_to_jaxpr, partial_eval_jaxpr
|
||||
|
||||
|
@ -1356,8 +1356,10 @@ def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name):
|
||||
idx = core.axis_index(axis_name) # type: ignore
|
||||
return idx * chunk_size + np.arange(chunk_size), True
|
||||
|
||||
def deleted_with_omnistaging(*a, **k):
|
||||
assert False, "Should be deleted"
|
||||
|
||||
@config.omnistaging_enablers.append
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enable() -> None:
|
||||
global DynamicAxisEnvFrame, DynamicAxisEnv, _ThreadLocalState, \
|
||||
_thread_local_state, extend_dynamic_axis_env, unmapped_device_count, \
|
||||
@ -1368,9 +1370,11 @@ def omnistaging_enable() -> None:
|
||||
del DynamicAxisEnvFrame, DynamicAxisEnv, _ThreadLocalState, \
|
||||
_thread_local_state, extend_dynamic_axis_env, unmapped_device_count, \
|
||||
axis_index, _axis_index_bind, _axis_index_translation_rule, \
|
||||
apply_parallel_primitive, parallel_pure_rules, \
|
||||
_pvals_to_results_handler, _pval_to_result_handler, replicate
|
||||
|
||||
apply_parallel_primitive = deleted_with_omnistaging
|
||||
parallel_pure_rules.clear()
|
||||
|
||||
def avals_to_results_handler(size, nrep, npart, out_parts, out_avals):
|
||||
nouts = len(out_avals)
|
||||
if out_parts is None:
|
||||
|
@ -354,7 +354,7 @@ def with_sharding_constraint(x, partitions: Optional[PartitionSpec]):
|
||||
return sharding_constraint_p.bind(x, partitions=partitions)
|
||||
|
||||
|
||||
@config.omnistaging_enablers.append
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
global _avals_to_results_handler, _aval_to_result_handler, \
|
||||
_pvals_to_results_handler, _pval_to_result_handler
|
||||
|
@ -1280,7 +1280,17 @@ call_translations[core.call_p] = _call_translation_rule
|
||||
|
||||
# TODO(mattjj): remove when omnistaging fully lands
|
||||
|
||||
@config.omnistaging_enablers.append
|
||||
def _pval_to_result_handler(device, pval):
|
||||
pv, const = pval
|
||||
if pv is None:
|
||||
const = _device_put_impl(const, device) if device else const
|
||||
return lambda _: const
|
||||
else:
|
||||
return aval_to_result_handler(device, pv)
|
||||
|
||||
pe.staged_out_calls.add(xla_call_p)
|
||||
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
global _pval_to_result_handler
|
||||
del _pval_to_result_handler
|
||||
@ -1292,13 +1302,3 @@ def omnistaging_enabler() -> None:
|
||||
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
|
||||
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
|
||||
parallel_translations[core.axis_index_p] = _axis_index_translation_rule # type: ignore
|
||||
|
||||
def _pval_to_result_handler(device, pval):
|
||||
pv, const = pval
|
||||
if pv is None:
|
||||
const = _device_put_impl(const, device) if device else const
|
||||
return lambda _: const
|
||||
else:
|
||||
return aval_to_result_handler(device, pv)
|
||||
|
||||
pe.staged_out_calls.add(xla_call_p)
|
||||
|
@ -5962,7 +5962,7 @@ def _check_user_dtype_supported(dtype, fun_name=None):
|
||||
warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
|
||||
|
||||
|
||||
@config.omnistaging_enablers.append
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
global _tie_in_transpose_rule, _tie_in_batch_rule, _tie_in_impl, tie_in_p
|
||||
del _tie_in_transpose_rule, _tie_in_batch_rule, _tie_in_impl, tie_in_p
|
||||
|
@ -2437,7 +2437,7 @@ def associative_scan(fn, elems):
|
||||
|
||||
|
||||
# TODO(mattjj): remove when omnistaging fully lands
|
||||
@config.omnistaging_enablers.append
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
global _initial_style_untyped_jaxpr, _initial_style_jaxpr, \
|
||||
_initial_style_jaxprs_with_common_consts
|
||||
|
@ -579,7 +579,7 @@ def all_gather(x, axis_name):
|
||||
return _allgather(x, 0, psum(1, axis_name), axis_name)
|
||||
|
||||
|
||||
@config.omnistaging_enablers.append
|
||||
@config.register_omnistaging_enabler
|
||||
def omnistaging_enabler() -> None:
|
||||
# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at
|
||||
# tracing time.
|
||||
@ -595,3 +595,5 @@ def omnistaging_enabler() -> None:
|
||||
return tuple(size * x for x in args)
|
||||
return core.Primitive.bind(
|
||||
psum_p, *args, axis_name=axis_name, axis_index_groups=axis_index_groups)
|
||||
|
||||
del pxla.parallel_pure_rules[psum_p]
|
||||
|
Loading…
x
Reference in New Issue
Block a user