Update keras to use tf.__internal__ for distribute tests.

PiperOrigin-RevId: 341332973
Change-Id: I61cce014764668a040bdcedb4e5e956b13983ddb
This commit is contained in:
Scott Zhu 2020-11-08 20:33:10 -08:00 committed by TensorFlower Gardener
parent 21550f8dcd
commit 17a521d9f1
20 changed files with 125 additions and 86 deletions

View File

@ -226,6 +226,7 @@ distribute_py_test(
],
deps = [
":optimizer_combinations",
":strategy_combinations",
"//tensorflow/python:platform_test",
"//tensorflow/python:util",
"//tensorflow/python/compat:v2_compat",
@ -246,6 +247,7 @@ distribute_py_test(
"multi_and_single_gpu",
],
deps = [
":strategy_combinations",
"//tensorflow/python:errors",
"//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
@ -271,6 +273,7 @@ distribute_py_test(
"notsan", # TODO(b/170869466)
],
deps = [
":strategy_combinations",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
@ -295,6 +298,7 @@ distribute_py_test(
"multi_and_single_gpu",
],
deps = [
":strategy_combinations",
"//tensorflow/python:framework_ops",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:combinations",
@ -314,6 +318,7 @@ py_library(
],
deps = [
":optimizer_combinations",
":strategy_combinations",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
"//tensorflow/python/distribute:combinations",
@ -400,6 +405,7 @@ py_library(
"keras_stateful_lstm_model_correctness_test.py",
],
deps = [
":strategy_combinations",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
"//tensorflow/python/distribute:collective_all_reduce_strategy",
@ -510,6 +516,7 @@ distribute_py_test(
"multi_and_single_gpu",
],
deps = [
":strategy_combinations",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/eager:test",
@ -949,3 +956,11 @@ tf_py_test(
"//tensorflow/python/keras/engine:base_layer",
],
)
py_library(
name = "strategy_combinations",
srcs = ["strategy_combinations.py"],
deps = [
"//tensorflow/python/distribute:strategy_combinations",
],
)

View File

@ -28,7 +28,6 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
@ -36,6 +35,7 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.framework import test_util
from tensorflow.python.keras.distribute import optimizer_combinations
from tensorflow.python.keras.distribute import strategy_combinations
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.platform import test

View File

@ -24,11 +24,11 @@ import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import metrics
from tensorflow.python.keras.distribute import strategy_combinations
from tensorflow.python.platform import test

View File

@ -28,10 +28,10 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.distribute import strategy_combinations
from tensorflow.python.module import module
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test

View File

@ -25,6 +25,7 @@ from tensorflow.python.distribute import values
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.distribute import strategy_combinations as keras_strategy_combinations
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -35,7 +36,7 @@ class OptimizerTest(test.TestCase, parameterized.TestCase):
@ds_combinations.generate(
combinations.times(
combinations.combine(
distribution=strategy_combinations.multidevice_strategies,
distribution=keras_strategy_combinations.multidevice_strategies,
mode=["eager"],
),
combinations.combine(

View File

@ -51,6 +51,11 @@ from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.distribute import distributed_training_utils
from tensorflow.python.keras.distribute import distributed_training_utils_v1
from tensorflow.python.keras.distribute import optimizer_combinations
from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.keras.distribute.strategy_combinations import multi_worker_mirrored_strategies
from tensorflow.python.keras.distribute.strategy_combinations import strategies_minus_default_minus_tpu
from tensorflow.python.keras.distribute.strategy_combinations import strategies_minus_tpu
from tensorflow.python.keras.distribute.strategy_combinations import tpu_strategies
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.mixed_precision import policy
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
@ -230,37 +235,6 @@ def multi_input_output_model():
return model
strategies_minus_default_minus_tpu = [
strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
]
strategies_minus_tpu = [
strategy_combinations.default_strategy,
strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
]
multi_worker_mirrored_strategies = [
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
]
tpu_strategies = [
strategy_combinations.tpu_strategy,
]
all_strategies = (
strategies_minus_tpu + tpu_strategies + multi_worker_mirrored_strategies)
def strategy_minus_tpu_combinations():
return combinations.combine(
distribution=strategies_minus_tpu, mode=['graph', 'eager'])
@ -1704,8 +1678,7 @@ class TestRegularizerLoss(test.TestCase, parameterized.TestCase):
return math_ops.reduce_mean(y_pred)
@ds_combinations.generate(
combinations.times(
strategy_combinations.all_strategy_combinations_minus_default()))
combinations.times(all_strategy_combinations_minus_default()))
def test_regularizer_loss(self, distribution):
batch_size = 2
if not distributed_training_utils.global_batch_size_supported(distribution):
@ -2648,9 +2621,7 @@ class TestModelCapturesStrategy(test.TestCase, parameterized.TestCase):
"""Tests that model creation captures the strategy."""
@ds_combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies,
mode=['eager']))
combinations.combine(distribution=all_strategies, mode=['eager']))
def test_fit_and_evaluate(self, distribution):
dataset = dataset_ops.DatasetV2.from_tensor_slices(
(array_ops.ones(shape=(64,)), array_ops.ones(shape=(64,))))

View File

@ -32,6 +32,9 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.distribute import distributed_training_utils
from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.keras.distribute.strategy_combinations import multi_worker_mirrored_strategies
from tensorflow.python.keras.distribute.strategy_combinations import strategies_minus_tpu
from tensorflow.python.keras.mixed_precision import policy
from tensorflow.python.keras.preprocessing import sequence
from tensorflow.python.platform import test
@ -44,22 +47,6 @@ _GLOBAL_BATCH_SIZE = 64
# Note: Please make sure the tests in this file are also covered in
# keras_backward_compat_test for features that are supported with both APIs.
all_strategies = [
strategy_combinations.default_strategy,
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.tpu_strategy, # steps_per_run=2
]
# TODO(b/159831559): add to all_strategies once all tests pass.
multi_worker_mirrored = [
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
]
def eager_mode_test_configuration():
return combinations.combine(
@ -85,8 +72,7 @@ def all_strategy_and_input_config_combinations_eager():
def strategy_minus_tpu_and_input_config_combinations_eager():
return (combinations.times(
combinations.combine(
distribution=strategy_combinations.strategies_minus_tpu),
combinations.combine(distribution=strategies_minus_tpu),
eager_mode_test_configuration()))
@ -130,13 +116,13 @@ def test_combinations_with_tpu_strategies_graph():
def multi_worker_mirrored_eager():
return combinations.times(
combinations.combine(distribution=multi_worker_mirrored),
combinations.combine(distribution=multi_worker_mirrored_strategies),
eager_mode_test_configuration())
def multi_worker_mirrored_eager_and_graph():
return combinations.times(
combinations.combine(distribution=multi_worker_mirrored),
combinations.combine(distribution=multi_worker_mirrored_strategies),
eager_mode_test_configuration() + graph_mode_test_configuration())

View File

@ -28,15 +28,16 @@ from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.distribute import keras_correctness_test_base
from tensorflow.python.keras.distribute import strategy_combinations
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
from tensorflow.python.training import gradient_descent
def all_strategy_combinations_with_eager_and_graph_modes():
return (combinations.combine(
distribution=keras_correctness_test_base.all_strategies,
distribution=strategy_combinations.all_strategies,
mode=['graph', 'eager']) + combinations.combine(
distribution=keras_correctness_test_base.multi_worker_mirrored,
distribution=strategy_combinations.multi_worker_mirrored_strategies,
mode='eager'))

View File

@ -23,8 +23,8 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.platform import test
@ -32,7 +32,7 @@ class KerasModelsTest(test.TestCase, parameterized.TestCase):
@ds_combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies, mode=["eager"]))
distribution=all_strategies, mode=["eager"]))
def test_lstm_model_with_dynamic_batch(self, distribution):
input_data = np.random.random([1, 32, 64, 64, 3])
input_shape = tuple(input_data.shape[1:])

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.distribute import central_storage_strategy
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import tpu_strategy
@ -116,6 +117,11 @@ class DistributionStrategyLstmModelCorrectnessTest(
@testing_utils.enable_v2_dtype_behavior
def test_lstm_model_correctness_mixed_precision(self, distribution, use_numpy,
use_validation_data):
if isinstance(distribution,
(central_storage_strategy.CentralStorageStrategy,
central_storage_strategy.CentralStorageStrategyV1)):
self.skipTest('CentralStorageStrategy is not supported by '
'mixed precision.')
if isinstance(distribution,
(tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)):
policy_name = 'mixed_bfloat16'

View File

@ -0,0 +1,59 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Strategy combinations for combinations.combine()."""
from tensorflow.python.distribute import strategy_combinations
multidevice_strategies = [
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.tpu_strategy,
]
multiworker_strategies = [
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu
]
strategies_minus_default_minus_tpu = [
strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
]
strategies_minus_tpu = [
strategy_combinations.default_strategy,
strategy_combinations.one_device_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
]
multi_worker_mirrored_strategies = [
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
]
tpu_strategies = [
strategy_combinations.tpu_strategy,
]
all_strategies = strategies_minus_tpu + tpu_strategies

View File

@ -404,9 +404,9 @@ distribute_py_test(
"//tensorflow/python:framework_test_combinations_lib",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/distribute:tpu_strategy",
"//tensorflow/python/keras",
"//tensorflow/python/keras/distribute:strategy_combinations",
],
)
@ -430,9 +430,9 @@ distribute_py_test(
"//tensorflow/python:framework_test_combinations_lib",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/distribute:tpu_strategy",
"//tensorflow/python/keras",
"//tensorflow/python/keras/distribute:strategy_combinations",
],
)
@ -453,8 +453,8 @@ distribute_py_test(
"//tensorflow/python:framework_test_combinations_lib",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/keras",
"//tensorflow/python/keras/distribute:strategy_combinations",
],
)
@ -489,8 +489,8 @@ distribute_py_test(
"//tensorflow/python:config",
"//tensorflow/python:framework_test_combinations_lib",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/keras",
"//tensorflow/python/keras/distribute:strategy_combinations",
],
)
@ -526,8 +526,8 @@ tpu_py_test(
deps = [
":hashing",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/keras",
"//tensorflow/python/keras/distribute:strategy_combinations",
],
)
@ -557,8 +557,8 @@ tpu_py_test(
deps = [
":index_lookup",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/keras",
"//tensorflow/python/keras/distribute:strategy_combinations",
],
)
@ -637,9 +637,9 @@ distribute_py_test(
"//tensorflow/python:framework_test_combinations_lib",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/eager:context",
"//tensorflow/python/keras",
"//tensorflow/python/keras/distribute:strategy_combinations",
],
)
@ -694,9 +694,9 @@ distribute_py_test(
"//tensorflow/python:framework_test_combinations_lib",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/eager:context",
"//tensorflow/python/keras",
"//tensorflow/python/keras/distribute:strategy_combinations",
],
)

View File

@ -23,12 +23,12 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.keras.layers.preprocessing import category_crossing
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.platform import test
@ -49,7 +49,7 @@ def batch_wrapper(dataset, batch_size, distribution, repeat=None):
@ds_combinations.generate(
combinations.combine(
# Investigate why crossing is not supported with TPU.
distribution=strategy_combinations.all_strategies,
distribution=all_strategies,
mode=['eager', 'graph']))
class CategoryCrossingDistributionTest(
keras_parameterized.TestCase,

View File

@ -23,12 +23,12 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.distribute import strategy_combinations
from tensorflow.python.keras.layers.preprocessing import category_encoding
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.platform import test

View File

@ -22,10 +22,10 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import config
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.distribute import strategy_combinations
from tensorflow.python.keras.layers.preprocessing import discretization
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.platform import test

View File

@ -23,11 +23,11 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.keras.layers.preprocessing import hashing
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.platform import test
@ -35,7 +35,7 @@ from tensorflow.python.platform import test
@ds_combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies,
distribution=all_strategies,
mode=["eager", "graph"]))
class HashingDistributionTest(keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):

View File

@ -23,10 +23,10 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.platform import test
@ -34,7 +34,7 @@ from tensorflow.python.platform import test
@ds_combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies,
distribution=all_strategies,
mode=["eager", "graph"]))
class ImagePreprocessingDistributionTest(
keras_parameterized.TestCase,

View File

@ -23,12 +23,12 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import context
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.keras.layers.preprocessing import index_lookup
from tensorflow.python.keras.layers.preprocessing import index_lookup_v1
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
@ -44,7 +44,7 @@ def get_layer_class():
@ds_combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies,
distribution=all_strategies,
mode=["eager"])) # Eager-only, no graph: b/158793009
class IndexLookupDistributionTest(
keras_parameterized.TestCase,

View File

@ -23,10 +23,10 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import context
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.keras.layers.preprocessing import normalization
from tensorflow.python.keras.layers.preprocessing import normalization_v1
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
@ -108,7 +108,7 @@ def _get_layer_computation_test_cases():
@ds_combinations.generate(
combinations.times(
combinations.combine(
distribution=strategy_combinations.all_strategies,
distribution=all_strategies,
mode=["eager", "graph"]), _get_layer_computation_test_cases()))
class NormalizationTest(keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):

View File

@ -23,12 +23,12 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import context
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.distribute.strategy_combinations import all_strategies
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.keras.layers.preprocessing import text_vectorization
from tensorflow.python.keras.layers.preprocessing import text_vectorization_v1
@ -44,7 +44,7 @@ def get_layer_class():
@ds_combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies,
distribution=all_strategies,
mode=["eager", "graph"]))
class TextVectorizationDistributionTest(
keras_parameterized.TestCase,