Support Google-internal TPU resolution in strategy combinations.
PiperOrigin-RevId: 313461094 Change-Id: I42c3c87a7f8c41d6b11b6de408ddfc44fd8bbda0
This commit is contained in:
parent
67a155f782
commit
3446f153cb
|
@ -111,7 +111,6 @@ class GradientTapeTest(test.TestCase, parameterized.TestCase,
|
||||||
return grads
|
return grads
|
||||||
return distribution.experimental_local_results(
|
return distribution.experimental_local_results(
|
||||||
distribution.run(train_step, args=(x,)))
|
distribution.run(train_step, args=(x,)))
|
||||||
|
|
||||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||||
results = []
|
results = []
|
||||||
for x in dist_dataset:
|
for x in dist_dataset:
|
||||||
|
|
|
@ -47,7 +47,6 @@ from tensorflow.python.training import ftrl
|
||||||
from tensorflow.python.training import gradient_descent
|
from tensorflow.python.training import gradient_descent
|
||||||
from tensorflow.python.training import rmsprop
|
from tensorflow.python.training import rmsprop
|
||||||
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
_did_connect_to_cluster = False
|
_did_connect_to_cluster = False
|
||||||
|
@ -58,16 +57,26 @@ def _get_tpu_strategy_creator(steps_per_run, use_single_core=False, **kwargs):
|
||||||
def _create_tpu_strategy():
|
def _create_tpu_strategy():
|
||||||
global _did_connect_to_cluster
|
global _did_connect_to_cluster
|
||||||
|
|
||||||
# These flags will be defined by tpu_test_wrapper.py.
|
try:
|
||||||
resolver = tpu_cluster_resolver.TPUClusterResolver(
|
# Attempt to locally discover the TPU. This will fail for Cloud TPU, in
|
||||||
tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "",
|
# which case we fall back to the values passed as flags.
|
||||||
zone=hasattr(FLAGS, "zone") and FLAGS.zone or None,
|
resolver = tpu_cluster_resolver.TPUClusterResolver()
|
||||||
project=hasattr(FLAGS, "project") and FLAGS.project or None,
|
did_automatically_resolve = True
|
||||||
)
|
except ValueError:
|
||||||
|
did_automatically_resolve = False
|
||||||
|
|
||||||
|
# These flags will be defined by tpu_test_wrapper.py.
|
||||||
|
resolver = tpu_cluster_resolver.TPUClusterResolver(
|
||||||
|
tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "",
|
||||||
|
zone=hasattr(FLAGS, "zone") and FLAGS.zone or None,
|
||||||
|
project=hasattr(FLAGS, "project") and FLAGS.project or None,
|
||||||
|
)
|
||||||
|
|
||||||
# Only connect once per process, rather than per test method.
|
# Only connect once per process, rather than per test method.
|
||||||
if hasattr(FLAGS, "tpu") and FLAGS.tpu and not _did_connect_to_cluster:
|
if getattr(FLAGS, "tpu", "") or did_automatically_resolve:
|
||||||
remote.connect_to_cluster(resolver)
|
if not _did_connect_to_cluster:
|
||||||
_did_connect_to_cluster = True
|
remote.connect_to_cluster(resolver)
|
||||||
|
_did_connect_to_cluster = True
|
||||||
|
|
||||||
topology = tpu_strategy_util.initialize_tpu_system(resolver)
|
topology = tpu_strategy_util.initialize_tpu_system(resolver)
|
||||||
device_assignment = None
|
device_assignment = None
|
||||||
|
|
Loading…
Reference in New Issue