Fork the keras related tpu_strategy_test to keras integration test.

PiperOrigin-RevId: 317232048
Change-Id: If05867985ff1ff81ac45bb601b701ee68d4d5279
This commit is contained in:
Scott Zhu 2020-06-18 19:35:14 -07:00 committed by TensorFlower Gardener
parent 64f7bdd56a
commit c159f15995
4 changed files with 82 additions and 20 deletions

View File

@ -654,7 +654,6 @@ tpu_py_test(
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:remote",
"//tensorflow/python/eager:test",
"//tensorflow/python/keras",
],
)

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context
@ -364,24 +363,6 @@ class TPUStrategyTest(test.TestCase):
expected_result,
strategy.experimental_local_results(train_step(next(input_iterator))))
def test_keras_metric_outside_strategy_scope_per_replica(self):
strategy = get_tpu_strategy()
metric = keras.metrics.Mean("test_metric", dtype=dtypes.float32)
dataset = dataset_ops.Dataset.range(strategy.num_replicas_in_sync *
2).batch(2)
dataset = strategy.experimental_distribute_dataset(dataset)
@def_function.function
def step_fn(i):
metric.update_state(i)
with self.assertRaisesRegex(ValueError, "Trying to run metric.update_state "
"in replica context"):
with strategy.scope():
for i in dataset:
strategy.run(step_fn, args=(i,))
# TODO(b/145574622): Remove this test once it is re-enabled in values_test.py.
def test_all_reduce_on_sync_on_read_variable(self):
strategy = get_tpu_strategy()

View File

@ -2,6 +2,7 @@
# Contains Keras integration tests that verify with other TF high level APIs.
load("//tensorflow:tensorflow.bzl", "cuda_py_test", "tf_py_test")
load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
package(
default_visibility = [
@ -91,3 +92,15 @@ cuda_py_test(
"//tensorflow/python:extra_py_tests_deps",
],
)
tpu_py_test(
name = "tpu_strategy_test",
srcs = ["tpu_strategy_test.py"],
disable_experimental = True,
python_version = "PY3",
tags = ["no_oss"],
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:extra_py_tests_deps",
],
)

View File

@ -0,0 +1,69 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TPUStrategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
import tensorflow as tf
FLAGS = flags.FLAGS
flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
flags.DEFINE_string("project", None, "Name of GCP project with TPU.")
flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.")
def get_tpu_cluster_resolver():
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu,
zone=FLAGS.zone,
project=FLAGS.project,
)
return resolver
def get_tpu_strategy():
resolver = get_tpu_cluster_resolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
return tf.distribute.experimental.TPUStrategy(resolver)
class TpuStrategyTest(tf.test.TestCase):
def test_keras_metric_outside_strategy_scope_per_replica(self):
strategy = get_tpu_strategy()
metric = tf.keras.metrics.Mean("test_metric", dtype=tf.float32)
dataset = tf.data.Dataset.range(strategy.num_replicas_in_sync * 2).batch(2)
dataset = strategy.experimental_distribute_dataset(dataset)
@tf.function
def step_fn(i):
metric.update_state(i)
with self.assertRaisesRegex(ValueError, "Trying to run metric.update_state "
"in replica context"):
with strategy.scope():
for i in dataset:
strategy.run(step_fn, args=(i,))
if __name__ == "__main__":
tf.test.main()