Split distribute/custom_training_loop_test into three parts as it is timing out on our kokoro TPU continuous tests.
PiperOrigin-RevId: 294765478 Change-Id: I574ca0433ade67673e1b5ea731db94e40e28ae5f
This commit is contained in:
parent
bceb1d7854
commit
a7f1d52b03
@ -925,9 +925,66 @@ distribute_py_test(
|
||||
)
|
||||
|
||||
distribute_py_test(
|
||||
name = "custom_training_loop_test",
|
||||
srcs = ["custom_training_loop_test.py"],
|
||||
main = "custom_training_loop_test.py",
|
||||
name = "custom_training_loop_gradient_test",
|
||||
srcs = ["custom_training_loop_gradient_test.py"],
|
||||
main = "custom_training_loop_gradient_test.py",
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
distribute_py_test(
|
||||
name = "custom_training_loop_input_test",
|
||||
srcs = ["custom_training_loop_input_test.py"],
|
||||
main = "custom_training_loop_input_test.py",
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
distribute_py_test(
|
||||
name = "custom_training_loop_metrics_test",
|
||||
srcs = ["custom_training_loop_metrics_test.py"],
|
||||
main = "custom_training_loop_metrics_test.py",
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
distribute_py_test(
|
||||
name = "custom_training_loop_models_test",
|
||||
srcs = ["custom_training_loop_models_test.py"],
|
||||
main = "custom_training_loop_models_test.py",
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
],
|
||||
|
@ -0,0 +1,152 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for custom training loops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
def get_dataset_from_tensor_slices(inp_array):
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices(inp_array)
|
||||
# TODO(b/138326910): Remove Dataset V1 version once bug resolved.
|
||||
if not tf2.enabled():
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(inp_array)
|
||||
return dataset
|
||||
|
||||
|
||||
class AssertFlattenedMixin(object):
|
||||
"""Mixin for specialized asserts."""
|
||||
|
||||
def assert_equal_flattened(self, expected_results, actual_results):
|
||||
"""Asserts that flattened results are equal.
|
||||
|
||||
Due to the number of replicas in the strategy, the output may have a
|
||||
different structure and needs to be flattened for comparison.
|
||||
|
||||
Args:
|
||||
expected_results: The results expected as a result of a computation.
|
||||
actual_results: The actual results of a computation.
|
||||
"""
|
||||
self.assertEqual(len(expected_results), len(actual_results))
|
||||
|
||||
for i, expected_result in enumerate(expected_results):
|
||||
final_result = []
|
||||
actual_result = actual_results[i]
|
||||
for val in actual_result:
|
||||
final_result.extend(val.numpy())
|
||||
self.assertAllEqual(expected_result, final_result)
|
||||
|
||||
|
||||
class GradientTapeTest(test.TestCase, parameterized.TestCase,
|
||||
AssertFlattenedMixin):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def testStepInFunctionGradient(self, distribution):
|
||||
dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
|
||||
|
||||
@def_function.function
|
||||
def train_step(x):
|
||||
def computation(x):
|
||||
return math_ops.square(x)
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x) # Manually watch non-variable tensors.
|
||||
y = computation(x)
|
||||
grads = tape.gradient(y, x)
|
||||
return grads
|
||||
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
results = []
|
||||
for x in dist_dataset:
|
||||
output = distribution.experimental_local_results(
|
||||
distribution.experimental_run_v2(train_step, args=(x,)))
|
||||
results.append(output)
|
||||
self.assert_equal_flattened([[10., 12.], [14., 16.]], results)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def testRunInFunctionGradient(self, distribution):
|
||||
dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
|
||||
|
||||
@def_function.function
|
||||
def run(x):
|
||||
def train_step(x):
|
||||
def computation(x):
|
||||
return math_ops.square(x)
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x) # Manually watch non-variable tensors.
|
||||
y = computation(x)
|
||||
grads = tape.gradient(y, x)
|
||||
return grads
|
||||
return distribution.experimental_local_results(
|
||||
distribution.experimental_run_v2(train_step, args=(x,)))
|
||||
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
results = []
|
||||
for x in dist_dataset:
|
||||
output = run(x)
|
||||
results.append(output)
|
||||
self.assert_equal_flattened([[10., 12.], [14., 16.]], results)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"],
|
||||
model_in_tf_function=[True, False]
|
||||
))
|
||||
def testNestedFunction(self, distribution, model_in_tf_function):
|
||||
def model(x):
|
||||
return x * x
|
||||
|
||||
if model_in_tf_function:
|
||||
model = def_function.function(model)
|
||||
|
||||
with distribution.scope():
|
||||
x = variables.Variable(1.0)
|
||||
|
||||
@def_function.function
|
||||
def train_step():
|
||||
def replica_step():
|
||||
with backprop.GradientTape() as tape:
|
||||
y = model(x)
|
||||
return tape.gradient(y, x)
|
||||
return distribution.experimental_run_v2(replica_step)
|
||||
|
||||
grads = distribution.experimental_local_results(train_step())
|
||||
self.assertLen(grads, distribution.num_replicas_in_sync)
|
||||
self.assertTrue(all(g is not None for g in grads))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -18,18 +18,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -596,447 +591,5 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||
self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
|
||||
|
||||
|
||||
class GradientTapeTest(test.TestCase, parameterized.TestCase,
|
||||
AssertFlattenedMixin):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def testStepInFunctionGradient(self, distribution):
|
||||
dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
|
||||
|
||||
@def_function.function
|
||||
def train_step(x):
|
||||
def computation(x):
|
||||
return math_ops.square(x)
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x) # Manually watch non-variable tensors.
|
||||
y = computation(x)
|
||||
grads = tape.gradient(y, x)
|
||||
return grads
|
||||
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
results = []
|
||||
for x in dist_dataset:
|
||||
output = distribution.experimental_local_results(
|
||||
distribution.experimental_run_v2(train_step, args=(x,)))
|
||||
results.append(output)
|
||||
self.assert_equal_flattened([[10., 12.], [14., 16.]], results)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def testRunInFunctionGradient(self, distribution):
|
||||
dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
|
||||
|
||||
@def_function.function
|
||||
def run(x):
|
||||
def train_step(x):
|
||||
def computation(x):
|
||||
return math_ops.square(x)
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(x) # Manually watch non-variable tensors.
|
||||
y = computation(x)
|
||||
grads = tape.gradient(y, x)
|
||||
return grads
|
||||
return distribution.experimental_local_results(
|
||||
distribution.experimental_run_v2(train_step, args=(x,)))
|
||||
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
results = []
|
||||
for x in dist_dataset:
|
||||
output = run(x)
|
||||
results.append(output)
|
||||
self.assert_equal_flattened([[10., 12.], [14., 16.]], results)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"],
|
||||
model_in_tf_function=[True, False]
|
||||
))
|
||||
def testNestedFunction(self, distribution, model_in_tf_function):
|
||||
def model(x):
|
||||
return x * x
|
||||
|
||||
if model_in_tf_function:
|
||||
model = def_function.function(model)
|
||||
|
||||
with distribution.scope():
|
||||
x = variables.Variable(1.0)
|
||||
|
||||
@def_function.function
|
||||
def train_step():
|
||||
def replica_step():
|
||||
with backprop.GradientTape() as tape:
|
||||
y = model(x)
|
||||
return tape.gradient(y, x)
|
||||
return distribution.experimental_run_v2(replica_step)
|
||||
|
||||
grads = distribution.experimental_local_results(train_step())
|
||||
self.assertLen(grads, distribution.num_replicas_in_sync)
|
||||
self.assertTrue(all(g is not None for g in grads))
|
||||
|
||||
|
||||
class KerasModelsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def test_single_keras_layer_experimental_run(self, distribution):
|
||||
dataset = self._get_dataset()
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
with distribution.scope():
|
||||
model = keras.layers.Dense(4, name="dense")
|
||||
|
||||
@def_function.function
|
||||
def train_step(iterator):
|
||||
def step_fn(inputs):
|
||||
images, targets = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
outputs = model(images)
|
||||
loss = math_ops.reduce_sum(outputs - targets)
|
||||
grads = tape.gradient(loss, model.variables)
|
||||
return grads
|
||||
|
||||
outputs = distribution.experimental_run_v2(
|
||||
step_fn, args=(next(iterator),))
|
||||
return nest.map_structure(distribution.experimental_local_results,
|
||||
outputs)
|
||||
|
||||
train_step(input_iterator)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def test_keras_model_creation_experimental_run(self, distribution):
|
||||
dataset = self._get_dataset()
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
with distribution.scope():
|
||||
model = self._get_model()
|
||||
|
||||
@def_function.function
|
||||
def train_step(iterator):
|
||||
def step_fn(inputs):
|
||||
images, targets = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
outputs = model(images)
|
||||
loss = math_ops.reduce_sum(outputs - targets)
|
||||
grads = tape.gradient(loss, model.variables)
|
||||
return grads
|
||||
|
||||
outputs = distribution.experimental_run_v2(
|
||||
step_fn, args=(next(iterator),))
|
||||
return nest.map_structure(distribution.experimental_local_results,
|
||||
outputs)
|
||||
|
||||
train_step(input_iterator)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def test_keras_model_optimizer_experimental_run(self, distribution):
|
||||
dataset = self._get_dataset()
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
with distribution.scope():
|
||||
model = self._get_model()
|
||||
optimizer = keras.optimizer_v2.rmsprop.RMSprop()
|
||||
|
||||
@def_function.function
|
||||
def train_step(iterator):
|
||||
def step_fn(inputs):
|
||||
images, targets = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
outputs = model(images)
|
||||
loss = math_ops.reduce_sum(outputs - targets)
|
||||
grads = tape.gradient(loss, model.variables)
|
||||
optimizer.apply_gradients(zip(grads, model.variables))
|
||||
return loss
|
||||
|
||||
outputs = distribution.experimental_run_v2(
|
||||
step_fn, args=(next(iterator),))
|
||||
return nest.map_structure(distribution.experimental_local_results,
|
||||
outputs)
|
||||
|
||||
train_step(input_iterator)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def test_keras_subclass_model_optimizer_experimental_run(self, distribution):
|
||||
def get_subclass_model():
|
||||
|
||||
class KerasSubclassModel(keras.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(KerasSubclassModel, self).__init__()
|
||||
self.l = keras.layers.Dense(4, name="dense")
|
||||
|
||||
def call(self, x):
|
||||
return self.l(x)
|
||||
|
||||
return KerasSubclassModel()
|
||||
dataset = self._get_dataset()
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
with distribution.scope():
|
||||
model = get_subclass_model()
|
||||
optimizer = keras.optimizer_v2.rmsprop.RMSprop()
|
||||
|
||||
@def_function.function
|
||||
def train_step(iterator):
|
||||
def step_fn(inputs):
|
||||
images, targets = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
outputs = model(images)
|
||||
loss = math_ops.reduce_sum(outputs - targets)
|
||||
grads = tape.gradient(loss, model.variables)
|
||||
optimizer.apply_gradients(zip(grads, model.variables))
|
||||
return loss
|
||||
|
||||
outputs = distribution.experimental_run_v2(
|
||||
step_fn, args=(next(iterator),))
|
||||
return nest.map_structure(distribution.experimental_local_results,
|
||||
outputs)
|
||||
|
||||
train_step(input_iterator)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def test_keras_model_optimizer_experimental_run_loop(self, distribution):
|
||||
dataset = self._get_dataset()
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
with distribution.scope():
|
||||
model = self._get_model()
|
||||
optimizer = keras.optimizer_v2.rmsprop.RMSprop()
|
||||
|
||||
@def_function.function
|
||||
def train_step(iterator):
|
||||
def step_fn(inputs):
|
||||
images, targets = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
outputs = model(images)
|
||||
loss = math_ops.reduce_sum(outputs - targets)
|
||||
grads = tape.gradient(loss, model.variables)
|
||||
optimizer.apply_gradients(zip(grads, model.variables))
|
||||
return loss
|
||||
|
||||
for _ in range(5):
|
||||
distribution.experimental_run_v2(step_fn, args=(next(iterator),))
|
||||
|
||||
train_step(input_iterator)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def test_lstm(self, distribution):
|
||||
|
||||
batch_size = 32
|
||||
|
||||
def create_lstm_model():
|
||||
model = keras.models.Sequential()
|
||||
# We only have LSTM variables so we can detect no gradient issues more
|
||||
# easily.
|
||||
model.add(
|
||||
keras.layers.LSTM(1, return_sequences=False, input_shape=(10, 1)))
|
||||
return model
|
||||
|
||||
def create_lstm_data():
|
||||
seq_length = 10
|
||||
|
||||
x_train = np.random.rand(batch_size, seq_length, 1).astype("float32")
|
||||
y_train = np.random.rand(batch_size, 1).astype("float32")
|
||||
return x_train, y_train
|
||||
|
||||
x, y = create_lstm_data()
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
with distribution.scope():
|
||||
model = create_lstm_model()
|
||||
optimizer = keras.optimizer_v2.gradient_descent.SGD()
|
||||
|
||||
@def_function.function
|
||||
def train_step(input_iterator):
|
||||
|
||||
def step_fn(inputs):
|
||||
inps, targ = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
output = model(inps)
|
||||
loss = math_ops.reduce_mean(
|
||||
keras.losses.binary_crossentropy(
|
||||
y_true=targ, y_pred=output, from_logits=False))
|
||||
grads = tape.gradient(loss, model.variables)
|
||||
optimizer.apply_gradients(zip(grads, model.variables))
|
||||
return loss
|
||||
|
||||
outputs = distribution.experimental_run_v2(
|
||||
step_fn, args=(next(input_iterator),))
|
||||
return distribution.experimental_local_results(outputs)
|
||||
|
||||
train_step(input_iterator)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies, mode=["eager"]))
|
||||
def test_nested_tf_functions(self, distribution):
|
||||
# The test builds two computations with keras layers, one with nested
|
||||
# tf.function, and the other without nested tf.function. We run these
|
||||
# computations independently on the model with same weights, and make sure
|
||||
# the variables are still the same after one training step.
|
||||
|
||||
inputs = np.random.random((10, 3)).astype(np.float32)
|
||||
targets = np.ones((10, 4), dtype=np.float32)
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).repeat()
|
||||
dataset = dataset.batch(10, drop_remainder=True)
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
def get_model():
|
||||
x = keras.layers.Input(shape=(3,), name="input")
|
||||
y = keras.layers.Dense(4, name="dense")(x)
|
||||
model = keras.Model(x, y)
|
||||
return model
|
||||
|
||||
with distribution.scope():
|
||||
model = get_model()
|
||||
optimizer = keras.optimizer_v2.gradient_descent.SGD(0.1, momentum=0.01)
|
||||
weights_file = os.path.join(self.get_temp_dir(), ".h5")
|
||||
model.save_weights(weights_file)
|
||||
model2 = get_model()
|
||||
model2.load_weights(weights_file)
|
||||
|
||||
# Make sure model and model2 variables are in sync when initialized.
|
||||
for model_v, model2_v in zip(model.variables, model2.variables):
|
||||
self.assertAllClose(model_v.numpy(), model2_v.numpy())
|
||||
|
||||
def compute_loss(images, targets):
|
||||
outputs = model(images)
|
||||
return math_ops.reduce_sum(outputs - targets)
|
||||
|
||||
@def_function.function
|
||||
def train_step_without_nested_tf_function(inputs):
|
||||
|
||||
def step_fn(inputs):
|
||||
images, targets = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = compute_loss(images, targets)
|
||||
grads = tape.gradient(loss, model.variables)
|
||||
optimizer.apply_gradients(zip(grads, model.variables))
|
||||
|
||||
distribution.experimental_run_v2(step_fn, args=(inputs,))
|
||||
|
||||
@def_function.function
|
||||
def compute_loss2(images, targets):
|
||||
outputs = model2(images)
|
||||
return math_ops.reduce_sum(outputs - targets)
|
||||
|
||||
@def_function.function
|
||||
def train_step_with_nested_tf_function(inputs):
|
||||
|
||||
def step_fn(inputs):
|
||||
images, targets = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = compute_loss2(images, targets)
|
||||
grads = tape.gradient(loss, model2.variables)
|
||||
optimizer.apply_gradients(zip(grads, model2.variables))
|
||||
|
||||
distribution.experimental_run_v2(step_fn, args=(inputs,))
|
||||
|
||||
inputs = next(input_iterator)
|
||||
|
||||
train_step_without_nested_tf_function(inputs)
|
||||
train_step_with_nested_tf_function(inputs)
|
||||
|
||||
# Make sure model and model2 variables are still in sync.
|
||||
for model_v, model2_v in zip(model.variables, model2.variables):
|
||||
self.assertAllClose(model_v.numpy(), model2_v.numpy())
|
||||
|
||||
def _get_dataset(self):
|
||||
inputs = np.zeros((10, 3), dtype=np.float32)
|
||||
targets = np.zeros((10, 4), dtype=np.float32)
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
|
||||
dataset = dataset.repeat(100)
|
||||
dataset = dataset.batch(10, drop_remainder=True)
|
||||
return dataset
|
||||
|
||||
def _get_model(self):
|
||||
x = keras.layers.Input(shape=(3,), name="input")
|
||||
y = keras.layers.Dense(4, name="dense")(x)
|
||||
model = keras.Model(x, y)
|
||||
return model
|
||||
|
||||
|
||||
class KerasMetricsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def test_multiple_keras_metrics_experimental_run(self, distribution):
|
||||
with distribution.scope():
|
||||
loss_metric = keras.metrics.Mean("loss", dtype=np.float32)
|
||||
loss_metric_2 = keras.metrics.Mean("loss_2", dtype=np.float32)
|
||||
|
||||
@def_function.function
|
||||
def train_step():
|
||||
def step_fn():
|
||||
loss = constant_op.constant(5.0, dtype=np.float32)
|
||||
loss_metric.update_state(loss)
|
||||
loss_metric_2.update_state(loss)
|
||||
|
||||
distribution.experimental_run_v2(step_fn)
|
||||
|
||||
train_step()
|
||||
self.assertEqual(loss_metric.result().numpy(),
|
||||
loss_metric_2.result().numpy())
|
||||
self.assertEqual(loss_metric.result().numpy(), 5.0)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def test_update_keras_metric_declared_in_strategy_scope(self, distribution):
|
||||
with distribution.scope():
|
||||
metric = keras.metrics.Mean("test_metric", dtype=np.float32)
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10).batch(2)
|
||||
dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
|
||||
@def_function.function
|
||||
def step_fn(i):
|
||||
metric.update_state(i)
|
||||
|
||||
for i in dataset:
|
||||
distribution.experimental_run_v2(step_fn, args=(i,))
|
||||
|
||||
# This should be the mean of integers 0-9 which has a sum of 45 and a count
|
||||
# of 10 resulting in mean of 4.5.
|
||||
self.assertEqual(metric.result().numpy(), 4.5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -0,0 +1,84 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for custom training loops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
|
||||
|
||||
class KerasMetricsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def test_multiple_keras_metrics_experimental_run(self, distribution):
|
||||
with distribution.scope():
|
||||
loss_metric = keras.metrics.Mean("loss", dtype=np.float32)
|
||||
loss_metric_2 = keras.metrics.Mean("loss_2", dtype=np.float32)
|
||||
|
||||
@def_function.function
|
||||
def train_step():
|
||||
def step_fn():
|
||||
loss = constant_op.constant(5.0, dtype=np.float32)
|
||||
loss_metric.update_state(loss)
|
||||
loss_metric_2.update_state(loss)
|
||||
|
||||
distribution.experimental_run_v2(step_fn)
|
||||
|
||||
train_step()
|
||||
self.assertEqual(loss_metric.result().numpy(),
|
||||
loss_metric_2.result().numpy())
|
||||
self.assertEqual(loss_metric.result().numpy(), 5.0)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def test_update_keras_metric_declared_in_strategy_scope(self, distribution):
|
||||
with distribution.scope():
|
||||
metric = keras.metrics.Mean("test_metric", dtype=np.float32)
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10).batch(2)
|
||||
dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
|
||||
@def_function.function
|
||||
def step_fn(i):
|
||||
metric.update_state(i)
|
||||
|
||||
for i in dataset:
|
||||
distribution.experimental_run_v2(step_fn, args=(i,))
|
||||
|
||||
# This should be the mean of integers 0-9 which has a sum of 45 and a count
|
||||
# of 10 resulting in mean of 4.5.
|
||||
self.assertEqual(metric.result().numpy(), 4.5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
344
tensorflow/python/distribute/custom_training_loop_models_test.py
Normal file
344
tensorflow/python/distribute/custom_training_loop_models_test.py
Normal file
@ -0,0 +1,344 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for custom training loops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class KerasModelsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def test_single_keras_layer_experimental_run(self, distribution):
|
||||
dataset = self._get_dataset()
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
with distribution.scope():
|
||||
model = keras.layers.Dense(4, name="dense")
|
||||
|
||||
@def_function.function
|
||||
def train_step(iterator):
|
||||
def step_fn(inputs):
|
||||
images, targets = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
outputs = model(images)
|
||||
loss = math_ops.reduce_sum(outputs - targets)
|
||||
grads = tape.gradient(loss, model.variables)
|
||||
return grads
|
||||
|
||||
outputs = distribution.experimental_run_v2(
|
||||
step_fn, args=(next(iterator),))
|
||||
return nest.map_structure(distribution.experimental_local_results,
|
||||
outputs)
|
||||
|
||||
train_step(input_iterator)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def test_keras_model_creation_experimental_run(self, distribution):
|
||||
dataset = self._get_dataset()
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
with distribution.scope():
|
||||
model = self._get_model()
|
||||
|
||||
@def_function.function
|
||||
def train_step(iterator):
|
||||
def step_fn(inputs):
|
||||
images, targets = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
outputs = model(images)
|
||||
loss = math_ops.reduce_sum(outputs - targets)
|
||||
grads = tape.gradient(loss, model.variables)
|
||||
return grads
|
||||
|
||||
outputs = distribution.experimental_run_v2(
|
||||
step_fn, args=(next(iterator),))
|
||||
return nest.map_structure(distribution.experimental_local_results,
|
||||
outputs)
|
||||
|
||||
train_step(input_iterator)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def test_keras_model_optimizer_experimental_run(self, distribution):
|
||||
dataset = self._get_dataset()
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
with distribution.scope():
|
||||
model = self._get_model()
|
||||
optimizer = keras.optimizer_v2.rmsprop.RMSprop()
|
||||
|
||||
@def_function.function
|
||||
def train_step(iterator):
|
||||
def step_fn(inputs):
|
||||
images, targets = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
outputs = model(images)
|
||||
loss = math_ops.reduce_sum(outputs - targets)
|
||||
grads = tape.gradient(loss, model.variables)
|
||||
optimizer.apply_gradients(zip(grads, model.variables))
|
||||
return loss
|
||||
|
||||
outputs = distribution.experimental_run_v2(
|
||||
step_fn, args=(next(iterator),))
|
||||
return nest.map_structure(distribution.experimental_local_results,
|
||||
outputs)
|
||||
|
||||
train_step(input_iterator)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def test_keras_subclass_model_optimizer_experimental_run(self, distribution):
|
||||
def get_subclass_model():
|
||||
|
||||
class KerasSubclassModel(keras.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(KerasSubclassModel, self).__init__()
|
||||
self.l = keras.layers.Dense(4, name="dense")
|
||||
|
||||
def call(self, x):
|
||||
return self.l(x)
|
||||
|
||||
return KerasSubclassModel()
|
||||
dataset = self._get_dataset()
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
with distribution.scope():
|
||||
model = get_subclass_model()
|
||||
optimizer = keras.optimizer_v2.rmsprop.RMSprop()
|
||||
|
||||
@def_function.function
|
||||
def train_step(iterator):
|
||||
def step_fn(inputs):
|
||||
images, targets = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
outputs = model(images)
|
||||
loss = math_ops.reduce_sum(outputs - targets)
|
||||
grads = tape.gradient(loss, model.variables)
|
||||
optimizer.apply_gradients(zip(grads, model.variables))
|
||||
return loss
|
||||
|
||||
outputs = distribution.experimental_run_v2(
|
||||
step_fn, args=(next(iterator),))
|
||||
return nest.map_structure(distribution.experimental_local_results,
|
||||
outputs)
|
||||
|
||||
train_step(input_iterator)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def test_keras_model_optimizer_experimental_run_loop(self, distribution):
|
||||
dataset = self._get_dataset()
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
with distribution.scope():
|
||||
model = self._get_model()
|
||||
optimizer = keras.optimizer_v2.rmsprop.RMSprop()
|
||||
|
||||
@def_function.function
|
||||
def train_step(iterator):
|
||||
def step_fn(inputs):
|
||||
images, targets = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
outputs = model(images)
|
||||
loss = math_ops.reduce_sum(outputs - targets)
|
||||
grads = tape.gradient(loss, model.variables)
|
||||
optimizer.apply_gradients(zip(grads, model.variables))
|
||||
return loss
|
||||
|
||||
for _ in range(5):
|
||||
distribution.experimental_run_v2(step_fn, args=(next(iterator),))
|
||||
|
||||
train_step(input_iterator)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies,
|
||||
mode=["eager"]
|
||||
))
|
||||
def test_lstm(self, distribution):
|
||||
|
||||
batch_size = 32
|
||||
|
||||
def create_lstm_model():
|
||||
model = keras.models.Sequential()
|
||||
# We only have LSTM variables so we can detect no gradient issues more
|
||||
# easily.
|
||||
model.add(
|
||||
keras.layers.LSTM(1, return_sequences=False, input_shape=(10, 1)))
|
||||
return model
|
||||
|
||||
def create_lstm_data():
|
||||
seq_length = 10
|
||||
|
||||
x_train = np.random.rand(batch_size, seq_length, 1).astype("float32")
|
||||
y_train = np.random.rand(batch_size, 1).astype("float32")
|
||||
return x_train, y_train
|
||||
|
||||
x, y = create_lstm_data()
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
with distribution.scope():
|
||||
model = create_lstm_model()
|
||||
optimizer = keras.optimizer_v2.gradient_descent.SGD()
|
||||
|
||||
@def_function.function
|
||||
def train_step(input_iterator):
|
||||
|
||||
def step_fn(inputs):
|
||||
inps, targ = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
output = model(inps)
|
||||
loss = math_ops.reduce_mean(
|
||||
keras.losses.binary_crossentropy(
|
||||
y_true=targ, y_pred=output, from_logits=False))
|
||||
grads = tape.gradient(loss, model.variables)
|
||||
optimizer.apply_gradients(zip(grads, model.variables))
|
||||
return loss
|
||||
|
||||
outputs = distribution.experimental_run_v2(
|
||||
step_fn, args=(next(input_iterator),))
|
||||
return distribution.experimental_local_results(outputs)
|
||||
|
||||
train_step(input_iterator)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies, mode=["eager"]))
|
||||
def test_nested_tf_functions(self, distribution):
|
||||
# The test builds two computations with keras layers, one with nested
|
||||
# tf.function, and the other without nested tf.function. We run these
|
||||
# computations independently on the model with same weights, and make sure
|
||||
# the variables are still the same after one training step.
|
||||
|
||||
inputs = np.random.random((10, 3)).astype(np.float32)
|
||||
targets = np.ones((10, 4), dtype=np.float32)
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).repeat()
|
||||
dataset = dataset.batch(10, drop_remainder=True)
|
||||
input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
def get_model():
|
||||
x = keras.layers.Input(shape=(3,), name="input")
|
||||
y = keras.layers.Dense(4, name="dense")(x)
|
||||
model = keras.Model(x, y)
|
||||
return model
|
||||
|
||||
with distribution.scope():
|
||||
model = get_model()
|
||||
optimizer = keras.optimizer_v2.gradient_descent.SGD(0.1, momentum=0.01)
|
||||
weights_file = os.path.join(self.get_temp_dir(), ".h5")
|
||||
model.save_weights(weights_file)
|
||||
model2 = get_model()
|
||||
model2.load_weights(weights_file)
|
||||
|
||||
# Make sure model and model2 variables are in sync when initialized.
|
||||
for model_v, model2_v in zip(model.variables, model2.variables):
|
||||
self.assertAllClose(model_v.numpy(), model2_v.numpy())
|
||||
|
||||
def compute_loss(images, targets):
|
||||
outputs = model(images)
|
||||
return math_ops.reduce_sum(outputs - targets)
|
||||
|
||||
@def_function.function
|
||||
def train_step_without_nested_tf_function(inputs):
|
||||
|
||||
def step_fn(inputs):
|
||||
images, targets = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = compute_loss(images, targets)
|
||||
grads = tape.gradient(loss, model.variables)
|
||||
optimizer.apply_gradients(zip(grads, model.variables))
|
||||
|
||||
distribution.experimental_run_v2(step_fn, args=(inputs,))
|
||||
|
||||
@def_function.function
|
||||
def compute_loss2(images, targets):
|
||||
outputs = model2(images)
|
||||
return math_ops.reduce_sum(outputs - targets)
|
||||
|
||||
@def_function.function
|
||||
def train_step_with_nested_tf_function(inputs):
|
||||
|
||||
def step_fn(inputs):
|
||||
images, targets = inputs
|
||||
with backprop.GradientTape() as tape:
|
||||
loss = compute_loss2(images, targets)
|
||||
grads = tape.gradient(loss, model2.variables)
|
||||
optimizer.apply_gradients(zip(grads, model2.variables))
|
||||
|
||||
distribution.experimental_run_v2(step_fn, args=(inputs,))
|
||||
|
||||
inputs = next(input_iterator)
|
||||
|
||||
train_step_without_nested_tf_function(inputs)
|
||||
train_step_with_nested_tf_function(inputs)
|
||||
|
||||
# Make sure model and model2 variables are still in sync.
|
||||
for model_v, model2_v in zip(model.variables, model2.variables):
|
||||
self.assertAllClose(model_v.numpy(), model2_v.numpy())
|
||||
|
||||
def _get_dataset(self):
|
||||
inputs = np.zeros((10, 3), dtype=np.float32)
|
||||
targets = np.zeros((10, 4), dtype=np.float32)
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
|
||||
dataset = dataset.repeat(100)
|
||||
dataset = dataset.batch(10, drop_remainder=True)
|
||||
return dataset
|
||||
|
||||
def _get_model(self):
|
||||
x = keras.layers.Input(shape=(3,), name="input")
|
||||
y = keras.layers.Dense(4, name="dense")(x)
|
||||
model = keras.Model(x, y)
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
Loading…
x
Reference in New Issue
Block a user