[XLA:Python] Add TPU backend.
[XLA:PJRT] Sets the number of retries for initializing the TPU backend to 0, since we always link in the TPU client to the Python extension even when it's not being run on TPU. PiperOrigin-RevId: 337190434 Change-Id: I81e2f93d1603c796f630aafa2cd3ae5106e8c6af
This commit is contained in:
parent
82f4f50f4f
commit
60424aaaeb
@ -176,6 +176,7 @@ cc_library(
|
||||
"//learning/brain/research/jax:__subpackages__",
|
||||
"//learning/deepmind/tensorflow/tensorfn:__subpackages__",
|
||||
"//learning/pathways:__subpackages__",
|
||||
"//tensorflow/compiler/xla:friends",
|
||||
],
|
||||
deps = [
|
||||
":local_device_state",
|
||||
|
@ -199,7 +199,8 @@ StatusOr<std::vector<std::unique_ptr<PjRtDevice>>> GetTpuDevices(
|
||||
StatusOr<std::shared_ptr<PjRtClient>> GetTpuClient(
|
||||
bool asynchronous, absl::Duration init_retry_timeout) {
|
||||
tf_tpu::TpuPlatformInterface* platform =
|
||||
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform();
|
||||
tf_tpu::TpuPlatformInterface::GetRegisteredPlatform(
|
||||
/*initialize_platform=*/true, /*num_tries=*/1);
|
||||
if (platform == nullptr) {
|
||||
return InvalidArgument("TpuPlatform is not available.");
|
||||
}
|
||||
|
@ -434,6 +434,7 @@ pybind_extension(
|
||||
"//tensorflow/compiler/xla/pjrt:interpreter_device",
|
||||
"//tensorflow/compiler/xla/pjrt:nvidia_gpu_device",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/compiler/xla/pjrt:tpu_client",
|
||||
"//tensorflow/compiler/xla/pjrt:tracked_device_buffer",
|
||||
"//tensorflow/compiler/xla/pjrt/distributed",
|
||||
"//tensorflow/compiler/xla/pjrt/distributed:client",
|
||||
|
@ -42,6 +42,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
|
||||
#include "tensorflow/compiler/xla/python/bfloat16.h"
|
||||
#include "tensorflow/compiler/xla/python/dlpack.h"
|
||||
#include "tensorflow/compiler/xla/python/jax_jit.h"
|
||||
@ -580,6 +581,14 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
py::arg("asynchronous") = true,
|
||||
py::arg("allocator_config") = GpuAllocatorConfig(),
|
||||
py::arg("distributed_client") = nullptr, py::arg("node_id") = 0);
|
||||
m.def(
|
||||
"get_tpu_client",
|
||||
[](bool asynchronous) -> StatusOr<std::shared_ptr<PyClient>> {
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<PjRtClient> client,
|
||||
GetTpuClient(asynchronous));
|
||||
return std::make_shared<PyClient>(std::move(client));
|
||||
},
|
||||
py::arg("asynchronous") = true);
|
||||
|
||||
py::class_<Traceback::Frame>(m, "Frame")
|
||||
.def_readonly("file_name", &Traceback::Frame::file_name)
|
||||
|
@ -90,11 +90,16 @@ def _gpu_backend_factory(distributed_client=None, node_id=0):
|
||||
node_id=node_id)
|
||||
|
||||
|
||||
def _tpu_backend_factory():
|
||||
return _xla.get_tpu_client(asynchronous=True)
|
||||
|
||||
|
||||
# Backend factories, keyed by user-visible name, in increasing priority order.
|
||||
_local_backend_factories = collections.OrderedDict([
|
||||
('interpreter', _interpreter_backend_factory),
|
||||
('cpu', _cpu_backend_factory),
|
||||
('gpu', _gpu_backend_factory),
|
||||
('tpu', _tpu_backend_factory),
|
||||
])
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user