From 898c3b6d335cc8e50b53e69099d449b07326388e Mon Sep 17 00:00:00 2001 From: Yuefeng Zhou Date: Wed, 21 Oct 2020 12:42:53 -0700 Subject: [PATCH] Add a KPL test for PSStrategy with precomputed states. PiperOrigin-RevId: 338319527 Change-Id: I91edcb84994b7d053dcf35c14fa791de26374bcf --- tensorflow/python/keras/distribute/BUILD | 1 + .../parameter_server_training_test.py | 105 ++++++++++-------- 2 files changed, 62 insertions(+), 44 deletions(-) diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index a89e33e8498..440e4cbf564 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -867,6 +867,7 @@ py_test( "//tensorflow/python:variables", "//tensorflow/python/compat:v2_compat", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:multi_worker_test_base", "//tensorflow/python/distribute:parameter_server_strategy_v2", "//tensorflow/python/distribute:sharded_variable", diff --git a/tensorflow/python/keras/distribute/parameter_server_training_test.py b/tensorflow/python/keras/distribute/parameter_server_training_test.py index de983b95440..503dd68eb71 100644 --- a/tensorflow/python/keras/distribute/parameter_server_training_test.py +++ b/tensorflow/python/keras/distribute/parameter_server_training_test.py @@ -21,10 +21,12 @@ from __future__ import print_function import random import tempfile +from absl.testing import parameterized from tensorflow.python import keras from tensorflow.python.compat import v2_compat from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import combinations from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import parameter_server_strategy_v2 from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver @@ -44,6 +46,14 @@ from tensorflow.python.platform import test from tensorflow.python.training.server_lib import ClusterSpec +# These vocabularies usually come from TFT or a Beam pipeline. +FEATURE_VOCAB = [ + "avenger", "ironman", "batman", "hulk", "spiderman", "kingkong", + "wonder_woman" +] +LABEL_VOCAB = ["yes", "no"] + + def make_coordinator(num_workers, num_ps): cluster_def = multi_worker_test_base.create_in_process_cluster( num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc") @@ -56,60 +66,63 @@ def make_coordinator(num_workers, num_ps): parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)) -class KPLTest(test.TestCase): +class KPLTest(test.TestCase, parameterized.TestCase): @classmethod def setUpClass(cls): super(KPLTest, cls).setUpClass() cls.coordinator = make_coordinator(num_workers=3, num_ps=2) - def testTrainAndServe(self): - # These vocabularies usually come from TFT or a Beam pipeline. - feature_vocab = [ - "avenger", "ironman", "batman", "hulk", "spiderman", "kingkong", - "wonder_woman" - ] - label_vocab = ["yes", "no"] + def define_kpls_for_training(self, use_adapt): + # Define KPLs under strategy's scope. Right now, if they have look up + # tables, they will be created on the client. Their variables will be + # created on PS. Ideally they should be cached on each worker since they + # will not be changed in a training step. + if use_adapt: + feature_lookup_layer = string_lookup.StringLookup(num_oov_indices=1) + feature_lookup_layer.adapt(FEATURE_VOCAB) + label_lookup_layer = string_lookup.StringLookup( + num_oov_indices=0, mask_token=None) + label_lookup_layer.adapt(LABEL_VOCAB) + else: + feature_lookup_layer = string_lookup.StringLookup( + vocabulary=FEATURE_VOCAB, num_oov_indices=1) + label_lookup_layer = string_lookup.StringLookup( + vocabulary=LABEL_VOCAB, num_oov_indices=0, mask_token=None) + + raw_feature_input = keras.layers.Input( + shape=(3,), dtype=dtypes.string, name="feature", ragged=True) + feature_id_input = feature_lookup_layer(raw_feature_input) + + # Model creates variables as well. + feature_ps = keras.Model({"features": raw_feature_input}, feature_id_input) + + raw_label_input = keras.layers.Input( + shape=(), dtype=dtypes.string, name="label") + label_id_input = label_lookup_layer(raw_label_input) + label_ps = keras.Model({"label": raw_label_input}, label_id_input) + + return feature_ps, label_ps + + def define_reverse_lookup_layer(self): + # Only needed for serving. + label_inverse_lookup_layer = string_lookup.StringLookup( + num_oov_indices=1, mask_token=None, vocabulary=LABEL_VOCAB, invert=True) + return label_inverse_lookup_layer + + @combinations.generate( + combinations.combine(mode=["eager"], use_adapt=[True, False])) + def testTrainAndServe(self, use_adapt): with self.coordinator.strategy.scope(): - # Define KPLs under strategy's scope. Right now, if they have look up - # tables, they will be created on the coordinator. Their variables will be - # created on PS. Ideally they should be cached on each worker since they - # will not be changed in a training step. - feature_lookup_layer = string_lookup.StringLookup() - raw_feature_input = keras.layers.Input( - shape=(3,), dtype=dtypes.string, name="feature", ragged=True) - feature_id_input = feature_lookup_layer(raw_feature_input) - - # Model creates variables as well. - feature_ps = keras.Model({"features": raw_feature_input}, - feature_id_input) - - # TODO(yuefengz): adapt may be expensive for large vocab? - feature_lookup_layer.adapt(feature_vocab) - - label_lookup_layer = string_lookup.StringLookup( - num_oov_indices=0, mask_token=None) - raw_label_input = keras.layers.Input( - shape=(), dtype=dtypes.string, name="label") - label_id_input = label_lookup_layer(raw_label_input) - label_ps = keras.Model({"label": raw_label_input}, label_id_input) - - label_lookup_layer.adapt(label_vocab) - - # Only needed for serving. - label_inverse_lookup_layer = string_lookup.StringLookup( - num_oov_indices=1, - mask_token=None, - vocabulary=label_lookup_layer.get_vocabulary(), - invert=True) + feature_ps, label_ps = self.define_kpls_for_training(use_adapt) def dataset_fn(): def feature_and_label_gen(): while True: - features = random.sample(feature_vocab, 3) + features = random.sample(FEATURE_VOCAB, 3) label = "yes" if "avenger" in features else "no" yield {"features": features, "label": label} @@ -134,23 +147,27 @@ class KPLTest(test.TestCase): distributed_dataset = self.coordinator.create_per_worker_dataset( dataset_fn) + # Create the model. The input needs to be compatible with KPLs. model_input = keras.layers.Input( shape=(3,), dtype=dtypes.int64, name="model_input") + + # input_dim includes a mask token and an oov token. emb_output = keras.layers.Embedding( - input_dim=len(feature_lookup_layer.get_vocabulary()), output_dim=20)( + input_dim=len(FEATURE_VOCAB) + 2, output_dim=20)( model_input) emb_output = math_ops.reduce_mean(emb_output, axis=1) dense_output = keras.layers.Dense( units=1, activation="sigmoid")( emb_output) model = keras.Model({"features": model_input}, dense_output) + optimizer = rmsprop.RMSprop(learning_rate=0.01) accuracy = keras.metrics.Accuracy() @def_function.function def worker_fn(iterator): - def train_step(iterator): + def replica_fn(iterator): batch_data, labels = next(iterator) with backprop.GradientTape() as tape: pred = model(batch_data, training=True) @@ -164,7 +181,7 @@ class KPLTest(test.TestCase): actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64) accuracy.update_state(labels, actual_pred) - self.coordinator._strategy.run(train_step, args=(iterator,)) + self.coordinator._strategy.run(replica_fn, args=(iterator,)) distributed_iterator = iter(distributed_dataset) for _ in range(10): @@ -175,7 +192,7 @@ class KPLTest(test.TestCase): # Create a saved model. model.feature_ps = feature_ps model.label_ps = label_ps - model.label_inverse_lookup_layer = label_inverse_lookup_layer + model.label_inverse_lookup_layer = self.define_reverse_lookup_layer() def create_serving_signature(model):