Support Google-internal TPU resolution in strategy combinations.
PiperOrigin-RevId: 313461094 Change-Id: I42c3c87a7f8c41d6b11b6de408ddfc44fd8bbda0
This commit is contained in:
parent
67a155f782
commit
3446f153cb
|
@ -111,7 +111,6 @@ class GradientTapeTest(test.TestCase, parameterized.TestCase,
|
|||
return grads
|
||||
return distribution.experimental_local_results(
|
||||
distribution.run(train_step, args=(x,)))
|
||||
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
results = []
|
||||
for x in dist_dataset:
|
||||
|
|
|
@ -47,7 +47,6 @@ from tensorflow.python.training import ftrl
|
|||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import rmsprop
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
_did_connect_to_cluster = False
|
||||
|
@ -58,14 +57,24 @@ def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, **kwargs):
|
|||
def _create_tpu_strategy():
|
||||
global _did_connect_to_cluster
|
||||
|
||||
try:
|
||||
# Attempt to locally discover the TPU. This will fail for Cloud TPU, in
|
||||
# which case we fall back to the values passed as flags.
|
||||
resolver = tpu_cluster_resolver.TPUClusterResolver()
|
||||
did_automatically_resolve = True
|
||||
except ValueError:
|
||||
did_automatically_resolve = False
|
||||
|
||||
# These flags will be defined by tpu_test_wrapper.py.
|
||||
resolver = tpu_cluster_resolver.TPUClusterResolver(
|
||||
tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "",
|
||||
zone=hasattr(FLAGS, "zone") and FLAGS.zone or None,
|
||||
project=hasattr(FLAGS, "project") and FLAGS.project or None,
|
||||
)
|
||||
|
||||
# Only connect once per process, rather than per test method.
|
||||
if hasattr(FLAGS, "tpu") and FLAGS.tpu and not _did_connect_to_cluster:
|
||||
if getattr(FLAGS, "tpu", "") or did_automatically_resolve:
|
||||
if not _did_connect_to_cluster:
|
||||
remote.connect_to_cluster(resolver)
|
||||
_did_connect_to_cluster = True
|
||||
|
||||
|
|
Loading…
Reference in New Issue