mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Move JAX example to public XLA:CPU API
PiperOrigin-RevId: 698143471
This commit is contained in:
parent
3161a28424
commit
42fbd301fc
@ -26,8 +26,13 @@ cc_binary(
|
||||
"@tsl//tsl/platform:platform_port",
|
||||
"@xla//xla:literal",
|
||||
"@xla//xla:literal_util",
|
||||
"@xla//xla/hlo/builder:xla_computation",
|
||||
"@xla//xla/hlo/ir:hlo",
|
||||
"@xla//xla/pjrt:pjrt_client",
|
||||
"@xla//xla/pjrt/cpu:cpu_client",
|
||||
"@xla//xla/pjrt:pjrt_executable",
|
||||
"@xla//xla/pjrt/plugin/xla_cpu:cpu_client_options",
|
||||
"@xla//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
|
||||
"@xla//xla/service:hlo_module_config",
|
||||
"@xla//xla/tools:hlo_module_loader",
|
||||
],
|
||||
)
|
||||
|
@ -36,15 +36,21 @@ limitations under the License.
|
||||
// }
|
||||
// )
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/absl/status/statusor.h"
|
||||
#include "xla/hlo/builder/xla_computation.h"
|
||||
#include "xla/hlo/ir/hlo_module.h"
|
||||
#include "xla/literal.h"
|
||||
#include "xla/literal_util.h"
|
||||
#include "xla/pjrt/cpu/cpu_client.h"
|
||||
#include "xla/pjrt/pjrt_client.h"
|
||||
#include "xla/pjrt/pjrt_executable.h"
|
||||
#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h"
|
||||
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
|
||||
#include "xla/service/hlo_module_config.h"
|
||||
#include "xla/tools/hlo_module_loader.h"
|
||||
#include "tsl/platform/init_main.h"
|
||||
#include "tsl/platform/logging.h"
|
||||
@ -66,8 +72,10 @@ int main(int argc, char** argv) {
|
||||
// Run it using JAX C++ Runtime (PJRT).
|
||||
|
||||
// Get a CPU client.
|
||||
xla::CpuClientOptions options;
|
||||
options.asynchronous = true;
|
||||
std::unique_ptr<xla::PjRtClient> client =
|
||||
xla::GetTfrtCpuClient(/*asynchronous=*/true).value();
|
||||
xla::GetXlaPjrtCpuClient(options).value();
|
||||
|
||||
// Compile XlaComputation to PjRtExecutable.
|
||||
xla::XlaComputation xla_computation(test_module_proto);
|
||||
|
Loading…
x
Reference in New Issue
Block a user