mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #21504 from gnecula:poly_approx
PiperOrigin-RevId: 638550165
This commit is contained in:
commit
f72b0f0ca6
@ -229,11 +229,19 @@ def _approx_top_k_abstract_eval(operand, *, k, reduction_dimension,
|
||||
if not dtypes.issubdtype(operand.dtype, np.floating):
|
||||
raise ValueError('operand must be a floating type')
|
||||
reduction_input_size = dims[reduction_dimension]
|
||||
dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize(
|
||||
reduction_input_size, len(dims), k, recall_target, aggregate_to_topk,
|
||||
reduction_input_size_override)[0]
|
||||
return (operand.update(
|
||||
shape=dims, dtype=operand.dtype, weak_type=operand.weak_type),
|
||||
if aggregate_to_topk:
|
||||
dims[reduction_dimension] = k
|
||||
elif core.is_constant_shape((reduction_input_size, k)):
|
||||
dims[reduction_dimension] = xc.ops.ApproxTopKReductionOutputSize(
|
||||
reduction_input_size, len(dims), k, recall_target, aggregate_to_topk,
|
||||
reduction_input_size_override)[0]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"approx_top_k with aggregate_to_topk=False not yet implemented when "
|
||||
f"either the `k` ({k}) or the "
|
||||
f" reduction dimension size ({reduction_input_size}) are symbolic")
|
||||
return (operand.update(shape=dims, dtype=operand.dtype,
|
||||
weak_type=operand.weak_type),
|
||||
operand.update(shape=dims, dtype=np.dtype(np.int32)))
|
||||
|
||||
|
||||
|
@ -1997,6 +1997,32 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
+ jnp.sin(x))),
|
||||
arg_descriptors=[RandArg((3, 4), _f32)],
|
||||
polymorphic_shapes=["b, ..."]),
|
||||
[ # approx_max_k
|
||||
# x: f32[b, {n}, 32] with n being either 8 or the symbol "n"
|
||||
# we reduce on dim=1, with size n
|
||||
# k is either the constant 4 or the symbol "k"
|
||||
PolyHarness("approx_max_k", f"n_{n}_k_{k}_agg={agg}",
|
||||
lambda x, x_k, agg: lax.approx_max_k(
|
||||
x, k=x_k.shape[0], reduction_dimension=1,
|
||||
aggregate_to_topk=agg),
|
||||
arg_descriptors=[RandArg((3, 8, 32), _f32),
|
||||
RandArg((4,), _f32),
|
||||
StaticArg(agg)],
|
||||
polymorphic_shapes=[f"b, {n}, 32", f"{k},"],
|
||||
# k must be at most the reduction dimension size
|
||||
symbolic_constraints=[f"{k} <= {n}"],
|
||||
expect_error=(
|
||||
(NotImplementedError, "aggregate_to_topk=False") if (
|
||||
not agg and (isinstance(k, str) or
|
||||
isinstance(n, str))) else
|
||||
# TODO(b/339398482) fix case with k symbolic
|
||||
(TypeError, "get") if (agg and isinstance(k, str)) else
|
||||
None
|
||||
))
|
||||
for n in [8, "n"]
|
||||
for k in [4, "k"]
|
||||
for agg in [True, False]
|
||||
],
|
||||
[ # arange
|
||||
PolyHarness("arange", name,
|
||||
f_jax,
|
||||
@ -3311,5 +3337,6 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
|
||||
for fname, _ in config_flags.items():
|
||||
jax.config.update(fname, prev_jax_config_flags[fname])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user