Fork the keras related tpu_strategy_test to keras integration test.
PiperOrigin-RevId: 317232048 Change-Id: If05867985ff1ff81ac45bb601b701ee68d4d5279
This commit is contained in:
parent
64f7bdd56a
commit
c159f15995
@ -654,7 +654,6 @@ tpu_py_test(
|
|||||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||||
"//tensorflow/python/eager:remote",
|
"//tensorflow/python/eager:remote",
|
||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
"//tensorflow/python/keras",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python import keras
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import distribute_lib
|
from tensorflow.python.distribute import distribute_lib
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
@ -364,24 +363,6 @@ class TPUStrategyTest(test.TestCase):
|
|||||||
expected_result,
|
expected_result,
|
||||||
strategy.experimental_local_results(train_step(next(input_iterator))))
|
strategy.experimental_local_results(train_step(next(input_iterator))))
|
||||||
|
|
||||||
def test_keras_metric_outside_strategy_scope_per_replica(self):
|
|
||||||
strategy = get_tpu_strategy()
|
|
||||||
metric = keras.metrics.Mean("test_metric", dtype=dtypes.float32)
|
|
||||||
|
|
||||||
dataset = dataset_ops.Dataset.range(strategy.num_replicas_in_sync *
|
|
||||||
2).batch(2)
|
|
||||||
dataset = strategy.experimental_distribute_dataset(dataset)
|
|
||||||
|
|
||||||
@def_function.function
|
|
||||||
def step_fn(i):
|
|
||||||
metric.update_state(i)
|
|
||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, "Trying to run metric.update_state "
|
|
||||||
"in replica context"):
|
|
||||||
with strategy.scope():
|
|
||||||
for i in dataset:
|
|
||||||
strategy.run(step_fn, args=(i,))
|
|
||||||
|
|
||||||
# TODO(b/145574622): Remove this test once it is re-enabled in values_test.py.
|
# TODO(b/145574622): Remove this test once it is re-enabled in values_test.py.
|
||||||
def test_all_reduce_on_sync_on_read_variable(self):
|
def test_all_reduce_on_sync_on_read_variable(self):
|
||||||
strategy = get_tpu_strategy()
|
strategy = get_tpu_strategy()
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
# Contains Keras integration tests that verify with other TF high level APIs.
|
# Contains Keras integration tests that verify with other TF high level APIs.
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test", "tf_py_test")
|
load("//tensorflow:tensorflow.bzl", "cuda_py_test", "tf_py_test")
|
||||||
|
load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [
|
default_visibility = [
|
||||||
@ -91,3 +92,15 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:extra_py_tests_deps",
|
"//tensorflow/python:extra_py_tests_deps",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tpu_py_test(
|
||||||
|
name = "tpu_strategy_test",
|
||||||
|
srcs = ["tpu_strategy_test.py"],
|
||||||
|
disable_experimental = True,
|
||||||
|
python_version = "PY3",
|
||||||
|
tags = ["no_oss"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:extra_py_tests_deps",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -0,0 +1,69 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for TPUStrategy."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
|
||||||
|
flags.DEFINE_string("project", None, "Name of GCP project with TPU.")
|
||||||
|
flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_tpu_cluster_resolver():
|
||||||
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
|
||||||
|
tpu=FLAGS.tpu,
|
||||||
|
zone=FLAGS.zone,
|
||||||
|
project=FLAGS.project,
|
||||||
|
)
|
||||||
|
return resolver
|
||||||
|
|
||||||
|
|
||||||
|
def get_tpu_strategy():
|
||||||
|
resolver = get_tpu_cluster_resolver()
|
||||||
|
tf.config.experimental_connect_to_cluster(resolver)
|
||||||
|
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||||
|
return tf.distribute.experimental.TPUStrategy(resolver)
|
||||||
|
|
||||||
|
|
||||||
|
class TpuStrategyTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test_keras_metric_outside_strategy_scope_per_replica(self):
|
||||||
|
strategy = get_tpu_strategy()
|
||||||
|
metric = tf.keras.metrics.Mean("test_metric", dtype=tf.float32)
|
||||||
|
|
||||||
|
dataset = tf.data.Dataset.range(strategy.num_replicas_in_sync * 2).batch(2)
|
||||||
|
dataset = strategy.experimental_distribute_dataset(dataset)
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def step_fn(i):
|
||||||
|
metric.update_state(i)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "Trying to run metric.update_state "
|
||||||
|
"in replica context"):
|
||||||
|
with strategy.scope():
|
||||||
|
for i in dataset:
|
||||||
|
strategy.run(step_fn, args=(i,))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.test.main()
|
Loading…
Reference in New Issue
Block a user