mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
Remove experimental_cpp_jit
since that flag is unused and also remove experimental_cpp_pjit
.
For dynamic shapes experimentation and normal debugging, `python_pjit` still exists so that problem doesn't exist which makes us free to remove these 2 flags. I am leaving pmap's flag alone for now. PiperOrigin-RevId: 522602754
This commit is contained in:
parent
b15ebb1bc5
commit
694e43a44a
@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.9
|
||||
|
||||
* Changes
|
||||
* The flags experimental_cpp_jit, and experimental_cpp_pjit have been removed.
|
||||
They are now always on.
|
||||
|
||||
* Deprecations
|
||||
* `jax.experimental.gda_serialization` is deprecated and has been renamed to
|
||||
`jax.experimental.array_serialization`.
|
||||
|
@ -20,7 +20,6 @@ import operator
|
||||
import google_benchmark
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import config as jax_config
|
||||
from jax.experimental import sparse
|
||||
from jax._src.api_util import shaped_abstractify # technically not an api fn
|
||||
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
|
||||
@ -693,9 +692,6 @@ def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
|
||||
|
||||
x = [x for _ in range(num_args)]
|
||||
|
||||
prev_state = jax_config.FLAGS.experimental_cpp_pjit
|
||||
jax_config.FLAGS.experimental_cpp_pjit = cpp_jit
|
||||
|
||||
in_axis_resources = jax.sharding.NamedSharding(mesh, spec)
|
||||
out_axis_resources = jax.sharding.NamedSharding(mesh, spec)
|
||||
|
||||
@ -713,54 +709,40 @@ def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
|
||||
while state:
|
||||
x = f(x)
|
||||
|
||||
jax_config.FLAGS.experimental_cpp_pjit = prev_state
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
|
||||
@google_benchmark.option.args([1, False])
|
||||
@google_benchmark.option.args([1, True])
|
||||
@google_benchmark.option.args([10, False])
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@google_benchmark.option.arg_names(['num_args'])
|
||||
@google_benchmark.option.args([1])
|
||||
@google_benchmark.option.args([10])
|
||||
@google_benchmark.option.args([100])
|
||||
def pjit_simple_1_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state, num_devices=1, num_args=state.range(0), cpp_jit=state.range(1))
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
|
||||
@google_benchmark.option.args([1, False])
|
||||
@google_benchmark.option.args([1, True])
|
||||
@google_benchmark.option.args([10, False])
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@google_benchmark.option.arg_names(['num_args'])
|
||||
@google_benchmark.option.args([1])
|
||||
@google_benchmark.option.args([10])
|
||||
@google_benchmark.option.args([100])
|
||||
def pjit_simple_4_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state, num_devices=4, num_args=state.range(0), cpp_jit=state.range(1))
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
|
||||
@google_benchmark.option.args([1, False])
|
||||
@google_benchmark.option.args([1, True])
|
||||
@google_benchmark.option.args([10, False])
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@google_benchmark.option.arg_names(['num_args'])
|
||||
@google_benchmark.option.args([1])
|
||||
@google_benchmark.option.args([10])
|
||||
@google_benchmark.option.args([100])
|
||||
def pjit_simple_4000_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1))
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
|
||||
@google_benchmark.option.args([1, False])
|
||||
@google_benchmark.option.args([1, True])
|
||||
@google_benchmark.option.args([10, False])
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@google_benchmark.option.arg_names(['num_args'])
|
||||
@google_benchmark.option.args([1])
|
||||
@google_benchmark.option.args([10])
|
||||
@google_benchmark.option.args([100])
|
||||
def pjit_aot_1_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state,
|
||||
@ -771,13 +753,10 @@ def pjit_aot_1_device(state):
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
|
||||
@google_benchmark.option.args([1, False])
|
||||
@google_benchmark.option.args([1, True])
|
||||
@google_benchmark.option.args([10, False])
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@google_benchmark.option.arg_names(['num_args'])
|
||||
@google_benchmark.option.args([1])
|
||||
@google_benchmark.option.args([10])
|
||||
@google_benchmark.option.args([100])
|
||||
def pjit_aot_4_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state,
|
||||
@ -788,13 +767,10 @@ def pjit_aot_4_device(state):
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
|
||||
@google_benchmark.option.args([1, False])
|
||||
@google_benchmark.option.args([1, True])
|
||||
@google_benchmark.option.args([10, False])
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@google_benchmark.option.arg_names(['num_args'])
|
||||
@google_benchmark.option.args([1])
|
||||
@google_benchmark.option.args([10])
|
||||
@google_benchmark.option.args([100])
|
||||
def pjit_aot_4000_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state,
|
||||
|
@ -40,21 +40,11 @@ traceback_util.register_exclusion(__file__)
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_bool(
|
||||
"experimental_cpp_jit", bool_env("JAX_CPP_JIT", True),
|
||||
"A flag enabling the C++ jax.jit fast path."
|
||||
"Set this to `False` only if it crashes otherwise and report "
|
||||
"the error to the jax-team.")
|
||||
flags.DEFINE_bool(
|
||||
"experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", True),
|
||||
"A flag enabling the C++ jax.pmap fast path. Until the default "
|
||||
"is switched to True, the feature is not supported and possibly broken "
|
||||
"(e.g. it may use unreleased code from jaxlib.")
|
||||
flags.DEFINE_bool(
|
||||
"experimental_cpp_pjit", bool_env("JAX_CPP_PJIT", True),
|
||||
"A flag enabling the C++ pjit fast path. Until the default "
|
||||
"is switched to True, the feature is not supported and possibly broken "
|
||||
"(e.g. it may use unreleased code from jaxlib.")
|
||||
|
||||
map = safe_map
|
||||
|
||||
|
@ -51,7 +51,6 @@ from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax._src.config import config
|
||||
from jax._src.config import flags
|
||||
from jax._src.core import ShapedArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
@ -99,8 +98,6 @@ MeshDimAssignment = Union[ShardedAxis, Replicated]
|
||||
ShardingSpec = sharding_specs.ShardingSpec
|
||||
|
||||
|
||||
|
||||
|
||||
### util
|
||||
|
||||
def identity(x): return x
|
||||
@ -2811,9 +2808,6 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
not self.unsafe_call.has_host_callbacks):
|
||||
return None
|
||||
|
||||
if not flags.FLAGS.experimental_cpp_pjit:
|
||||
return None
|
||||
|
||||
def aot_cache_miss(*args, **kwargs):
|
||||
params = stages.CompiledCallParams(self, no_kwargs, in_tree, out_tree)
|
||||
outs, out_flat, args_flat = stages.Compiled.call(params, *args, **kwargs)
|
||||
|
@ -37,7 +37,7 @@ from jax._src import xla_bridge as xb
|
||||
from jax._src.api_util import (
|
||||
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
|
||||
donation_vector, shaped_abstractify, check_callable, resolve_argnums,
|
||||
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info, FLAGS)
|
||||
argnames_partial_except, debug_info, result_paths, jaxpr_debug_info)
|
||||
from jax._src.errors import JAXTypeError
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -354,7 +354,7 @@ def pre_infer_params(fun, in_shardings, out_shardings,
|
||||
def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
donate_argnums, abstracted_axes,
|
||||
pjit_has_explicit_sharding):
|
||||
if FLAGS.experimental_cpp_pjit and abstracted_axes is None:
|
||||
if abstracted_axes is None:
|
||||
wrapped = _cpp_pjit(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
donate_argnums, pjit_has_explicit_sharding)
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user