diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index dbd33705d0e..dd50d0577d4 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -158,6 +158,20 @@ cc_library( ], ) +cc_library( + name = "interpreter_device", + srcs = ["interpreter_device.cc"], + hdrs = ["interpreter_device.h"], + deps = [ + ":pjrt_client", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/service:interpreter_plugin", + "//tensorflow/compiler/xla/service:platform_util", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "cpu_device", srcs = ["cpu_device.cc"], diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.cc b/tensorflow/compiler/xla/pjrt/interpreter_device.cc new file mode 100644 index 00000000000..63254d4aa70 --- /dev/null +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.cc @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/pjrt/interpreter_device.h" + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/service/platform_util.h" + +namespace xla { + +static const char kInterpreterPlatformName[] = "interpreter"; + +InterpreterDevice::InterpreterDevice( + int id, std::unique_ptr local_device_state) + : Device(id, std::move(local_device_state), kInterpreterPlatformName, + /*device_kind=*/kInterpreterPlatformName) {} + +StatusOr> GetInterpreterClient() { + TF_ASSIGN_OR_RETURN(se::Platform * platform, + PlatformUtil::GetPlatform("Interpreter")); + if (platform->VisibleDeviceCount() != 1) { + return FailedPrecondition( + "Interpreter platform should have exactly one device."); + } + LocalClientOptions options; + options.set_platform(platform); + TF_ASSIGN_OR_RETURN(LocalClient * client, + ClientLibrary::GetOrCreateLocalClient(options)); + + std::vector> devices; + se::StreamExecutor* executor = + client->backend().stream_executor(0).ValueOrDie(); + auto device_state = absl::make_unique( + executor, client, LocalDeviceState::kSynchronous, /*asynchronous=*/false, + /*allow_event_reuse=*/false); + auto device = + absl::make_unique(0, std::move(device_state)); + devices.push_back(std::move(device)); + + return std::make_shared( + kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0, + /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, + /*gpu_run_options=*/nullptr); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.h b/tensorflow/compiler/xla/pjrt/interpreter_device.h new file mode 100644 index 00000000000..58b210ad762 --- /dev/null +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.h @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PJRT_INTERPRETER_DEVICE_H_ +#define TENSORFLOW_COMPILER_XLA_PJRT_INTERPRETER_DEVICE_H_ + +#include + +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +class InterpreterDevice : public Device { + public: + InterpreterDevice(int id, + std::unique_ptr local_device_state); +}; + +StatusOr> GetInterpreterClient(); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PJRT_INTERPRETER_DEVICE_H_ diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 10737489331..7913c2d9dd4 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -315,6 +315,7 @@ pybind_extension( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/pjrt:cpu_device", + "//tensorflow/compiler/xla/pjrt:interpreter_device", "//tensorflow/compiler/xla/pjrt:nvidia_gpu_device", "//tensorflow/compiler/xla/pjrt:pjrt_client", "//tensorflow/compiler/xla/pjrt:tracked_device_buffer", diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 0b6824e83e9..6ebe5e85245 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/distributed/client.h" #include "tensorflow/compiler/xla/pjrt/distributed/distributed.h" #include "tensorflow/compiler/xla/pjrt/distributed/service.h" +#include "tensorflow/compiler/xla/pjrt/interpreter_device.h" #include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/bfloat16.h" @@ -767,6 +768,7 @@ PYBIND11_MODULE(xla_extension, m) { py::arg("computation"), py::arg("compile_options") = CompileOptions()); m.def("get_cpu_client", &GetCpuClient, py::arg("asynchronous") = true); + m.def("get_interpreter_client", &GetInterpreterClient); m.def("get_nvidia_gpu_client", &GetNvidiaGpuClient, py::arg("asynchronous") = true, py::arg("allocator_config") = GpuAllocatorConfig(), diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 76c3bc33a91..3085715bf12 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -52,6 +52,10 @@ xla_platform_names = { } +def _interpreter_backend_factory(): + return _xla.get_interpreter_client() + + def _cpu_backend_factory(): return _xla.get_cpu_client(asynchronous=True) @@ -85,6 +89,7 @@ def _gpu_backend_factory(distributed_client=None, node_id=0): # Backend factories, keyed by user-visible name, in increasing priority order. _local_backend_factories = collections.OrderedDict([ + ('interpreter', _interpreter_backend_factory), ('cpu', _cpu_backend_factory), ('gpu', _gpu_backend_factory), ])