rocm_jax/jax/_src/dispatch.py

623 lines
23 KiB
Python
Raw Normal View History

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Primitive dispatch and jit dispatch.
from __future__ import annotations
import atexit
from collections.abc import Callable, Sequence
2021-12-13 21:51:08 -08:00
import contextlib
import dataclasses
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
import enum
from functools import partial
import itertools
2021-12-13 21:51:08 -08:00
import time
from typing import Any, NamedTuple
import logging
import threading
import numpy as np
import jax
from jax._src import basearray
from jax._src import config
from jax._src import core
Replace apply_primitive internals with `jax.jit`. This allows deletion of a lot of code and leads to ~40% eager performance speedup. Benchmarks: ``` name old time/op new time/op delta eager_unary_dispatch 31.3µs ± 1% 19.4µs ± 6% -37.91% (p=0.016 n=4+5) eager_unary 32.1µs ± 0% 19.8µs ± 4% -38.26% (p=0.016 n=4+5) eager_binary_dispatch 35.9µs ± 1% 20.5µs ± 4% -42.93% (p=0.016 n=4+5) eager_binary 36.6µs ± 1% 21.1µs ± 4% -42.29% (p=0.016 n=4+5) jit_trivial_dispatch 3.87µs ± 2% 4.12µs ±25% ~ (p=1.000 n=5+5) jit_trivial 4.75µs ± 2% 4.82µs ±11% ~ (p=0.690 n=5+5) jit_simple_dispatch 2.95µs ± 2% 2.97µs ± 7% ~ (p=1.000 n=5+5) jit_simple 3.52µs ± 6% 3.51µs ± 5% ~ (p=0.841 n=5+5) jit_simple_dispatch_array 2.95µs ± 2% 2.96µs ± 6% ~ (p=1.000 n=5+5) jit_simple_array 3.46µs ± 2% 3.51µs ± 5% ~ (p=0.690 n=5+5) jit_small_matmul 3.01µs ± 1% 3.00µs ± 4% ~ (p=0.548 n=5+5) jit_big_matmul 34.0µs ±18% 35.5µs ±17% ~ (p=0.310 n=5+5) jit_simple_many_args_dispatch/num_args:10 6.93µs ± 6% 6.80µs ± 6% ~ (p=0.481 n=10+10) jit_simple_many_args_dispatch/num_args:100 47.7µs ± 7% 45.4µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args_dispatch/num_args:1000 545µs ± 8% 516µs ± 2% ~ (p=0.101 n=10+8) jit_simple_many_args_dispatch/num_args:2000 1.12ms ± 7% 1.07ms ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:10 7.42µs ± 5% 7.23µs ± 2% ~ (p=0.173 n=10+8) jit_simple_many_args/num_args:100 48.4µs ± 7% 45.6µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:1000 542µs ± 6% 524µs ± 8% ~ (p=0.089 n=10+10) jit_simple_many_args/num_args:2000 1.12ms ± 7% 1.08ms ± 1% ~ (p=0.068 n=10+8) jit_simple_pruned_args_dispatch_10 4.79µs ± 8% 4.98µs ±10% ~ (p=0.421 n=5+5) jit_simple_pruned_args_10 5.32µs ± 6% 5.30µs ± 4% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_100 24.7µs ± 6% 23.8µs ± 8% ~ (p=0.548 n=5+5) jit_simple_pruned_args_100 25.2µs ± 6% 24.4µs ± 8% ~ (p=0.690 n=5+5) jit_simple_pruned_args_dispatch_1000 238µs ± 7% 232µs ± 8% ~ (p=0.841 n=5+5) jit_simple_pruned_args_1000 240µs ± 7% 234µs ± 8% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_2000 516µs ± 6% 497µs ± 1% ~ (p=0.413 n=5+4) jit_simple_pruned_args_2000 517µs ± 6% 505µs ± 7% ~ (p=0.690 n=5+5) jit_dispatch_without_transfer 719µs ± 9% 751µs ± 8% ~ (p=0.222 n=5+5) jit_dispatch_with_transfer 799µs ±14% 793µs ± 9% ~ (p=1.000 n=5+5) pmap_trivial_2_devices 49.9µs ±40% 48.2µs ±42% ~ (p=0.841 n=5+5) pmap_trivial_dispatch_8_devices 74.5µs ±24% 78.9µs ±29% ~ (p=0.421 n=5+5) pmap_trivial_8_devices 79.3µs ± 6% 82.7µs ±20% ~ (p=0.841 n=5+5) pmap_simple_2_devices 47.1µs ±17% 49.1µs ±20% ~ (p=0.548 n=5+5) pmap_simple_dispatch_8_devices 73.4µs ±16% 76.8µs ±21% ~ (p=0.690 n=5+5) pmap_simple_8_devices 76.0µs ±10% 80.6µs ±29% ~ (p=1.000 n=5+5) pmap_simple_dispatch_8_devices_100_args 1.12ms ±22% 1.08ms ±42% ~ (p=0.841 n=5+5) pmap_simple_8_devices_100_args 12.5ms ± 8% 12.8ms ±10% ~ (p=1.000 n=5+5) sda_index_1 413µs ± 1% 686µs ± 4% +66.08% (p=0.008 n=5+5) sda_index_2 850µs ± 1% 1378µs ± 4% +62.02% (p=0.008 n=5+5) sda_index_8 3.60ms ± 1% 5.69ms ± 4% +58.00% (p=0.008 n=5+5) bench_shaped_abstractify 300µs ± 1% 305µs ± 3% ~ (p=0.056 n=5+5) bench_xla_abstractify_scalar_int 6.45µs ± 1% 6.50µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_scalar_float 3.73µs ± 1% 3.73µs ± 3% ~ (p=0.690 n=5+5) bench_xla_abstractify_scalar_numpy_int32 4.97µs ± 1% 4.83µs ± 3% ~ (p=0.095 n=5+5) bench_xla_abstractify_scalar_numpy_uint32 4.91µs ± 1% 4.75µs ± 0% -3.30% (p=0.016 n=5+4) bench_xla_abstractify_numpy_random 4.34µs ± 2% 4.31µs ± 3% ~ (p=0.310 n=5+5) bench_xla_abstractify_numpy_arange_100_float32 3.94µs ± 1% 3.93µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_enum 6.85µs ± 1% 7.06µs ± 7% +3.07% (p=0.032 n=5+5) bench_are_op_shardings_equal 26.9µs ± 2% 27.0µs ± 3% ~ (p=0.841 n=5+5) bench_pjit_check_aval_sharding 691µs ± 2% 711µs ±13% ~ (p=0.841 n=5+5) bench_addressable_shards_index 656ns ± 4% 688ns ± 9% ~ (p=0.095 n=5+5) bench_remat_eager_retracing_overheads 12.7ms ± 4% 10.7ms ± 1% -15.48% (p=0.016 n=5+4) bench_remat_eager_retracing_overheads_static_argnums 13.0ms ± 2% 11.3ms ± 6% -13.71% (p=0.008 n=5+5) bench_slicing_compilation 12.1ms ± 1% 12.3ms ± 4% ~ (p=0.690 n=5+5) bench_slicing_compilation2 11.3ms ± 0% 11.5ms ± 6% ~ (p=0.690 n=5+5) bench_repeated_static_indexing 62.5ms ± 2% 40.8ms ± 8% -34.77% (p=0.008 n=5+5) bench_repeated_static_slicing 46.7ms ± 1% 31.4ms ± 2% -32.76% (p=0.008 n=5+5) pjit_simple_1_device/num_args:1 2.72µs ± 2% 2.68µs ± 5% ~ (p=0.151 n=5+5) pjit_simple_1_device/num_args:10 12.6µs ± 7% 12.3µs ± 3% ~ (p=0.310 n=5+5) pjit_simple_1_device/num_args:100 109µs ± 3% 108µs ± 4% ~ (p=0.548 n=5+5) pjit_simple_4_device/num_args:1 38.0µs ±26% 36.8µs ±19% ~ (p=0.690 n=5+5) pjit_simple_4_device/num_args:10 93.3µs ±19% 96.6µs ±23% ~ (p=0.841 n=5+5) pjit_simple_4_device/num_args:100 730µs ±16% 698µs ±48% ~ (p=0.841 n=5+5) pjit_aot_1_device/num_args:1 3.29µs ± 2% 3.12µs ± 4% -5.24% (p=0.016 n=4+5) pjit_aot_1_device/num_args:10 13.0µs ± 1% 12.7µs ± 2% ~ (p=0.063 n=4+5) pjit_aot_1_device/num_args:100 111µs ± 5% 110µs ±11% ~ (p=0.421 n=5+5) pjit_aot_4_device/num_args:1 38.4µs ±19% 38.9µs ±24% ~ (p=1.000 n=5+5) pjit_aot_4_device/num_args:10 91.3µs ±15% 96.9µs ±29% ~ (p=0.548 n=5+5) pjit_aot_4_device/num_args:100 676µs ±20% 689µs ±41% ~ (p=0.841 n=5+5) host_local_array_to_global_array 196µs ± 6% 194µs ± 4% ~ (p=0.548 n=5+5) device_put 50.8µs ± 1% 50.7µs ± 4% ~ (p=0.413 n=4+5) device_put_sharded 176µs ± 0% 177µs ± 4% ~ (p=0.190 n=4+5) device_get_8_devices 3.96ms ± 4% 4.03ms ± 7% ~ (p=0.413 n=4+5) np_asarray_8_devices 3.34ms ±18% 3.30ms ±10% ~ (p=0.548 n=5+5) jax_array_arrays_8_devices 5.01ms ±10% 5.09ms ±21% ~ (p=0.421 n=5+5) batch_inplace_while_scatter 440µs ± 1% 439µs ± 1% ~ (p=0.421 n=5+5) batch_inplace_while_dynamic_update_slice 454µs ± 0% 457µs ± 1% ~ (p=0.905 n=4+5) serial_dot_products 4.51µs ± 3% 4.41µs ± 2% ~ (p=0.151 n=5+5) bench_make_array_from_callback_fully_replicated_sharding 26.6µs ± 1% 27.0µs ± 2% ~ (p=0.056 n=5+5) ``` PiperOrigin-RevId: 586505950
2023-11-29 18:06:36 -08:00
from jax._src import api
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
from jax._src import array
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.abstract_arrays import array_types
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
Replace apply_primitive internals with `jax.jit`. This allows deletion of a lot of code and leads to ~40% eager performance speedup. Benchmarks: ``` name old time/op new time/op delta eager_unary_dispatch 31.3µs ± 1% 19.4µs ± 6% -37.91% (p=0.016 n=4+5) eager_unary 32.1µs ± 0% 19.8µs ± 4% -38.26% (p=0.016 n=4+5) eager_binary_dispatch 35.9µs ± 1% 20.5µs ± 4% -42.93% (p=0.016 n=4+5) eager_binary 36.6µs ± 1% 21.1µs ± 4% -42.29% (p=0.016 n=4+5) jit_trivial_dispatch 3.87µs ± 2% 4.12µs ±25% ~ (p=1.000 n=5+5) jit_trivial 4.75µs ± 2% 4.82µs ±11% ~ (p=0.690 n=5+5) jit_simple_dispatch 2.95µs ± 2% 2.97µs ± 7% ~ (p=1.000 n=5+5) jit_simple 3.52µs ± 6% 3.51µs ± 5% ~ (p=0.841 n=5+5) jit_simple_dispatch_array 2.95µs ± 2% 2.96µs ± 6% ~ (p=1.000 n=5+5) jit_simple_array 3.46µs ± 2% 3.51µs ± 5% ~ (p=0.690 n=5+5) jit_small_matmul 3.01µs ± 1% 3.00µs ± 4% ~ (p=0.548 n=5+5) jit_big_matmul 34.0µs ±18% 35.5µs ±17% ~ (p=0.310 n=5+5) jit_simple_many_args_dispatch/num_args:10 6.93µs ± 6% 6.80µs ± 6% ~ (p=0.481 n=10+10) jit_simple_many_args_dispatch/num_args:100 47.7µs ± 7% 45.4µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args_dispatch/num_args:1000 545µs ± 8% 516µs ± 2% ~ (p=0.101 n=10+8) jit_simple_many_args_dispatch/num_args:2000 1.12ms ± 7% 1.07ms ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:10 7.42µs ± 5% 7.23µs ± 2% ~ (p=0.173 n=10+8) jit_simple_many_args/num_args:100 48.4µs ± 7% 45.6µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:1000 542µs ± 6% 524µs ± 8% ~ (p=0.089 n=10+10) jit_simple_many_args/num_args:2000 1.12ms ± 7% 1.08ms ± 1% ~ (p=0.068 n=10+8) jit_simple_pruned_args_dispatch_10 4.79µs ± 8% 4.98µs ±10% ~ (p=0.421 n=5+5) jit_simple_pruned_args_10 5.32µs ± 6% 5.30µs ± 4% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_100 24.7µs ± 6% 23.8µs ± 8% ~ (p=0.548 n=5+5) jit_simple_pruned_args_100 25.2µs ± 6% 24.4µs ± 8% ~ (p=0.690 n=5+5) jit_simple_pruned_args_dispatch_1000 238µs ± 7% 232µs ± 8% ~ (p=0.841 n=5+5) jit_simple_pruned_args_1000 240µs ± 7% 234µs ± 8% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_2000 516µs ± 6% 497µs ± 1% ~ (p=0.413 n=5+4) jit_simple_pruned_args_2000 517µs ± 6% 505µs ± 7% ~ (p=0.690 n=5+5) jit_dispatch_without_transfer 719µs ± 9% 751µs ± 8% ~ (p=0.222 n=5+5) jit_dispatch_with_transfer 799µs ±14% 793µs ± 9% ~ (p=1.000 n=5+5) pmap_trivial_2_devices 49.9µs ±40% 48.2µs ±42% ~ (p=0.841 n=5+5) pmap_trivial_dispatch_8_devices 74.5µs ±24% 78.9µs ±29% ~ (p=0.421 n=5+5) pmap_trivial_8_devices 79.3µs ± 6% 82.7µs ±20% ~ (p=0.841 n=5+5) pmap_simple_2_devices 47.1µs ±17% 49.1µs ±20% ~ (p=0.548 n=5+5) pmap_simple_dispatch_8_devices 73.4µs ±16% 76.8µs ±21% ~ (p=0.690 n=5+5) pmap_simple_8_devices 76.0µs ±10% 80.6µs ±29% ~ (p=1.000 n=5+5) pmap_simple_dispatch_8_devices_100_args 1.12ms ±22% 1.08ms ±42% ~ (p=0.841 n=5+5) pmap_simple_8_devices_100_args 12.5ms ± 8% 12.8ms ±10% ~ (p=1.000 n=5+5) sda_index_1 413µs ± 1% 686µs ± 4% +66.08% (p=0.008 n=5+5) sda_index_2 850µs ± 1% 1378µs ± 4% +62.02% (p=0.008 n=5+5) sda_index_8 3.60ms ± 1% 5.69ms ± 4% +58.00% (p=0.008 n=5+5) bench_shaped_abstractify 300µs ± 1% 305µs ± 3% ~ (p=0.056 n=5+5) bench_xla_abstractify_scalar_int 6.45µs ± 1% 6.50µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_scalar_float 3.73µs ± 1% 3.73µs ± 3% ~ (p=0.690 n=5+5) bench_xla_abstractify_scalar_numpy_int32 4.97µs ± 1% 4.83µs ± 3% ~ (p=0.095 n=5+5) bench_xla_abstractify_scalar_numpy_uint32 4.91µs ± 1% 4.75µs ± 0% -3.30% (p=0.016 n=5+4) bench_xla_abstractify_numpy_random 4.34µs ± 2% 4.31µs ± 3% ~ (p=0.310 n=5+5) bench_xla_abstractify_numpy_arange_100_float32 3.94µs ± 1% 3.93µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_enum 6.85µs ± 1% 7.06µs ± 7% +3.07% (p=0.032 n=5+5) bench_are_op_shardings_equal 26.9µs ± 2% 27.0µs ± 3% ~ (p=0.841 n=5+5) bench_pjit_check_aval_sharding 691µs ± 2% 711µs ±13% ~ (p=0.841 n=5+5) bench_addressable_shards_index 656ns ± 4% 688ns ± 9% ~ (p=0.095 n=5+5) bench_remat_eager_retracing_overheads 12.7ms ± 4% 10.7ms ± 1% -15.48% (p=0.016 n=5+4) bench_remat_eager_retracing_overheads_static_argnums 13.0ms ± 2% 11.3ms ± 6% -13.71% (p=0.008 n=5+5) bench_slicing_compilation 12.1ms ± 1% 12.3ms ± 4% ~ (p=0.690 n=5+5) bench_slicing_compilation2 11.3ms ± 0% 11.5ms ± 6% ~ (p=0.690 n=5+5) bench_repeated_static_indexing 62.5ms ± 2% 40.8ms ± 8% -34.77% (p=0.008 n=5+5) bench_repeated_static_slicing 46.7ms ± 1% 31.4ms ± 2% -32.76% (p=0.008 n=5+5) pjit_simple_1_device/num_args:1 2.72µs ± 2% 2.68µs ± 5% ~ (p=0.151 n=5+5) pjit_simple_1_device/num_args:10 12.6µs ± 7% 12.3µs ± 3% ~ (p=0.310 n=5+5) pjit_simple_1_device/num_args:100 109µs ± 3% 108µs ± 4% ~ (p=0.548 n=5+5) pjit_simple_4_device/num_args:1 38.0µs ±26% 36.8µs ±19% ~ (p=0.690 n=5+5) pjit_simple_4_device/num_args:10 93.3µs ±19% 96.6µs ±23% ~ (p=0.841 n=5+5) pjit_simple_4_device/num_args:100 730µs ±16% 698µs ±48% ~ (p=0.841 n=5+5) pjit_aot_1_device/num_args:1 3.29µs ± 2% 3.12µs ± 4% -5.24% (p=0.016 n=4+5) pjit_aot_1_device/num_args:10 13.0µs ± 1% 12.7µs ± 2% ~ (p=0.063 n=4+5) pjit_aot_1_device/num_args:100 111µs ± 5% 110µs ±11% ~ (p=0.421 n=5+5) pjit_aot_4_device/num_args:1 38.4µs ±19% 38.9µs ±24% ~ (p=1.000 n=5+5) pjit_aot_4_device/num_args:10 91.3µs ±15% 96.9µs ±29% ~ (p=0.548 n=5+5) pjit_aot_4_device/num_args:100 676µs ±20% 689µs ±41% ~ (p=0.841 n=5+5) host_local_array_to_global_array 196µs ± 6% 194µs ± 4% ~ (p=0.548 n=5+5) device_put 50.8µs ± 1% 50.7µs ± 4% ~ (p=0.413 n=4+5) device_put_sharded 176µs ± 0% 177µs ± 4% ~ (p=0.190 n=4+5) device_get_8_devices 3.96ms ± 4% 4.03ms ± 7% ~ (p=0.413 n=4+5) np_asarray_8_devices 3.34ms ±18% 3.30ms ±10% ~ (p=0.548 n=5+5) jax_array_arrays_8_devices 5.01ms ±10% 5.09ms ±21% ~ (p=0.421 n=5+5) batch_inplace_while_scatter 440µs ± 1% 439µs ± 1% ~ (p=0.421 n=5+5) batch_inplace_while_dynamic_update_slice 454µs ± 0% 457µs ± 1% ~ (p=0.905 n=4+5) serial_dot_products 4.51µs ± 3% 4.41µs ± 2% ~ (p=0.151 n=5+5) bench_make_array_from_callback_fully_replicated_sharding 26.6µs ± 1% 27.0µs ± 2% ~ (p=0.056 n=5+5) ``` PiperOrigin-RevId: 586505950
2023-11-29 18:06:36 -08:00
from jax._src import lib
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`). **Semantics** Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing. During compilation, the order of devices throughout the program needs to be consistent (same as before this change). Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh. **Why do this?** There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature. So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example: ``` mesh1 = Mesh(jax.devices()[:2], 'x') mesh2 = Mesh(jax.devices()[2:4], 'x') arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P())) arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P())) @jax.jit def f(x): y = with_sharding_constraint(x, NamedSharding(mesh1, P('x'))) return y * 2 f(arr_mesh1) f(arr_mesh2) # DEVICE MISMATCH ERROR! ``` The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature. **Okay, so how do you fix this?** As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here) The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh. **The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.** ``` mesh1 = Mesh(jax.devices()[:2], 'x') mesh2 = Mesh(jax.devices()[2:4], 'x') arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P())) arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P())) # Creating abstract mesh with mesh1 but since both meshes have the same shape (names # and axis size), it should be ok. abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple) @jax.jit def f(x): y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x'))) return y * 2 f(arr_mesh1) f(arr_mesh2) # tracing and lowering cache hit ``` **One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.** **What about `shard_map`?** shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`. ``` mesh1 = Mesh(jax.devices()[:2], 'x') mesh2 = Mesh(jax.devices()[2:4], 'x') arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P())) arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P())) # Creating abstract mesh with mesh1 but since both meshes have the same shape (names # and axis size), it should be ok. abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple) @jax.jit def f(x): y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x')) return y * 2 f(arr_mesh1) f(arr_mesh2) # tracing and lowering cache hit ``` This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits! PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
from jax._src.mesh import AbstractMesh
from jax._src.lib import xla_client as xc
from jax._src.monitoring import record_event_duration_secs
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
SingleDeviceSharding, NamedSharding,
GSPMDSharding, TransferToMemoryKind, is_single_device_sharding)
from jax._src.layout import Layout, DeviceLocalLayout
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration"
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
traceback_util.register_exclusion(__file__)
xe = xc._xla
Backend = xe.Client
Device = xc.Device
CompileOptions = xc.CompileOptions
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
logger = logging.getLogger(__name__)
# This flag is set on exit; no logging should be attempted
_on_exit = False
### op-by-op execution
def apply_primitive(prim, *args, **params):
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
Replace apply_primitive internals with `jax.jit`. This allows deletion of a lot of code and leads to ~40% eager performance speedup. Benchmarks: ``` name old time/op new time/op delta eager_unary_dispatch 31.3µs ± 1% 19.4µs ± 6% -37.91% (p=0.016 n=4+5) eager_unary 32.1µs ± 0% 19.8µs ± 4% -38.26% (p=0.016 n=4+5) eager_binary_dispatch 35.9µs ± 1% 20.5µs ± 4% -42.93% (p=0.016 n=4+5) eager_binary 36.6µs ± 1% 21.1µs ± 4% -42.29% (p=0.016 n=4+5) jit_trivial_dispatch 3.87µs ± 2% 4.12µs ±25% ~ (p=1.000 n=5+5) jit_trivial 4.75µs ± 2% 4.82µs ±11% ~ (p=0.690 n=5+5) jit_simple_dispatch 2.95µs ± 2% 2.97µs ± 7% ~ (p=1.000 n=5+5) jit_simple 3.52µs ± 6% 3.51µs ± 5% ~ (p=0.841 n=5+5) jit_simple_dispatch_array 2.95µs ± 2% 2.96µs ± 6% ~ (p=1.000 n=5+5) jit_simple_array 3.46µs ± 2% 3.51µs ± 5% ~ (p=0.690 n=5+5) jit_small_matmul 3.01µs ± 1% 3.00µs ± 4% ~ (p=0.548 n=5+5) jit_big_matmul 34.0µs ±18% 35.5µs ±17% ~ (p=0.310 n=5+5) jit_simple_many_args_dispatch/num_args:10 6.93µs ± 6% 6.80µs ± 6% ~ (p=0.481 n=10+10) jit_simple_many_args_dispatch/num_args:100 47.7µs ± 7% 45.4µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args_dispatch/num_args:1000 545µs ± 8% 516µs ± 2% ~ (p=0.101 n=10+8) jit_simple_many_args_dispatch/num_args:2000 1.12ms ± 7% 1.07ms ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:10 7.42µs ± 5% 7.23µs ± 2% ~ (p=0.173 n=10+8) jit_simple_many_args/num_args:100 48.4µs ± 7% 45.6µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:1000 542µs ± 6% 524µs ± 8% ~ (p=0.089 n=10+10) jit_simple_many_args/num_args:2000 1.12ms ± 7% 1.08ms ± 1% ~ (p=0.068 n=10+8) jit_simple_pruned_args_dispatch_10 4.79µs ± 8% 4.98µs ±10% ~ (p=0.421 n=5+5) jit_simple_pruned_args_10 5.32µs ± 6% 5.30µs ± 4% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_100 24.7µs ± 6% 23.8µs ± 8% ~ (p=0.548 n=5+5) jit_simple_pruned_args_100 25.2µs ± 6% 24.4µs ± 8% ~ (p=0.690 n=5+5) jit_simple_pruned_args_dispatch_1000 238µs ± 7% 232µs ± 8% ~ (p=0.841 n=5+5) jit_simple_pruned_args_1000 240µs ± 7% 234µs ± 8% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_2000 516µs ± 6% 497µs ± 1% ~ (p=0.413 n=5+4) jit_simple_pruned_args_2000 517µs ± 6% 505µs ± 7% ~ (p=0.690 n=5+5) jit_dispatch_without_transfer 719µs ± 9% 751µs ± 8% ~ (p=0.222 n=5+5) jit_dispatch_with_transfer 799µs ±14% 793µs ± 9% ~ (p=1.000 n=5+5) pmap_trivial_2_devices 49.9µs ±40% 48.2µs ±42% ~ (p=0.841 n=5+5) pmap_trivial_dispatch_8_devices 74.5µs ±24% 78.9µs ±29% ~ (p=0.421 n=5+5) pmap_trivial_8_devices 79.3µs ± 6% 82.7µs ±20% ~ (p=0.841 n=5+5) pmap_simple_2_devices 47.1µs ±17% 49.1µs ±20% ~ (p=0.548 n=5+5) pmap_simple_dispatch_8_devices 73.4µs ±16% 76.8µs ±21% ~ (p=0.690 n=5+5) pmap_simple_8_devices 76.0µs ±10% 80.6µs ±29% ~ (p=1.000 n=5+5) pmap_simple_dispatch_8_devices_100_args 1.12ms ±22% 1.08ms ±42% ~ (p=0.841 n=5+5) pmap_simple_8_devices_100_args 12.5ms ± 8% 12.8ms ±10% ~ (p=1.000 n=5+5) sda_index_1 413µs ± 1% 686µs ± 4% +66.08% (p=0.008 n=5+5) sda_index_2 850µs ± 1% 1378µs ± 4% +62.02% (p=0.008 n=5+5) sda_index_8 3.60ms ± 1% 5.69ms ± 4% +58.00% (p=0.008 n=5+5) bench_shaped_abstractify 300µs ± 1% 305µs ± 3% ~ (p=0.056 n=5+5) bench_xla_abstractify_scalar_int 6.45µs ± 1% 6.50µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_scalar_float 3.73µs ± 1% 3.73µs ± 3% ~ (p=0.690 n=5+5) bench_xla_abstractify_scalar_numpy_int32 4.97µs ± 1% 4.83µs ± 3% ~ (p=0.095 n=5+5) bench_xla_abstractify_scalar_numpy_uint32 4.91µs ± 1% 4.75µs ± 0% -3.30% (p=0.016 n=5+4) bench_xla_abstractify_numpy_random 4.34µs ± 2% 4.31µs ± 3% ~ (p=0.310 n=5+5) bench_xla_abstractify_numpy_arange_100_float32 3.94µs ± 1% 3.93µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_enum 6.85µs ± 1% 7.06µs ± 7% +3.07% (p=0.032 n=5+5) bench_are_op_shardings_equal 26.9µs ± 2% 27.0µs ± 3% ~ (p=0.841 n=5+5) bench_pjit_check_aval_sharding 691µs ± 2% 711µs ±13% ~ (p=0.841 n=5+5) bench_addressable_shards_index 656ns ± 4% 688ns ± 9% ~ (p=0.095 n=5+5) bench_remat_eager_retracing_overheads 12.7ms ± 4% 10.7ms ± 1% -15.48% (p=0.016 n=5+4) bench_remat_eager_retracing_overheads_static_argnums 13.0ms ± 2% 11.3ms ± 6% -13.71% (p=0.008 n=5+5) bench_slicing_compilation 12.1ms ± 1% 12.3ms ± 4% ~ (p=0.690 n=5+5) bench_slicing_compilation2 11.3ms ± 0% 11.5ms ± 6% ~ (p=0.690 n=5+5) bench_repeated_static_indexing 62.5ms ± 2% 40.8ms ± 8% -34.77% (p=0.008 n=5+5) bench_repeated_static_slicing 46.7ms ± 1% 31.4ms ± 2% -32.76% (p=0.008 n=5+5) pjit_simple_1_device/num_args:1 2.72µs ± 2% 2.68µs ± 5% ~ (p=0.151 n=5+5) pjit_simple_1_device/num_args:10 12.6µs ± 7% 12.3µs ± 3% ~ (p=0.310 n=5+5) pjit_simple_1_device/num_args:100 109µs ± 3% 108µs ± 4% ~ (p=0.548 n=5+5) pjit_simple_4_device/num_args:1 38.0µs ±26% 36.8µs ±19% ~ (p=0.690 n=5+5) pjit_simple_4_device/num_args:10 93.3µs ±19% 96.6µs ±23% ~ (p=0.841 n=5+5) pjit_simple_4_device/num_args:100 730µs ±16% 698µs ±48% ~ (p=0.841 n=5+5) pjit_aot_1_device/num_args:1 3.29µs ± 2% 3.12µs ± 4% -5.24% (p=0.016 n=4+5) pjit_aot_1_device/num_args:10 13.0µs ± 1% 12.7µs ± 2% ~ (p=0.063 n=4+5) pjit_aot_1_device/num_args:100 111µs ± 5% 110µs ±11% ~ (p=0.421 n=5+5) pjit_aot_4_device/num_args:1 38.4µs ±19% 38.9µs ±24% ~ (p=1.000 n=5+5) pjit_aot_4_device/num_args:10 91.3µs ±15% 96.9µs ±29% ~ (p=0.548 n=5+5) pjit_aot_4_device/num_args:100 676µs ±20% 689µs ±41% ~ (p=0.841 n=5+5) host_local_array_to_global_array 196µs ± 6% 194µs ± 4% ~ (p=0.548 n=5+5) device_put 50.8µs ± 1% 50.7µs ± 4% ~ (p=0.413 n=4+5) device_put_sharded 176µs ± 0% 177µs ± 4% ~ (p=0.190 n=4+5) device_get_8_devices 3.96ms ± 4% 4.03ms ± 7% ~ (p=0.413 n=4+5) np_asarray_8_devices 3.34ms ±18% 3.30ms ±10% ~ (p=0.548 n=5+5) jax_array_arrays_8_devices 5.01ms ±10% 5.09ms ±21% ~ (p=0.421 n=5+5) batch_inplace_while_scatter 440µs ± 1% 439µs ± 1% ~ (p=0.421 n=5+5) batch_inplace_while_dynamic_update_slice 454µs ± 0% 457µs ± 1% ~ (p=0.905 n=4+5) serial_dot_products 4.51µs ± 3% 4.41µs ± 2% ~ (p=0.151 n=5+5) bench_make_array_from_callback_fully_replicated_sharding 26.6µs ± 1% 27.0µs ± 2% ~ (p=0.056 n=5+5) ``` PiperOrigin-RevId: 586505950
2023-11-29 18:06:36 -08:00
fun = xla_primitive_callable(prim, **params)
# TODO(yashkatariya): Investigate adding is_primitive to jit and never
# triggering the disable jit path instead of messing around with it here.
prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
try:
outs = fun(*args)
finally:
lib.jax_jit.swap_thread_local_state_disable_jit(prev)
Replace apply_primitive internals with `jax.jit`. This allows deletion of a lot of code and leads to ~40% eager performance speedup. Benchmarks: ``` name old time/op new time/op delta eager_unary_dispatch 31.3µs ± 1% 19.4µs ± 6% -37.91% (p=0.016 n=4+5) eager_unary 32.1µs ± 0% 19.8µs ± 4% -38.26% (p=0.016 n=4+5) eager_binary_dispatch 35.9µs ± 1% 20.5µs ± 4% -42.93% (p=0.016 n=4+5) eager_binary 36.6µs ± 1% 21.1µs ± 4% -42.29% (p=0.016 n=4+5) jit_trivial_dispatch 3.87µs ± 2% 4.12µs ±25% ~ (p=1.000 n=5+5) jit_trivial 4.75µs ± 2% 4.82µs ±11% ~ (p=0.690 n=5+5) jit_simple_dispatch 2.95µs ± 2% 2.97µs ± 7% ~ (p=1.000 n=5+5) jit_simple 3.52µs ± 6% 3.51µs ± 5% ~ (p=0.841 n=5+5) jit_simple_dispatch_array 2.95µs ± 2% 2.96µs ± 6% ~ (p=1.000 n=5+5) jit_simple_array 3.46µs ± 2% 3.51µs ± 5% ~ (p=0.690 n=5+5) jit_small_matmul 3.01µs ± 1% 3.00µs ± 4% ~ (p=0.548 n=5+5) jit_big_matmul 34.0µs ±18% 35.5µs ±17% ~ (p=0.310 n=5+5) jit_simple_many_args_dispatch/num_args:10 6.93µs ± 6% 6.80µs ± 6% ~ (p=0.481 n=10+10) jit_simple_many_args_dispatch/num_args:100 47.7µs ± 7% 45.4µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args_dispatch/num_args:1000 545µs ± 8% 516µs ± 2% ~ (p=0.101 n=10+8) jit_simple_many_args_dispatch/num_args:2000 1.12ms ± 7% 1.07ms ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:10 7.42µs ± 5% 7.23µs ± 2% ~ (p=0.173 n=10+8) jit_simple_many_args/num_args:100 48.4µs ± 7% 45.6µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:1000 542µs ± 6% 524µs ± 8% ~ (p=0.089 n=10+10) jit_simple_many_args/num_args:2000 1.12ms ± 7% 1.08ms ± 1% ~ (p=0.068 n=10+8) jit_simple_pruned_args_dispatch_10 4.79µs ± 8% 4.98µs ±10% ~ (p=0.421 n=5+5) jit_simple_pruned_args_10 5.32µs ± 6% 5.30µs ± 4% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_100 24.7µs ± 6% 23.8µs ± 8% ~ (p=0.548 n=5+5) jit_simple_pruned_args_100 25.2µs ± 6% 24.4µs ± 8% ~ (p=0.690 n=5+5) jit_simple_pruned_args_dispatch_1000 238µs ± 7% 232µs ± 8% ~ (p=0.841 n=5+5) jit_simple_pruned_args_1000 240µs ± 7% 234µs ± 8% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_2000 516µs ± 6% 497µs ± 1% ~ (p=0.413 n=5+4) jit_simple_pruned_args_2000 517µs ± 6% 505µs ± 7% ~ (p=0.690 n=5+5) jit_dispatch_without_transfer 719µs ± 9% 751µs ± 8% ~ (p=0.222 n=5+5) jit_dispatch_with_transfer 799µs ±14% 793µs ± 9% ~ (p=1.000 n=5+5) pmap_trivial_2_devices 49.9µs ±40% 48.2µs ±42% ~ (p=0.841 n=5+5) pmap_trivial_dispatch_8_devices 74.5µs ±24% 78.9µs ±29% ~ (p=0.421 n=5+5) pmap_trivial_8_devices 79.3µs ± 6% 82.7µs ±20% ~ (p=0.841 n=5+5) pmap_simple_2_devices 47.1µs ±17% 49.1µs ±20% ~ (p=0.548 n=5+5) pmap_simple_dispatch_8_devices 73.4µs ±16% 76.8µs ±21% ~ (p=0.690 n=5+5) pmap_simple_8_devices 76.0µs ±10% 80.6µs ±29% ~ (p=1.000 n=5+5) pmap_simple_dispatch_8_devices_100_args 1.12ms ±22% 1.08ms ±42% ~ (p=0.841 n=5+5) pmap_simple_8_devices_100_args 12.5ms ± 8% 12.8ms ±10% ~ (p=1.000 n=5+5) sda_index_1 413µs ± 1% 686µs ± 4% +66.08% (p=0.008 n=5+5) sda_index_2 850µs ± 1% 1378µs ± 4% +62.02% (p=0.008 n=5+5) sda_index_8 3.60ms ± 1% 5.69ms ± 4% +58.00% (p=0.008 n=5+5) bench_shaped_abstractify 300µs ± 1% 305µs ± 3% ~ (p=0.056 n=5+5) bench_xla_abstractify_scalar_int 6.45µs ± 1% 6.50µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_scalar_float 3.73µs ± 1% 3.73µs ± 3% ~ (p=0.690 n=5+5) bench_xla_abstractify_scalar_numpy_int32 4.97µs ± 1% 4.83µs ± 3% ~ (p=0.095 n=5+5) bench_xla_abstractify_scalar_numpy_uint32 4.91µs ± 1% 4.75µs ± 0% -3.30% (p=0.016 n=5+4) bench_xla_abstractify_numpy_random 4.34µs ± 2% 4.31µs ± 3% ~ (p=0.310 n=5+5) bench_xla_abstractify_numpy_arange_100_float32 3.94µs ± 1% 3.93µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_enum 6.85µs ± 1% 7.06µs ± 7% +3.07% (p=0.032 n=5+5) bench_are_op_shardings_equal 26.9µs ± 2% 27.0µs ± 3% ~ (p=0.841 n=5+5) bench_pjit_check_aval_sharding 691µs ± 2% 711µs ±13% ~ (p=0.841 n=5+5) bench_addressable_shards_index 656ns ± 4% 688ns ± 9% ~ (p=0.095 n=5+5) bench_remat_eager_retracing_overheads 12.7ms ± 4% 10.7ms ± 1% -15.48% (p=0.016 n=5+4) bench_remat_eager_retracing_overheads_static_argnums 13.0ms ± 2% 11.3ms ± 6% -13.71% (p=0.008 n=5+5) bench_slicing_compilation 12.1ms ± 1% 12.3ms ± 4% ~ (p=0.690 n=5+5) bench_slicing_compilation2 11.3ms ± 0% 11.5ms ± 6% ~ (p=0.690 n=5+5) bench_repeated_static_indexing 62.5ms ± 2% 40.8ms ± 8% -34.77% (p=0.008 n=5+5) bench_repeated_static_slicing 46.7ms ± 1% 31.4ms ± 2% -32.76% (p=0.008 n=5+5) pjit_simple_1_device/num_args:1 2.72µs ± 2% 2.68µs ± 5% ~ (p=0.151 n=5+5) pjit_simple_1_device/num_args:10 12.6µs ± 7% 12.3µs ± 3% ~ (p=0.310 n=5+5) pjit_simple_1_device/num_args:100 109µs ± 3% 108µs ± 4% ~ (p=0.548 n=5+5) pjit_simple_4_device/num_args:1 38.0µs ±26% 36.8µs ±19% ~ (p=0.690 n=5+5) pjit_simple_4_device/num_args:10 93.3µs ±19% 96.6µs ±23% ~ (p=0.841 n=5+5) pjit_simple_4_device/num_args:100 730µs ±16% 698µs ±48% ~ (p=0.841 n=5+5) pjit_aot_1_device/num_args:1 3.29µs ± 2% 3.12µs ± 4% -5.24% (p=0.016 n=4+5) pjit_aot_1_device/num_args:10 13.0µs ± 1% 12.7µs ± 2% ~ (p=0.063 n=4+5) pjit_aot_1_device/num_args:100 111µs ± 5% 110µs ±11% ~ (p=0.421 n=5+5) pjit_aot_4_device/num_args:1 38.4µs ±19% 38.9µs ±24% ~ (p=1.000 n=5+5) pjit_aot_4_device/num_args:10 91.3µs ±15% 96.9µs ±29% ~ (p=0.548 n=5+5) pjit_aot_4_device/num_args:100 676µs ±20% 689µs ±41% ~ (p=0.841 n=5+5) host_local_array_to_global_array 196µs ± 6% 194µs ± 4% ~ (p=0.548 n=5+5) device_put 50.8µs ± 1% 50.7µs ± 4% ~ (p=0.413 n=4+5) device_put_sharded 176µs ± 0% 177µs ± 4% ~ (p=0.190 n=4+5) device_get_8_devices 3.96ms ± 4% 4.03ms ± 7% ~ (p=0.413 n=4+5) np_asarray_8_devices 3.34ms ±18% 3.30ms ±10% ~ (p=0.548 n=5+5) jax_array_arrays_8_devices 5.01ms ±10% 5.09ms ±21% ~ (p=0.421 n=5+5) batch_inplace_while_scatter 440µs ± 1% 439µs ± 1% ~ (p=0.421 n=5+5) batch_inplace_while_dynamic_update_slice 454µs ± 0% 457µs ± 1% ~ (p=0.905 n=4+5) serial_dot_products 4.51µs ± 3% 4.41µs ± 2% ~ (p=0.151 n=5+5) bench_make_array_from_callback_fully_replicated_sharding 26.6µs ± 1% 27.0µs ± 2% ~ (p=0.056 n=5+5) ``` PiperOrigin-RevId: 586505950
2023-11-29 18:06:36 -08:00
return outs
@util.cache()
Replace apply_primitive internals with `jax.jit`. This allows deletion of a lot of code and leads to ~40% eager performance speedup. Benchmarks: ``` name old time/op new time/op delta eager_unary_dispatch 31.3µs ± 1% 19.4µs ± 6% -37.91% (p=0.016 n=4+5) eager_unary 32.1µs ± 0% 19.8µs ± 4% -38.26% (p=0.016 n=4+5) eager_binary_dispatch 35.9µs ± 1% 20.5µs ± 4% -42.93% (p=0.016 n=4+5) eager_binary 36.6µs ± 1% 21.1µs ± 4% -42.29% (p=0.016 n=4+5) jit_trivial_dispatch 3.87µs ± 2% 4.12µs ±25% ~ (p=1.000 n=5+5) jit_trivial 4.75µs ± 2% 4.82µs ±11% ~ (p=0.690 n=5+5) jit_simple_dispatch 2.95µs ± 2% 2.97µs ± 7% ~ (p=1.000 n=5+5) jit_simple 3.52µs ± 6% 3.51µs ± 5% ~ (p=0.841 n=5+5) jit_simple_dispatch_array 2.95µs ± 2% 2.96µs ± 6% ~ (p=1.000 n=5+5) jit_simple_array 3.46µs ± 2% 3.51µs ± 5% ~ (p=0.690 n=5+5) jit_small_matmul 3.01µs ± 1% 3.00µs ± 4% ~ (p=0.548 n=5+5) jit_big_matmul 34.0µs ±18% 35.5µs ±17% ~ (p=0.310 n=5+5) jit_simple_many_args_dispatch/num_args:10 6.93µs ± 6% 6.80µs ± 6% ~ (p=0.481 n=10+10) jit_simple_many_args_dispatch/num_args:100 47.7µs ± 7% 45.4µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args_dispatch/num_args:1000 545µs ± 8% 516µs ± 2% ~ (p=0.101 n=10+8) jit_simple_many_args_dispatch/num_args:2000 1.12ms ± 7% 1.07ms ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:10 7.42µs ± 5% 7.23µs ± 2% ~ (p=0.173 n=10+8) jit_simple_many_args/num_args:100 48.4µs ± 7% 45.6µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:1000 542µs ± 6% 524µs ± 8% ~ (p=0.089 n=10+10) jit_simple_many_args/num_args:2000 1.12ms ± 7% 1.08ms ± 1% ~ (p=0.068 n=10+8) jit_simple_pruned_args_dispatch_10 4.79µs ± 8% 4.98µs ±10% ~ (p=0.421 n=5+5) jit_simple_pruned_args_10 5.32µs ± 6% 5.30µs ± 4% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_100 24.7µs ± 6% 23.8µs ± 8% ~ (p=0.548 n=5+5) jit_simple_pruned_args_100 25.2µs ± 6% 24.4µs ± 8% ~ (p=0.690 n=5+5) jit_simple_pruned_args_dispatch_1000 238µs ± 7% 232µs ± 8% ~ (p=0.841 n=5+5) jit_simple_pruned_args_1000 240µs ± 7% 234µs ± 8% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_2000 516µs ± 6% 497µs ± 1% ~ (p=0.413 n=5+4) jit_simple_pruned_args_2000 517µs ± 6% 505µs ± 7% ~ (p=0.690 n=5+5) jit_dispatch_without_transfer 719µs ± 9% 751µs ± 8% ~ (p=0.222 n=5+5) jit_dispatch_with_transfer 799µs ±14% 793µs ± 9% ~ (p=1.000 n=5+5) pmap_trivial_2_devices 49.9µs ±40% 48.2µs ±42% ~ (p=0.841 n=5+5) pmap_trivial_dispatch_8_devices 74.5µs ±24% 78.9µs ±29% ~ (p=0.421 n=5+5) pmap_trivial_8_devices 79.3µs ± 6% 82.7µs ±20% ~ (p=0.841 n=5+5) pmap_simple_2_devices 47.1µs ±17% 49.1µs ±20% ~ (p=0.548 n=5+5) pmap_simple_dispatch_8_devices 73.4µs ±16% 76.8µs ±21% ~ (p=0.690 n=5+5) pmap_simple_8_devices 76.0µs ±10% 80.6µs ±29% ~ (p=1.000 n=5+5) pmap_simple_dispatch_8_devices_100_args 1.12ms ±22% 1.08ms ±42% ~ (p=0.841 n=5+5) pmap_simple_8_devices_100_args 12.5ms ± 8% 12.8ms ±10% ~ (p=1.000 n=5+5) sda_index_1 413µs ± 1% 686µs ± 4% +66.08% (p=0.008 n=5+5) sda_index_2 850µs ± 1% 1378µs ± 4% +62.02% (p=0.008 n=5+5) sda_index_8 3.60ms ± 1% 5.69ms ± 4% +58.00% (p=0.008 n=5+5) bench_shaped_abstractify 300µs ± 1% 305µs ± 3% ~ (p=0.056 n=5+5) bench_xla_abstractify_scalar_int 6.45µs ± 1% 6.50µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_scalar_float 3.73µs ± 1% 3.73µs ± 3% ~ (p=0.690 n=5+5) bench_xla_abstractify_scalar_numpy_int32 4.97µs ± 1% 4.83µs ± 3% ~ (p=0.095 n=5+5) bench_xla_abstractify_scalar_numpy_uint32 4.91µs ± 1% 4.75µs ± 0% -3.30% (p=0.016 n=5+4) bench_xla_abstractify_numpy_random 4.34µs ± 2% 4.31µs ± 3% ~ (p=0.310 n=5+5) bench_xla_abstractify_numpy_arange_100_float32 3.94µs ± 1% 3.93µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_enum 6.85µs ± 1% 7.06µs ± 7% +3.07% (p=0.032 n=5+5) bench_are_op_shardings_equal 26.9µs ± 2% 27.0µs ± 3% ~ (p=0.841 n=5+5) bench_pjit_check_aval_sharding 691µs ± 2% 711µs ±13% ~ (p=0.841 n=5+5) bench_addressable_shards_index 656ns ± 4% 688ns ± 9% ~ (p=0.095 n=5+5) bench_remat_eager_retracing_overheads 12.7ms ± 4% 10.7ms ± 1% -15.48% (p=0.016 n=5+4) bench_remat_eager_retracing_overheads_static_argnums 13.0ms ± 2% 11.3ms ± 6% -13.71% (p=0.008 n=5+5) bench_slicing_compilation 12.1ms ± 1% 12.3ms ± 4% ~ (p=0.690 n=5+5) bench_slicing_compilation2 11.3ms ± 0% 11.5ms ± 6% ~ (p=0.690 n=5+5) bench_repeated_static_indexing 62.5ms ± 2% 40.8ms ± 8% -34.77% (p=0.008 n=5+5) bench_repeated_static_slicing 46.7ms ± 1% 31.4ms ± 2% -32.76% (p=0.008 n=5+5) pjit_simple_1_device/num_args:1 2.72µs ± 2% 2.68µs ± 5% ~ (p=0.151 n=5+5) pjit_simple_1_device/num_args:10 12.6µs ± 7% 12.3µs ± 3% ~ (p=0.310 n=5+5) pjit_simple_1_device/num_args:100 109µs ± 3% 108µs ± 4% ~ (p=0.548 n=5+5) pjit_simple_4_device/num_args:1 38.0µs ±26% 36.8µs ±19% ~ (p=0.690 n=5+5) pjit_simple_4_device/num_args:10 93.3µs ±19% 96.6µs ±23% ~ (p=0.841 n=5+5) pjit_simple_4_device/num_args:100 730µs ±16% 698µs ±48% ~ (p=0.841 n=5+5) pjit_aot_1_device/num_args:1 3.29µs ± 2% 3.12µs ± 4% -5.24% (p=0.016 n=4+5) pjit_aot_1_device/num_args:10 13.0µs ± 1% 12.7µs ± 2% ~ (p=0.063 n=4+5) pjit_aot_1_device/num_args:100 111µs ± 5% 110µs ±11% ~ (p=0.421 n=5+5) pjit_aot_4_device/num_args:1 38.4µs ±19% 38.9µs ±24% ~ (p=1.000 n=5+5) pjit_aot_4_device/num_args:10 91.3µs ±15% 96.9µs ±29% ~ (p=0.548 n=5+5) pjit_aot_4_device/num_args:100 676µs ±20% 689µs ±41% ~ (p=0.841 n=5+5) host_local_array_to_global_array 196µs ± 6% 194µs ± 4% ~ (p=0.548 n=5+5) device_put 50.8µs ± 1% 50.7µs ± 4% ~ (p=0.413 n=4+5) device_put_sharded 176µs ± 0% 177µs ± 4% ~ (p=0.190 n=4+5) device_get_8_devices 3.96ms ± 4% 4.03ms ± 7% ~ (p=0.413 n=4+5) np_asarray_8_devices 3.34ms ±18% 3.30ms ±10% ~ (p=0.548 n=5+5) jax_array_arrays_8_devices 5.01ms ±10% 5.09ms ±21% ~ (p=0.421 n=5+5) batch_inplace_while_scatter 440µs ± 1% 439µs ± 1% ~ (p=0.421 n=5+5) batch_inplace_while_dynamic_update_slice 454µs ± 0% 457µs ± 1% ~ (p=0.905 n=4+5) serial_dot_products 4.51µs ± 3% 4.41µs ± 2% ~ (p=0.151 n=5+5) bench_make_array_from_callback_fully_replicated_sharding 26.6µs ± 1% 27.0µs ± 2% ~ (p=0.056 n=5+5) ``` PiperOrigin-RevId: 586505950
2023-11-29 18:06:36 -08:00
def xla_primitive_callable(prim: core.Primitive, **params):
def prim_fun(*args):
Replace apply_primitive internals with `jax.jit`. This allows deletion of a lot of code and leads to ~40% eager performance speedup. Benchmarks: ``` name old time/op new time/op delta eager_unary_dispatch 31.3µs ± 1% 19.4µs ± 6% -37.91% (p=0.016 n=4+5) eager_unary 32.1µs ± 0% 19.8µs ± 4% -38.26% (p=0.016 n=4+5) eager_binary_dispatch 35.9µs ± 1% 20.5µs ± 4% -42.93% (p=0.016 n=4+5) eager_binary 36.6µs ± 1% 21.1µs ± 4% -42.29% (p=0.016 n=4+5) jit_trivial_dispatch 3.87µs ± 2% 4.12µs ±25% ~ (p=1.000 n=5+5) jit_trivial 4.75µs ± 2% 4.82µs ±11% ~ (p=0.690 n=5+5) jit_simple_dispatch 2.95µs ± 2% 2.97µs ± 7% ~ (p=1.000 n=5+5) jit_simple 3.52µs ± 6% 3.51µs ± 5% ~ (p=0.841 n=5+5) jit_simple_dispatch_array 2.95µs ± 2% 2.96µs ± 6% ~ (p=1.000 n=5+5) jit_simple_array 3.46µs ± 2% 3.51µs ± 5% ~ (p=0.690 n=5+5) jit_small_matmul 3.01µs ± 1% 3.00µs ± 4% ~ (p=0.548 n=5+5) jit_big_matmul 34.0µs ±18% 35.5µs ±17% ~ (p=0.310 n=5+5) jit_simple_many_args_dispatch/num_args:10 6.93µs ± 6% 6.80µs ± 6% ~ (p=0.481 n=10+10) jit_simple_many_args_dispatch/num_args:100 47.7µs ± 7% 45.4µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args_dispatch/num_args:1000 545µs ± 8% 516µs ± 2% ~ (p=0.101 n=10+8) jit_simple_many_args_dispatch/num_args:2000 1.12ms ± 7% 1.07ms ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:10 7.42µs ± 5% 7.23µs ± 2% ~ (p=0.173 n=10+8) jit_simple_many_args/num_args:100 48.4µs ± 7% 45.6µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:1000 542µs ± 6% 524µs ± 8% ~ (p=0.089 n=10+10) jit_simple_many_args/num_args:2000 1.12ms ± 7% 1.08ms ± 1% ~ (p=0.068 n=10+8) jit_simple_pruned_args_dispatch_10 4.79µs ± 8% 4.98µs ±10% ~ (p=0.421 n=5+5) jit_simple_pruned_args_10 5.32µs ± 6% 5.30µs ± 4% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_100 24.7µs ± 6% 23.8µs ± 8% ~ (p=0.548 n=5+5) jit_simple_pruned_args_100 25.2µs ± 6% 24.4µs ± 8% ~ (p=0.690 n=5+5) jit_simple_pruned_args_dispatch_1000 238µs ± 7% 232µs ± 8% ~ (p=0.841 n=5+5) jit_simple_pruned_args_1000 240µs ± 7% 234µs ± 8% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_2000 516µs ± 6% 497µs ± 1% ~ (p=0.413 n=5+4) jit_simple_pruned_args_2000 517µs ± 6% 505µs ± 7% ~ (p=0.690 n=5+5) jit_dispatch_without_transfer 719µs ± 9% 751µs ± 8% ~ (p=0.222 n=5+5) jit_dispatch_with_transfer 799µs ±14% 793µs ± 9% ~ (p=1.000 n=5+5) pmap_trivial_2_devices 49.9µs ±40% 48.2µs ±42% ~ (p=0.841 n=5+5) pmap_trivial_dispatch_8_devices 74.5µs ±24% 78.9µs ±29% ~ (p=0.421 n=5+5) pmap_trivial_8_devices 79.3µs ± 6% 82.7µs ±20% ~ (p=0.841 n=5+5) pmap_simple_2_devices 47.1µs ±17% 49.1µs ±20% ~ (p=0.548 n=5+5) pmap_simple_dispatch_8_devices 73.4µs ±16% 76.8µs ±21% ~ (p=0.690 n=5+5) pmap_simple_8_devices 76.0µs ±10% 80.6µs ±29% ~ (p=1.000 n=5+5) pmap_simple_dispatch_8_devices_100_args 1.12ms ±22% 1.08ms ±42% ~ (p=0.841 n=5+5) pmap_simple_8_devices_100_args 12.5ms ± 8% 12.8ms ±10% ~ (p=1.000 n=5+5) sda_index_1 413µs ± 1% 686µs ± 4% +66.08% (p=0.008 n=5+5) sda_index_2 850µs ± 1% 1378µs ± 4% +62.02% (p=0.008 n=5+5) sda_index_8 3.60ms ± 1% 5.69ms ± 4% +58.00% (p=0.008 n=5+5) bench_shaped_abstractify 300µs ± 1% 305µs ± 3% ~ (p=0.056 n=5+5) bench_xla_abstractify_scalar_int 6.45µs ± 1% 6.50µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_scalar_float 3.73µs ± 1% 3.73µs ± 3% ~ (p=0.690 n=5+5) bench_xla_abstractify_scalar_numpy_int32 4.97µs ± 1% 4.83µs ± 3% ~ (p=0.095 n=5+5) bench_xla_abstractify_scalar_numpy_uint32 4.91µs ± 1% 4.75µs ± 0% -3.30% (p=0.016 n=5+4) bench_xla_abstractify_numpy_random 4.34µs ± 2% 4.31µs ± 3% ~ (p=0.310 n=5+5) bench_xla_abstractify_numpy_arange_100_float32 3.94µs ± 1% 3.93µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_enum 6.85µs ± 1% 7.06µs ± 7% +3.07% (p=0.032 n=5+5) bench_are_op_shardings_equal 26.9µs ± 2% 27.0µs ± 3% ~ (p=0.841 n=5+5) bench_pjit_check_aval_sharding 691µs ± 2% 711µs ±13% ~ (p=0.841 n=5+5) bench_addressable_shards_index 656ns ± 4% 688ns ± 9% ~ (p=0.095 n=5+5) bench_remat_eager_retracing_overheads 12.7ms ± 4% 10.7ms ± 1% -15.48% (p=0.016 n=5+4) bench_remat_eager_retracing_overheads_static_argnums 13.0ms ± 2% 11.3ms ± 6% -13.71% (p=0.008 n=5+5) bench_slicing_compilation 12.1ms ± 1% 12.3ms ± 4% ~ (p=0.690 n=5+5) bench_slicing_compilation2 11.3ms ± 0% 11.5ms ± 6% ~ (p=0.690 n=5+5) bench_repeated_static_indexing 62.5ms ± 2% 40.8ms ± 8% -34.77% (p=0.008 n=5+5) bench_repeated_static_slicing 46.7ms ± 1% 31.4ms ± 2% -32.76% (p=0.008 n=5+5) pjit_simple_1_device/num_args:1 2.72µs ± 2% 2.68µs ± 5% ~ (p=0.151 n=5+5) pjit_simple_1_device/num_args:10 12.6µs ± 7% 12.3µs ± 3% ~ (p=0.310 n=5+5) pjit_simple_1_device/num_args:100 109µs ± 3% 108µs ± 4% ~ (p=0.548 n=5+5) pjit_simple_4_device/num_args:1 38.0µs ±26% 36.8µs ±19% ~ (p=0.690 n=5+5) pjit_simple_4_device/num_args:10 93.3µs ±19% 96.6µs ±23% ~ (p=0.841 n=5+5) pjit_simple_4_device/num_args:100 730µs ±16% 698µs ±48% ~ (p=0.841 n=5+5) pjit_aot_1_device/num_args:1 3.29µs ± 2% 3.12µs ± 4% -5.24% (p=0.016 n=4+5) pjit_aot_1_device/num_args:10 13.0µs ± 1% 12.7µs ± 2% ~ (p=0.063 n=4+5) pjit_aot_1_device/num_args:100 111µs ± 5% 110µs ±11% ~ (p=0.421 n=5+5) pjit_aot_4_device/num_args:1 38.4µs ±19% 38.9µs ±24% ~ (p=1.000 n=5+5) pjit_aot_4_device/num_args:10 91.3µs ±15% 96.9µs ±29% ~ (p=0.548 n=5+5) pjit_aot_4_device/num_args:100 676µs ±20% 689µs ±41% ~ (p=0.841 n=5+5) host_local_array_to_global_array 196µs ± 6% 194µs ± 4% ~ (p=0.548 n=5+5) device_put 50.8µs ± 1% 50.7µs ± 4% ~ (p=0.413 n=4+5) device_put_sharded 176µs ± 0% 177µs ± 4% ~ (p=0.190 n=4+5) device_get_8_devices 3.96ms ± 4% 4.03ms ± 7% ~ (p=0.413 n=4+5) np_asarray_8_devices 3.34ms ±18% 3.30ms ±10% ~ (p=0.548 n=5+5) jax_array_arrays_8_devices 5.01ms ±10% 5.09ms ±21% ~ (p=0.421 n=5+5) batch_inplace_while_scatter 440µs ± 1% 439µs ± 1% ~ (p=0.421 n=5+5) batch_inplace_while_dynamic_update_slice 454µs ± 0% 457µs ± 1% ~ (p=0.905 n=4+5) serial_dot_products 4.51µs ± 3% 4.41µs ± 2% ~ (p=0.151 n=5+5) bench_make_array_from_callback_fully_replicated_sharding 26.6µs ± 1% 27.0µs ± 2% ~ (p=0.056 n=5+5) ``` PiperOrigin-RevId: 586505950
2023-11-29 18:06:36 -08:00
return prim.bind(*args, **params)
prim_fun.__name__ = prim.name
prim_fun.__qualname__ = prim.name
return api.jit(prim_fun)
def simple_impl(prim):
prim.def_impl(partial(apply_primitive, prim))
RuntimeToken = Any
class RuntimeTokenSet(threading.local):
"""See docstring for effects.py module for the calling convention for tokens."""
# For each ordered effect, the token returned by the last dispatched
# computation, sharded over the devices in that computation.
current_tokens: dict[core.Effect, core.Token]
# For each device, the runtime token returned by the last dispatched
# computation on that device.
output_runtime_tokens: dict[Device, RuntimeToken]
def __init__(self):
self.current_tokens = {}
self.output_runtime_tokens = {}
def get_token_input(
self, eff: core.Effect, devices: list[Device]
) -> core.Token:
tok = self.current_tokens.get(eff, np.zeros(0, np.bool_))
if isinstance(tok, core.Token):
# The order of devices may change, so we need to reshard if necessary.
# TODO(yueshengys): This might still be buggy in a multi-process SPMD
# scenario. Revise the logic later. A distributed shutdown barrier inside
# the XLA program may be needed.
return jax.device_put(tok, jax.sharding.PositionalSharding(devices))
# We only use replicated sharding for the first time when the token for the
# order effect hasn't been created.
s = jax.sharding.GSPMDSharding.get_replicated(devices)
sharded_tok = core.Token(pxla.shard_args([s], [None], [tok])[0])
self.current_tokens[eff] = sharded_tok
return sharded_tok
def set_token_result(self, eff: core.Effect, token: core.Token):
self.current_tokens[eff] = token
def set_output_runtime_token(self, device: Device, token: RuntimeToken):
# We're free to clobber the previous output token because on each
# device we have a total ordering of computations. Only the token
# from the latest computation matters.
self.output_runtime_tokens[device] = token
def clear(self):
self.current_tokens = {}
self.output_runtime_tokens = {}
def block_until_ready(self):
for token in self.current_tokens.values():
token.block_until_ready()
for token in self.output_runtime_tokens.values():
token.block_until_ready()
self.clear()
runtime_tokens: RuntimeTokenSet = RuntimeTokenSet()
@atexit.register
def wait_for_tokens():
runtime_tokens.block_until_ready()
2021-12-13 21:51:08 -08:00
@contextlib.contextmanager
def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None):
2021-12-13 21:51:08 -08:00
if _on_exit:
yield
else:
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
2021-12-13 21:51:08 -08:00
start_time = time.time()
yield
elapsed_time = time.time() - start_time
if logger.isEnabledFor(log_priority):
Add `jax_debug_log_modules` config option. This can be used to enable debug logging for specific files (e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`) or all jax (`JAX_DEBUG_LOG_MODULES="jax"`). Example output: ``` $ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)" DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O. DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu' DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu' DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})). DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]] DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec ```
2023-06-07 00:20:32 +00:00
logger.log(log_priority, fmt.format(
fun_name=fun_name, elapsed_time=elapsed_time))
if event is not None:
record_event_duration_secs(event, elapsed_time)
2021-12-13 21:51:08 -08:00
def should_tuple_args(num_args: int, platform: str) -> bool:
# CPU and GPU do not need tuples as they use host-side data structures that
# do not have small bounds.
# TPU only needs a tuple for very long lists
if platform == "tpu":
return num_args > 2000
else:
return False
def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool:
"""Whether there is a primitive given by user anywhere inside a Jaxpr."""
for eqn in jaxpr.eqns:
if prim_name in eqn.primitive.name:
return True
for subjaxpr in core.subjaxprs(jaxpr):
if jaxpr_has_primitive(subjaxpr, prim_name):
return True
return False
# Use this registry with caution. It will void the guarantee that lowering to
# stablehlo is oblivious of physical devices.
prim_requires_devices_during_lowering: set[core.Primitive] = set()
@util.weakref_lru_cache
def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr) -> bool:
for eqn in jaxpr.eqns:
if eqn.primitive in prim_requires_devices_during_lowering:
return True
for subjaxpr in core.subjaxprs(jaxpr):
if jaxpr_has_prim_requiring_devices(subjaxpr):
return True
return False
class SourceInfo(NamedTuple):
source_info: source_info_util.SourceInfo
eqn_name: str
@util.weakref_lru_cache
def get_intermediate_shardings(
jaxpr: core.Jaxpr) -> Sequence[tuple[Sharding, SourceInfo]]:
from jax._src import pjit
from jax.experimental import shard_map
out = []
for eqn in jaxpr.eqns:
if eqn.primitive is pjit.sharding_constraint_p:
Introduce `jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...])` and allow `with_sharding_constraint` and `shard_map` to accept an abstract mesh as input (`with_sharding_constraint` is via `NamedSharding(abstract_mesh, pspec)`). **Semantics** Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing. During compilation, the order of devices throughout the program needs to be consistent (same as before this change). Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh. **Why do this?** There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature. So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example: ``` mesh1 = Mesh(jax.devices()[:2], 'x') mesh2 = Mesh(jax.devices()[2:4], 'x') arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P())) arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P())) @jax.jit def f(x): y = with_sharding_constraint(x, NamedSharding(mesh1, P('x'))) return y * 2 f(arr_mesh1) f(arr_mesh2) # DEVICE MISMATCH ERROR! ``` The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature. **Okay, so how do you fix this?** As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here) The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh. **The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.** ``` mesh1 = Mesh(jax.devices()[:2], 'x') mesh2 = Mesh(jax.devices()[2:4], 'x') arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P())) arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P())) # Creating abstract mesh with mesh1 but since both meshes have the same shape (names # and axis size), it should be ok. abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple) @jax.jit def f(x): y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x'))) return y * 2 f(arr_mesh1) f(arr_mesh2) # tracing and lowering cache hit ``` **One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.** **What about `shard_map`?** shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`. ``` mesh1 = Mesh(jax.devices()[:2], 'x') mesh2 = Mesh(jax.devices()[2:4], 'x') arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P())) arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P())) # Creating abstract mesh with mesh1 but since both meshes have the same shape (names # and axis size), it should be ok. abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple) @jax.jit def f(x): y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x')) return y * 2 f(arr_mesh1) f(arr_mesh2) # tracing and lowering cache hit ``` This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits! PiperOrigin-RevId: 662670932
2024-08-13 15:17:30 -07:00
s = eqn.params['sharding']
if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
continue
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
out.append((s, source_info))
elif eqn.primitive is pjit.pjit_p:
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
out.extend((i, source_info) for i in eqn.params['in_shardings'])
out.extend((o, source_info) for o in eqn.params['out_shardings'])
elif eqn.primitive is shard_map.shard_map_p:
if not eqn.params['mesh']._is_jax_device_mesh:
continue
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
def _names_to_pspec(names):
ndmin = max(names) + 1 if names else 0
return PartitionSpec(*(names.get(i) for i in range(ndmin)))
out.extend((NamedSharding(eqn.params['mesh'], _names_to_pspec(names)), source_info)
for names in [*eqn.params['in_names'], *eqn.params['out_names']])
elif eqn.primitive is device_put_p:
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
out.extend((s, source_info) for s in eqn.params['devices']
if isinstance(s, Sharding) and s.memory_kind is not None)
for subjaxpr in core.subjaxprs(jaxpr):
out.extend(get_intermediate_shardings(subjaxpr))
return out
def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:
return (any(type(v.aval.dtype) is core.bint for v in jaxpr.invars
if isinstance(v.aval, core.UnshapedArray)) or
any(_is_bint_axis_size(d)
for j in itertools.chain([jaxpr], core.subjaxprs(jaxpr))
for e in j.eqns for v in e.outvars
if isinstance(v.aval, core.DShapedArray) for d in v.aval.shape))
def _is_bint_axis_size(d: core.AxisSize) -> bool:
if isinstance(d, core.DArray):
assert not d.shape
return type(d.dtype) is core.bint
elif isinstance(d, core.Var):
return (isinstance(d.aval, core.DShapedArray) and
type(d.aval.dtype) is core.bint)
return False
# We can optionally set a Jaxpr rewriter that can be applied just before
# compilation. This mechanism is used for compiling id_tap, we can
# remove it once we bring the id_tap implementation into the core.
outfeed_rewriter: Callable[[core.Jaxpr], core.Jaxpr] | None = None
def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr:
if outfeed_rewriter is not None:
return outfeed_rewriter(jaxpr)
else:
return jaxpr
def check_arg(arg: Any):
if not (isinstance(arg, core.Tracer) or core.valid_jaxtype(arg)):
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid "
"JAX type.")
def jaxpr_replicas(jaxpr: core.Jaxpr) -> int:
"""The number of replicas needed for a jaxpr.
For a eqn, multiply the `axis_size` with the `jaxpr_replicas` of the
subjaxprs. For a list of eqns, take the maximum number of replicas.
"""
return max(unsafe_map(_eqn_replicas, jaxpr.eqns), default=1)
# TODO(mattjj): this function assumes that only pmap has a parameter named
# axis_size, and that it corresponds to cross-replica mapping
def _eqn_replicas(eqn: core.JaxprEqn) -> int:
call_jaxpr = eqn.params.get("call_jaxpr")
if call_jaxpr:
return eqn.params.get('axis_size', 1) * jaxpr_replicas(call_jaxpr)
elif eqn.primitive in xla.initial_style_primitives:
return _initial_style_primitive_replicas(eqn.params)
else:
return 1
def _initial_style_primitive_replicas(params: dict[str, Any]) -> int:
return max(core.traverse_jaxpr_params(jaxpr_replicas, params).values(),
default=1)
def needs_check_special() -> bool:
return config.debug_infs.value or config.debug_nans.value
def check_special(name: str, bufs: Sequence[basearray.Array]) -> None:
if needs_check_special():
for buf in bufs:
_check_special(name, buf.dtype, buf)
def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
if dtypes.issubdtype(dtype, np.inexact):
if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
class CopySemantics(enum.Enum):
ALIAS = enum.auto()
COPY = enum.auto()
DONATE = enum.auto()
def _identity_fn(x):
return x
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
def _different_device_order_reshard(x, target_sharding, copy: CopySemantics):
x._check_if_deleted()
inp_sharding = x.sharding
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
donate_argnums = 0 if copy == CopySemantics.DONATE else None
if inp_sharding._device_assignment == target_sharding._device_assignment:
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
return api.jit(_identity_fn, out_shardings=target_sharding,
donate_argnums=donate_argnums)(x)
if inp_sharding.device_set != target_sharding.device_set:
inp_ids = [d.id for d in inp_sharding._device_assignment]
inp_plat = inp_sharding._device_assignment[0].platform.upper()
target_ids = [d.id for d in target_sharding._device_assignment]
target_plat = target_sharding._device_assignment[0].platform.upper()
raise ValueError("Input and target sharding should have the same set of "
f"devices. Got input's device set ids: {inp_ids} on "
f"platform {inp_plat} and target sharding's device set "
f"ids: {target_ids} on platform {target_plat}")
old_hlo_sharding = inp_sharding._to_xla_hlo_sharding(x.ndim)
if old_hlo_sharding.is_replicated():
new_hlo_sharding = old_hlo_sharding
else:
permute_order = np.vectorize(target_sharding._device_assignment.index,
otypes=[int])(inp_sharding._device_assignment)
# Unfortunately need to fallback to V1 sharding here.
new_op_sharding = old_hlo_sharding.to_proto()
new_op_sharding.iota_reshape_dims = []
new_op_sharding.iota_transpose_perm = []
new_op_sharding.tile_assignment_devices = np.take(
permute_order, old_hlo_sharding.tile_assignment_devices()
)
new_hlo_sharding = xc.HloSharding.from_proto(new_op_sharding)
# TODO(yashkatariya): Enable this when HloSharding conversion is fixed in
# XLA.
# assert (new_op_sharding.tile_assignment_dimensions
# == new_hlo_sharding.tile_assignment_dimensions())
# assert (new_op_sharding.tile_assignment_devices
# == new_hlo_sharding.tile_assignment_devices())
assert (list(np.take(inp_sharding._device_assignment,
old_hlo_sharding.tile_assignment_devices()))
== list(np.take(target_sharding._device_assignment,
new_op_sharding.tile_assignment_devices)))
new_x = array.make_array_from_single_device_arrays(
x.shape,
GSPMDSharding(target_sharding._device_assignment, new_hlo_sharding,
memory_kind=target_sharding.memory_kind),
x._arrays,
)
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
return api.jit(_identity_fn, out_shardings=target_sharding,
donate_argnums=donate_argnums)(new_x)
@dataclasses.dataclass(frozen=True)
class _DeferredShardArg:
"""Deferred call to `pxla.shard_args`.
Per-array impls return this object instead of a result array to indicate a
deferred `shard_args` call. `_batched_device_put_impl` then batches all
`_DeferredShardArg` objects into a single `shard_args` call.
"""
x: Any
s: Sharding
aval: core.AbstractValue
committed: bool
@property
def result_handler(self):
return pxla.global_aval_to_result_handler(self.aval, self.s, self.committed)
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
def _device_put_sharding_impl(x, aval, device, copy):
from jax.experimental import multihost_utils
if isinstance(device, Sharding):
s = device
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
if (getattr(x, 'sharding', None) == s and getattr(x, '_committed', False)
and copy == CopySemantics.ALIAS):
return x
if (not s.is_fully_addressable and
isinstance(x, array.ArrayImpl) and not x.is_fully_addressable):
assert isinstance(s, Sharding)
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
return _different_device_order_reshard(x, s, copy)
if (s.is_fully_addressable and isinstance(x, array.ArrayImpl) and
x.is_fully_addressable and s.num_devices > 1 and
s._internal_device_list != x.sharding._internal_device_list and # pytype: disable=attribute-error
s.device_set == x.sharding.device_set):
assert isinstance(s, Sharding)
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
return _different_device_order_reshard(x, s, copy)
if not s.is_fully_addressable:
if ((isinstance(x, array.ArrayImpl) and not x._committed) or
type(x) in array_types):
# TODO(yashkatariya): Move this check to `jit`.
multihost_utils.assert_equal(
x, fail_message=(
f"{type(x)} passed to device_put is not the same on each"
" process. Make sure you are passing the same value of"
f" {type(x)} on each process."))
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
return api.jit(
_identity_fn, out_shardings=s,
donate_argnums=(0 if copy == CopySemantics.DONATE else None))(x)
# TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array.
raise ValueError(
"device_put's second argument must be a Device or a Sharding which"
f" represents addressable devices, but got {s}. Please pass device or"
" Sharding which represents addressable devices.")
return _DeferredShardArg(x, s, aval, True)
# Only `Device` exists below. `Sharding` instance is handled above.
if isinstance(x, array.ArrayImpl):
if not x.is_fully_addressable:
raise ValueError(
"device_put's first argument must be a fully addressable array, but "
f"got value with devices {x.devices()}")
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
if device is None and copy == CopySemantics.ALIAS:
return x
elif is_single_device_sharding(x.sharding):
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
device = x.sharding._device_assignment[0] if device is None else device
return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x],
[device])
sh = SingleDeviceSharding(pxla._get_default_device()
if device is None else device)
return _DeferredShardArg(x, sh, aval, device is not None)
def _device_put_impl(
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
x, *, device: Device | Sharding | Layout | None,
src: Device | Sharding | Layout | None, copy: CopySemantics):
if (isinstance(device, TransferToMemoryKind) or
isinstance(src, TransferToMemoryKind)):
raise ValueError(
"TransferToMemoryKind argument to jax.device_put can only be used"
" inside jax.jit. If you are using device_put outside jax.jit, then"
" please provide a concrete Sharding with memory_kind.")
try:
aval = xla.abstractify(x)
except TypeError as err:
raise TypeError(
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
if isinstance(device, Layout):
l = device
dll = l.device_local_layout
x_dll = x.layout.device_local_layout if hasattr(x, 'layout') else None
if dll is None and l.sharding is None:
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
return _device_put_sharding_impl(x, aval, l.sharding, copy)
if (not isinstance(l.sharding, Sharding) or
not isinstance(dll, (DeviceLocalLayout, type(None)))):
raise ValueError(
"sharding and device_local_layout in `Layout` instance should be"
f" concrete. Got layout: {l} for input {aval.str_short()}")
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
if (getattr(x, 'layout', None) == l and getattr(x, '_committed', False) and
copy == CopySemantics.ALIAS):
return x
if x_dll is None and dll is None:
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
return _device_put_sharding_impl(x, aval, l.sharding, copy)
return api.jit(
_identity_fn, out_shardings=l,
donate_argnums=(0 if copy == CopySemantics.DONATE else None))(x)
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
return _device_put_sharding_impl(x, aval, device, copy)
def _batched_device_put_impl(
*xs,
devices: Sequence[Device | Sharding | Layout | None],
srcs: Sequence[Device | Sharding | Layout | None],
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
copy_semantics: Sequence[CopySemantics]):
ys = []
shard_arg_indices, shard_arg_xs, shard_arg_shardings = [], [], []
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
for i, (x, device, src, cp) in enumerate(zip(xs, devices, srcs, copy_semantics)):
y = _device_put_impl(x, device=device, src=src, copy=cp)
if isinstance(y, _DeferredShardArg):
shard_arg_indices.append(i)
shard_arg_xs.append(y.x)
shard_arg_shardings.append(y.s)
ys.append(y)
if shard_arg_xs:
# Batch shard_arg calls. Helps improve efficiency for backends that support
# efficient batch transfer.
# device_put handles `Layout` via a different path, so just pass `None` as
# the layout here.
shard_arg_results = pxla.shard_args(
shard_arg_shardings, [None] * len(shard_arg_xs), shard_arg_xs)
for i, shard_arg_result in zip(shard_arg_indices, shard_arg_results):
assert isinstance(ys[i], _DeferredShardArg)
ys[i] = ys[i].result_handler(shard_arg_result)
return ys
device_put_p = core.Primitive('device_put')
device_put_p.multiple_results = True
device_put_p.def_impl(_batched_device_put_impl)
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
device_put_p.def_abstract_eval(lambda *xs, devices, srcs, copy_semantics: xs)
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
def _device_put_transpose(cts, *_, devices, srcs, copy_semantics):
results = [None] * len(cts)
dp_args = []
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
for i, (ct, device, src, cp) in enumerate(zip(cts, devices, srcs, copy_semantics)):
if type(ct) is not ad.Zero:
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
dp_args.append((i, ct, device, src, cp))
if dp_args:
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
indices, args, devices, srcs, copy_semantics = list(zip(*dp_args))
new_copy_semantics = []
for cp in copy_semantics:
if cp == CopySemantics.DONATE:
raise ValueError(
"donate=True is not allowed during tranposition of device_put."
" Please file an issue if you want this to be supported.")
elif cp == CopySemantics.ALIAS:
new_copy_semantics.append(CopySemantics.COPY)
else:
assert cp == CopySemantics.COPY
new_copy_semantics.append(CopySemantics.COPY)
ys = device_put_p.bind(*args, devices=srcs, srcs=devices,
copy_semantics=new_copy_semantics)
for i, y in zip(indices, ys):
results[i] = y
return results
ad.primitive_jvps[device_put_p] = partial(ad.linear_jvp, device_put_p)
ad.primitive_transposes[device_put_p] = _device_put_transpose
def _device_put_batcher(batched_args, batch_dims, **params):
mapped_batch_dims = [bd for bd in batch_dims if bd is not batching.not_mapped]
assert not mapped_batch_dims or all(
mapped_batch_dims[0] == bd for bd in mapped_batch_dims[1:]
), batch_dims
return device_put_p.bind(*batched_args, **params), batch_dims
batching.primitive_batchers[device_put_p] = _device_put_batcher
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics):
# TODO(yashkatariya): Maybe we should add the custom calls anyways if it's
# being used inside jit? Atleast for now, this preserves the old behavior.
if ctx.module_context.all_default_mem_kind:
return xs
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
def lower(x, device, aval, out_aval):
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
device.memory_kind is not None):
if isinstance(device, Sharding):
if config.use_shardy_partitioner.value:
x = mlir.wrap_with_sharding_op(
ctx, x, out_aval,
device._to_sdy_sharding(aval.ndim))
else:
x = mlir.wrap_with_sharding_op(
ctx, x, out_aval,
device._to_xla_hlo_sharding(aval.ndim).to_proto())
x = mlir.wrap_with_memory_kind(x, device.memory_kind, out_aval)
return x
return x
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
return list(map(lower, xs, devices, ctx.avals_in, ctx.avals_out))
2024-02-29 07:04:36 -08:00
mlir.register_lowering(
device_put_p, _tpu_gpu_device_put_lowering, platform='tpu')
mlir.register_lowering(
device_put_p, _tpu_gpu_device_put_lowering, platform='gpu')
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
def _common_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics):
return xs
mlir.register_lowering(device_put_p, _common_device_put_lowering)
Add `donate` and `may_alias` as an argument to `device_put` to allow for donation and aliasing. The end state we want to work towards is to remove `may_alias` and **always copy by default**. But there is some work needed to get to that state. **Definition:** * donate: The input buffer will be marked as deleted (see below for some caveats). The output buffer may or may not reuse the input buffer's underlying memory. * may_alias: If True, we may return the original buffer depending on the implementation. **What problem are we solving?** Eventually, we want `device_put` to always copy so introducing `may_alias` as a transition state to help towards that goal. We might end up deciding to keep `may_alias` but now you have an explicit option to **always copy** i.e. set `may_alias=False` which is what some users want. Adding `donate` allows users to avoid this pattern of code: ``` inp = ... out = device_put(inp, sharding) jax.block_until_ready(out) jax.tree.map(lambda x: x.delete(), inp) ``` Now it can just be: `jax.device_put(inp, sharding, donate=True)` **So what are the semantics of these 2 options?** Let's create a table: | may-alias \= None (default) | donate \= False (default) | Result | | :---- | :---- | :---- | | True | True | Error | | True | False | May return the original buffer. Input Array marked as deleted: No. Reuses input buffer for output: Maybe | | False | True | Original buffer deleted i.e. Donation. Input Array marked as deleted: Yes. Reuses input buffer for output: Maybe | | False | False | Pure copy. Input Array marked as deleted: No. Reuses input buffer for output: No | | None | True | `may_alias` will be marked as False. See Row 2 i.e. may\_alias \= False, donate \= True | | None | False | `may_alias` will be marked as True. See Row 1 i.e. may\_alias \= True, donate \= False | `donate` is best effort for now until we fix the following things: * Delete input when `donate=True` regardless of whether XLA could donate or not. This will affect `jax.jit` too but it's a good thing to do. * Plumb donate to PJRT/IFRT APIs so we can donate where transfers are not happening via `jit`. PiperOrigin-RevId: 681073828
2024-10-01 10:26:25 -07:00
def _propagate_mem_kind_dp(*xm, devices, srcs, copy_semantics):
memory_kinds = []
for device in devices:
if isinstance(device, (Sharding, TransferToMemoryKind)):
memory_kinds.append(device.memory_kind)
else:
memory_kinds.append(None)
return memory_kinds
pxla.memory_kind_propagate_rule[device_put_p] = _propagate_mem_kind_dp