Fix TPU initialization for local servers
Requires an identity in the TPU initialization function to avoid placement errors. I believe this only comes up when using local servers (i.e. affects mostly testing; we do have plenty of tests for TPUs on remote jobs). Also exposes a mapping from job name to TPU topology. I have a use for it: we need to look up the topology corresponding to the correct job when replicating a function. PiperOrigin-RevId: 290188617 Change-Id: I24e1e5995f6f55b565a1aac05909698ac3ee49d8
This commit is contained in:
parent
c6fa2dc9e4
commit
80f0540bc8
@ -820,13 +820,19 @@ void MakeTPUInitializationFunctionDef(
|
|||||||
tensorflow::OpDef_ArgDef* arg_def(signature_def->add_output_arg());
|
tensorflow::OpDef_ArgDef* arg_def(signature_def->add_output_arg());
|
||||||
arg_def->set_name("topology_proto");
|
arg_def->set_name("topology_proto");
|
||||||
arg_def->set_type(tensorflow::DataType::DT_STRING);
|
arg_def->set_type(tensorflow::DataType::DT_STRING);
|
||||||
tensorflow::NodeDef* node_def(function_def->add_node_def());
|
tensorflow::NodeDef* configure_node_def(function_def->add_node_def());
|
||||||
node_def->set_name("ConfigureDistributedTPU");
|
configure_node_def->set_name("ConfigureDistributedTPU");
|
||||||
node_def->set_op("ConfigureDistributedTPU");
|
configure_node_def->set_op("ConfigureDistributedTPU");
|
||||||
(*node_def->mutable_attr())["compilation_failure_closes_chips"].set_b(false);
|
(*configure_node_def->mutable_attr())["compilation_failure_closes_chips"]
|
||||||
node_def->set_device(tpu_system_device_name);
|
.set_b(false);
|
||||||
(*function_def->mutable_ret())["topology_proto"] =
|
configure_node_def->set_device(tpu_system_device_name);
|
||||||
"ConfigureDistributedTPU:topology:0";
|
tensorflow::NodeDef* identity_node_def(function_def->add_node_def());
|
||||||
|
identity_node_def->set_name("Identity");
|
||||||
|
identity_node_def->set_op("Identity");
|
||||||
|
identity_node_def->add_input("ConfigureDistributedTPU:topology:0");
|
||||||
|
(*identity_node_def->mutable_attr())["T"].set_type(
|
||||||
|
tensorflow::DataType::DT_STRING);
|
||||||
|
(*function_def->mutable_ret())["topology_proto"] = "Identity:output:0";
|
||||||
(*function_def->mutable_control_ret())["ConfigureDistributedTPU"] =
|
(*function_def->mutable_control_ret())["ConfigureDistributedTPU"] =
|
||||||
"ConfigureDistributedTPU";
|
"ConfigureDistributedTPU";
|
||||||
}
|
}
|
||||||
|
|||||||
@ -429,7 +429,7 @@ class Context(object):
|
|||||||
self._soft_device_placement = None
|
self._soft_device_placement = None
|
||||||
self._log_device_placement = None
|
self._log_device_placement = None
|
||||||
self._enable_mlir_bridge = None
|
self._enable_mlir_bridge = None
|
||||||
self._tpu_topologies = []
|
self._tpu_topologies_by_job = {}
|
||||||
self._attempted_tpu_initialization = set()
|
self._attempted_tpu_initialization = set()
|
||||||
self._optimizer_experimental_options = {}
|
self._optimizer_experimental_options = {}
|
||||||
|
|
||||||
@ -478,8 +478,8 @@ class Context(object):
|
|||||||
# TODO(b/134094971): Remove this when lazy tensor copy in multi-device
|
# TODO(b/134094971): Remove this when lazy tensor copy in multi-device
|
||||||
# function has been implemented.
|
# function has been implemented.
|
||||||
self.mirroring_policy = MIRRORING_ALL
|
self.mirroring_policy = MIRRORING_ALL
|
||||||
self._tpu_topologies.append(
|
parsed_topology = topology.Topology(serialized=topology_proto_data)
|
||||||
topology.Topology(serialized=topology_proto_data))
|
self._tpu_topologies_by_job[job] = parsed_topology
|
||||||
|
|
||||||
def _initialize_logical_devices(self):
|
def _initialize_logical_devices(self):
|
||||||
"""Helper to initialize devices."""
|
"""Helper to initialize devices."""
|
||||||
@ -1441,7 +1441,13 @@ class Context(object):
|
|||||||
def tpu_topologies(self):
|
def tpu_topologies(self):
|
||||||
"""A sequence of TPU topologies for connected TPU systems."""
|
"""A sequence of TPU topologies for connected TPU systems."""
|
||||||
ensure_initialized()
|
ensure_initialized()
|
||||||
return self._tpu_topologies
|
return tuple(self._tpu_topologies_by_job.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tpu_topologies_by_job(self):
|
||||||
|
"""A mapping from job name to TPU topology for connected TPU systems."""
|
||||||
|
ensure_initialized()
|
||||||
|
return self._tpu_topologies_by_job
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def log_device_placement(self):
|
def log_device_placement(self):
|
||||||
|
|||||||
@ -23,10 +23,12 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.eager import remote
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.tpu import tpu
|
from tensorflow.python.tpu import tpu
|
||||||
|
from tensorflow.python.training import server_lib
|
||||||
|
|
||||||
|
|
||||||
class ContextTest(test.TestCase):
|
class ContextTest(test.TestCase):
|
||||||
@ -121,6 +123,19 @@ class ContextTest(test.TestCase):
|
|||||||
self.assertGreater(topology.num_tasks, 0)
|
self.assertGreater(topology.num_tasks, 0)
|
||||||
self.assertGreater(topology.num_tpus_per_task, 0)
|
self.assertGreater(topology.num_tpus_per_task, 0)
|
||||||
|
|
||||||
|
def testTPUInitializationMultiHost(self):
|
||||||
|
ctx = context.context()
|
||||||
|
if not ctx.list_physical_devices('TPU'):
|
||||||
|
self.assertEmpty(ctx.tpu_topologies_by_job)
|
||||||
|
self.skipTest('A TPU is required to run this test.')
|
||||||
|
self.assertEqual(['localhost'], list(ctx.tpu_topologies_by_job.keys()))
|
||||||
|
server = server_lib.Server.create_local_server()
|
||||||
|
target = server.target[len('grpc://'):]
|
||||||
|
remote.connect_to_remote_host([target])
|
||||||
|
self.assertIn('localhost', ctx.tpu_topologies_by_job)
|
||||||
|
self.assertIn('worker', ctx.tpu_topologies_by_job)
|
||||||
|
self.assertLen(ctx.tpu_topologies, 2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user