Add a test case to cover the training phase when using TPUStrategy. The test case also involves KPL calls.
PiperOrigin-RevId: 355007578 Change-Id: I1717ccb86c0f8cd2c19da0c8f75f2f984cb50a2e
This commit is contained in:
parent
acea35d1f8
commit
2f05920fb7
@ -18,16 +18,26 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
|
||||
from absl import flags
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
preproc_layer = tf.keras.layers.experimental.preprocessing
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
|
||||
flags.DEFINE_string("project", None, "Name of GCP project with TPU.")
|
||||
flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.")
|
||||
|
||||
# These vocabularies usually come from TFT or a Beam pipeline.
|
||||
FEATURE_VOCAB = [
|
||||
"avenger", "ironman", "batman", "hulk", "spiderman", "kingkong",
|
||||
"wonder_woman"
|
||||
]
|
||||
LABEL_VOCAB = ["yes", "no"]
|
||||
|
||||
|
||||
def get_tpu_cluster_resolver():
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
|
||||
@ -47,6 +57,37 @@ def get_tpu_strategy():
|
||||
|
||||
class TpuStrategyTest(tf.test.TestCase):
|
||||
|
||||
def define_kpls_for_training(self, use_adapt):
|
||||
if use_adapt:
|
||||
feature_lookup_layer = (
|
||||
tf.keras.layers.experimental.preprocessing.StringLookup(
|
||||
num_oov_indices=1))
|
||||
feature_lookup_layer.adapt(FEATURE_VOCAB)
|
||||
label_lookup_layer = (
|
||||
tf.keras.layers.experimental.preprocessing.StringLookup(
|
||||
num_oov_indices=0, mask_token=None))
|
||||
label_lookup_layer.adapt(LABEL_VOCAB)
|
||||
else:
|
||||
feature_lookup_layer = (
|
||||
tf.keras.layers.experimental.preprocessing.StringLookup(
|
||||
vocabulary=FEATURE_VOCAB, num_oov_indices=1))
|
||||
label_lookup_layer = (
|
||||
tf.keras.layers.experimental.preprocessing.StringLookup(
|
||||
vocabulary=LABEL_VOCAB, num_oov_indices=0, mask_token=None))
|
||||
|
||||
raw_feature_input = tf.keras.layers.Input(
|
||||
shape=(3,), dtype=tf.dtypes.string, name="feature", ragged=True)
|
||||
feature_id_input = feature_lookup_layer(raw_feature_input)
|
||||
feature_mapper = tf.keras.Model({"features": raw_feature_input},
|
||||
feature_id_input)
|
||||
|
||||
raw_label_input = tf.keras.layers.Input(
|
||||
shape=(1,), dtype=tf.dtypes.string, name="label")
|
||||
label_id_input = label_lookup_layer(raw_label_input)
|
||||
label_mapper = tf.keras.Model({"label": raw_label_input}, label_id_input)
|
||||
|
||||
return feature_mapper, label_mapper
|
||||
|
||||
def test_keras_metric_outside_strategy_scope_per_replica(self):
|
||||
strategy = get_tpu_strategy()
|
||||
metric = tf.keras.metrics.Mean("test_metric", dtype=tf.float32)
|
||||
@ -58,12 +99,93 @@ class TpuStrategyTest(tf.test.TestCase):
|
||||
def step_fn(i):
|
||||
metric.update_state(i)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Trying to run metric.update_state "
|
||||
"in replica context"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Trying to run metric.update_state "
|
||||
"in replica context"):
|
||||
with strategy.scope():
|
||||
for i in dataset:
|
||||
strategy.run(step_fn, args=(i,))
|
||||
|
||||
def test_train_and_serve(self):
|
||||
strategy = get_tpu_strategy()
|
||||
use_adapt = False
|
||||
|
||||
with strategy.scope():
|
||||
feature_mapper, label_mapper = self.define_kpls_for_training(use_adapt)
|
||||
|
||||
def dataset_fn(_):
|
||||
|
||||
def feature_and_label_gen():
|
||||
# Generator of dataset.
|
||||
while True:
|
||||
features = random.sample(FEATURE_VOCAB, 3)
|
||||
label = ["yes"] if "avenger" in features else ["no"]
|
||||
yield {"features": features, "label": label}
|
||||
|
||||
raw_dataset = tf.data.Dataset.from_generator(
|
||||
feature_and_label_gen,
|
||||
output_signature={
|
||||
"features": tf.TensorSpec([3], tf.dtypes.string),
|
||||
"label": tf.TensorSpec([1], tf.dtypes.string)
|
||||
}).shuffle(100).batch(32)
|
||||
|
||||
train_dataset = raw_dataset.map(lambda x: ( # pylint: disable=g-long-lambda
|
||||
{
|
||||
"features": feature_mapper(x["features"])
|
||||
}, label_mapper(x["label"])))
|
||||
return train_dataset
|
||||
|
||||
# Create the model. The input needs to be compatible with KPLs.
|
||||
model_input = tf.keras.layers.Input(
|
||||
shape=(3,), dtype=tf.dtypes.int64, name="model_input")
|
||||
|
||||
# input_dim includes a mask token and an oov token.
|
||||
emb_output = tf.keras.layers.Embedding(
|
||||
input_dim=len(FEATURE_VOCAB) + 2, output_dim=20)(
|
||||
model_input)
|
||||
emb_output = tf.math.reduce_mean(emb_output, axis=1)
|
||||
dense_output = tf.keras.layers.Dense(
|
||||
units=1, activation="sigmoid")(
|
||||
emb_output)
|
||||
model = tf.keras.Model({"features": model_input}, dense_output)
|
||||
|
||||
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.1)
|
||||
accuracy = tf.keras.metrics.Accuracy()
|
||||
|
||||
@tf.function
|
||||
def train_step(iterator):
|
||||
"""The step function for one training step."""
|
||||
|
||||
def step_fn(inputs):
|
||||
"""The computation to run on each TPU device."""
|
||||
features, labels = inputs
|
||||
with tf.GradientTape() as tape:
|
||||
pred = model(features, training=True)
|
||||
loss = tf.keras.losses.binary_crossentropy(labels, pred)
|
||||
loss = tf.nn.compute_average_loss(loss)
|
||||
grads = tape.gradient(loss, model.trainable_variables)
|
||||
optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
|
||||
|
||||
actual_pred = tf.cast(tf.math.greater(pred, 0.5), tf.dtypes.int64)
|
||||
accuracy.update_state(labels, actual_pred)
|
||||
|
||||
strategy.run(step_fn, args=(next(iterator),))
|
||||
|
||||
distributed_dataset = strategy.distribute_datasets_from_function(
|
||||
dataset_fn)
|
||||
distributed_iterator = iter(distributed_dataset)
|
||||
num_epochs = 4
|
||||
num_steps = 7
|
||||
for _ in range(num_epochs):
|
||||
accuracy.reset_states()
|
||||
for _ in range(num_steps):
|
||||
train_step(distributed_iterator)
|
||||
|
||||
self.assertGreater(accuracy.result().numpy(), 0.5)
|
||||
self.assertEqual(optimizer.iterations.numpy(), num_epochs * num_steps)
|
||||
|
||||
# TODO(b/178495959): Add tests that cover the serving phase.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user