After disabling TFRT, the broken tests should pass.
PiperOrigin-RevId: 356842632 Change-Id: I484af6196b1415b3e7a6f84440110f63fed02563
This commit is contained in:
parent
41b3e84fa9
commit
779e443222
@ -692,6 +692,7 @@ cuda_py_test(
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_tfrt", # TODO(b/179839466): Reenable TFRT after the issue is resolved.
|
||||
"no_windows_gpu", # TODO(b/130551176)
|
||||
],
|
||||
deps = [
|
||||
@ -706,6 +707,7 @@ cuda_py_test(
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/keras/engine",
|
||||
"//tensorflow/python/keras/layers:core",
|
||||
"//tensorflow/python/keras/utils:kpl_test_utils",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -19,15 +19,22 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.distribute import combinations as ds_combinations
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_combinations as combinations
|
||||
from tensorflow.python.keras.engine import training as keras_training
|
||||
from tensorflow.python.keras.layers import core as keras_core
|
||||
from tensorflow.python.keras.optimizer_v2 import rmsprop
|
||||
from tensorflow.python.keras.utils import kpl_test_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import gradient_descent
|
||||
@ -84,6 +91,51 @@ class MirroredStrategyDefunTest(test.TestCase):
|
||||
self.assertAllEqual(0.5 * np.ones([10, 1]), updated_var_values[0])
|
||||
self.assertAllEqual([0.5], updated_var_values[1])
|
||||
|
||||
def testTrainAndServeWithKPL(self, distribution):
|
||||
use_adapt = False
|
||||
test_utils_obj = kpl_test_utils.DistributeKplTestUtils()
|
||||
with distribution.scope():
|
||||
feature_mapper, label_mapper = test_utils_obj.define_kpls_for_training(
|
||||
use_adapt)
|
||||
model = test_utils_obj.define_model()
|
||||
optimizer = rmsprop.RMSprop(learning_rate=0.1)
|
||||
accuracy = keras.metrics.Accuracy()
|
||||
|
||||
def dataset_fn(_):
|
||||
return test_utils_obj.dataset_fn(feature_mapper, label_mapper)
|
||||
|
||||
@def_function.function
|
||||
def train_step(iterator):
|
||||
"""The step function for one training step."""
|
||||
|
||||
def step_fn(inputs):
|
||||
"""The computation to run on each TPU device."""
|
||||
features, labels = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
pred = model(features, training=True)
|
||||
loss = keras.losses.binary_crossentropy(labels, pred)
|
||||
loss = nn.compute_average_loss(loss)
|
||||
grads = tape.gradient(loss, model.trainable_variables)
|
||||
optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
|
||||
|
||||
actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64)
|
||||
accuracy.update_state(labels, actual_pred)
|
||||
|
||||
distribution.run(step_fn, args=(next(iterator),))
|
||||
|
||||
distributed_dataset = distribution.distribute_datasets_from_function(
|
||||
dataset_fn)
|
||||
distributed_iterator = iter(distributed_dataset)
|
||||
num_epochs = 4
|
||||
num_steps = 7
|
||||
for _ in range(num_epochs):
|
||||
accuracy.reset_states()
|
||||
for _ in range(num_steps):
|
||||
train_step(distributed_iterator)
|
||||
|
||||
self.assertGreater(accuracy.result().numpy(), 0.5)
|
||||
self.assertEqual(optimizer.iterations.numpy(), num_epochs * num_steps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -40,6 +40,7 @@ py_library(
|
||||
":control_flow_util",
|
||||
":engine_utils",
|
||||
":generic_utils",
|
||||
":kpl_test_utils",
|
||||
":layer_utils",
|
||||
":multi_gpu_utils",
|
||||
":np_utils",
|
||||
@ -54,6 +55,13 @@ py_library(
|
||||
deps = [],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "kpl_test_utils",
|
||||
srcs = ["kpl_test_utils.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "data_utils",
|
||||
srcs = ["data_utils.py"],
|
||||
|
129
tensorflow/python/keras/utils/kpl_test_utils.py
Normal file
129
tensorflow/python/keras/utils/kpl_test_utils.py
Normal file
@ -0,0 +1,129 @@
|
||||
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test related utilities for KPL + tf.distribute."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras.layers.preprocessing import string_lookup
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
class DistributeKplTestUtils:
|
||||
"""Utils for test of tf.distribute + KPL."""
|
||||
FEATURE_VOCAB = [
|
||||
"avenger", "ironman", "batman", "hulk", "spiderman", "kingkong",
|
||||
"wonder_woman"
|
||||
]
|
||||
LABEL_VOCAB = ["yes", "no"]
|
||||
|
||||
def define_kpls_for_training(self, use_adapt):
|
||||
"""Function that defines KPL used for unit tests of tf.distribute.
|
||||
|
||||
Args:
|
||||
use_adapt: if adapt will be called. False means there will be precomputed
|
||||
statistics.
|
||||
|
||||
Returns:
|
||||
feature_mapper: a simple keras model with one keras StringLookup layer
|
||||
which maps feature to index.
|
||||
label_mapper: similar to feature_mapper, but maps label to index.
|
||||
|
||||
"""
|
||||
if use_adapt:
|
||||
feature_lookup_layer = (
|
||||
string_lookup.StringLookup(
|
||||
num_oov_indices=1))
|
||||
feature_lookup_layer.adapt(self.FEATURE_VOCAB)
|
||||
label_lookup_layer = (
|
||||
string_lookup.StringLookup(
|
||||
num_oov_indices=0, mask_token=None))
|
||||
label_lookup_layer.adapt(self.LABEL_VOCAB)
|
||||
else:
|
||||
feature_lookup_layer = (
|
||||
string_lookup.StringLookup(
|
||||
vocabulary=self.FEATURE_VOCAB, num_oov_indices=1))
|
||||
label_lookup_layer = (
|
||||
string_lookup.StringLookup(
|
||||
vocabulary=self.LABEL_VOCAB, num_oov_indices=0, mask_token=None))
|
||||
|
||||
raw_feature_input = keras.layers.Input(
|
||||
shape=(3,), dtype=dtypes.string, name="feature", ragged=True)
|
||||
feature_id_input = feature_lookup_layer(raw_feature_input)
|
||||
feature_mapper = keras.Model({"features": raw_feature_input},
|
||||
feature_id_input)
|
||||
|
||||
raw_label_input = keras.layers.Input(
|
||||
shape=(1,), dtype=dtypes.string, name="label")
|
||||
label_id_input = label_lookup_layer(raw_label_input)
|
||||
label_mapper = keras.Model({"label": raw_label_input}, label_id_input)
|
||||
|
||||
return feature_mapper, label_mapper
|
||||
|
||||
def dataset_fn(self, feature_mapper, label_mapper):
|
||||
"""Function that generates dataset for test of tf.distribute + KPL.
|
||||
|
||||
Args:
|
||||
feature_mapper: a simple keras model with one keras StringLookup layer
|
||||
which maps feature to index.
|
||||
label_mapper: similar to feature_mapper, but maps label to index.
|
||||
|
||||
Returns:
|
||||
Generated dataset for test of tf.distribute + KPL.
|
||||
|
||||
"""
|
||||
|
||||
def feature_and_label_gen():
|
||||
# Generator of dataset.
|
||||
while True:
|
||||
features = random.sample(self.FEATURE_VOCAB, 3)
|
||||
label = ["yes"] if self.FEATURE_VOCAB[0] in features else ["no"]
|
||||
yield {"features": features, "label": label}
|
||||
|
||||
raw_dataset = dataset_ops.Dataset.from_generator(
|
||||
feature_and_label_gen,
|
||||
output_signature={
|
||||
"features": tensor_spec.TensorSpec([3], dtypes.string),
|
||||
"label": tensor_spec.TensorSpec([1], dtypes.string)
|
||||
}).shuffle(100).batch(32)
|
||||
|
||||
train_dataset = raw_dataset.map(lambda x: ( # pylint: disable=g-long-lambda
|
||||
{
|
||||
"features": feature_mapper(x["features"])
|
||||
}, label_mapper(x["label"])))
|
||||
return train_dataset
|
||||
|
||||
def define_model(self):
|
||||
"""A simple model for test of tf.distribute + KPL."""
|
||||
# Create the model. The input needs to be compatible with KPLs.
|
||||
model_input = keras.layers.Input(
|
||||
shape=(3,), dtype=dtypes.int64, name="model_input")
|
||||
|
||||
# input_dim includes a mask token and an oov token.
|
||||
emb_output = keras.layers.Embedding(
|
||||
input_dim=len(self.FEATURE_VOCAB) + 2, output_dim=20)(
|
||||
model_input)
|
||||
emb_output = math_ops.reduce_mean(emb_output, axis=1)
|
||||
dense_output = keras.layers.Dense(
|
||||
units=1, activation="sigmoid")(
|
||||
emb_output)
|
||||
model = keras.Model({"features": model_input}, dense_output)
|
||||
return model
|
Loading…
Reference in New Issue
Block a user