Update tpu initialization to use separate embedding initialization op.
PiperOrigin-RevId: 304212916 Change-Id: I59324931158807b54a1bc23c741c86b9104594d7
This commit is contained in:
parent
2c80835262
commit
d15ddcfdbf
|
@ -118,10 +118,21 @@ def initialize_system(embedding_config=None,
|
||||||
config_string = ("" if embedding_config is None else
|
config_string = ("" if embedding_config is None else
|
||||||
embedding_config.SerializeToString())
|
embedding_config.SerializeToString())
|
||||||
with ops.device(_tpu_system_device_name(job)):
|
with ops.device(_tpu_system_device_name(job)):
|
||||||
return tpu_ops.configure_distributed_tpu(
|
topology = tpu_ops.configure_distributed_tpu(
|
||||||
embedding_config=config_string,
|
|
||||||
compilation_failure_closes_chips=compilation_failure_closes_chips)
|
compilation_failure_closes_chips=compilation_failure_closes_chips)
|
||||||
|
|
||||||
|
if embedding_config is None:
|
||||||
|
return topology
|
||||||
|
|
||||||
|
# This set of control dependencies is needed as this function is expected to
|
||||||
|
# return an op which will return the topology when executed, but we need to
|
||||||
|
# call the embedding initialization op between initializing the TPU and
|
||||||
|
# returning the topology.
|
||||||
|
with ops.control_dependencies([topology]):
|
||||||
|
embedding_init = tpu_ops.configure_tpu_embedding(config=config_string)
|
||||||
|
with ops.control_dependencies([embedding_init]):
|
||||||
|
return array_ops.identity(topology, name="tpu_init_identity")
|
||||||
|
|
||||||
|
|
||||||
def initialize_system_for_tpu_embedding(embedding_config, job=None):
|
def initialize_system_for_tpu_embedding(embedding_config, job=None):
|
||||||
"""Initializes a distributed TPU Embedding system for use with TensorFlow.
|
"""Initializes a distributed TPU Embedding system for use with TensorFlow.
|
||||||
|
|
Loading…
Reference in New Issue