From d15ddcfdbf779455713dabf8ae74d49389c84058 Mon Sep 17 00:00:00 2001 From: Bruce Fontaine Date: Wed, 1 Apr 2020 10:28:17 -0700 Subject: [PATCH] Update tpu initialization to use separate embedding initialization op. PiperOrigin-RevId: 304212916 Change-Id: I59324931158807b54a1bc23c741c86b9104594d7 --- tensorflow/python/tpu/tpu.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index 846053f69e2..c70a26f2b4d 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -118,10 +118,21 @@ def initialize_system(embedding_config=None, config_string = ("" if embedding_config is None else embedding_config.SerializeToString()) with ops.device(_tpu_system_device_name(job)): - return tpu_ops.configure_distributed_tpu( - embedding_config=config_string, + topology = tpu_ops.configure_distributed_tpu( compilation_failure_closes_chips=compilation_failure_closes_chips) + if embedding_config is None: + return topology + + # This set of control dependencies is needed as this function is expected to + # return an op which will return the topology when executed, but we need to + # call the embedding initialization op between initializing the TPU and + # returning the topology. + with ops.control_dependencies([topology]): + embedding_init = tpu_ops.configure_tpu_embedding(config=config_string) + with ops.control_dependencies([embedding_init]): + return array_ops.identity(topology, name="tpu_init_identity") + def initialize_system_for_tpu_embedding(embedding_config, job=None): """Initializes a distributed TPU Embedding system for use with TensorFlow.