Support Google-internal TPU resolution in strategy combinations.

PiperOrigin-RevId: 313461094
Change-Id: I42c3c87a7f8c41d6b11b6de408ddfc44fd8bbda0
This commit is contained in:
Revan Sopher 2020-05-27 14:11:01 -07:00 committed by TensorFlower Gardener
parent 67a155f782
commit 3446f153cb
2 changed files with 19 additions and 11 deletions

View File

@ -111,7 +111,6 @@ class GradientTapeTest(test.TestCase, parameterized.TestCase,
return grads return grads
return distribution.experimental_local_results( return distribution.experimental_local_results(
distribution.run(train_step, args=(x,))) distribution.run(train_step, args=(x,)))
dist_dataset = distribution.experimental_distribute_dataset(dataset) dist_dataset = distribution.experimental_distribute_dataset(dataset)
results = [] results = []
for x in dist_dataset: for x in dist_dataset:

View File

@ -47,7 +47,6 @@ from tensorflow.python.training import ftrl
from tensorflow.python.training import gradient_descent from tensorflow.python.training import gradient_descent
from tensorflow.python.training import rmsprop from tensorflow.python.training import rmsprop
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
_did_connect_to_cluster = False _did_connect_to_cluster = False
@ -58,16 +57,26 @@ def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, **kwargs):
def _create_tpu_strategy(): def _create_tpu_strategy():
global _did_connect_to_cluster global _did_connect_to_cluster
# These flags will be defined by tpu_test_wrapper.py. try:
resolver = tpu_cluster_resolver.TPUClusterResolver( # Attempt to locally discover the TPU. This will fail for Cloud TPU, in
tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "", # which case we fall back to the values passed as flags.
zone=hasattr(FLAGS, "zone") and FLAGS.zone or None, resolver = tpu_cluster_resolver.TPUClusterResolver()
project=hasattr(FLAGS, "project") and FLAGS.project or None, did_automatically_resolve = True
) except ValueError:
did_automatically_resolve = False
# These flags will be defined by tpu_test_wrapper.py.
resolver = tpu_cluster_resolver.TPUClusterResolver(
tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "",
zone=hasattr(FLAGS, "zone") and FLAGS.zone or None,
project=hasattr(FLAGS, "project") and FLAGS.project or None,
)
# Only connect once per process, rather than per test method. # Only connect once per process, rather than per test method.
if hasattr(FLAGS, "tpu") and FLAGS.tpu and not _did_connect_to_cluster: if getattr(FLAGS, "tpu", "") or did_automatically_resolve:
remote.connect_to_cluster(resolver) if not _did_connect_to_cluster:
_did_connect_to_cluster = True remote.connect_to_cluster(resolver)
_did_connect_to_cluster = True
topology = tpu_strategy_util.initialize_tpu_system(resolver) topology = tpu_strategy_util.initialize_tpu_system(resolver)
device_assignment = None device_assignment = None