diff --git a/tensorflow/python/tpu/tpu.py b/tensorflow/python/tpu/tpu.py index 846053f69e2..c70a26f2b4d 100644 --- a/tensorflow/python/tpu/tpu.py +++ b/tensorflow/python/tpu/tpu.py @@ -118,10 +118,21 @@ def initialize_system(embedding_config=None, config_string = ("" if embedding_config is None else embedding_config.SerializeToString()) with ops.device(_tpu_system_device_name(job)): - return tpu_ops.configure_distributed_tpu( - embedding_config=config_string, + topology = tpu_ops.configure_distributed_tpu( compilation_failure_closes_chips=compilation_failure_closes_chips) + if embedding_config is None: + return topology + + # This set of control dependencies is needed as this function is expected to + # return an op which will return the topology when executed, but we need to + # call the embedding initialization op between initializing the TPU and + # returning the topology. + with ops.control_dependencies([topology]): + embedding_init = tpu_ops.configure_tpu_embedding(config=config_string) + with ops.control_dependencies([embedding_init]): + return array_ops.identity(topology, name="tpu_init_identity") + def initialize_system_for_tpu_embedding(embedding_config, job=None): """Initializes a distributed TPU Embedding system for use with TensorFlow.