Fork the keras related tpu_strategy_test to keras integration test.
PiperOrigin-RevId: 317232048 Change-Id: If05867985ff1ff81ac45bb601b701ee68d4d5279
This commit is contained in:
parent
64f7bdd56a
commit
c159f15995
@ -654,7 +654,6 @@ tpu_py_test(
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/eager:remote",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/keras",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
@ -364,24 +363,6 @@ class TPUStrategyTest(test.TestCase):
|
||||
expected_result,
|
||||
strategy.experimental_local_results(train_step(next(input_iterator))))
|
||||
|
||||
def test_keras_metric_outside_strategy_scope_per_replica(self):
|
||||
strategy = get_tpu_strategy()
|
||||
metric = keras.metrics.Mean("test_metric", dtype=dtypes.float32)
|
||||
|
||||
dataset = dataset_ops.Dataset.range(strategy.num_replicas_in_sync *
|
||||
2).batch(2)
|
||||
dataset = strategy.experimental_distribute_dataset(dataset)
|
||||
|
||||
@def_function.function
|
||||
def step_fn(i):
|
||||
metric.update_state(i)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Trying to run metric.update_state "
|
||||
"in replica context"):
|
||||
with strategy.scope():
|
||||
for i in dataset:
|
||||
strategy.run(step_fn, args=(i,))
|
||||
|
||||
# TODO(b/145574622): Remove this test once it is re-enabled in values_test.py.
|
||||
def test_all_reduce_on_sync_on_read_variable(self):
|
||||
strategy = get_tpu_strategy()
|
||||
|
@ -2,6 +2,7 @@
|
||||
# Contains Keras integration tests that verify with other TF high level APIs.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test", "tf_py_test")
|
||||
load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -91,3 +92,15 @@ cuda_py_test(
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
],
|
||||
)
|
||||
|
||||
tpu_py_test(
|
||||
name = "tpu_strategy_test",
|
||||
srcs = ["tpu_strategy_test.py"],
|
||||
disable_experimental = True,
|
||||
python_version = "PY3",
|
||||
tags = ["no_oss"],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,69 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for TPUStrategy."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import flags
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
|
||||
flags.DEFINE_string("project", None, "Name of GCP project with TPU.")
|
||||
flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.")
|
||||
|
||||
|
||||
def get_tpu_cluster_resolver():
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
|
||||
tpu=FLAGS.tpu,
|
||||
zone=FLAGS.zone,
|
||||
project=FLAGS.project,
|
||||
)
|
||||
return resolver
|
||||
|
||||
|
||||
def get_tpu_strategy():
|
||||
resolver = get_tpu_cluster_resolver()
|
||||
tf.config.experimental_connect_to_cluster(resolver)
|
||||
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
return tf.distribute.experimental.TPUStrategy(resolver)
|
||||
|
||||
|
||||
class TpuStrategyTest(tf.test.TestCase):
|
||||
|
||||
def test_keras_metric_outside_strategy_scope_per_replica(self):
|
||||
strategy = get_tpu_strategy()
|
||||
metric = tf.keras.metrics.Mean("test_metric", dtype=tf.float32)
|
||||
|
||||
dataset = tf.data.Dataset.range(strategy.num_replicas_in_sync * 2).batch(2)
|
||||
dataset = strategy.experimental_distribute_dataset(dataset)
|
||||
|
||||
@tf.function
|
||||
def step_fn(i):
|
||||
metric.update_state(i)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Trying to run metric.update_state "
|
||||
"in replica context"):
|
||||
with strategy.scope():
|
||||
for i in dataset:
|
||||
strategy.run(step_fn, args=(i,))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
Loading…
Reference in New Issue
Block a user