[XLA:Python] Plumb in the HLO interpreter platform.
Enable it by default at the lowest priority (below CPU, which is always present.) PiperOrigin-RevId: 314437599 Change-Id: Ib21c20edee8e006fe09c03fdc6b794371a8b9878
This commit is contained in:
parent
e06251b493
commit
5bca7e22aa
@ -158,6 +158,20 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "interpreter_device",
|
||||
srcs = ["interpreter_device.cc"],
|
||||
hdrs = ["interpreter_device.h"],
|
||||
deps = [
|
||||
":pjrt_client",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/service:interpreter_plugin",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_device",
|
||||
srcs = ["cpu_device.cc"],
|
||||
|
59
tensorflow/compiler/xla/pjrt/interpreter_device.cc
Normal file
59
tensorflow/compiler/xla/pjrt/interpreter_device.cc
Normal file
@ -0,0 +1,59 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
static const char kInterpreterPlatformName[] = "interpreter";
|
||||
|
||||
InterpreterDevice::InterpreterDevice(
|
||||
int id, std::unique_ptr<LocalDeviceState> local_device_state)
|
||||
: Device(id, std::move(local_device_state), kInterpreterPlatformName,
|
||||
/*device_kind=*/kInterpreterPlatformName) {}
|
||||
|
||||
StatusOr<std::shared_ptr<PjRtClient>> GetInterpreterClient() {
|
||||
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
||||
PlatformUtil::GetPlatform("Interpreter"));
|
||||
if (platform->VisibleDeviceCount() != 1) {
|
||||
return FailedPrecondition(
|
||||
"Interpreter platform should have exactly one device.");
|
||||
}
|
||||
LocalClientOptions options;
|
||||
options.set_platform(platform);
|
||||
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||
ClientLibrary::GetOrCreateLocalClient(options));
|
||||
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
se::StreamExecutor* executor =
|
||||
client->backend().stream_executor(0).ValueOrDie();
|
||||
auto device_state = absl::make_unique<LocalDeviceState>(
|
||||
executor, client, LocalDeviceState::kSynchronous, /*asynchronous=*/false,
|
||||
/*allow_event_reuse=*/false);
|
||||
auto device =
|
||||
absl::make_unique<InterpreterDevice>(0, std::move(device_state));
|
||||
devices.push_back(std::move(device));
|
||||
|
||||
return std::make_shared<PjRtClient>(
|
||||
kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0,
|
||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||
/*gpu_run_options=*/nullptr);
|
||||
}
|
||||
|
||||
} // namespace xla
|
36
tensorflow/compiler/xla/pjrt/interpreter_device.h
Normal file
36
tensorflow/compiler/xla/pjrt/interpreter_device.h
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_INTERPRETER_DEVICE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_PJRT_INTERPRETER_DEVICE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
class InterpreterDevice : public Device {
|
||||
public:
|
||||
InterpreterDevice(int id,
|
||||
std::unique_ptr<LocalDeviceState> local_device_state);
|
||||
};
|
||||
|
||||
StatusOr<std::shared_ptr<PjRtClient>> GetInterpreterClient();
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_INTERPRETER_DEVICE_H_
|
@ -315,6 +315,7 @@ pybind_extension(
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/pjrt:cpu_device",
|
||||
"//tensorflow/compiler/xla/pjrt:interpreter_device",
|
||||
"//tensorflow/compiler/xla/pjrt:nvidia_gpu_device",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/compiler/xla/pjrt:tracked_device_buffer",
|
||||
|
@ -38,6 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/distributed.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/distributed/service.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/python/bfloat16.h"
|
||||
@ -767,6 +768,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
py::arg("computation"), py::arg("compile_options") = CompileOptions());
|
||||
|
||||
m.def("get_cpu_client", &GetCpuClient, py::arg("asynchronous") = true);
|
||||
m.def("get_interpreter_client", &GetInterpreterClient);
|
||||
m.def("get_nvidia_gpu_client", &GetNvidiaGpuClient,
|
||||
py::arg("asynchronous") = true,
|
||||
py::arg("allocator_config") = GpuAllocatorConfig(),
|
||||
|
@ -52,6 +52,10 @@ xla_platform_names = {
|
||||
}
|
||||
|
||||
|
||||
def _interpreter_backend_factory():
|
||||
return _xla.get_interpreter_client()
|
||||
|
||||
|
||||
def _cpu_backend_factory():
|
||||
return _xla.get_cpu_client(asynchronous=True)
|
||||
|
||||
@ -85,6 +89,7 @@ def _gpu_backend_factory(distributed_client=None, node_id=0):
|
||||
|
||||
# Backend factories, keyed by user-visible name, in increasing priority order.
|
||||
_local_backend_factories = collections.OrderedDict([
|
||||
('interpreter', _interpreter_backend_factory),
|
||||
('cpu', _cpu_backend_factory),
|
||||
('gpu', _gpu_backend_factory),
|
||||
])
|
||||
|
Loading…
Reference in New Issue
Block a user