[XLA:Python] Add TPU backend.

[XLA:PJRT] Sets the number of retries for initializing the TPU backend
to 0, since we always link in the TPU client to the Python extension
even when it's not being run on TPU.

PiperOrigin-RevId: 337190434
Change-Id: I81e2f93d1603c796f630aafa2cd3ae5106e8c6af
This commit is contained in:
Skye Wanderman-Milne 2020-10-14 15:52:49 -07:00 committed by TensorFlower Gardener
parent 82f4f50f4f
commit 60424aaaeb
5 changed files with 18 additions and 1 deletions

View File

@ -176,6 +176,7 @@ cc_library(
"//learning/brain/research/jax:__subpackages__",
"//learning/deepmind/tensorflow/tensorfn:__subpackages__",
"//learning/pathways:__subpackages__",
"//tensorflow/compiler/xla:friends",
],
deps = [
":local_device_state",

View File

@ -199,7 +199,8 @@ StatusOr<std::vector<std::unique_ptr<PjRtDevice>>> GetTpuDevices(
StatusOr<std::shared_ptr<PjRtClient>> GetTpuClient(
bool asynchronous, absl::Duration init_retry_timeout) {
tf_tpu::TpuPlatformInterface* platform =
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform();
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform(
/*initialize_platform=*/true, /*num_tries=*/1);
if (platform == nullptr) {
return InvalidArgument("TpuPlatform is not available.");
}

View File

@ -434,6 +434,7 @@ pybind_extension(
"//tensorflow/compiler/xla/pjrt:interpreter_device",
"//tensorflow/compiler/xla/pjrt:nvidia_gpu_device",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/compiler/xla/pjrt:tpu_client",
"//tensorflow/compiler/xla/pjrt:tracked_device_buffer",
"//tensorflow/compiler/xla/pjrt/distributed",
"//tensorflow/compiler/xla/pjrt/distributed:client",

View File

@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
#include "tensorflow/compiler/xla/python/bfloat16.h"
#include "tensorflow/compiler/xla/python/dlpack.h"
#include "tensorflow/compiler/xla/python/jax_jit.h"
@ -580,6 +581,14 @@ PYBIND11_MODULE(xla_extension, m) {
py::arg("asynchronous") = true,
py::arg("allocator_config") = GpuAllocatorConfig(),
py::arg("distributed_client") = nullptr, py::arg("node_id") = 0);
m.def(
"get_tpu_client",
[](bool asynchronous) -> StatusOr<std::shared_ptr<PyClient>> {
TF_ASSIGN_OR_RETURN(std::shared_ptr<PjRtClient> client,
GetTpuClient(asynchronous));
return std::make_shared<PyClient>(std::move(client));
},
py::arg("asynchronous") = true);
py::class_<Traceback::Frame>(m, "Frame")
.def_readonly("file_name", &Traceback::Frame::file_name)

View File

@ -90,11 +90,16 @@ def _gpu_backend_factory(distributed_client=None, node_id=0):
node_id=node_id)
def _tpu_backend_factory():
return _xla.get_tpu_client(asynchronous=True)
# Backend factories, keyed by user-visible name, in increasing priority order.
_local_backend_factories = collections.OrderedDict([
('interpreter', _interpreter_backend_factory),
('cpu', _cpu_backend_factory),
('gpu', _gpu_backend_factory),
('tpu', _tpu_backend_factory),
])