Support Google-internal TPU resolution in strategy combinations.

PiperOrigin-RevId: 313461094
Change-Id: I42c3c87a7f8c41d6b11b6de408ddfc44fd8bbda0
This commit is contained in:
Revan Sopher 2020-05-27 14:11:01 -07:00 committed by TensorFlower Gardener
parent 67a155f782
commit 3446f153cb
2 changed files with 19 additions and 11 deletions

View File

@ -111,7 +111,6 @@ class GradientTapeTest(test.TestCase, parameterized.TestCase,
return grads
return distribution.experimental_local_results(
distribution.run(train_step, args=(x,)))
dist_dataset = distribution.experimental_distribute_dataset(dataset)
results = []
for x in dist_dataset:

View File

@ -47,7 +47,6 @@ from tensorflow.python.training import ftrl
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import rmsprop
FLAGS = flags.FLAGS
_did_connect_to_cluster = False
@ -58,14 +57,24 @@ def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, **kwargs):
def _create_tpu_strategy():
global _did_connect_to_cluster
try:
# Attempt to locally discover the TPU. This will fail for Cloud TPU, in
# which case we fall back to the values passed as flags.
resolver = tpu_cluster_resolver.TPUClusterResolver()
did_automatically_resolve = True
except ValueError:
did_automatically_resolve = False
# These flags will be defined by tpu_test_wrapper.py.
resolver = tpu_cluster_resolver.TPUClusterResolver(
tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "",
zone=hasattr(FLAGS, "zone") and FLAGS.zone or None,
project=hasattr(FLAGS, "project") and FLAGS.project or None,
)
# Only connect once per process, rather than per test method.
if hasattr(FLAGS, "tpu") and FLAGS.tpu and not _did_connect_to_cluster:
if getattr(FLAGS, "tpu", "") or did_automatically_resolve:
if not _did_connect_to_cluster:
remote.connect_to_cluster(resolver)
_did_connect_to_cluster = True