Add a KPL test for PSStrategy with precomputed states.
PiperOrigin-RevId: 338319527 Change-Id: I91edcb84994b7d053dcf35c14fa791de26374bcf
This commit is contained in:
parent
dbf191bb17
commit
898c3b6d33
@ -867,6 +867,7 @@ py_test(
|
|||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
"//tensorflow/python/compat:v2_compat",
|
"//tensorflow/python/compat:v2_compat",
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
"//tensorflow/python/distribute:combinations",
|
||||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||||
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
"//tensorflow/python/distribute:parameter_server_strategy_v2",
|
||||||
"//tensorflow/python/distribute:sharded_variable",
|
"//tensorflow/python/distribute:sharded_variable",
|
||||||
|
@ -21,10 +21,12 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.compat import v2_compat
|
from tensorflow.python.compat import v2_compat
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.distribute import combinations
|
||||||
from tensorflow.python.distribute import multi_worker_test_base
|
from tensorflow.python.distribute import multi_worker_test_base
|
||||||
from tensorflow.python.distribute import parameter_server_strategy_v2
|
from tensorflow.python.distribute import parameter_server_strategy_v2
|
||||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||||
@ -44,6 +46,14 @@ from tensorflow.python.platform import test
|
|||||||
from tensorflow.python.training.server_lib import ClusterSpec
|
from tensorflow.python.training.server_lib import ClusterSpec
|
||||||
|
|
||||||
|
|
||||||
|
# These vocabularies usually come from TFT or a Beam pipeline.
|
||||||
|
FEATURE_VOCAB = [
|
||||||
|
"avenger", "ironman", "batman", "hulk", "spiderman", "kingkong",
|
||||||
|
"wonder_woman"
|
||||||
|
]
|
||||||
|
LABEL_VOCAB = ["yes", "no"]
|
||||||
|
|
||||||
|
|
||||||
def make_coordinator(num_workers, num_ps):
|
def make_coordinator(num_workers, num_ps):
|
||||||
cluster_def = multi_worker_test_base.create_in_process_cluster(
|
cluster_def = multi_worker_test_base.create_in_process_cluster(
|
||||||
num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
|
num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
|
||||||
@ -56,60 +66,63 @@ def make_coordinator(num_workers, num_ps):
|
|||||||
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver))
|
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver))
|
||||||
|
|
||||||
|
|
||||||
class KPLTest(test.TestCase):
|
class KPLTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
super(KPLTest, cls).setUpClass()
|
super(KPLTest, cls).setUpClass()
|
||||||
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
|
cls.coordinator = make_coordinator(num_workers=3, num_ps=2)
|
||||||
|
|
||||||
def testTrainAndServe(self):
|
def define_kpls_for_training(self, use_adapt):
|
||||||
# These vocabularies usually come from TFT or a Beam pipeline.
|
|
||||||
feature_vocab = [
|
|
||||||
"avenger", "ironman", "batman", "hulk", "spiderman", "kingkong",
|
|
||||||
"wonder_woman"
|
|
||||||
]
|
|
||||||
label_vocab = ["yes", "no"]
|
|
||||||
|
|
||||||
with self.coordinator.strategy.scope():
|
|
||||||
|
|
||||||
# Define KPLs under strategy's scope. Right now, if they have look up
|
# Define KPLs under strategy's scope. Right now, if they have look up
|
||||||
# tables, they will be created on the coordinator. Their variables will be
|
# tables, they will be created on the client. Their variables will be
|
||||||
# created on PS. Ideally they should be cached on each worker since they
|
# created on PS. Ideally they should be cached on each worker since they
|
||||||
# will not be changed in a training step.
|
# will not be changed in a training step.
|
||||||
feature_lookup_layer = string_lookup.StringLookup()
|
if use_adapt:
|
||||||
|
feature_lookup_layer = string_lookup.StringLookup(num_oov_indices=1)
|
||||||
|
feature_lookup_layer.adapt(FEATURE_VOCAB)
|
||||||
|
label_lookup_layer = string_lookup.StringLookup(
|
||||||
|
num_oov_indices=0, mask_token=None)
|
||||||
|
label_lookup_layer.adapt(LABEL_VOCAB)
|
||||||
|
else:
|
||||||
|
feature_lookup_layer = string_lookup.StringLookup(
|
||||||
|
vocabulary=FEATURE_VOCAB, num_oov_indices=1)
|
||||||
|
label_lookup_layer = string_lookup.StringLookup(
|
||||||
|
vocabulary=LABEL_VOCAB, num_oov_indices=0, mask_token=None)
|
||||||
|
|
||||||
raw_feature_input = keras.layers.Input(
|
raw_feature_input = keras.layers.Input(
|
||||||
shape=(3,), dtype=dtypes.string, name="feature", ragged=True)
|
shape=(3,), dtype=dtypes.string, name="feature", ragged=True)
|
||||||
feature_id_input = feature_lookup_layer(raw_feature_input)
|
feature_id_input = feature_lookup_layer(raw_feature_input)
|
||||||
|
|
||||||
# Model creates variables as well.
|
# Model creates variables as well.
|
||||||
feature_ps = keras.Model({"features": raw_feature_input},
|
feature_ps = keras.Model({"features": raw_feature_input}, feature_id_input)
|
||||||
feature_id_input)
|
|
||||||
|
|
||||||
# TODO(yuefengz): adapt may be expensive for large vocab?
|
|
||||||
feature_lookup_layer.adapt(feature_vocab)
|
|
||||||
|
|
||||||
label_lookup_layer = string_lookup.StringLookup(
|
|
||||||
num_oov_indices=0, mask_token=None)
|
|
||||||
raw_label_input = keras.layers.Input(
|
raw_label_input = keras.layers.Input(
|
||||||
shape=(), dtype=dtypes.string, name="label")
|
shape=(), dtype=dtypes.string, name="label")
|
||||||
label_id_input = label_lookup_layer(raw_label_input)
|
label_id_input = label_lookup_layer(raw_label_input)
|
||||||
label_ps = keras.Model({"label": raw_label_input}, label_id_input)
|
label_ps = keras.Model({"label": raw_label_input}, label_id_input)
|
||||||
|
|
||||||
label_lookup_layer.adapt(label_vocab)
|
return feature_ps, label_ps
|
||||||
|
|
||||||
|
def define_reverse_lookup_layer(self):
|
||||||
# Only needed for serving.
|
# Only needed for serving.
|
||||||
label_inverse_lookup_layer = string_lookup.StringLookup(
|
label_inverse_lookup_layer = string_lookup.StringLookup(
|
||||||
num_oov_indices=1,
|
num_oov_indices=1, mask_token=None, vocabulary=LABEL_VOCAB, invert=True)
|
||||||
mask_token=None,
|
return label_inverse_lookup_layer
|
||||||
vocabulary=label_lookup_layer.get_vocabulary(),
|
|
||||||
invert=True)
|
@combinations.generate(
|
||||||
|
combinations.combine(mode=["eager"], use_adapt=[True, False]))
|
||||||
|
def testTrainAndServe(self, use_adapt):
|
||||||
|
|
||||||
|
with self.coordinator.strategy.scope():
|
||||||
|
|
||||||
|
feature_ps, label_ps = self.define_kpls_for_training(use_adapt)
|
||||||
|
|
||||||
def dataset_fn():
|
def dataset_fn():
|
||||||
|
|
||||||
def feature_and_label_gen():
|
def feature_and_label_gen():
|
||||||
while True:
|
while True:
|
||||||
features = random.sample(feature_vocab, 3)
|
features = random.sample(FEATURE_VOCAB, 3)
|
||||||
label = "yes" if "avenger" in features else "no"
|
label = "yes" if "avenger" in features else "no"
|
||||||
yield {"features": features, "label": label}
|
yield {"features": features, "label": label}
|
||||||
|
|
||||||
@ -134,23 +147,27 @@ class KPLTest(test.TestCase):
|
|||||||
distributed_dataset = self.coordinator.create_per_worker_dataset(
|
distributed_dataset = self.coordinator.create_per_worker_dataset(
|
||||||
dataset_fn)
|
dataset_fn)
|
||||||
|
|
||||||
|
# Create the model. The input needs to be compatible with KPLs.
|
||||||
model_input = keras.layers.Input(
|
model_input = keras.layers.Input(
|
||||||
shape=(3,), dtype=dtypes.int64, name="model_input")
|
shape=(3,), dtype=dtypes.int64, name="model_input")
|
||||||
|
|
||||||
|
# input_dim includes a mask token and an oov token.
|
||||||
emb_output = keras.layers.Embedding(
|
emb_output = keras.layers.Embedding(
|
||||||
input_dim=len(feature_lookup_layer.get_vocabulary()), output_dim=20)(
|
input_dim=len(FEATURE_VOCAB) + 2, output_dim=20)(
|
||||||
model_input)
|
model_input)
|
||||||
emb_output = math_ops.reduce_mean(emb_output, axis=1)
|
emb_output = math_ops.reduce_mean(emb_output, axis=1)
|
||||||
dense_output = keras.layers.Dense(
|
dense_output = keras.layers.Dense(
|
||||||
units=1, activation="sigmoid")(
|
units=1, activation="sigmoid")(
|
||||||
emb_output)
|
emb_output)
|
||||||
model = keras.Model({"features": model_input}, dense_output)
|
model = keras.Model({"features": model_input}, dense_output)
|
||||||
|
|
||||||
optimizer = rmsprop.RMSprop(learning_rate=0.01)
|
optimizer = rmsprop.RMSprop(learning_rate=0.01)
|
||||||
accuracy = keras.metrics.Accuracy()
|
accuracy = keras.metrics.Accuracy()
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
def worker_fn(iterator):
|
def worker_fn(iterator):
|
||||||
|
|
||||||
def train_step(iterator):
|
def replica_fn(iterator):
|
||||||
batch_data, labels = next(iterator)
|
batch_data, labels = next(iterator)
|
||||||
with backprop.GradientTape() as tape:
|
with backprop.GradientTape() as tape:
|
||||||
pred = model(batch_data, training=True)
|
pred = model(batch_data, training=True)
|
||||||
@ -164,7 +181,7 @@ class KPLTest(test.TestCase):
|
|||||||
actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64)
|
actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64)
|
||||||
accuracy.update_state(labels, actual_pred)
|
accuracy.update_state(labels, actual_pred)
|
||||||
|
|
||||||
self.coordinator._strategy.run(train_step, args=(iterator,))
|
self.coordinator._strategy.run(replica_fn, args=(iterator,))
|
||||||
|
|
||||||
distributed_iterator = iter(distributed_dataset)
|
distributed_iterator = iter(distributed_dataset)
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
@ -175,7 +192,7 @@ class KPLTest(test.TestCase):
|
|||||||
# Create a saved model.
|
# Create a saved model.
|
||||||
model.feature_ps = feature_ps
|
model.feature_ps = feature_ps
|
||||||
model.label_ps = label_ps
|
model.label_ps = label_ps
|
||||||
model.label_inverse_lookup_layer = label_inverse_lookup_layer
|
model.label_inverse_lookup_layer = self.define_reverse_lookup_layer()
|
||||||
|
|
||||||
def create_serving_signature(model):
|
def create_serving_signature(model):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user