Update tpu initialization to use separate embedding initialization op.
PiperOrigin-RevId: 304212916 Change-Id: I59324931158807b54a1bc23c741c86b9104594d7
This commit is contained in:
parent
2c80835262
commit
d15ddcfdbf
|
@ -118,10 +118,21 @@ def initialize_system(embedding_config=None,
|
|||
config_string = ("" if embedding_config is None else
|
||||
embedding_config.SerializeToString())
|
||||
with ops.device(_tpu_system_device_name(job)):
|
||||
return tpu_ops.configure_distributed_tpu(
|
||||
embedding_config=config_string,
|
||||
topology = tpu_ops.configure_distributed_tpu(
|
||||
compilation_failure_closes_chips=compilation_failure_closes_chips)
|
||||
|
||||
if embedding_config is None:
|
||||
return topology
|
||||
|
||||
# This set of control dependencies is needed as this function is expected to
|
||||
# return an op which will return the topology when executed, but we need to
|
||||
# call the embedding initialization op between initializing the TPU and
|
||||
# returning the topology.
|
||||
with ops.control_dependencies([topology]):
|
||||
embedding_init = tpu_ops.configure_tpu_embedding(config=config_string)
|
||||
with ops.control_dependencies([embedding_init]):
|
||||
return array_ops.identity(topology, name="tpu_init_identity")
|
||||
|
||||
|
||||
def initialize_system_for_tpu_embedding(embedding_config, job=None):
|
||||
"""Initializes a distributed TPU Embedding system for use with TensorFlow.
|
||||
|
|
Loading…
Reference in New Issue