Test TextVectorization with all distribution strategies.

PiperOrigin-RevId: 308949199
Change-Id: I33f99224767720f5636b4c98fef4695c89d274f9
This commit is contained in:
A. Unique TensorFlower 2020-04-28 20:13:10 -07:00 committed by TensorFlower Gardener
parent d3610b144d
commit 9403febb65
2 changed files with 32 additions and 18 deletions

View File

@ -4,6 +4,7 @@
load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test")
load("//tensorflow/core/platform/default:distribute.bzl", "distribute_py_test")
package(
default_visibility = [
@ -421,17 +422,20 @@ tf_py_test(
],
)
tpu_py_test(
name = "text_vectorization_tpu_test",
srcs = ["text_vectorization_tpu_test.py"],
disable_experimental = True,
distribute_py_test(
name = "text_vectorization_distribution_test",
srcs = ["text_vectorization_distribution_test.py"],
main = "text_vectorization_distribution_test.py",
python_version = "PY3",
tags = ["no_oss"],
tags = [
"multi_and_single_gpu",
],
deps = [
":text_vectorization",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/eager:test",
"//tensorflow/python/keras",
"//tensorflow/python/keras/distribute:tpu_strategy_test_utils",
],
)

View File

@ -22,22 +22,34 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import context
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.distribute import tpu_strategy_test_utils
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.keras.layers.preprocessing import text_vectorization
from tensorflow.python.keras.layers.preprocessing import text_vectorization_v1
from tensorflow.python.platform import test
@keras_parameterized.run_all_keras_modes(
always_skip_v1=True, always_skip_eager=True)
class TextVectorizationTPUDistributionTest(
def get_layer_class():
if context.executing_eagerly():
return text_vectorization.TextVectorization
else:
return text_vectorization_v1.TextVectorization
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies,
mode=["eager", "graph"]))
class TextVectorizationDistributionTest(
keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):
def test_distribution_strategy_output(self):
def test_distribution_strategy_output(self, distribution):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "fire"],
["fire", "and", "earth", "michigan"]])
@ -47,11 +59,10 @@ class TextVectorizationTPUDistributionTest(
expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]]
config.set_soft_device_placement(True)
strategy = tpu_strategy_test_utils.get_tpu_strategy()
with strategy.scope():
with distribution.scope():
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
layer = text_vectorization.TextVectorization(
layer = get_layer_class()(
max_tokens=None,
standardize=None,
split=None,
@ -63,7 +74,7 @@ class TextVectorizationTPUDistributionTest(
output_dataset = model.predict(input_dataset)
self.assertAllEqual(expected_output, output_dataset)
def test_distribution_strategy_output_with_adapt(self):
def test_distribution_strategy_output_with_adapt(self, distribution):
vocab_data = [[
"earth", "earth", "earth", "earth", "wind", "wind", "wind", "and",
"and", "fire"
@ -77,11 +88,10 @@ class TextVectorizationTPUDistributionTest(
expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]]
config.set_soft_device_placement(True)
strategy = tpu_strategy_test_utils.get_tpu_strategy()
with strategy.scope():
with distribution.scope():
input_data = keras.Input(shape=(None,), dtype=dtypes.string)
layer = text_vectorization.TextVectorization(
layer = get_layer_class()(
max_tokens=None,
standardize=None,
split=None,