mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Move CostEstimate from pltu to pl
* Move CostEstimate from TPU-specific `compiler_params` to a platform-independent argument of `pallas_call`. Passing a CostEstimate in `compiler_params` is now deprecated and will be removed in 3 months time. * Update the CostEstimate when batching a kernel by scaling it by the size of the batch axis. PiperOrigin-RevId: 659560330
This commit is contained in:
parent
ecf9f64240
commit
9762ac53c8
@ -826,3 +826,16 @@ class PallasMesh(mesh_lib.Mesh):
|
||||
@property
|
||||
def _is_jax_device_mesh(self):
|
||||
return False
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CostEstimate:
|
||||
flops: int
|
||||
transcendentals: int
|
||||
bytes_accessed: int
|
||||
|
||||
def to_json(self) -> bytes:
|
||||
return (
|
||||
f'{{"flops": {self.flops}, "transcendentals": {self.transcendentals},'
|
||||
f' "bytes_accessed": {self.bytes_accessed}}}'
|
||||
).encode("ascii")
|
||||
|
@ -22,8 +22,8 @@ from typing import Any
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
from jax import dtypes
|
||||
from jax import core as jax_core
|
||||
from jax import dtypes
|
||||
from jax._src import config
|
||||
from jax._src import core as jax_src_core
|
||||
from jax._src import sharding_impls
|
||||
@ -34,6 +34,7 @@ from jax._src.pallas.mosaic import lowering
|
||||
from jax._src.pallas.mosaic import verification
|
||||
from jax.experimental import mosaic
|
||||
from jax.experimental.mosaic.dialects import tpu
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
|
||||
def _maybe_cast_to_int(x: jax.Array | jax_core.ShapedArray):
|
||||
"""Casts boolean values to integers.
|
||||
@ -71,7 +72,9 @@ def pallas_call_tpu_lowering_rule(
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
compiler_params: dict[str, Any]):
|
||||
compiler_params: dict[str, Any],
|
||||
cost_estimate: core.CostEstimate | None,
|
||||
):
|
||||
"""Lowers a pallas_call to a Mosaic TPU custom call."""
|
||||
del interpret
|
||||
if debug:
|
||||
@ -90,6 +93,22 @@ def pallas_call_tpu_lowering_rule(
|
||||
mosaic_params = compiler_params["mosaic"]
|
||||
else:
|
||||
mosaic_params = {}
|
||||
|
||||
if "cost_estimate" in mosaic_params:
|
||||
# TODO(amagni): Remove this branch after October 22th 2024.
|
||||
if cost_estimate is not None:
|
||||
raise ValueError(
|
||||
"Passing cost estimate via both compiler_params=dict(mosaic=...) and"
|
||||
" pallas_call(..., cost_estimate=...) is not supported."
|
||||
)
|
||||
|
||||
warnings.warn(
|
||||
"Passing cost estimate via compiler_params=dict(cost_estimate=...) is"
|
||||
" deprecated. Use pallas_call(..., cost_estimate=...) instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
cost_estimate = mosaic_params["cost_estimate"]
|
||||
|
||||
mesh = None
|
||||
axis_context = ctx.module_context.axis_context
|
||||
if axis_context is not None:
|
||||
@ -172,27 +191,33 @@ def pallas_call_tpu_lowering_rule(
|
||||
# Dynamic grid bounds have to go at the front.
|
||||
dynamic_grid_args, args = in_nodes[:num_dyn_bounds], in_nodes[num_dyn_bounds:]
|
||||
kernel_ctx = ctx.replace(avals_in=kernel_in_avals, avals_out=kernel_out_avals)
|
||||
out_nodes = mosaic.lower_module_to_custom_call(
|
||||
kernel_ctx,
|
||||
*dynamic_grid_args,
|
||||
*extra_args,
|
||||
*args,
|
||||
module=mosaic_module,
|
||||
out_type=kernel_out_avals,
|
||||
backend="tpu",
|
||||
kernel_name=name_and_src_info.name,
|
||||
cost_estimate=mosaic_params.get("cost_estimate"),
|
||||
vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"),
|
||||
flags=mosaic_params.get("flags"),
|
||||
allow_input_fusion=mosaic_params.get("allow_input_fusion"),
|
||||
input_output_aliases=input_output_aliases,
|
||||
serialization_format=mosaic_params.get("serialization_format", 1),
|
||||
device_type=mosaic_params.get("device_type"),
|
||||
internal_scratch_in_bytes=mosaic_params.get(
|
||||
"internal_scratch_in_bytes"
|
||||
),
|
||||
collective_id=mosaic_params.get("collective_id", None),
|
||||
if cost_estimate is not None:
|
||||
mosaic_cost_estimate = pltpu.CostEstimate(
|
||||
flops=cost_estimate.flops,
|
||||
bytes_accessed=cost_estimate.bytes_accessed,
|
||||
transcendentals=cost_estimate.transcendentals,
|
||||
)
|
||||
else:
|
||||
mosaic_cost_estimate = None
|
||||
out_nodes = mosaic.lower_module_to_custom_call(
|
||||
kernel_ctx,
|
||||
*dynamic_grid_args,
|
||||
*extra_args,
|
||||
*args,
|
||||
module=mosaic_module,
|
||||
out_type=kernel_out_avals,
|
||||
backend="tpu",
|
||||
kernel_name=name_and_src_info.name,
|
||||
cost_estimate=mosaic_cost_estimate,
|
||||
vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"),
|
||||
flags=mosaic_params.get("flags"),
|
||||
allow_input_fusion=mosaic_params.get("allow_input_fusion"),
|
||||
input_output_aliases=input_output_aliases,
|
||||
serialization_format=mosaic_params.get("serialization_format", 1),
|
||||
device_type=mosaic_params.get("device_type"),
|
||||
internal_scratch_in_bytes=mosaic_params.get("internal_scratch_in_bytes"),
|
||||
collective_id=mosaic_params.get("collective_id", None),
|
||||
)
|
||||
_maybe_cast_to_bool = lambda x, aval: x.astype(
|
||||
jax.numpy.bool_) if aval.dtype == jax.numpy.bool_ else x
|
||||
def _maybe_cast_outputs(*args):
|
||||
|
@ -146,7 +146,9 @@ def lower_jaxpr_to_module(
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
name_and_src_info: pallas_core.NameAndSrcInfo,
|
||||
compiler_params: dict[str, Any],
|
||||
cost_estimate: pallas_core.CostEstimate | None,
|
||||
) -> LoweringResult:
|
||||
del cost_estimate # Unused.
|
||||
in_structs = tuple(grid_mapping.in_shapes)
|
||||
out_structs = grid_mapping.out_shapes
|
||||
assert len(jaxpr.outvars) == 0
|
||||
|
@ -36,6 +36,7 @@ def pallas_call_lowering(
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
compiler_params: dict[str, Any],
|
||||
cost_estimate: pallas_core.CostEstimate | None,
|
||||
):
|
||||
del interpret
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
@ -58,6 +59,7 @@ def pallas_call_lowering(
|
||||
jaxpr,
|
||||
name_and_src_info,
|
||||
compiler_params,
|
||||
cost_estimate,
|
||||
)
|
||||
if debug:
|
||||
print(f"\nThe Mosaic GPU module for pallas_call {name_and_src_info}:")
|
||||
|
@ -60,6 +60,7 @@ BlockSpec = pallas_core.BlockSpec
|
||||
BlockSpecTree = pallas_core.BlockSpecTree
|
||||
NoBlockSpec = pallas_core.NoBlockSpec
|
||||
no_block_spec = pallas_core.no_block_spec
|
||||
CostEstimate = pallas_core.CostEstimate
|
||||
|
||||
# See the docstring for GridMapping for the calling convention
|
||||
pallas_call_p = jax_core.Primitive('pallas_call')
|
||||
@ -164,6 +165,7 @@ def _get_next_indices(grid, indices):
|
||||
def _pallas_call_impl(*args, **kwargs):
|
||||
assert False # We always jit a pallas call, we only need the lowering rule
|
||||
|
||||
|
||||
def _pallas_call_impl_interpret(
|
||||
*args,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
@ -171,8 +173,10 @@ def _pallas_call_impl_interpret(
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
compiler_params: Any):
|
||||
del compiler_params
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate,
|
||||
):
|
||||
del compiler_params, cost_estimate
|
||||
# If we're in interpreter mode, we *scan* over the grid and eval the
|
||||
# discharged jaxpr.
|
||||
dynamic_grid_args, args = split_list( # type: ignore
|
||||
@ -294,6 +298,7 @@ def _pallas_call_impl_interpret(
|
||||
out_nopad.append(o)
|
||||
return out_nopad
|
||||
|
||||
|
||||
pallas_call_p.def_impl(_pallas_call_impl)
|
||||
|
||||
def _pallas_call_abstract_eval(*avals, grid_mapping: GridMapping, **_):
|
||||
@ -302,9 +307,20 @@ def _pallas_call_abstract_eval(*avals, grid_mapping: GridMapping, **_):
|
||||
for bm in grid_mapping.block_mappings_output)
|
||||
pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval)
|
||||
|
||||
def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name_and_src_info,
|
||||
|
||||
def _pallas_call_jvp_rule(
|
||||
primals,
|
||||
tangents,
|
||||
*,
|
||||
jaxpr,
|
||||
name_and_src_info,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping, debug, interpret, compiler_params: Any):
|
||||
grid_mapping,
|
||||
debug,
|
||||
interpret,
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate | None,
|
||||
):
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
|
||||
if grid_mapping.num_index_operands:
|
||||
@ -346,20 +362,32 @@ def _pallas_call_jvp_rule(primals, tangents, *, jaxpr, name_and_src_info,
|
||||
num_inputs=grid_mapping.num_inputs * 2,
|
||||
num_outputs=grid_mapping.num_outputs * 2,
|
||||
)
|
||||
if cost_estimate is not None:
|
||||
jvp_cost_estimate = CostEstimate(
|
||||
flops=2 * cost_estimate.flops,
|
||||
bytes_accessed=2 * cost_estimate.bytes_accessed,
|
||||
transcendentals=2 * cost_estimate.transcendentals,
|
||||
)
|
||||
else:
|
||||
jvp_cost_estimate = None
|
||||
out_flat = pallas_call_p.bind(
|
||||
*primals,
|
||||
*tangents,
|
||||
jaxpr=jvp_jaxpr,
|
||||
name_and_src_info=name_and_src_info.replace(
|
||||
name=f"{name_and_src_info.name}_jvp"),
|
||||
name=f"{name_and_src_info.name}_jvp"
|
||||
),
|
||||
grid_mapping=jvp_grid_mapping,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
input_output_aliases=(),
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=jvp_cost_estimate,
|
||||
)
|
||||
out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2])
|
||||
return out_primals, out_tangents
|
||||
|
||||
|
||||
ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule
|
||||
|
||||
def _batch_block_mapping(grid_mapping: GridMapping,
|
||||
@ -436,6 +464,7 @@ def _batch_with_explicit_loop(
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate | None,
|
||||
):
|
||||
"""Batch the pallas_call by calling it in loop over the batch size.
|
||||
|
||||
@ -501,6 +530,7 @@ def _batch_with_explicit_loop(
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=cost_estimate,
|
||||
)
|
||||
for i, batch_out_array in enumerate(batch_out):
|
||||
state[i] = jax.lax.dynamic_update_index_in_dim(
|
||||
@ -528,6 +558,7 @@ def _pallas_call_batching_rule(
|
||||
debug: bool,
|
||||
interpret: bool,
|
||||
compiler_params: Any,
|
||||
cost_estimate: CostEstimate | None,
|
||||
):
|
||||
def _maybe_squeeze_out_bdim(
|
||||
x: jax.Array, bdim: int | batching.NotMapped
|
||||
@ -536,8 +567,9 @@ def _pallas_call_batching_rule(
|
||||
return x
|
||||
return jnp.squeeze(x, axis=bdim)
|
||||
|
||||
axis_size, = {x.shape[d] for x, d in zip(args, dims)
|
||||
if d is not batching.not_mapped}
|
||||
(axis_size,) = {
|
||||
x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped
|
||||
}
|
||||
if axis_size == 1:
|
||||
# Why are we even vmapping?
|
||||
args = map(_maybe_squeeze_out_bdim, args, dims)
|
||||
@ -550,6 +582,7 @@ def _pallas_call_batching_rule(
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=cost_estimate,
|
||||
)
|
||||
return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out)
|
||||
|
||||
@ -581,6 +614,7 @@ def _pallas_call_batching_rule(
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=cost_estimate,
|
||||
)
|
||||
else:
|
||||
pass # No dynamic grid dimensions
|
||||
@ -613,6 +647,7 @@ def _pallas_call_batching_rule(
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=cost_estimate,
|
||||
)
|
||||
|
||||
if not dims:
|
||||
@ -657,18 +692,29 @@ def _pallas_call_batching_rule(
|
||||
block_mappings=tuple(batched_block_mappings),
|
||||
index_map_avals=batched_index_map_avals,
|
||||
index_map_tree=batched_index_map_tree,
|
||||
vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims))
|
||||
vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims),
|
||||
)
|
||||
if cost_estimate is not None:
|
||||
batched_cost_estimate = CostEstimate(
|
||||
flops=cost_estimate.flops * axis_size,
|
||||
bytes_accessed=cost_estimate.bytes_accessed * axis_size,
|
||||
transcendentals=cost_estimate.transcendentals * axis_size,
|
||||
)
|
||||
else:
|
||||
batched_cost_estimate = None
|
||||
out = pallas_call_p.bind(
|
||||
*dynamic_grid_args,
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
name_and_src_info=name_and_src_info.replace(
|
||||
name=f"{name_and_src_info.name}_batched"),
|
||||
name=f"{name_and_src_info.name}_batched"
|
||||
),
|
||||
grid_mapping=batched_grid_mapping,
|
||||
input_output_aliases=input_output_aliases,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=batched_cost_estimate,
|
||||
)
|
||||
return out, (0,) * len(out)
|
||||
|
||||
@ -971,6 +1017,7 @@ def pallas_call(
|
||||
interpret: bool = False,
|
||||
name: str | None = None,
|
||||
compiler_params: dict[str, Any] | None = None,
|
||||
cost_estimate: CostEstimate | None = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""Invokes a Pallas kernel on some inputs.
|
||||
|
||||
@ -1106,14 +1153,18 @@ def pallas_call(
|
||||
|
||||
index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands])
|
||||
out_flat = pallas_call_p.bind(
|
||||
*dynamic_grid_bounds, *index_args, *rest_args,
|
||||
*dynamic_grid_bounds,
|
||||
*index_args,
|
||||
*rest_args,
|
||||
jaxpr=jaxpr,
|
||||
name_and_src_info=name_and_src_info,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
grid_mapping=grid_mapping,
|
||||
input_output_aliases=tuple(input_output_aliases.items()),
|
||||
compiler_params=compiler_params)
|
||||
compiler_params=compiler_params,
|
||||
cost_estimate=cost_estimate,
|
||||
)
|
||||
out = tree_util.tree_unflatten(out_tree, out_flat)
|
||||
return out
|
||||
return wrapped
|
||||
|
@ -48,6 +48,7 @@ def pallas_call_lowering(
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
compiler_params: dict[str, Any],
|
||||
cost_estimate: pallas_core.CostEstimate | None,
|
||||
):
|
||||
del interpret
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
|
@ -21,6 +21,7 @@ https://jax.readthedocs.io/en/latest/pallas.html.
|
||||
from jax._src.deprecations import register as _register_deprecation
|
||||
from jax._src.pallas.core import Blocked
|
||||
from jax._src.pallas.core import BlockSpec
|
||||
from jax._src.pallas.core import CostEstimate
|
||||
from jax._src.pallas.core import IndexingMode
|
||||
from jax._src.pallas.core import no_block_spec
|
||||
from jax._src.pallas.core import Unblocked
|
||||
|
@ -520,7 +520,7 @@ def gmm(
|
||||
(lhs_bytes * tiles_n) + (rhs_bytes * max_active_tiles) + out_bytes
|
||||
)
|
||||
flops = 2 * m * k * n
|
||||
cost_estimate = pltpu.CostEstimate(
|
||||
cost_estimate = pl.CostEstimate(
|
||||
flops=flops, bytes_accessed=bytes_accessed, transcendentals=0
|
||||
)
|
||||
call_gmm = pl.pallas_call(
|
||||
@ -541,10 +541,10 @@ def gmm(
|
||||
compiler_params=dict(
|
||||
mosaic=dict(
|
||||
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
|
||||
cost_estimate=cost_estimate,
|
||||
)
|
||||
),
|
||||
interpret=interpret,
|
||||
cost_estimate=cost_estimate,
|
||||
)
|
||||
|
||||
out = call_gmm(
|
||||
@ -759,7 +759,7 @@ def tgmm(
|
||||
(lhs_bytes * tiles_n) + (rhs_bytes * tiles_k) + out_bytes
|
||||
)
|
||||
flops = 2 * m * k * n
|
||||
cost_estimate = pltpu.CostEstimate(
|
||||
cost_estimate = pl.CostEstimate(
|
||||
flops=flops, bytes_accessed=bytes_accessed, transcendentals=0
|
||||
)
|
||||
lhs = lhs.swapaxes(0, 1)
|
||||
|
@ -22,12 +22,12 @@ from jax._src.pallas.mosaic.core import semaphore
|
||||
from jax._src.pallas.mosaic.core import SemaphoreType
|
||||
from jax._src.pallas.mosaic.core import TPUMemorySpace
|
||||
from jax._src.pallas.mosaic.lowering import LoweringException
|
||||
from jax._src.pallas.mosaic.pipeline import ARBITRARY
|
||||
from jax._src.pallas.mosaic.pipeline import BufferedRef
|
||||
from jax._src.pallas.mosaic.pipeline import emit_pipeline
|
||||
from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations
|
||||
from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule
|
||||
from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations
|
||||
from jax._src.pallas.mosaic.pipeline import ARBITRARY
|
||||
from jax._src.pallas.mosaic.pipeline import PARALLEL
|
||||
from jax._src.pallas.mosaic.primitives import async_copy
|
||||
from jax._src.pallas.mosaic.primitives import async_remote_copy
|
||||
@ -38,14 +38,15 @@ from jax._src.pallas.mosaic.primitives import DeviceIdType
|
||||
from jax._src.pallas.mosaic.primitives import get_barrier_semaphore
|
||||
from jax._src.pallas.mosaic.primitives import make_async_copy
|
||||
from jax._src.pallas.mosaic.primitives import make_async_remote_copy
|
||||
from jax._src.pallas.mosaic.primitives import prng_random_bits
|
||||
from jax._src.pallas.mosaic.primitives import prng_seed
|
||||
from jax._src.pallas.mosaic.primitives import repeat
|
||||
from jax._src.pallas.mosaic.primitives import roll
|
||||
from jax._src.pallas.mosaic.primitives import semaphore_read
|
||||
from jax._src.pallas.mosaic.primitives import semaphore_signal
|
||||
from jax._src.pallas.mosaic.primitives import semaphore_wait
|
||||
from jax._src.pallas.mosaic.primitives import prng_seed
|
||||
from jax._src.pallas.mosaic.primitives import prng_random_bits
|
||||
from jax._src.pallas.mosaic.random import to_pallas_key
|
||||
# Remove this import after October 22th 2024.
|
||||
from jax._src.tpu_custom_call import CostEstimate
|
||||
|
||||
# TODO(cperivol): Temporary alias to the global run_scoped. Remove
|
||||
|
@ -1487,12 +1487,8 @@ class PallasCallTest(PallasBaseTest):
|
||||
f = self.pallas_call(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
|
||||
compiler_params=dict(
|
||||
mosaic=dict(
|
||||
cost_estimate=pltpu.CostEstimate(
|
||||
flops=1234, transcendentals=21, bytes_accessed=12345
|
||||
)
|
||||
)
|
||||
cost_estimate=pl.CostEstimate(
|
||||
flops=1234, transcendentals=21, bytes_accessed=12345
|
||||
),
|
||||
)
|
||||
(analysis_result,) = jax.jit(f).lower(x).compile().cost_analysis()
|
||||
@ -1500,6 +1496,26 @@ class PallasCallTest(PallasBaseTest):
|
||||
self.assertEqual(analysis_result['transcendentals'], 21)
|
||||
self.assertEqual(analysis_result['bytes accessed'], 12345)
|
||||
|
||||
|
||||
def test_cost_analysis_vmap(self):
|
||||
def kernel(x, y):
|
||||
y[:] = x[:]
|
||||
batch_size = 3
|
||||
x = jnp.arange(batch_size * 1024.).reshape(batch_size, 8, 128)
|
||||
f = pl.pallas_call(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
|
||||
cost_estimate=pl.CostEstimate(
|
||||
flops=1234, transcendentals=21, bytes_accessed=12345
|
||||
),
|
||||
)
|
||||
f = jax.vmap(f)
|
||||
(analysis_result,) = jax.jit(f).lower(x).compile().cost_analysis()
|
||||
self.assertEqual(analysis_result['flops'], batch_size * 1234)
|
||||
self.assertEqual(analysis_result['transcendentals'], batch_size * 21)
|
||||
self.assertEqual(analysis_result['bytes accessed'], batch_size * 12345)
|
||||
|
||||
|
||||
def test_vmem_limit(self):
|
||||
shape = (128, 128)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user