mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Removed various ununsed functions
To rerun the analysis do python -m vulture jax/_src --ignore-names "[A-Za-z]*" --ignore-decorators "*"
This commit is contained in:
parent
0734345279
commit
92b1f71314
@ -18,7 +18,6 @@ from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
|
||||
@ -62,11 +61,6 @@ for t in array_types:
|
||||
|
||||
core.literalable_types.update(array_types)
|
||||
|
||||
def _zeros_like_python_scalar(t, x):
|
||||
dtype = dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[t])
|
||||
aval = core.ShapedArray((), dtype, weak_type=True)
|
||||
return ad_util.zeros_like_aval(aval)
|
||||
|
||||
def _make_concrete_python_scalar(t, x):
|
||||
dtype = dtypes._scalar_type_to_dtype(t, x)
|
||||
weak_type = dtypes.is_weakly_typed(x)
|
||||
|
@ -757,12 +757,6 @@ def flatten_ir_values(xs: Iterable[IrValues]) -> list[ir.Value]:
|
||||
out.extend(x)
|
||||
return out
|
||||
|
||||
|
||||
_unflatten_done = object()
|
||||
|
||||
def _unwrap_singleton_ir_values(x): return x[0] if len(x) == 1 else x
|
||||
|
||||
|
||||
def flatten_ir_types(xs: Iterable[IrTypes]) -> list[ir.Type]:
|
||||
"""Concatenates/flattens a list of ir.Types or ir.Type sequences."""
|
||||
out = []
|
||||
|
@ -81,7 +81,6 @@ from jax._src import dtypes
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
@ -244,21 +243,6 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension,
|
||||
weak_type=operand.weak_type),
|
||||
operand.update(shape=dims, dtype=np.dtype(np.int32)))
|
||||
|
||||
|
||||
def _comparator_builder(op_type, is_max_k):
|
||||
c = xc.XlaBuilder(
|
||||
'top_k_{}_comparator'.format('gt' if is_max_k else 'lt'))
|
||||
p0 = xla.parameter(c, 0, xc.Shape.scalar_shape(op_type))
|
||||
p1 = xla.parameter(c, 1, xc.Shape.scalar_shape(op_type))
|
||||
xla.parameter(c, 2, xc.Shape.scalar_shape(np.dtype(np.int32)))
|
||||
xla.parameter(c, 3, xc.Shape.scalar_shape(np.dtype(np.int32)))
|
||||
if is_max_k:
|
||||
cmp_result = xc.ops.Gt(p0, p1)
|
||||
else:
|
||||
cmp_result = xc.ops.Lt(p0, p1)
|
||||
return c.build(cmp_result)
|
||||
|
||||
|
||||
def _get_init_val_literal(op_type, is_max_k):
|
||||
return np.array(-np.inf if is_max_k else np.inf, dtype=op_type)
|
||||
|
||||
|
@ -1118,23 +1118,6 @@ def _scan_typecheck(bind_time, *in_atoms, reverse, length, num_consts,
|
||||
f'called with sequence whose items have type\n{_avals_short(x_avals_mapped)}')
|
||||
return [*init_avals, *y_avals], jaxpr.effects
|
||||
|
||||
def _scan_pp_rule(eqn, context, settings):
|
||||
printed_params = dict(eqn.params)
|
||||
del printed_params['linear']
|
||||
if eqn.params['num_consts'] + eqn.params['num_carry'] == len(eqn.invars):
|
||||
del printed_params['length']
|
||||
if printed_params['unroll'] == 1:
|
||||
del printed_params['unroll']
|
||||
if printed_params['num_carry'] == 0:
|
||||
del printed_params['num_carry']
|
||||
if printed_params['num_consts'] == 0:
|
||||
del printed_params['num_consts']
|
||||
if not printed_params['reverse']:
|
||||
del printed_params['reverse']
|
||||
if not printed_params['_split_transpose']:
|
||||
del printed_params['_split_transpose']
|
||||
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)
|
||||
|
||||
def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
|
||||
num_carry, linear, unroll, reverse, length,
|
||||
_split_transpose):
|
||||
@ -1233,8 +1216,6 @@ pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom
|
||||
pe.padding_rules[scan_p] = _scan_padding_rule
|
||||
pe.dce_rules[scan_p] = _scan_dce_rule
|
||||
state_discharge.register_discharge_rule(scan_p)(_scan_state_discharge_rule)
|
||||
# TODO(mattjj,frostig): un-comment this pp rule
|
||||
# core.pp_eqn_rules[scan_p] = _scan_pp_rule
|
||||
|
||||
def _propagate_mem_kind_scan(*xm, reverse, length, num_consts, num_carry, jaxpr,
|
||||
linear, unroll, _split_transpose):
|
||||
|
@ -2267,11 +2267,6 @@ def _add_transpose(t, x, y):
|
||||
else:
|
||||
return [_unbroadcast(x_aval, t), _unbroadcast(y_aval, t)]
|
||||
|
||||
def _add_inverse(r, x, y):
|
||||
xr = r - y
|
||||
yr = r - x
|
||||
return xr, yr
|
||||
|
||||
# TODO(slebedev): Why does mypy fail to infer the type here?
|
||||
add_p: Primitive = standard_naryop([_num, _num], 'add')
|
||||
ad.primitive_jvps[add_p] = _add_jvp
|
||||
@ -2321,11 +2316,6 @@ def _mul_transpose(ct, x, y):
|
||||
else:
|
||||
return [None, _unbroadcast(y.aval, mul(x, ct))]
|
||||
|
||||
def _mul_inverse(r, x, y):
|
||||
xr = r / y
|
||||
yr = r / x
|
||||
return xr, yr
|
||||
|
||||
mul_p = standard_naryop([_num, _num], 'mul')
|
||||
ad.defjvp(mul_p,
|
||||
lambda xdot, x, y: mul(xdot, y),
|
||||
@ -3352,15 +3342,6 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions) ->
|
||||
return [mlir.broadcast_in_dim(ctx, x, aval_out,
|
||||
broadcast_dimensions=broadcast_dimensions)]
|
||||
|
||||
def _broadcast_in_dim_pp_rule(eqn, context, settings):
|
||||
# Don't print shape or trivial broadcast_dimensions in params, since it can be
|
||||
# inferred from the let-binder's type annotation.
|
||||
printed_params = {}
|
||||
if eqn.params['broadcast_dimensions']:
|
||||
printed_params['broadcast_dimensions'] = eqn.params['broadcast_dimensions']
|
||||
new_eqn = eqn.replpace(params=printed_params, invars=eqn.invars[:1])
|
||||
return core._pp_eqn(new_eqn, context, settings)
|
||||
|
||||
def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions):
|
||||
if (not dyn_shape and
|
||||
not any(isinstance(d, core.DArray) and
|
||||
@ -3385,8 +3366,6 @@ pe.custom_staging_rules[broadcast_in_dim_p] = _broadcast_in_dim_staging_rule
|
||||
pe.padding_rules[broadcast_in_dim_p] = _broadcast_in_dim_padding_rule
|
||||
core.custom_typechecks[broadcast_in_dim_p] = _broadcast_in_dim_typecheck_rule
|
||||
mlir.register_lowering(broadcast_in_dim_p, _broadcast_in_dim_lower)
|
||||
# TODO(mattjj): un-comment the next line
|
||||
# core.pp_eqn_rules[broadcast_in_dim_p] = _broadcast_in_dim_pp_rule
|
||||
|
||||
|
||||
def _clamp_shape_rule(min, operand, max):
|
||||
@ -4161,9 +4140,6 @@ pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, _reduce_prod,
|
||||
_get_prod_identity)
|
||||
|
||||
|
||||
def _reduce_chooser_shape_rule(operand, *, axes):
|
||||
return tuple(np.delete(operand.shape, axes))
|
||||
|
||||
def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
|
||||
# TODO(mattjj): an alternative is to use variadic reduce to compute the chosen
|
||||
# locations in a single pass (rather than comparing equality) and use a
|
||||
@ -4989,13 +4965,6 @@ def _iota_batching_rule(in_vals, in_dims, *, dtype, shape, dimension):
|
||||
return iota, batching.RaggedAxis(ax, ((ragged_axis+1, segment_lengths),))
|
||||
batching.primitive_batchers[iota_p] = _iota_batching_rule
|
||||
|
||||
def _iota_pp_rule(eqn, context, settings):
|
||||
printed_params = {}
|
||||
if len(eqn.params['shape']) > 1:
|
||||
printed_params['dimension'] = eqn.params['dimension']
|
||||
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)
|
||||
# core.pp_eqn_rules[iota_p] = _iota_pp_rule
|
||||
|
||||
def _iota_padding_rule(in_avals, out_avals, *dyn_shape, dtype, shape, dimension):
|
||||
out_aval, = out_avals
|
||||
new_shape = []
|
||||
|
@ -841,17 +841,6 @@ def _foldaxis(axis, x):
|
||||
new_shape[axis:axis+2] = [x.shape[axis] * x.shape[axis + 1]]
|
||||
return x.reshape(new_shape)
|
||||
|
||||
def _index_in_group(axis_name, axis_index_groups):
|
||||
cur_device_id = axis_index(axis_name)
|
||||
if axis_index_groups is None:
|
||||
return cur_device_id
|
||||
# We use argsort to invert the axis_index_groups permutation
|
||||
flat_groups = np.array(axis_index_groups).flatten()
|
||||
device_id_to_idx = flat_groups.argsort() % len(axis_index_groups[0])
|
||||
return lax.squeeze(
|
||||
slicing.dynamic_slice_in_dim(device_id_to_idx, cur_device_id, 1), [0])
|
||||
|
||||
|
||||
def _all_to_all_lowering(
|
||||
ctx, x, *, split_axis, concat_axis, axis_name, axis_index_groups, tiled
|
||||
):
|
||||
@ -1070,18 +1059,6 @@ def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False):
|
||||
axis_size=axis_size, tiled=tiled)
|
||||
return tree_util.tree_map(bind, x)
|
||||
|
||||
def _expand(dim, size, index, tiled, x):
|
||||
shape = list(x.shape)
|
||||
if tiled:
|
||||
tile_size = shape[dim]
|
||||
shape[dim] *= size
|
||||
out = lax.full(shape, lax._const(x, 0))
|
||||
return slicing.dynamic_update_slice_in_dim(out, x, index * tile_size, dim)
|
||||
else:
|
||||
shape.insert(dim, size)
|
||||
out = lax.full(shape, lax._const(x, 0))
|
||||
return slicing.dynamic_update_index_in_dim(out, x, index, dim)
|
||||
|
||||
def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled):
|
||||
raise AssertionError("Unexpected call to _all_gather_impl")
|
||||
|
||||
|
@ -569,13 +569,6 @@ def _convert_block_spec_to_block_mapping(
|
||||
mapping.check_invariants()
|
||||
return mapping
|
||||
|
||||
def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
|
||||
) -> state.AbstractRef:
|
||||
if block_shape is None:
|
||||
return ref
|
||||
shape = tuple(s for s in block_shape if s is not None)
|
||||
return ref.update(inner_aval=ref.inner_aval.update(shape=shape))
|
||||
|
||||
index_map_grid_aval = jax_core.ShapedArray((), jnp.int32)
|
||||
|
||||
@dataclasses.dataclass(init=False)
|
||||
|
@ -847,27 +847,6 @@ def _ensure_mlir_value(val, aval):
|
||||
)
|
||||
|
||||
|
||||
def _convert_flat_indexing_to_indexer(ref_aval, non_slice_idx,
|
||||
non_slice_idx_avals, indexed_dims):
|
||||
non_slice_idx_iter = iter(zip(non_slice_idx, non_slice_idx_avals))
|
||||
splatted_idx_idx_avals = tuple(
|
||||
next(non_slice_idx_iter)
|
||||
if indexed
|
||||
else (primitives.Slice(0, s), primitives.Slice(0, s))
|
||||
for s, indexed in zip(ref_aval.shape,indexed_dims)
|
||||
)
|
||||
splatted_idx, splatted_idx_avals = unzip2(splatted_idx_idx_avals)
|
||||
if non_slice_idx:
|
||||
(int_indexer_shape,) = {idx_aval.shape for idx_aval in splatted_idx_avals
|
||||
if not isinstance(idx_aval, primitives.Slice)}
|
||||
else:
|
||||
int_indexer_shape = ()
|
||||
nd_indexer = NDIndexer(splatted_idx, ref_aval.shape, int_indexer_shape)
|
||||
nd_indexer_avals = NDIndexer(splatted_idx_avals, ref_aval.shape,
|
||||
int_indexer_shape)
|
||||
return nd_indexer, nd_indexer_avals
|
||||
|
||||
|
||||
def _get_lowering_rule(
|
||||
ctx: LoweringRuleContext, ref, *idx, tree,
|
||||
):
|
||||
@ -2151,27 +2130,6 @@ def _lower_jaxpr_to_for_loop(ctx: LoweringRuleContext,
|
||||
return for_op.results
|
||||
|
||||
|
||||
def _lower_jaxpr_to_unrolled_for_loop(ctx: LoweringRuleContext,
|
||||
jaxpr: jax_core.Jaxpr, start: int,
|
||||
num_steps: int, consts, *args,
|
||||
has_loop_index: bool):
|
||||
for i in range(start, start + num_steps):
|
||||
if has_loop_index:
|
||||
lowering_context = ctx.lowering_context.replace(
|
||||
block_shapes=ctx.block_shapes)
|
||||
args = jaxpr_subcomp(
|
||||
lowering_context, jaxpr, *consts,
|
||||
ir_constant(i, mlir_type=_dtype_to_ir_type(jnp.dtype('int32'))),
|
||||
*args)
|
||||
else:
|
||||
lowering_context = ctx.lowering_context.replace(
|
||||
block_shapes=ctx.block_shapes[:len(consts)]
|
||||
+ ctx.block_shapes[len(consts) + 1:],
|
||||
)
|
||||
args = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
|
||||
return args
|
||||
|
||||
|
||||
def _scan_lowering_rule(
|
||||
ctx: LoweringRuleContext,
|
||||
*args,
|
||||
|
@ -71,10 +71,6 @@ def _mul(scalar, tree):
|
||||
return tree_map(partial(operator.mul, scalar), tree)
|
||||
|
||||
|
||||
def _div(tree, scalar):
|
||||
return tree_map(partial(lambda v: v / scalar), tree)
|
||||
|
||||
|
||||
_add = partial(tree_map, operator.add)
|
||||
_sub = partial(tree_map, operator.sub)
|
||||
_dot_tree = partial(tree_map, _dot)
|
||||
|
@ -286,12 +286,6 @@ def _get_discharge(x, idx, tree):
|
||||
indexers = tree_util.tree_unflatten(tree, idx)
|
||||
return index_array(x, indexers)
|
||||
|
||||
def _indexer(idx, indexed_dims):
|
||||
idx_ = iter(idx)
|
||||
indexer = tuple(next(idx_) if b else slice(None) for b in indexed_dims)
|
||||
assert next(idx_, None) is None
|
||||
return indexer
|
||||
|
||||
@register_discharge_rule(swap_p)
|
||||
def _swap_discharge_rule(
|
||||
in_avals: Sequence[core.AbstractValue],
|
||||
|
@ -15,6 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
import types
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
@ -56,18 +57,7 @@ zip, unsafe_zip = safe_zip, zip
|
||||
get_p = core.Primitive("get")
|
||||
get_p.def_impl(partial(dispatch.apply_primitive, get_p))
|
||||
|
||||
Indexer = tuple[Union[int, slice, Array], ...]
|
||||
# or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType)
|
||||
|
||||
def _get_slice_output_shape(in_shape: tuple[int, ...],
|
||||
idx_shapes: tuple[tuple[int, ...], ...],
|
||||
indexed_dims: tuple[bool, ...]) -> tuple[int, ...]:
|
||||
shape_suffix = [d for i, d in zip(indexed_dims, in_shape) if not i]
|
||||
shape_prefix, = set(idx_shapes) or [()] # tie fighter
|
||||
# Move shape prefix dimensions to the front
|
||||
shape = (*shape_prefix, *shape_suffix)
|
||||
return shape
|
||||
|
||||
Indexer = tuple[Union[int, slice, Array, types.EllipsisType], ...]
|
||||
|
||||
def get_ref_and_indexers(
|
||||
ref_or_view: Any, idx: Indexer | None, function_name: str
|
||||
@ -408,11 +398,6 @@ pe.partial_eval_jaxpr_custom_rules[addupdate_p] = partial(
|
||||
|
||||
## get/swap/addupdate batching rules
|
||||
|
||||
def _output_bdim(indexed_dims: tuple[bool, ...], ref_dim: int,
|
||||
idxs_shape: tuple[int, ...]):
|
||||
num_idxs_to_left = sum(indexed_dims[:ref_dim])
|
||||
return ref_dim - num_idxs_to_left + len(idxs_shape)
|
||||
|
||||
def _batch_indexer(indexer: indexing.NDIndexer, dims,
|
||||
axis_size: int,
|
||||
ref_shape: tuple[int, ...],
|
||||
|
Loading…
x
Reference in New Issue
Block a user