1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00

Remove experimental_cpp_jit since that flag is unused and also remove experimental_cpp_pjit.

For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags.

I am leaving pmap's flag alone for now.

PiperOrigin-RevId: 522602754
This commit is contained in:
Yash Katariya 2023-04-07 08:28:46 -07:00 committed by jax authors
parent b15ebb1bc5
commit 694e43a44a
5 changed files with 30 additions and 66 deletions

@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.9
* Changes
* The flags experimental_cpp_jit, and experimental_cpp_pjit have been removed.
They are now always on.
* Deprecations
* `jax.experimental.gda_serialization` is deprecated and has been renamed to
`jax.experimental.array_serialization`.

@ -20,7 +20,6 @@ import operator
import google_benchmark
import jax
from jax import lax
from jax._src import config as jax_config
from jax.experimental import sparse
from jax._src.api_util import shaped_abstractify # technically not an api fn
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
@ -693,9 +692,6 @@ def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
x = [x for _ in range(num_args)]
prev_state = jax_config.FLAGS.experimental_cpp_pjit
jax_config.FLAGS.experimental_cpp_pjit = cpp_jit
in_axis_resources = jax.sharding.NamedSharding(mesh, spec)
out_axis_resources = jax.sharding.NamedSharding(mesh, spec)
@ -713,54 +709,40 @@ def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
while state:
x = f(x)
jax_config.FLAGS.experimental_cpp_pjit = prev_state
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([1])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
def pjit_simple_1_device(state):
pjit_simple_benchmark(
state, num_devices=1, num_args=state.range(0), cpp_jit=state.range(1))
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([1])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
def pjit_simple_4_device(state):
pjit_simple_benchmark(
state, num_devices=4, num_args=state.range(0), cpp_jit=state.range(1))
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([1])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
def pjit_simple_4000_device(state):
pjit_simple_benchmark(
state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1))
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([1])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
def pjit_aot_1_device(state):
pjit_simple_benchmark(
state,
@ -771,13 +753,10 @@ def pjit_aot_1_device(state):
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([1])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
def pjit_aot_4_device(state):
pjit_simple_benchmark(
state,
@ -788,13 +767,10 @@ def pjit_aot_4_device(state):
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@google_benchmark.option.arg_names(['num_args'])
@google_benchmark.option.args([1])
@google_benchmark.option.args([10])
@google_benchmark.option.args([100])
def pjit_aot_4000_device(state):
pjit_simple_benchmark(
state,

@ -40,21 +40,11 @@ traceback_util.register_exclusion(__file__)
FLAGS = flags.FLAGS
flags.DEFINE_bool(
"experimental_cpp_jit", bool_env("JAX_CPP_JIT", True),
"A flag enabling the C++ jax.jit fast path."
"Set this to `False` only if it crashes otherwise and report "
"the error to the jax-team.")
flags.DEFINE_bool(
"experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", True),
"A flag enabling the C++ jax.pmap fast path. Until the default "
"is switched to True, the feature is not supported and possibly broken "
"(e.g. it may use unreleased code from jaxlib.")
flags.DEFINE_bool(
"experimental_cpp_pjit", bool_env("JAX_CPP_PJIT", True),
"A flag enabling the C++ pjit fast path. Until the default "
"is switched to True, the feature is not supported and possibly broken "
"(e.g. it may use unreleased code from jaxlib.")
map = safe_map

@ -51,7 +51,6 @@ from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src.config import flags
from jax._src.core import ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -99,8 +98,6 @@ MeshDimAssignment = Union[ShardedAxis, Replicated]
ShardingSpec = sharding_specs.ShardingSpec
### util
def identity(x): return x
@ -2811,9 +2808,6 @@ class MeshExecutable(stages.XlaExecutable):
not self.unsafe_call.has_host_callbacks):
return None
if not flags.FLAGS.experimental_cpp_pjit:
return None
def aot_cache_miss(*args, **kwargs):
params = stages.CompiledCallParams(self, no_kwargs, in_tree, out_tree)
outs, out_flat, args_flat = stages.Compiled.call(params, *args, **kwargs)

@ -37,7 +37,7 @@ from jax._src import xla_bridge as xb
from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, shaped_abstractify, check_callable, resolve_argnums,
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info, FLAGS)
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info)
from jax._src.errors import JAXTypeError
from jax._src.interpreters import partial_eval as pe
from jax._src.partition_spec import PartitionSpec
@ -354,7 +354,7 @@ def pre_infer_params(fun, in_shardings, out_shardings,
def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
donate_argnums, abstracted_axes,
pjit_has_explicit_sharding):
if FLAGS.experimental_cpp_pjit and abstracted_axes is None:
if abstracted_axes is None:
wrapped = _cpp_pjit(fun, infer_params_fn, static_argnums, static_argnames,
donate_argnums, pjit_has_explicit_sharding)
else: