These are also set in the TSL build rules as part of the CUDA stub libraries, which these libraries depend on, so these copies of the rpath settings are redundant.
PiperOrigin-RevId: 716844265
This commits adds a Github action workflow that will be used to jobs that test the nightly/release artifacts. These artifacts are built by our internal CI jobs and are uploaded to a transient GCS bucket. After all the wheels have finished uploading, an internal job is run that that will trigger the `wheel_tests_nightly_release.yml` workflow.
PiperOrigin-RevId: 716789482
Also, use a conditional expression in the continuous workflow to control concurrent runs. We don't want to cancel runs on multiple pushes to main or release branch.
PiperOrigin-RevId: 716780290
Also make the `axes` parameter optional of hidden_axes and visible_axes functions. If axes is optional, you drop into full hidden/visible mode.
PiperOrigin-RevId: 716771872
* mesh_cast only works when the axis types between src and dst mesh changes. Hence the name!
* No explicit data movement is allowed. Specs containing axes that are visible cannot be different between src and dst shardings.
* src and dst mesh axis_names and axis_sizes should be the same.
TODO: Make `shardings` parameter to `mesh_cast` optional.
PiperOrigin-RevId: 716727084
The pass adds versioning to the Mosaic GPU IR in the lowered custom calls
and can apply forward/backward migration rules. Currently, no rules are
necessary since we are at version 1.
PiperOrigin-RevId: 716596848
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager
Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.
PiperOrigin-RevId: 716446406
This commit stores the GCS upload URI as a step output and then maps the job output to it so that we can make it available when a workflow call is made to `build_artifacts.yml`.
This helps in the case where we want to debug a workflow such as https://github.com/jax-ml/jax/blob/main/.github/workflows/wheel_tests_continuous.yml where artifact build jobs and test jobs are run separately. Previously, we could not re-try a failing test workflow as the upload URI contains `${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}`. In GitHub, `github.run_attempt` is mapped to a workflow and not individual jobs so even a re-trigger of a test job alone would lead to the `run_attempt` value being increased which in turn invalidates the GCS download URI that it reads.
By storing the the upload URI as a step output, we freeze the upload/download URIs until the build artifact jobs are re-run. However, note that this still has a edge case where things can break - `run-pytest-cuda` job in `wheel_tests_continuous.yml` depends on both `build-jaxlib-artifact` and `build-cuda-artifacts` but consumes the upload URI from the output of `build-jaxlib-artifact` alone. This is done on the assumption that both these jobs will have uploaded to the same location. However, that would not be the case if one of these jobs fail and have to re-run. We are working on a longterm solution for this case but in the meantime, the recommendation for now is just to re-run the whole set of jobs again.
PiperOrigin-RevId: 716348745
and multiple devices.
Whenever this happens we can essentially introduce an effects barrier
instead of doing the normal device -> host -> device transfer.
Fixes https://github.com/jax-ml/jax/issues/25671.
PiperOrigin-RevId: 716309978
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
1. For apply, llvm::StringMap()::insert(MapEntryTy*) will cause dangling reference if not constructing mlir::tpu::extensions::rules() with const-reference. However, if we do construct it with const-reference, the signature is not const-qualified and fails to compile. Hence, change it to llvm::StringMap()::insert(std::pair<...>) and get extension rules by const-reference.
2. Pass default tiling to infer rule, we need it to infer single op. See infer of tpu::MatmulOp.
PiperOrigin-RevId: 716274818