mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Removed unused which_linear= param from pallas_call_p
As far as I can tell, it was threaded through everywhere, but never actually used. PiperOrigin-RevId: 644457293
This commit is contained in:
parent
c8cdf303fb
commit
2bb80d540c
@ -36,7 +36,6 @@ def pallas_call_tpu_lowering_rule(
|
||||
ctx: mlir.LoweringRuleContext, *in_nodes,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name: str,
|
||||
which_linear: tuple[bool, ...],
|
||||
grid_mapping: core.GridMapping,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
@ -49,7 +48,6 @@ def pallas_call_tpu_lowering_rule(
|
||||
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
|
||||
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
|
||||
in_shapes=in_shapes,
|
||||
which_linear=which_linear,
|
||||
interpret=interpret, debug=debug,
|
||||
input_output_aliases=input_output_aliases,
|
||||
grid_mapping=grid_mapping,
|
||||
|
@ -35,7 +35,6 @@ def pallas_call_lowering(
|
||||
name: str,
|
||||
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
which_linear: tuple[bool, ...],
|
||||
interpret: bool,
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
@ -50,7 +49,6 @@ def pallas_call_lowering(
|
||||
name=name,
|
||||
out_shapes=out_shapes,
|
||||
in_shapes=in_shapes,
|
||||
which_linear=which_linear,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
input_output_aliases=input_output_aliases,
|
||||
|
@ -163,7 +163,7 @@ def _get_next_indices(grid, indices):
|
||||
next_indices.append(jnp.where(carry, 0, i))
|
||||
return tuple(reversed(next_indices))
|
||||
|
||||
def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
|
||||
def _pallas_call_impl(*args, jaxpr, name, out_shapes,
|
||||
interpret, debug: bool,
|
||||
in_shapes,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
@ -294,7 +294,7 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
|
||||
return out_nopad
|
||||
return xla.apply_primitive(pallas_call_p, *args, jaxpr=jaxpr, name=name,
|
||||
in_shapes=in_shapes,
|
||||
out_shapes=out_shapes, which_linear=which_linear,
|
||||
out_shapes=out_shapes,
|
||||
grid_mapping=grid_mapping, interpret=interpret,
|
||||
debug=debug,
|
||||
input_output_aliases=input_output_aliases,
|
||||
@ -305,7 +305,7 @@ def _pallas_call_abstract_eval(*avals, out_shapes, **_):
|
||||
return map(lambda x: jax_core.ShapedArray(x.shape, x.dtype), out_shapes)
|
||||
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
|
||||
|
||||
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
|
||||
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
in_shapes, out_shapes, grid_mapping, debug, interpret, compiler_params: Any):
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
@ -351,7 +351,6 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name, which_linear,
|
||||
in_shapes=(*in_shapes, *in_shapes),
|
||||
out_shapes=(*out_shapes, *out_shapes),
|
||||
grid_mapping=grid_mapping.replace(block_mappings=jvp_bms),
|
||||
which_linear=which_linear + (True,) * len(tangents),
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
input_output_aliases=(),
|
||||
@ -439,7 +438,6 @@ def _batch_with_explicit_loop(
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
which_linear: tuple[bool, ...],
|
||||
compiler_params: Any,
|
||||
):
|
||||
"""Batch the pallas_call by calling it in loop over the batch size.
|
||||
@ -506,7 +504,6 @@ def _batch_with_explicit_loop(
|
||||
name=name,
|
||||
in_shapes=in_shapes,
|
||||
out_shapes=out_shapes,
|
||||
which_linear=which_linear,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
@ -540,7 +537,6 @@ def _pallas_call_batching_rule(
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
which_linear: tuple[bool, ...],
|
||||
compiler_params: Any,
|
||||
):
|
||||
|
||||
@ -562,7 +558,6 @@ def _pallas_call_batching_rule(
|
||||
name=name,
|
||||
in_shapes=in_shapes,
|
||||
out_shapes=out_shapes,
|
||||
which_linear=which_linear,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
@ -596,7 +591,6 @@ def _pallas_call_batching_rule(
|
||||
name=name,
|
||||
in_shapes=in_shapes,
|
||||
out_shapes=out_shapes,
|
||||
which_linear=which_linear,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
@ -631,7 +625,6 @@ def _pallas_call_batching_rule(
|
||||
name=name,
|
||||
in_shapes=in_shapes,
|
||||
out_shapes=out_shapes,
|
||||
which_linear=which_linear,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
@ -690,7 +683,6 @@ def _pallas_call_batching_rule(
|
||||
name=f"batched_{name}",
|
||||
in_shapes=batched_in_shapes,
|
||||
out_shapes=batched_out_shapes,
|
||||
which_linear=which_linear,
|
||||
grid_mapping=batched_grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
@ -1040,10 +1032,9 @@ def pallas_call(
|
||||
grid_mapping, jaxpr, consts, _ = _trace_to_jaxpr(
|
||||
f, grid_spec, flat_in_avals, flat_out_avals, in_tree,
|
||||
out_tree, interpret=interpret)
|
||||
which_linear = (False,) * len(flat_args)
|
||||
out_flat = pallas_call_p.bind(
|
||||
*dynamic_grid_bounds, *consts, *flat_args,
|
||||
jaxpr=jaxpr, name=name, which_linear=which_linear,
|
||||
jaxpr=jaxpr, name=name,
|
||||
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
|
||||
for a in flat_args),
|
||||
out_shapes=tuple(flat_out_shapes), debug=debug,
|
||||
|
@ -47,7 +47,6 @@ def pallas_call_lowering(
|
||||
name: str,
|
||||
in_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
out_shapes: tuple[jax.ShapeDtypeStruct, ...],
|
||||
which_linear: tuple[bool, ...],
|
||||
interpret: bool,
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
@ -62,7 +61,6 @@ def pallas_call_lowering(
|
||||
name=name,
|
||||
out_shapes=out_shapes,
|
||||
in_shapes=in_shapes,
|
||||
which_linear=which_linear,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
input_output_aliases=input_output_aliases,
|
||||
|
Loading…
x
Reference in New Issue
Block a user