Add a KPL test for PSStrategy with precomputed states.

PiperOrigin-RevId: 338319527
Change-Id: I91edcb84994b7d053dcf35c14fa791de26374bcf
This commit is contained in:
Yuefeng Zhou 2020-10-21 12:42:53 -07:00 committed by TensorFlower Gardener
parent dbf191bb17
commit 898c3b6d33
2 changed files with 62 additions and 44 deletions

View File

@ -867,6 +867,7 @@ py_test(
"//tensorflow/python:variables",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/distribute:parameter_server_strategy_v2",
"//tensorflow/python/distribute:sharded_variable",

View File

@ -21,10 +21,12 @@ from __future__ import print_function
import random
import tempfile
from absl.testing import parameterized
from tensorflow.python import keras
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import parameter_server_strategy_v2
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
@ -44,6 +46,14 @@ from tensorflow.python.platform import test
from tensorflow.python.training.server_lib import ClusterSpec
# These vocabularies usually come from TFT or a Beam pipeline.
FEATURE_VOCAB = [
"avenger", "ironman", "batman", "hulk", "spiderman", "kingkong",
"wonder_woman"
]
LABEL_VOCAB = ["yes", "no"]
def make_coordinator(num_workers, num_ps):
cluster_def = multi_worker_test_base.create_in_process_cluster(
num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
@ -56,60 +66,63 @@ def make_coordinator(num_workers, num_ps):
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver))
class KPLTest(test.TestCase):
class KPLTest(test.TestCase, parameterized.TestCase):
@classmethod
def setUpClass(cls):
super(KPLTest, cls).setUpClass()
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
def testTrainAndServe(self):
# These vocabularies usually come from TFT or a Beam pipeline.
feature_vocab = [
"avenger", "ironman", "batman", "hulk", "spiderman", "kingkong",
"wonder_woman"
]
label_vocab = ["yes", "no"]
def define_kpls_for_training(self, use_adapt):
# Define KPLs under strategy's scope. Right now, if they have look up
# tables, they will be created on the client. Their variables will be
# created on PS. Ideally they should be cached on each worker since they
# will not be changed in a training step.
if use_adapt:
feature_lookup_layer = string_lookup.StringLookup(num_oov_indices=1)
feature_lookup_layer.adapt(FEATURE_VOCAB)
label_lookup_layer = string_lookup.StringLookup(
num_oov_indices=0, mask_token=None)
label_lookup_layer.adapt(LABEL_VOCAB)
else:
feature_lookup_layer = string_lookup.StringLookup(
vocabulary=FEATURE_VOCAB, num_oov_indices=1)
label_lookup_layer = string_lookup.StringLookup(
vocabulary=LABEL_VOCAB, num_oov_indices=0, mask_token=None)
raw_feature_input = keras.layers.Input(
shape=(3,), dtype=dtypes.string, name="feature", ragged=True)
feature_id_input = feature_lookup_layer(raw_feature_input)
# Model creates variables as well.
feature_ps = keras.Model({"features": raw_feature_input}, feature_id_input)
raw_label_input = keras.layers.Input(
shape=(), dtype=dtypes.string, name="label")
label_id_input = label_lookup_layer(raw_label_input)
label_ps = keras.Model({"label": raw_label_input}, label_id_input)
return feature_ps, label_ps
def define_reverse_lookup_layer(self):
# Only needed for serving.
label_inverse_lookup_layer = string_lookup.StringLookup(
num_oov_indices=1, mask_token=None, vocabulary=LABEL_VOCAB, invert=True)
return label_inverse_lookup_layer
@combinations.generate(
combinations.combine(mode=["eager"], use_adapt=[True, False]))
def testTrainAndServe(self, use_adapt):
with self.coordinator.strategy.scope():
# Define KPLs under strategy's scope. Right now, if they have look up
# tables, they will be created on the coordinator. Their variables will be
# created on PS. Ideally they should be cached on each worker since they
# will not be changed in a training step.
feature_lookup_layer = string_lookup.StringLookup()
raw_feature_input = keras.layers.Input(
shape=(3,), dtype=dtypes.string, name="feature", ragged=True)
feature_id_input = feature_lookup_layer(raw_feature_input)
# Model creates variables as well.
feature_ps = keras.Model({"features": raw_feature_input},
feature_id_input)
# TODO(yuefengz): adapt may be expensive for large vocab?
feature_lookup_layer.adapt(feature_vocab)
label_lookup_layer = string_lookup.StringLookup(
num_oov_indices=0, mask_token=None)
raw_label_input = keras.layers.Input(
shape=(), dtype=dtypes.string, name="label")
label_id_input = label_lookup_layer(raw_label_input)
label_ps = keras.Model({"label": raw_label_input}, label_id_input)
label_lookup_layer.adapt(label_vocab)
# Only needed for serving.
label_inverse_lookup_layer = string_lookup.StringLookup(
num_oov_indices=1,
mask_token=None,
vocabulary=label_lookup_layer.get_vocabulary(),
invert=True)
feature_ps, label_ps = self.define_kpls_for_training(use_adapt)
def dataset_fn():
def feature_and_label_gen():
while True:
features = random.sample(feature_vocab, 3)
features = random.sample(FEATURE_VOCAB, 3)
label = "yes" if "avenger" in features else "no"
yield {"features": features, "label": label}
@ -134,23 +147,27 @@ class KPLTest(test.TestCase):
distributed_dataset = self.coordinator.create_per_worker_dataset(
dataset_fn)
# Create the model. The input needs to be compatible with KPLs.
model_input = keras.layers.Input(
shape=(3,), dtype=dtypes.int64, name="model_input")
# input_dim includes a mask token and an oov token.
emb_output = keras.layers.Embedding(
input_dim=len(feature_lookup_layer.get_vocabulary()), output_dim=20)(
input_dim=len(FEATURE_VOCAB) + 2, output_dim=20)(
model_input)
emb_output = math_ops.reduce_mean(emb_output, axis=1)
dense_output = keras.layers.Dense(
units=1, activation="sigmoid")(
emb_output)
model = keras.Model({"features": model_input}, dense_output)
optimizer = rmsprop.RMSprop(learning_rate=0.01)
accuracy = keras.metrics.Accuracy()
@def_function.function
def worker_fn(iterator):
def train_step(iterator):
def replica_fn(iterator):
batch_data, labels = next(iterator)
with backprop.GradientTape() as tape:
pred = model(batch_data, training=True)
@ -164,7 +181,7 @@ class KPLTest(test.TestCase):
actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64)
accuracy.update_state(labels, actual_pred)
self.coordinator._strategy.run(train_step, args=(iterator,))
self.coordinator._strategy.run(replica_fn, args=(iterator,))
distributed_iterator = iter(distributed_dataset)
for _ in range(10):
@ -175,7 +192,7 @@ class KPLTest(test.TestCase):
# Create a saved model.
model.feature_ps = feature_ps
model.label_ps = label_ps
model.label_inverse_lookup_layer = label_inverse_lookup_layer
model.label_inverse_lookup_layer = self.define_reverse_lookup_layer()
def create_serving_signature(model):