From 80f0540bc83d55a5e33407f38e1c370f5853814d Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Thu, 16 Jan 2020 18:34:56 -0800 Subject: [PATCH] Fix TPU initialization for local servers Requires an identity in the TPU initialization function to avoid placement errors. I believe this only comes up when using local servers (i.e. affects mostly testing; we do have plenty of tests for TPUs on remote jobs). Also exposes a mapping from job name to TPU topology. I have a use for it: we need to look up the topology corresponding to the correct job when replicating a function. PiperOrigin-RevId: 290188617 Change-Id: I24e1e5995f6f55b565a1aac05909698ac3ee49d8 --- tensorflow/c/c_api_experimental.cc | 20 +++++++++++++------- tensorflow/python/eager/context.py | 14 ++++++++++---- tensorflow/python/eager/context_test.py | 15 +++++++++++++++ 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 3355e9c4df5..43df88ca667 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -820,13 +820,19 @@ void MakeTPUInitializationFunctionDef( tensorflow::OpDef_ArgDef* arg_def(signature_def->add_output_arg()); arg_def->set_name("topology_proto"); arg_def->set_type(tensorflow::DataType::DT_STRING); - tensorflow::NodeDef* node_def(function_def->add_node_def()); - node_def->set_name("ConfigureDistributedTPU"); - node_def->set_op("ConfigureDistributedTPU"); - (*node_def->mutable_attr())["compilation_failure_closes_chips"].set_b(false); - node_def->set_device(tpu_system_device_name); - (*function_def->mutable_ret())["topology_proto"] = - "ConfigureDistributedTPU:topology:0"; + tensorflow::NodeDef* configure_node_def(function_def->add_node_def()); + configure_node_def->set_name("ConfigureDistributedTPU"); + configure_node_def->set_op("ConfigureDistributedTPU"); + (*configure_node_def->mutable_attr())["compilation_failure_closes_chips"] + .set_b(false); + configure_node_def->set_device(tpu_system_device_name); + tensorflow::NodeDef* identity_node_def(function_def->add_node_def()); + identity_node_def->set_name("Identity"); + identity_node_def->set_op("Identity"); + identity_node_def->add_input("ConfigureDistributedTPU:topology:0"); + (*identity_node_def->mutable_attr())["T"].set_type( + tensorflow::DataType::DT_STRING); + (*function_def->mutable_ret())["topology_proto"] = "Identity:output:0"; (*function_def->mutable_control_ret())["ConfigureDistributedTPU"] = "ConfigureDistributedTPU"; } diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index b2fb2975260..05f20a342f9 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -429,7 +429,7 @@ class Context(object): self._soft_device_placement = None self._log_device_placement = None self._enable_mlir_bridge = None - self._tpu_topologies = [] + self._tpu_topologies_by_job = {} self._attempted_tpu_initialization = set() self._optimizer_experimental_options = {} @@ -478,8 +478,8 @@ class Context(object): # TODO(b/134094971): Remove this when lazy tensor copy in multi-device # function has been implemented. self.mirroring_policy = MIRRORING_ALL - self._tpu_topologies.append( - topology.Topology(serialized=topology_proto_data)) + parsed_topology = topology.Topology(serialized=topology_proto_data) + self._tpu_topologies_by_job[job] = parsed_topology def _initialize_logical_devices(self): """Helper to initialize devices.""" @@ -1441,7 +1441,13 @@ class Context(object): def tpu_topologies(self): """A sequence of TPU topologies for connected TPU systems.""" ensure_initialized() - return self._tpu_topologies + return tuple(self._tpu_topologies_by_job.values()) + + @property + def tpu_topologies_by_job(self): + """A mapping from job name to TPU topology for connected TPU systems.""" + ensure_initialized() + return self._tpu_topologies_by_job @property def log_device_placement(self): diff --git a/tensorflow/python/eager/context_test.py b/tensorflow/python/eager/context_test.py index 5059bb45241..c5ede8f8304 100644 --- a/tensorflow/python/eager/context_test.py +++ b/tensorflow/python/eager/context_test.py @@ -23,10 +23,12 @@ import numpy as np from tensorflow.python.eager import context from tensorflow.python.eager import def_function +from tensorflow.python.eager import remote from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.platform import test from tensorflow.python.tpu import tpu +from tensorflow.python.training import server_lib class ContextTest(test.TestCase): @@ -121,6 +123,19 @@ class ContextTest(test.TestCase): self.assertGreater(topology.num_tasks, 0) self.assertGreater(topology.num_tpus_per_task, 0) + def testTPUInitializationMultiHost(self): + ctx = context.context() + if not ctx.list_physical_devices('TPU'): + self.assertEmpty(ctx.tpu_topologies_by_job) + self.skipTest('A TPU is required to run this test.') + self.assertEqual(['localhost'], list(ctx.tpu_topologies_by_job.keys())) + server = server_lib.Server.create_local_server() + target = server.target[len('grpc://'):] + remote.connect_to_remote_host([target]) + self.assertIn('localhost', ctx.tpu_topologies_by_job) + self.assertIn('worker', ctx.tpu_topologies_by_job) + self.assertLen(ctx.tpu_topologies, 2) + if __name__ == '__main__': ops.enable_eager_execution()