Move JAX example to public XLA:CPU API

PiperOrigin-RevId: 698143471
This commit is contained in:
Mason Chang 2024-11-19 14:18:32 -08:00 committed by jax authors
parent 3161a28424
commit 42fbd301fc
2 changed files with 16 additions and 3 deletions

View File

@ -26,8 +26,13 @@ cc_binary(
"@tsl//tsl/platform:platform_port",
"@xla//xla:literal",
"@xla//xla:literal_util",
"@xla//xla/hlo/builder:xla_computation",
"@xla//xla/hlo/ir:hlo",
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/pjrt/cpu:cpu_client",
"@xla//xla/pjrt:pjrt_executable",
"@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options",
"@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
"@xla//xla/service:hlo_module_config",
"@xla//xla/tools:hlo_module_loader",
],
)

View File

@ -36,15 +36,21 @@ limitations under the License.
// }
// )
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include "third_party/absl/status/statusor.h"
#include "xla/hlo/builder/xla_computation.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/pjrt/cpu/cpu_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_executable.h"
#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h"
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
#include "xla/service/hlo_module_config.h"
#include "xla/tools/hlo_module_loader.h"
#include "tsl/platform/init_main.h"
#include "tsl/platform/logging.h"
@ -66,8 +72,10 @@ int main(int argc, char** argv) {
// Run it using JAX C++ Runtime (PJRT).
// Get a CPU client.
xla::CpuClientOptions options;
options.asynchronous = true;
std::unique_ptr<xla::PjRtClient> client =
xla::GetTfrtCpuClient(/*asynchronous=*/true).value();
xla::GetXlaPjrtCpuClient(options).value();
// Compile XlaComputation to PjRtExecutable.
xla::XlaComputation xla_computation(test_module_proto);