Fix TPU initialization for local servers
Requires an identity in the TPU initialization function to avoid placement errors. I believe this only comes up when using local servers (i.e. affects mostly testing; we do have plenty of tests for TPUs on remote jobs). Also exposes a mapping from job name to TPU topology. I have a use for it: we need to look up the topology corresponding to the correct job when replicating a function. PiperOrigin-RevId: 290188617 Change-Id: I24e1e5995f6f55b565a1aac05909698ac3ee49d8
This commit is contained in:
parent
c6fa2dc9e4
commit
80f0540bc8
@ -820,13 +820,19 @@ void MakeTPUInitializationFunctionDef(
|
||||
tensorflow::OpDef_ArgDef* arg_def(signature_def->add_output_arg());
|
||||
arg_def->set_name("topology_proto");
|
||||
arg_def->set_type(tensorflow::DataType::DT_STRING);
|
||||
tensorflow::NodeDef* node_def(function_def->add_node_def());
|
||||
node_def->set_name("ConfigureDistributedTPU");
|
||||
node_def->set_op("ConfigureDistributedTPU");
|
||||
(*node_def->mutable_attr())["compilation_failure_closes_chips"].set_b(false);
|
||||
node_def->set_device(tpu_system_device_name);
|
||||
(*function_def->mutable_ret())["topology_proto"] =
|
||||
"ConfigureDistributedTPU:topology:0";
|
||||
tensorflow::NodeDef* configure_node_def(function_def->add_node_def());
|
||||
configure_node_def->set_name("ConfigureDistributedTPU");
|
||||
configure_node_def->set_op("ConfigureDistributedTPU");
|
||||
(*configure_node_def->mutable_attr())["compilation_failure_closes_chips"]
|
||||
.set_b(false);
|
||||
configure_node_def->set_device(tpu_system_device_name);
|
||||
tensorflow::NodeDef* identity_node_def(function_def->add_node_def());
|
||||
identity_node_def->set_name("Identity");
|
||||
identity_node_def->set_op("Identity");
|
||||
identity_node_def->add_input("ConfigureDistributedTPU:topology:0");
|
||||
(*identity_node_def->mutable_attr())["T"].set_type(
|
||||
tensorflow::DataType::DT_STRING);
|
||||
(*function_def->mutable_ret())["topology_proto"] = "Identity:output:0";
|
||||
(*function_def->mutable_control_ret())["ConfigureDistributedTPU"] =
|
||||
"ConfigureDistributedTPU";
|
||||
}
|
||||
|
||||
@ -429,7 +429,7 @@ class Context(object):
|
||||
self._soft_device_placement = None
|
||||
self._log_device_placement = None
|
||||
self._enable_mlir_bridge = None
|
||||
self._tpu_topologies = []
|
||||
self._tpu_topologies_by_job = {}
|
||||
self._attempted_tpu_initialization = set()
|
||||
self._optimizer_experimental_options = {}
|
||||
|
||||
@ -478,8 +478,8 @@ class Context(object):
|
||||
# TODO(b/134094971): Remove this when lazy tensor copy in multi-device
|
||||
# function has been implemented.
|
||||
self.mirroring_policy = MIRRORING_ALL
|
||||
self._tpu_topologies.append(
|
||||
topology.Topology(serialized=topology_proto_data))
|
||||
parsed_topology = topology.Topology(serialized=topology_proto_data)
|
||||
self._tpu_topologies_by_job[job] = parsed_topology
|
||||
|
||||
def _initialize_logical_devices(self):
|
||||
"""Helper to initialize devices."""
|
||||
@ -1441,7 +1441,13 @@ class Context(object):
|
||||
def tpu_topologies(self):
|
||||
"""A sequence of TPU topologies for connected TPU systems."""
|
||||
ensure_initialized()
|
||||
return self._tpu_topologies
|
||||
return tuple(self._tpu_topologies_by_job.values())
|
||||
|
||||
@property
|
||||
def tpu_topologies_by_job(self):
|
||||
"""A mapping from job name to TPU topology for connected TPU systems."""
|
||||
ensure_initialized()
|
||||
return self._tpu_topologies_by_job
|
||||
|
||||
@property
|
||||
def log_device_placement(self):
|
||||
|
||||
@ -23,10 +23,12 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import remote
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import tpu
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
|
||||
class ContextTest(test.TestCase):
|
||||
@ -121,6 +123,19 @@ class ContextTest(test.TestCase):
|
||||
self.assertGreater(topology.num_tasks, 0)
|
||||
self.assertGreater(topology.num_tpus_per_task, 0)
|
||||
|
||||
def testTPUInitializationMultiHost(self):
|
||||
ctx = context.context()
|
||||
if not ctx.list_physical_devices('TPU'):
|
||||
self.assertEmpty(ctx.tpu_topologies_by_job)
|
||||
self.skipTest('A TPU is required to run this test.')
|
||||
self.assertEqual(['localhost'], list(ctx.tpu_topologies_by_job.keys()))
|
||||
server = server_lib.Server.create_local_server()
|
||||
target = server.target[len('grpc://'):]
|
||||
remote.connect_to_remote_host([target])
|
||||
self.assertIn('localhost', ctx.tpu_topologies_by_job)
|
||||
self.assertIn('worker', ctx.tpu_topologies_by_job)
|
||||
self.assertLen(ctx.tpu_topologies, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user