Fix TPU initialization for local servers

Requires an identity in the TPU initialization function to avoid placement errors. I believe this only comes up when using local servers (i.e. affects mostly testing; we do have plenty of tests for TPUs on remote jobs).

Also exposes a mapping from job name to TPU topology. I have a use for it: we need to look up the topology corresponding to the correct job when replicating a function.

PiperOrigin-RevId: 290188617
Change-Id: I24e1e5995f6f55b565a1aac05909698ac3ee49d8
This commit is contained in:
Allen Lavoie 2020-01-16 18:34:56 -08:00 committed by TensorFlower Gardener
parent c6fa2dc9e4
commit 80f0540bc8
3 changed files with 38 additions and 11 deletions

View File

@ -820,13 +820,19 @@ void MakeTPUInitializationFunctionDef(
tensorflow::OpDef_ArgDef* arg_def(signature_def->add_output_arg());
arg_def->set_name("topology_proto");
arg_def->set_type(tensorflow::DataType::DT_STRING);
tensorflow::NodeDef* node_def(function_def->add_node_def());
node_def->set_name("ConfigureDistributedTPU");
node_def->set_op("ConfigureDistributedTPU");
(*node_def->mutable_attr())["compilation_failure_closes_chips"].set_b(false);
node_def->set_device(tpu_system_device_name);
(*function_def->mutable_ret())["topology_proto"] =
"ConfigureDistributedTPU:topology:0";
tensorflow::NodeDef* configure_node_def(function_def->add_node_def());
configure_node_def->set_name("ConfigureDistributedTPU");
configure_node_def->set_op("ConfigureDistributedTPU");
(*configure_node_def->mutable_attr())["compilation_failure_closes_chips"]
.set_b(false);
configure_node_def->set_device(tpu_system_device_name);
tensorflow::NodeDef* identity_node_def(function_def->add_node_def());
identity_node_def->set_name("Identity");
identity_node_def->set_op("Identity");
identity_node_def->add_input("ConfigureDistributedTPU:topology:0");
(*identity_node_def->mutable_attr())["T"].set_type(
tensorflow::DataType::DT_STRING);
(*function_def->mutable_ret())["topology_proto"] = "Identity:output:0";
(*function_def->mutable_control_ret())["ConfigureDistributedTPU"] =
"ConfigureDistributedTPU";
}

View File

@ -429,7 +429,7 @@ class Context(object):
self._soft_device_placement = None
self._log_device_placement = None
self._enable_mlir_bridge = None
self._tpu_topologies = []
self._tpu_topologies_by_job = {}
self._attempted_tpu_initialization = set()
self._optimizer_experimental_options = {}
@ -478,8 +478,8 @@ class Context(object):
# TODO(b/134094971): Remove this when lazy tensor copy in multi-device
# function has been implemented.
self.mirroring_policy = MIRRORING_ALL
self._tpu_topologies.append(
topology.Topology(serialized=topology_proto_data))
parsed_topology = topology.Topology(serialized=topology_proto_data)
self._tpu_topologies_by_job[job] = parsed_topology
def _initialize_logical_devices(self):
"""Helper to initialize devices."""
@ -1441,7 +1441,13 @@ class Context(object):
def tpu_topologies(self):
"""A sequence of TPU topologies for connected TPU systems."""
ensure_initialized()
return self._tpu_topologies
return tuple(self._tpu_topologies_by_job.values())
@property
def tpu_topologies_by_job(self):
"""A mapping from job name to TPU topology for connected TPU systems."""
ensure_initialized()
return self._tpu_topologies_by_job
@property
def log_device_placement(self):

View File

@ -23,10 +23,12 @@ import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import remote
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu
from tensorflow.python.training import server_lib
class ContextTest(test.TestCase):
@ -121,6 +123,19 @@ class ContextTest(test.TestCase):
self.assertGreater(topology.num_tasks, 0)
self.assertGreater(topology.num_tpus_per_task, 0)
def testTPUInitializationMultiHost(self):
ctx = context.context()
if not ctx.list_physical_devices('TPU'):
self.assertEmpty(ctx.tpu_topologies_by_job)
self.skipTest('A TPU is required to run this test.')
self.assertEqual(['localhost'], list(ctx.tpu_topologies_by_job.keys()))
server = server_lib.Server.create_local_server()
target = server.target[len('grpc://'):]
remote.connect_to_remote_host([target])
self.assertIn('localhost', ctx.tpu_topologies_by_job)
self.assertIn('worker', ctx.tpu_topologies_by_job)
self.assertLen(ctx.tpu_topologies, 2)
if __name__ == '__main__':
ops.enable_eager_execution()