Use omnistaging env var even when not using absl flags for config. (#4152)

This commit is contained in:
Tom Hennigan 2020-08-26 22:06:27 +01:00 committed by GitHub
parent 1d93991003
commit f0fb7d0925
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 37 additions and 27 deletions

View File

@ -42,10 +42,8 @@ class Config:
self.meta = {}
self.FLAGS = NameSpace(self.read)
self.use_absl = False
self.omnistaging_enabled = False
self.omnistaging_enabled = False
self.omnistaging_enablers = []
self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', False)
self._omnistaging_enablers = []
def update(self, name, val):
if self.use_absl:
@ -119,10 +117,16 @@ class Config:
if FLAGS.jax_omnistaging:
self.enable_omnistaging()
def register_omnistaging_enabler(self, enabler):
if not self.omnistaging_enabled:
self._omnistaging_enablers.append(enabler)
else:
enabler()
# TODO(mattjj): remove this when omnistaging fully lands
def enable_omnistaging(self):
if not self.omnistaging_enabled:
for enabler in self.omnistaging_enablers:
for enabler in self._omnistaging_enablers:
enabler()
self.omnistaging_enabled = True

View File

@ -1431,7 +1431,7 @@ def pp_kv_pairs(kv_pairs):
axis_frame = None
# TODO(mattjj): remove when omnistaging fully lands
@config.omnistaging_enablers.append
@config.register_omnistaging_enabler
@no_type_check
def omnistaging_enabler() -> None:
global thread_local_state, call_bind, find_top_trace, initial_style_staging, \

View File

@ -603,7 +603,7 @@ batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
# TODO(mattjj): remove when omnistaging fully lands
@config.omnistaging_enablers.append
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
global _initial_style_jaxpr

View File

@ -681,7 +681,7 @@ def defvjp2(prim, *vjps):
# TODO(mattjj): remove when omnistaging fully lands
@config.omnistaging_enablers.append
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
global jvp_jaxpr

View File

@ -426,7 +426,7 @@ def _merge_bdims(x, y):
return x # arbitrary
@config.omnistaging_enablers.append
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
global batch_jaxpr

View File

@ -1080,7 +1080,7 @@ def fun_sourceinfo(fun):
# TODO(mattjj): remove when omnistaging fully lands
@config.omnistaging_enablers.append
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
global trace_to_jaxpr, partial_eval_jaxpr

View File

@ -1356,8 +1356,10 @@ def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name):
idx = core.axis_index(axis_name) # type: ignore
return idx * chunk_size + np.arange(chunk_size), True
def deleted_with_omnistaging(*a, **k):
assert False, "Should be deleted"
@config.omnistaging_enablers.append
@config.register_omnistaging_enabler
def omnistaging_enable() -> None:
global DynamicAxisEnvFrame, DynamicAxisEnv, _ThreadLocalState, \
_thread_local_state, extend_dynamic_axis_env, unmapped_device_count, \
@ -1368,9 +1370,11 @@ def omnistaging_enable() -> None:
del DynamicAxisEnvFrame, DynamicAxisEnv, _ThreadLocalState, \
_thread_local_state, extend_dynamic_axis_env, unmapped_device_count, \
axis_index, _axis_index_bind, _axis_index_translation_rule, \
apply_parallel_primitive, parallel_pure_rules, \
_pvals_to_results_handler, _pval_to_result_handler, replicate
apply_parallel_primitive = deleted_with_omnistaging
parallel_pure_rules.clear()
def avals_to_results_handler(size, nrep, npart, out_parts, out_avals):
nouts = len(out_avals)
if out_parts is None:

View File

@ -354,7 +354,7 @@ def with_sharding_constraint(x, partitions: Optional[PartitionSpec]):
return sharding_constraint_p.bind(x, partitions=partitions)
@config.omnistaging_enablers.append
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
global _avals_to_results_handler, _aval_to_result_handler, \
_pvals_to_results_handler, _pval_to_result_handler

View File

@ -1280,7 +1280,17 @@ call_translations[core.call_p] = _call_translation_rule
# TODO(mattjj): remove when omnistaging fully lands
@config.omnistaging_enablers.append
def _pval_to_result_handler(device, pval):
pv, const = pval
if pv is None:
const = _device_put_impl(const, device) if device else const
return lambda _: const
else:
return aval_to_result_handler(device, pv)
pe.staged_out_calls.add(xla_call_p)
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
global _pval_to_result_handler
del _pval_to_result_handler
@ -1292,13 +1302,3 @@ def omnistaging_enabler() -> None:
unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))
parallel_translations[core.axis_index_p] = _axis_index_translation_rule # type: ignore
def _pval_to_result_handler(device, pval):
pv, const = pval
if pv is None:
const = _device_put_impl(const, device) if device else const
return lambda _: const
else:
return aval_to_result_handler(device, pv)
pe.staged_out_calls.add(xla_call_p)

View File

@ -5962,7 +5962,7 @@ def _check_user_dtype_supported(dtype, fun_name=None):
warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
@config.omnistaging_enablers.append
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
global _tie_in_transpose_rule, _tie_in_batch_rule, _tie_in_impl, tie_in_p
del _tie_in_transpose_rule, _tie_in_batch_rule, _tie_in_impl, tie_in_p

View File

@ -2437,7 +2437,7 @@ def associative_scan(fn, elems):
# TODO(mattjj): remove when omnistaging fully lands
@config.omnistaging_enablers.append
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
global _initial_style_untyped_jaxpr, _initial_style_jaxpr, \
_initial_style_jaxprs_with_common_consts

View File

@ -579,7 +579,7 @@ def all_gather(x, axis_name):
return _allgather(x, 0, psum(1, axis_name), axis_name)
@config.omnistaging_enablers.append
@config.register_omnistaging_enabler
def omnistaging_enabler() -> None:
# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at
# tracing time.
@ -595,3 +595,5 @@ def omnistaging_enabler() -> None:
return tuple(size * x for x in args)
return core.Primitive.bind(
psum_p, *args, axis_name=axis_name, axis_index_groups=axis_index_groups)
del pxla.parallel_pure_rules[psum_p]