Removed unused which_linear= param from pallas_call_p

As far as I can tell, it was threaded through everywhere, but never actually used.

PiperOrigin-RevId: 644457293
This commit is contained in:
Sergei Lebedev 2024-06-18 11:29:21 -07:00 committed by jax authors
parent c8cdf303fb
commit 2bb80d540c
4 changed files with 4 additions and 19 deletions

View File

@ -36,7 +36,6 @@ def pallas_call_tpu_lowering_rule(
ctx: mlir.LoweringRuleContext, *in_nodes,
jaxpr: jax_core.Jaxpr,
name: str,
which_linear: tuple[bool, ...],
grid_mapping: core.GridMapping,
input_output_aliases: tuple[tuple[int, int], ...],
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
@ -49,7 +48,6 @@ def pallas_call_tpu_lowering_rule(
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
in_shapes=in_shapes,
which_linear=which_linear,
interpret=interpret, debug=debug,
input_output_aliases=input_output_aliases,
grid_mapping=grid_mapping,

View File

@ -35,7 +35,6 @@ def pallas_call_lowering(
name: str,
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
which_linear: tuple[bool, ...],
interpret: bool,
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
@ -50,7 +49,6 @@ def pallas_call_lowering(
name=name,
out_shapes=out_shapes,
in_shapes=in_shapes,
which_linear=which_linear,
interpret=interpret,
debug=debug,
input_output_aliases=input_output_aliases,

View File

@ -163,7 +163,7 @@ def _get_next_indices(grid, indices):
next_indices.append(jnp.where(carry, 0, i))
return tuple(reversed(next_indices))
def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
def _pallas_call_impl(*args, jaxpr, name, out_shapes,
interpret, debug: bool,
in_shapes,
input_output_aliases: tuple[tuple[int, int], ...],
@ -294,7 +294,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
return out_nopad
return xla.apply_primitive(pallas_call_p, *args, jaxpr=jaxpr, name=name,
in_shapes=in_shapes,
out_shapes=out_shapes, which_linear=which_linear,
out_shapes=out_shapes,
grid_mapping=grid_mapping, interpret=interpret,
debug=debug,
input_output_aliases=input_output_aliases,
@ -305,7 +305,7 @@ def _pallas_call_abstract_eval(*avals, out_shapes, **_):
return map(lambda x: jax_core.ShapedArray(x.shape, x.dtype), out_shapes)
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
input_output_aliases: tuple[tuple[int, int], ...],
in_shapes, out_shapes, grid_mapping, debug, interpret, compiler_params: Any):
if grid_mapping.num_dynamic_grid_bounds:
@ -351,7 +351,6 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
in_shapes=(*in_shapes, *in_shapes),
out_shapes=(*out_shapes, *out_shapes),
grid_mapping=grid_mapping.replace(block_mappings=jvp_bms),
which_linear=which_linear + (True,) * len(tangents),
interpret=interpret,
debug=debug,
input_output_aliases=(),
@ -439,7 +438,6 @@ def _batch_with_explicit_loop(
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
interpret: bool,
which_linear: tuple[bool, ...],
compiler_params: Any,
):
"""Batch the pallas_call by calling it in loop over the batch size.
@ -506,7 +504,6 @@ def _batch_with_explicit_loop(
name=name,
in_shapes=in_shapes,
out_shapes=out_shapes,
which_linear=which_linear,
grid_mapping=grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
@ -540,7 +537,6 @@ def _pallas_call_batching_rule(
input_output_aliases: tuple[tuple[int, int], ...],
debug: bool,
interpret: bool,
which_linear: tuple[bool, ...],
compiler_params: Any,
):
@ -562,7 +558,6 @@ def _pallas_call_batching_rule(
name=name,
in_shapes=in_shapes,
out_shapes=out_shapes,
which_linear=which_linear,
grid_mapping=grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
@ -596,7 +591,6 @@ def _pallas_call_batching_rule(
name=name,
in_shapes=in_shapes,
out_shapes=out_shapes,
which_linear=which_linear,
grid_mapping=grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
@ -631,7 +625,6 @@ def _pallas_call_batching_rule(
name=name,
in_shapes=in_shapes,
out_shapes=out_shapes,
which_linear=which_linear,
grid_mapping=grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
@ -690,7 +683,6 @@ def _pallas_call_batching_rule(
name=f"batched_{name}",
in_shapes=batched_in_shapes,
out_shapes=batched_out_shapes,
which_linear=which_linear,
grid_mapping=batched_grid_mapping,
input_output_aliases=input_output_aliases,
debug=debug,
@ -1040,10 +1032,9 @@ def pallas_call(
grid_mapping, jaxpr, consts, _ = _trace_to_jaxpr(
f, grid_spec, flat_in_avals, flat_out_avals, in_tree,
out_tree, interpret=interpret)
which_linear = (False,) * len(flat_args)
out_flat = pallas_call_p.bind(
*dynamic_grid_bounds, *consts, *flat_args,
jaxpr=jaxpr, name=name, which_linear=which_linear,
jaxpr=jaxpr, name=name,
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
for a in flat_args),
out_shapes=tuple(flat_out_shapes), debug=debug,

View File

@ -47,7 +47,6 @@ def pallas_call_lowering(
name: str,
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
which_linear: tuple[bool, ...],
interpret: bool,
debug: bool,
input_output_aliases: tuple[tuple[int, int], ...],
@ -62,7 +61,6 @@ def pallas_call_lowering(
name=name,
out_shapes=out_shapes,
in_shapes=in_shapes,
which_linear=which_linear,
interpret=interpret,
debug=debug,
input_output_aliases=input_output_aliases,