14 Commits

Author SHA1 Message Date
Zac Mustin
bff0fa18ad Support conv unfused_flops in roofline.
Since calculating flops is non-trivial, we don't test all the cases currently tested by `test_conv_general_dilated_unfused_hbm_bytes`. Instead, we test behaviors more directly.

PiperOrigin-RevId: 743272840
2025-04-02 14:08:30 -07:00
Zac Mustin
b89cf0de91 Stop using mesh and *_specs in roofline tests.
These args are optional, so not specifying them in our tests will make them simpler and easier to read. This change is a no-op.

PiperOrigin-RevId: 740015584
2025-03-24 11:41:22 -07:00
Zac Mustin
7953e6d0f8 Add tests for varying {batch, feature}_group_counts for roofline conv.
We'll need to use batch/feature when calculating flops, so it'll help reduce the size of the "calculating-flops" change if we can include them in our tests now.

PiperOrigin-RevId: 739081930
2025-03-21 00:36:38 -07:00
Zac Mustin
0c8e601f90 Support convolution in roofline.
So far we support only `unfused_hmb_bytes` and don't account for `{feature, batch}_group_count`s due to complexity.

PiperOrigin-RevId: 736948528
2025-03-14 12:26:20 -07:00
Yash Katariya
a4ca0dbc6c Make the signature of AbstractMesh to be AbstractMesh(axis_size: tuple[int, ...], axis_name: tuple[str, ...], *, axis_types) instead of AbstractMesh(shape_tuple: tuple[tuple[str, int], ...], *, axis_types) so that we are consistent across all Mesh APIs: Mesh, AbstractMesh and make_mesh
PiperOrigin-RevId: 736371111
2025-03-12 21:32:31 -07:00
Zac Mustin
8095d842c8 roofline: Support computing flops for unary ops.
PiperOrigin-RevId: 734351741
2025-03-06 17:44:36 -08:00
Matthias Kramm
e8543024e5 Add unfused_hbm usage to binary ops and dot_general.
PiperOrigin-RevId: 731066135
2025-02-25 16:10:25 -08:00
Matthias Kramm
aad178a6f8 roofline: Add support for min_p, max_p, reduce_sum_p.
PiperOrigin-RevId: 731024098
2025-02-25 14:10:15 -08:00
Matthias Kramm
08081c4db6 roofline: Support broadcasting, for binary ops.
PiperOrigin-RevId: 731014250
2025-02-25 13:46:00 -08:00
Matthias Kramm
79e1e1fcee Make mesh and *_spec parameters optional.
PiperOrigin-RevId: 730499695
2025-02-24 10:15:38 -08:00
Matthias Kramm
b3fcba7c05 roofline: Handle ClosedJaxpr instances.
PiperOrigin-RevId: 729636113
2025-02-21 13:19:31 -08:00
Matthias Kramm
7eee2de703 roofline: Support computing flops for binary ops.
PiperOrigin-RevId: 728708058
2025-02-19 09:45:24 -08:00
Peter Hawkins
51b9fe3010 [JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
2025-01-08 06:37:44 -08:00
Enrique Piqueras
8c521547b7
Add experimental JAX roofline API. 2024-11-27 14:38:57 -08:00