Update tpu initialization to use separate embedding initialization op.

PiperOrigin-RevId: 304212916
Change-Id: I59324931158807b54a1bc23c741c86b9104594d7
This commit is contained in:
Bruce Fontaine 2020-04-01 10:28:17 -07:00 committed by TensorFlower Gardener
parent 2c80835262
commit d15ddcfdbf
1 changed files with 13 additions and 2 deletions

View File

@ -118,10 +118,21 @@ def initialize_system(embedding_config=None,
config_string = ("" if embedding_config is None else
embedding_config.SerializeToString())
with ops.device(_tpu_system_device_name(job)):
return tpu_ops.configure_distributed_tpu(
embedding_config=config_string,
topology = tpu_ops.configure_distributed_tpu(
compilation_failure_closes_chips=compilation_failure_closes_chips)
if embedding_config is None:
return topology
# This set of control dependencies is needed as this function is expected to
# return an op which will return the topology when executed, but we need to
# call the embedding initialization op between initializing the TPU and
# returning the topology.
with ops.control_dependencies([topology]):
embedding_init = tpu_ops.configure_tpu_embedding(config=config_string)
with ops.control_dependencies([embedding_init]):
return array_ops.identity(topology, name="tpu_init_identity")
def initialize_system_for_tpu_embedding(embedding_config, job=None):
"""Initializes a distributed TPU Embedding system for use with TensorFlow.