Update the tutorial to simplify the datasets creation.
PiperOrigin-RevId: 341914190 Change-Id: I3a8e7cd62d068fae9924f47d8a5115a9d73109f5
This commit is contained in:
parent
d0dbdb763a
commit
cc4ca559ef
@ -66,6 +66,7 @@ def make_coordinator(num_workers, num_ps):
|
||||
parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver))
|
||||
|
||||
|
||||
# TODO(yuefengz): move this to keras/integration_tests.
|
||||
class KPLTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@classmethod
|
||||
@ -98,7 +99,7 @@ class KPLTest(test.TestCase, parameterized.TestCase):
|
||||
feature_ps = keras.Model({"features": raw_feature_input}, feature_id_input)
|
||||
|
||||
raw_label_input = keras.layers.Input(
|
||||
shape=(), dtype=dtypes.string, name="label")
|
||||
shape=(1,), dtype=dtypes.string, name="label")
|
||||
label_id_input = label_lookup_layer(raw_label_input)
|
||||
label_ps = keras.Model({"label": raw_label_input}, label_id_input)
|
||||
|
||||
@ -123,29 +124,22 @@ class KPLTest(test.TestCase, parameterized.TestCase):
|
||||
def feature_and_label_gen():
|
||||
while True:
|
||||
features = random.sample(FEATURE_VOCAB, 3)
|
||||
label = "yes" if "avenger" in features else "no"
|
||||
label = ["yes"] if "avenger" in features else ["no"]
|
||||
yield {"features": features, "label": label}
|
||||
|
||||
# The dataset will be created on the coordinator?
|
||||
# The dataset will be created on the coordinator.
|
||||
raw_dataset = dataset_ops.Dataset.from_generator(
|
||||
feature_and_label_gen,
|
||||
output_types={
|
||||
"features": dtypes.string,
|
||||
"label": dtypes.string
|
||||
}).shuffle(200).batch(32)
|
||||
preproc_dataset = raw_dataset.map(
|
||||
lambda x: { # pylint: disable=g-long-lambda
|
||||
"features": feature_ps(x["features"]),
|
||||
"label": label_ps(x["label"])
|
||||
})
|
||||
train_dataset = preproc_dataset.map(lambda x: ( # pylint: disable=g-long-lambda
|
||||
{
|
||||
"features": x["features"]
|
||||
}, [x["label"]]))
|
||||
return train_dataset
|
||||
output_signature={
|
||||
"features": tensor_spec.TensorSpec([3], dtypes.string),
|
||||
"label": tensor_spec.TensorSpec([1], dtypes.string)
|
||||
}).shuffle(100).batch(32)
|
||||
|
||||
distributed_dataset = self.coordinator.create_per_worker_dataset(
|
||||
dataset_fn)
|
||||
train_dataset = raw_dataset.map(lambda x: ( # pylint: disable=g-long-lambda
|
||||
{
|
||||
"features": feature_ps(x["features"])
|
||||
}, label_ps(x["label"])))
|
||||
return train_dataset
|
||||
|
||||
# Create the model. The input needs to be compatible with KPLs.
|
||||
model_input = keras.layers.Input(
|
||||
@ -161,33 +155,36 @@ class KPLTest(test.TestCase, parameterized.TestCase):
|
||||
emb_output)
|
||||
model = keras.Model({"features": model_input}, dense_output)
|
||||
|
||||
optimizer = rmsprop.RMSprop(learning_rate=0.01)
|
||||
optimizer = rmsprop.RMSprop(learning_rate=0.1)
|
||||
accuracy = keras.metrics.Accuracy()
|
||||
|
||||
@def_function.function
|
||||
def worker_fn(iterator):
|
||||
@def_function.function
|
||||
def worker_fn(iterator):
|
||||
|
||||
def replica_fn(iterator):
|
||||
batch_data, labels = next(iterator)
|
||||
with backprop.GradientTape() as tape:
|
||||
pred = model(batch_data, training=True)
|
||||
loss = nn.compute_average_loss(
|
||||
keras.losses.BinaryCrossentropy(
|
||||
reduction=loss_reduction.ReductionV2.NONE)(labels, pred))
|
||||
gradients = tape.gradient(loss, model.trainable_variables)
|
||||
def replica_fn(iterator):
|
||||
batch_data, labels = next(iterator)
|
||||
with backprop.GradientTape() as tape:
|
||||
pred = model(batch_data, training=True)
|
||||
loss = nn.compute_average_loss(
|
||||
keras.losses.BinaryCrossentropy(
|
||||
reduction=loss_reduction.ReductionV2.NONE)(labels, pred))
|
||||
gradients = tape.gradient(loss, model.trainable_variables)
|
||||
|
||||
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
||||
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
||||
|
||||
actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64)
|
||||
accuracy.update_state(labels, actual_pred)
|
||||
actual_pred = math_ops.cast(math_ops.greater(pred, 0.5), dtypes.int64)
|
||||
accuracy.update_state(labels, actual_pred)
|
||||
|
||||
self.coordinator._strategy.run(replica_fn, args=(iterator,))
|
||||
self.coordinator._strategy.run(replica_fn, args=(iterator,))
|
||||
|
||||
distributed_dataset = self.coordinator.create_per_worker_dataset(dataset_fn)
|
||||
distributed_iterator = iter(distributed_dataset)
|
||||
for _ in range(10):
|
||||
self.coordinator.schedule(worker_fn, args=(distributed_iterator,))
|
||||
self.coordinator.join()
|
||||
self.assertGreater(accuracy.result().numpy(), 0.0)
|
||||
for _ in range(4):
|
||||
accuracy.reset_states()
|
||||
for _ in range(7):
|
||||
self.coordinator.schedule(worker_fn, args=(distributed_iterator,))
|
||||
self.coordinator.join()
|
||||
self.assertGreater(accuracy.result().numpy(), 0.5)
|
||||
|
||||
# Create a saved model.
|
||||
model.feature_ps = feature_ps
|
||||
|
Loading…
Reference in New Issue
Block a user