[XLA:Python] Plumb in the HLO interpreter platform.

Enable it by default at the lowest priority (below CPU, which is always present.)

PiperOrigin-RevId: 314437599
Change-Id: Ib21c20edee8e006fe09c03fdc6b794371a8b9878
This commit is contained in:
Peter Hawkins 2020-06-02 17:34:44 -07:00 committed by TensorFlower Gardener
parent e06251b493
commit 5bca7e22aa
6 changed files with 117 additions and 0 deletions

View File

@ -158,6 +158,20 @@ cc_library(
],
)
cc_library(
name = "interpreter_device",
srcs = ["interpreter_device.cc"],
hdrs = ["interpreter_device.h"],
deps = [
":pjrt_client",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/service:interpreter_plugin",
"//tensorflow/compiler/xla/service:platform_util",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "cpu_device",
srcs = ["cpu_device.cc"],

View File

@ -0,0 +1,59 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
namespace xla {
static const char kInterpreterPlatformName[] = "interpreter";
InterpreterDevice::InterpreterDevice(
int id, std::unique_ptr<LocalDeviceState> local_device_state)
: Device(id, std::move(local_device_state), kInterpreterPlatformName,
/*device_kind=*/kInterpreterPlatformName) {}
StatusOr<std::shared_ptr<PjRtClient>> GetInterpreterClient() {
TF_ASSIGN_OR_RETURN(se::Platform * platform,
PlatformUtil::GetPlatform("Interpreter"));
if (platform->VisibleDeviceCount() != 1) {
return FailedPrecondition(
"Interpreter platform should have exactly one device.");
}
LocalClientOptions options;
options.set_platform(platform);
TF_ASSIGN_OR_RETURN(LocalClient * client,
ClientLibrary::GetOrCreateLocalClient(options));
std::vector<std::unique_ptr<Device>> devices;
se::StreamExecutor* executor =
client->backend().stream_executor(0).ValueOrDie();
auto device_state = absl::make_unique<LocalDeviceState>(
executor, client, LocalDeviceState::kSynchronous, /*asynchronous=*/false,
/*allow_event_reuse=*/false);
auto device =
absl::make_unique<InterpreterDevice>(0, std::move(device_state));
devices.push_back(std::move(device));
return std::make_shared<PjRtClient>(
kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*gpu_run_options=*/nullptr);
}
} // namespace xla

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_INTERPRETER_DEVICE_H_
#define TENSORFLOW_COMPILER_XLA_PJRT_INTERPRETER_DEVICE_H_
#include <memory>
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
class InterpreterDevice : public Device {
public:
InterpreterDevice(int id,
std::unique_ptr<LocalDeviceState> local_device_state);
};
StatusOr<std::shared_ptr<PjRtClient>> GetInterpreterClient();
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PJRT_INTERPRETER_DEVICE_H_

View File

@ -315,6 +315,7 @@ pybind_extension(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/pjrt:cpu_device",
"//tensorflow/compiler/xla/pjrt:interpreter_device",
"//tensorflow/compiler/xla/pjrt:nvidia_gpu_device",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/compiler/xla/pjrt:tracked_device_buffer",

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/pjrt/distributed/client.h"
#include "tensorflow/compiler/xla/pjrt/distributed/distributed.h"
#include "tensorflow/compiler/xla/pjrt/distributed/service.h"
#include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/python/bfloat16.h"
@ -767,6 +768,7 @@ PYBIND11_MODULE(xla_extension, m) {
py::arg("computation"), py::arg("compile_options") = CompileOptions());
m.def("get_cpu_client", &GetCpuClient, py::arg("asynchronous") = true);
m.def("get_interpreter_client", &GetInterpreterClient);
m.def("get_nvidia_gpu_client", &GetNvidiaGpuClient,
py::arg("asynchronous") = true,
py::arg("allocator_config") = GpuAllocatorConfig(),

View File

@ -52,6 +52,10 @@ xla_platform_names = {
}
def _interpreter_backend_factory():
return _xla.get_interpreter_client()
def _cpu_backend_factory():
return _xla.get_cpu_client(asynchronous=True)
@ -85,6 +89,7 @@ def _gpu_backend_factory(distributed_client=None, node_id=0):
# Backend factories, keyed by user-visible name, in increasing priority order.
_local_backend_factories = collections.OrderedDict([
('interpreter', _interpreter_backend_factory),
('cpu', _cpu_backend_factory),
('gpu', _gpu_backend_factory),
])