Removed various ununsed functions

To rerun the analysis do

    python -m vulture jax/_src --ignore-names "[A-Za-z]*" --ignore-decorators "*"
This commit is contained in:
Sergei Lebedev 2024-08-01 11:13:08 +01:00
parent 0734345279
commit 92b1f71314
11 changed files with 2 additions and 177 deletions

View File

@ -18,7 +18,6 @@ from functools import partial
import numpy as np
from jax._src import ad_util
from jax._src import core
from jax._src import dtypes
@ -62,11 +61,6 @@ for t in array_types:
core.literalable_types.update(array_types)
def _zeros_like_python_scalar(t, x):
dtype = dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[t])
aval = core.ShapedArray((), dtype, weak_type=True)
return ad_util.zeros_like_aval(aval)
def _make_concrete_python_scalar(t, x):
dtype = dtypes._scalar_type_to_dtype(t, x)
weak_type = dtypes.is_weakly_typed(x)

View File

@ -757,12 +757,6 @@ def flatten_ir_values(xs: Iterable[IrValues]) -> list[ir.Value]:
out.extend(x)
return out
_unflatten_done = object()
def _unwrap_singleton_ir_values(x): return x[0] if len(x) == 1 else x
def flatten_ir_types(xs: Iterable[IrTypes]) -> list[ir.Type]:
"""Concatenates/flattens a list of ir.Types or ir.Type sequences."""
out = []

View File

@ -81,7 +81,6 @@ from jax._src import dtypes
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.lax import lax
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
@ -244,21 +243,6 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension,
weak_type=operand.weak_type),
operand.update(shape=dims, dtype=np.dtype(np.int32)))
def _comparator_builder(op_type, is_max_k):
c = xc.XlaBuilder(
'top_k_{}_comparator'.format('gt' if is_max_k else 'lt'))
p0 = xla.parameter(c, 0, xc.Shape.scalar_shape(op_type))
p1 = xla.parameter(c, 1, xc.Shape.scalar_shape(op_type))
xla.parameter(c, 2, xc.Shape.scalar_shape(np.dtype(np.int32)))
xla.parameter(c, 3, xc.Shape.scalar_shape(np.dtype(np.int32)))
if is_max_k:
cmp_result = xc.ops.Gt(p0, p1)
else:
cmp_result = xc.ops.Lt(p0, p1)
return c.build(cmp_result)
def _get_init_val_literal(op_type, is_max_k):
return np.array(-np.inf if is_max_k else np.inf, dtype=op_type)

View File

@ -1118,23 +1118,6 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts,
f'called with sequence whose items have type\n{_avals_short(x_avals_mapped)}')
return [*init_avals, *y_avals], jaxpr.effects
def _scan_pp_rule(eqn, context, settings):
printed_params = dict(eqn.params)
del printed_params['linear']
if eqn.params['num_consts'] + eqn.params['num_carry'] == len(eqn.invars):
del printed_params['length']
if printed_params['unroll'] == 1:
del printed_params['unroll']
if printed_params['num_carry'] == 0:
del printed_params['num_carry']
if printed_params['num_consts'] == 0:
del printed_params['num_consts']
if not printed_params['reverse']:
del printed_params['reverse']
if not printed_params['_split_transpose']:
del printed_params['_split_transpose']
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)
def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
num_carry, linear, unroll, reverse, length,
_split_transpose):
@ -1233,8 +1216,6 @@ pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom
pe.padding_rules[scan_p] = _scan_padding_rule
pe.dce_rules[scan_p] = _scan_dce_rule
state_discharge.register_discharge_rule(scan_p)(_scan_state_discharge_rule)
# TODO(mattjj,frostig): un-comment this pp rule
# core.pp_eqn_rules[scan_p] = _scan_pp_rule
def _propagate_mem_kind_scan(*xm, reverse, length, num_consts, num_carry, jaxpr,
linear, unroll, _split_transpose):

View File

@ -2267,11 +2267,6 @@ def _add_transpose(t, x, y):
else:
return [_unbroadcast(x_aval, t), _unbroadcast(y_aval, t)]
def _add_inverse(r, x, y):
xr = r - y
yr = r - x
return xr, yr
# TODO(slebedev): Why does mypy fail to infer the type here?
add_p: Primitive = standard_naryop([_num, _num], 'add')
ad.primitive_jvps[add_p] = _add_jvp
@ -2321,11 +2316,6 @@ def _mul_transpose(ct, x, y):
else:
return [None, _unbroadcast(y.aval, mul(x, ct))]
def _mul_inverse(r, x, y):
xr = r / y
yr = r / x
return xr, yr
mul_p = standard_naryop([_num, _num], 'mul')
ad.defjvp(mul_p,
lambda xdot, x, y: mul(xdot, y),
@ -3352,15 +3342,6 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions) ->
return [mlir.broadcast_in_dim(ctx, x, aval_out,
broadcast_dimensions=broadcast_dimensions)]
def _broadcast_in_dim_pp_rule(eqn, context, settings):
# Don't print shape or trivial broadcast_dimensions in params, since it can be
# inferred from the let-binder's type annotation.
printed_params = {}
if eqn.params['broadcast_dimensions']:
printed_params['broadcast_dimensions'] = eqn.params['broadcast_dimensions']
new_eqn = eqn.replpace(params=printed_params, invars=eqn.invars[:1])
return core._pp_eqn(new_eqn, context, settings)
def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions):
if (not dyn_shape and
not any(isinstance(d, core.DArray) and
@ -3385,8 +3366,6 @@ pe.custom_staging_rules[broadcast_in_dim_p] = _broadcast_in_dim_staging_rule
pe.padding_rules[broadcast_in_dim_p] = _broadcast_in_dim_padding_rule
core.custom_typechecks[broadcast_in_dim_p] = _broadcast_in_dim_typecheck_rule
mlir.register_lowering(broadcast_in_dim_p, _broadcast_in_dim_lower)
# TODO(mattjj): un-comment the next line
# core.pp_eqn_rules[broadcast_in_dim_p] = _broadcast_in_dim_pp_rule
def _clamp_shape_rule(min, operand, max):
@ -4161,9 +4140,6 @@ pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, _reduce_prod,
_get_prod_identity)
def _reduce_chooser_shape_rule(operand, *, axes):
return tuple(np.delete(operand.shape, axes))
def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
# TODO(mattjj): an alternative is to use variadic reduce to compute the chosen
# locations in a single pass (rather than comparing equality) and use a
@ -4989,13 +4965,6 @@ def _iota_batching_rule(in_vals, in_dims, *, dtype, shape, dimension):
return iota, batching.RaggedAxis(ax, ((ragged_axis+1, segment_lengths),))
batching.primitive_batchers[iota_p] = _iota_batching_rule
def _iota_pp_rule(eqn, context, settings):
printed_params = {}
if len(eqn.params['shape']) > 1:
printed_params['dimension'] = eqn.params['dimension']
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)
# core.pp_eqn_rules[iota_p] = _iota_pp_rule
def _iota_padding_rule(in_avals, out_avals, *dyn_shape, dtype, shape, dimension):
out_aval, = out_avals
new_shape = []

View File

@ -841,17 +841,6 @@ def _foldaxis(axis, x):
new_shape[axis:axis+2] = [x.shape[axis] * x.shape[axis + 1]]
return x.reshape(new_shape)
def _index_in_group(axis_name, axis_index_groups):
cur_device_id = axis_index(axis_name)
if axis_index_groups is None:
return cur_device_id
# We use argsort to invert the axis_index_groups permutation
flat_groups = np.array(axis_index_groups).flatten()
device_id_to_idx = flat_groups.argsort() % len(axis_index_groups[0])
return lax.squeeze(
slicing.dynamic_slice_in_dim(device_id_to_idx, cur_device_id, 1), [0])
def _all_to_all_lowering(
ctx, x, *, split_axis, concat_axis, axis_name, axis_index_groups, tiled
):
@ -1070,18 +1059,6 @@ def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False):
axis_size=axis_size, tiled=tiled)
return tree_util.tree_map(bind, x)
def _expand(dim, size, index, tiled, x):
shape = list(x.shape)
if tiled:
tile_size = shape[dim]
shape[dim] *= size
out = lax.full(shape, lax._const(x, 0))
return slicing.dynamic_update_slice_in_dim(out, x, index * tile_size, dim)
else:
shape.insert(dim, size)
out = lax.full(shape, lax._const(x, 0))
return slicing.dynamic_update_index_in_dim(out, x, index, dim)
def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled):
raise AssertionError("Unexpected call to _all_gather_impl")

View File

@ -569,13 +569,6 @@ def _convert_block_spec_to_block_mapping(
mapping.check_invariants()
return mapping
def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
) -> state.AbstractRef:
if block_shape is None:
return ref
shape = tuple(s for s in block_shape if s is not None)
return ref.update(inner_aval=ref.inner_aval.update(shape=shape))
index_map_grid_aval = jax_core.ShapedArray((), jnp.int32)
@dataclasses.dataclass(init=False)

View File

@ -847,27 +847,6 @@ def _ensure_mlir_value(val, aval):
)
def _convert_flat_indexing_to_indexer(ref_aval, non_slice_idx,
non_slice_idx_avals, indexed_dims):
non_slice_idx_iter = iter(zip(non_slice_idx, non_slice_idx_avals))
splatted_idx_idx_avals = tuple(
next(non_slice_idx_iter)
if indexed
else (primitives.Slice(0, s), primitives.Slice(0, s))
for s, indexed in zip(ref_aval.shape,indexed_dims)
)
splatted_idx, splatted_idx_avals = unzip2(splatted_idx_idx_avals)
if non_slice_idx:
(int_indexer_shape,) = {idx_aval.shape for idx_aval in splatted_idx_avals
if not isinstance(idx_aval, primitives.Slice)}
else:
int_indexer_shape = ()
nd_indexer = NDIndexer(splatted_idx, ref_aval.shape, int_indexer_shape)
nd_indexer_avals = NDIndexer(splatted_idx_avals, ref_aval.shape,
int_indexer_shape)
return nd_indexer, nd_indexer_avals
def _get_lowering_rule(
ctx: LoweringRuleContext, ref, *idx, tree,
):
@ -2151,27 +2130,6 @@ def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext,
return for_op.results
def _lower_jaxpr_to_unrolled_for_loop(ctx: LoweringRuleContext,
jaxpr: jax_core.Jaxpr, start: int,
num_steps: int, consts, *args,
has_loop_index: bool):
for i in range(start, start + num_steps):
if has_loop_index:
lowering_context = ctx.lowering_context.replace(
block_shapes=ctx.block_shapes)
args = jaxpr_subcomp(
lowering_context, jaxpr, *consts,
ir_constant(i, mlir_type=_dtype_to_ir_type(jnp.dtype('int32'))),
*args)
else:
lowering_context = ctx.lowering_context.replace(
block_shapes=ctx.block_shapes[:len(consts)]
+ ctx.block_shapes[len(consts) + 1:],
)
args = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
return args
def _scan_lowering_rule(
ctx: LoweringRuleContext,
*args,

View File

@ -71,10 +71,6 @@ def _mul(scalar, tree):
return tree_map(partial(operator.mul, scalar), tree)
def _div(tree, scalar):
return tree_map(partial(lambda v: v / scalar), tree)
_add = partial(tree_map, operator.add)
_sub = partial(tree_map, operator.sub)
_dot_tree = partial(tree_map, _dot)

View File

@ -286,12 +286,6 @@ def _get_discharge(x, idx, tree):
indexers = tree_util.tree_unflatten(tree, idx)
return index_array(x, indexers)
def _indexer(idx, indexed_dims):
idx_ = iter(idx)
indexer = tuple(next(idx_) if b else slice(None) for b in indexed_dims)
assert next(idx_, None) is None
return indexer
@register_discharge_rule(swap_p)
def _swap_discharge_rule(
in_avals: Sequence[core.AbstractValue],

View File

@ -15,6 +15,7 @@
from __future__ import annotations
from functools import partial
import types
from typing import Any, Union
import numpy as np
@ -56,18 +57,7 @@ zip, unsafe_zip = safe_zip, zip
get_p = core.Primitive("get")
get_p.def_impl(partial(dispatch.apply_primitive, get_p))
Indexer = tuple[Union[int, slice, Array], ...]
# or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType)
def _get_slice_output_shape(in_shape: tuple[int, ...],
idx_shapes: tuple[tuple[int, ...], ...],
indexed_dims: tuple[bool, ...]) -> tuple[int, ...]:
shape_suffix = [d for i, d in zip(indexed_dims, in_shape) if not i]
shape_prefix, = set(idx_shapes) or [()] # tie fighter
# Move shape prefix dimensions to the front
shape = (*shape_prefix, *shape_suffix)
return shape
Indexer = tuple[Union[int, slice, Array, types.EllipsisType], ...]
def get_ref_and_indexers(
ref_or_view: Any, idx: Indexer | None, function_name: str
@ -408,11 +398,6 @@ pe.partial_eval_jaxpr_custom_rules[addupdate_p] = partial(
## get/swap/addupdate batching rules
def _output_bdim(indexed_dims: tuple[bool, ...], ref_dim: int,
idxs_shape: tuple[int, ...]):
num_idxs_to_left = sum(indexed_dims[:ref_dim])
return ref_dim - num_idxs_to_left + len(idxs_shape)
def _batch_indexer(indexer: indexing.NDIndexer, dims,
axis_size: int,
ref_shape: tuple[int, ...],