Split distribute/custom_training_loop_test into three parts as it is timing out on our kokoro TPU continuous tests.
PiperOrigin-RevId: 294765478 Change-Id: I574ca0433ade67673e1b5ea731db94e40e28ae5f
This commit is contained in:
		
							parent
							
								
									bceb1d7854
								
							
						
					
					
						commit
						a7f1d52b03
					
				@ -925,9 +925,66 @@ distribute_py_test(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
distribute_py_test(
 | 
			
		||||
    name = "custom_training_loop_test",
 | 
			
		||||
    srcs = ["custom_training_loop_test.py"],
 | 
			
		||||
    main = "custom_training_loop_test.py",
 | 
			
		||||
    name = "custom_training_loop_gradient_test",
 | 
			
		||||
    srcs = ["custom_training_loop_gradient_test.py"],
 | 
			
		||||
    main = "custom_training_loop_gradient_test.py",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "multi_and_single_gpu",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python:variables",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//tensorflow/python/distribute:combinations",
 | 
			
		||||
        "//tensorflow/python/distribute:strategy_combinations",
 | 
			
		||||
        "//tensorflow/python/eager:test",
 | 
			
		||||
        "//tensorflow/python/keras",
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
distribute_py_test(
 | 
			
		||||
    name = "custom_training_loop_input_test",
 | 
			
		||||
    srcs = ["custom_training_loop_input_test.py"],
 | 
			
		||||
    main = "custom_training_loop_input_test.py",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "multi_and_single_gpu",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python:variables",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//tensorflow/python/distribute:combinations",
 | 
			
		||||
        "//tensorflow/python/distribute:strategy_combinations",
 | 
			
		||||
        "//tensorflow/python/eager:test",
 | 
			
		||||
        "//tensorflow/python/keras",
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
distribute_py_test(
 | 
			
		||||
    name = "custom_training_loop_metrics_test",
 | 
			
		||||
    srcs = ["custom_training_loop_metrics_test.py"],
 | 
			
		||||
    main = "custom_training_loop_metrics_test.py",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "multi_and_single_gpu",
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python:variables",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//tensorflow/python/distribute:combinations",
 | 
			
		||||
        "//tensorflow/python/distribute:strategy_combinations",
 | 
			
		||||
        "//tensorflow/python/eager:test",
 | 
			
		||||
        "//tensorflow/python/keras",
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
distribute_py_test(
 | 
			
		||||
    name = "custom_training_loop_models_test",
 | 
			
		||||
    srcs = ["custom_training_loop_models_test.py"],
 | 
			
		||||
    main = "custom_training_loop_models_test.py",
 | 
			
		||||
    tags = [
 | 
			
		||||
        "multi_and_single_gpu",
 | 
			
		||||
    ],
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,152 @@
 | 
			
		||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
# ==============================================================================
 | 
			
		||||
"""Tests for custom training loops."""
 | 
			
		||||
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
 | 
			
		||||
from tensorflow.python import tf2
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.distribute import combinations
 | 
			
		||||
from tensorflow.python.distribute import strategy_combinations
 | 
			
		||||
from tensorflow.python.eager import backprop
 | 
			
		||||
from tensorflow.python.eager import def_function
 | 
			
		||||
from tensorflow.python.eager import test
 | 
			
		||||
from tensorflow.python.ops import math_ops
 | 
			
		||||
from tensorflow.python.ops import variables
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_dataset_from_tensor_slices(inp_array):
 | 
			
		||||
  dataset = dataset_ops.DatasetV2.from_tensor_slices(inp_array)
 | 
			
		||||
  # TODO(b/138326910): Remove Dataset V1 version once bug resolved.
 | 
			
		||||
  if not tf2.enabled():
 | 
			
		||||
    dataset = dataset_ops.Dataset.from_tensor_slices(inp_array)
 | 
			
		||||
  return dataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AssertFlattenedMixin(object):
 | 
			
		||||
  """Mixin for specialized asserts."""
 | 
			
		||||
 | 
			
		||||
  def assert_equal_flattened(self, expected_results, actual_results):
 | 
			
		||||
    """Asserts that flattened results are equal.
 | 
			
		||||
 | 
			
		||||
    Due to the number of replicas in the strategy, the output may have a
 | 
			
		||||
    different structure and needs to be flattened for comparison.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      expected_results: The results expected as a result of a computation.
 | 
			
		||||
      actual_results: The actual results of a computation.
 | 
			
		||||
    """
 | 
			
		||||
    self.assertEqual(len(expected_results), len(actual_results))
 | 
			
		||||
 | 
			
		||||
    for i, expected_result in enumerate(expected_results):
 | 
			
		||||
      final_result = []
 | 
			
		||||
      actual_result = actual_results[i]
 | 
			
		||||
      for val in actual_result:
 | 
			
		||||
        final_result.extend(val.numpy())
 | 
			
		||||
      self.assertAllEqual(expected_result, final_result)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GradientTapeTest(test.TestCase, parameterized.TestCase,
 | 
			
		||||
                       AssertFlattenedMixin):
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def testStepInFunctionGradient(self, distribution):
 | 
			
		||||
    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step(x):
 | 
			
		||||
      def computation(x):
 | 
			
		||||
        return math_ops.square(x)
 | 
			
		||||
      with backprop.GradientTape() as tape:
 | 
			
		||||
        tape.watch(x)  # Manually watch non-variable tensors.
 | 
			
		||||
        y = computation(x)
 | 
			
		||||
      grads = tape.gradient(y, x)
 | 
			
		||||
      return grads
 | 
			
		||||
 | 
			
		||||
    dist_dataset = distribution.experimental_distribute_dataset(dataset)
 | 
			
		||||
    results = []
 | 
			
		||||
    for x in dist_dataset:
 | 
			
		||||
      output = distribution.experimental_local_results(
 | 
			
		||||
          distribution.experimental_run_v2(train_step, args=(x,)))
 | 
			
		||||
      results.append(output)
 | 
			
		||||
    self.assert_equal_flattened([[10., 12.], [14., 16.]], results)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def testRunInFunctionGradient(self, distribution):
 | 
			
		||||
    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def run(x):
 | 
			
		||||
      def train_step(x):
 | 
			
		||||
        def computation(x):
 | 
			
		||||
          return math_ops.square(x)
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          tape.watch(x)  # Manually watch non-variable tensors.
 | 
			
		||||
          y = computation(x)
 | 
			
		||||
        grads = tape.gradient(y, x)
 | 
			
		||||
        return grads
 | 
			
		||||
      return distribution.experimental_local_results(
 | 
			
		||||
          distribution.experimental_run_v2(train_step, args=(x,)))
 | 
			
		||||
 | 
			
		||||
    dist_dataset = distribution.experimental_distribute_dataset(dataset)
 | 
			
		||||
    results = []
 | 
			
		||||
    for x in dist_dataset:
 | 
			
		||||
      output = run(x)
 | 
			
		||||
      results.append(output)
 | 
			
		||||
    self.assert_equal_flattened([[10., 12.], [14., 16.]], results)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"],
 | 
			
		||||
          model_in_tf_function=[True, False]
 | 
			
		||||
      ))
 | 
			
		||||
  def testNestedFunction(self, distribution, model_in_tf_function):
 | 
			
		||||
    def model(x):
 | 
			
		||||
      return x * x
 | 
			
		||||
 | 
			
		||||
    if model_in_tf_function:
 | 
			
		||||
      model = def_function.function(model)
 | 
			
		||||
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      x = variables.Variable(1.0)
 | 
			
		||||
 | 
			
		||||
      @def_function.function
 | 
			
		||||
      def train_step():
 | 
			
		||||
        def replica_step():
 | 
			
		||||
          with backprop.GradientTape() as tape:
 | 
			
		||||
            y = model(x)
 | 
			
		||||
          return tape.gradient(y, x)
 | 
			
		||||
        return distribution.experimental_run_v2(replica_step)
 | 
			
		||||
 | 
			
		||||
      grads = distribution.experimental_local_results(train_step())
 | 
			
		||||
      self.assertLen(grads, distribution.num_replicas_in_sync)
 | 
			
		||||
      self.assertTrue(all(g is not None for g in grads))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
  test.main()
 | 
			
		||||
@ -18,18 +18,13 @@ from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from tensorflow.python import keras
 | 
			
		||||
from tensorflow.python import tf2
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.distribute import combinations
 | 
			
		||||
from tensorflow.python.distribute import reduce_util
 | 
			
		||||
from tensorflow.python.distribute import strategy_combinations
 | 
			
		||||
from tensorflow.python.eager import backprop
 | 
			
		||||
from tensorflow.python.eager import def_function
 | 
			
		||||
from tensorflow.python.eager import test
 | 
			
		||||
from tensorflow.python.framework import constant_op
 | 
			
		||||
@ -596,447 +591,5 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
 | 
			
		||||
    self.assert_equal_flattened([[25., 36.], [49., 64.]], results)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GradientTapeTest(test.TestCase, parameterized.TestCase,
 | 
			
		||||
                       AssertFlattenedMixin):
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def testStepInFunctionGradient(self, distribution):
 | 
			
		||||
    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step(x):
 | 
			
		||||
      def computation(x):
 | 
			
		||||
        return math_ops.square(x)
 | 
			
		||||
      with backprop.GradientTape() as tape:
 | 
			
		||||
        tape.watch(x)  # Manually watch non-variable tensors.
 | 
			
		||||
        y = computation(x)
 | 
			
		||||
      grads = tape.gradient(y, x)
 | 
			
		||||
      return grads
 | 
			
		||||
 | 
			
		||||
    dist_dataset = distribution.experimental_distribute_dataset(dataset)
 | 
			
		||||
    results = []
 | 
			
		||||
    for x in dist_dataset:
 | 
			
		||||
      output = distribution.experimental_local_results(
 | 
			
		||||
          distribution.experimental_run_v2(train_step, args=(x,)))
 | 
			
		||||
      results.append(output)
 | 
			
		||||
    self.assert_equal_flattened([[10., 12.], [14., 16.]], results)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def testRunInFunctionGradient(self, distribution):
 | 
			
		||||
    dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def run(x):
 | 
			
		||||
      def train_step(x):
 | 
			
		||||
        def computation(x):
 | 
			
		||||
          return math_ops.square(x)
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          tape.watch(x)  # Manually watch non-variable tensors.
 | 
			
		||||
          y = computation(x)
 | 
			
		||||
        grads = tape.gradient(y, x)
 | 
			
		||||
        return grads
 | 
			
		||||
      return distribution.experimental_local_results(
 | 
			
		||||
          distribution.experimental_run_v2(train_step, args=(x,)))
 | 
			
		||||
 | 
			
		||||
    dist_dataset = distribution.experimental_distribute_dataset(dataset)
 | 
			
		||||
    results = []
 | 
			
		||||
    for x in dist_dataset:
 | 
			
		||||
      output = run(x)
 | 
			
		||||
      results.append(output)
 | 
			
		||||
    self.assert_equal_flattened([[10., 12.], [14., 16.]], results)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"],
 | 
			
		||||
          model_in_tf_function=[True, False]
 | 
			
		||||
      ))
 | 
			
		||||
  def testNestedFunction(self, distribution, model_in_tf_function):
 | 
			
		||||
    def model(x):
 | 
			
		||||
      return x * x
 | 
			
		||||
 | 
			
		||||
    if model_in_tf_function:
 | 
			
		||||
      model = def_function.function(model)
 | 
			
		||||
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      x = variables.Variable(1.0)
 | 
			
		||||
 | 
			
		||||
      @def_function.function
 | 
			
		||||
      def train_step():
 | 
			
		||||
        def replica_step():
 | 
			
		||||
          with backprop.GradientTape() as tape:
 | 
			
		||||
            y = model(x)
 | 
			
		||||
          return tape.gradient(y, x)
 | 
			
		||||
        return distribution.experimental_run_v2(replica_step)
 | 
			
		||||
 | 
			
		||||
      grads = distribution.experimental_local_results(train_step())
 | 
			
		||||
      self.assertLen(grads, distribution.num_replicas_in_sync)
 | 
			
		||||
      self.assertTrue(all(g is not None for g in grads))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KerasModelsTest(test.TestCase, parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def test_single_keras_layer_experimental_run(self, distribution):
 | 
			
		||||
    dataset = self._get_dataset()
 | 
			
		||||
    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 | 
			
		||||
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      model = keras.layers.Dense(4, name="dense")
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step(iterator):
 | 
			
		||||
      def step_fn(inputs):
 | 
			
		||||
        images, targets = inputs
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          outputs = model(images)
 | 
			
		||||
          loss = math_ops.reduce_sum(outputs - targets)
 | 
			
		||||
        grads = tape.gradient(loss, model.variables)
 | 
			
		||||
        return grads
 | 
			
		||||
 | 
			
		||||
      outputs = distribution.experimental_run_v2(
 | 
			
		||||
          step_fn, args=(next(iterator),))
 | 
			
		||||
      return nest.map_structure(distribution.experimental_local_results,
 | 
			
		||||
                                outputs)
 | 
			
		||||
 | 
			
		||||
    train_step(input_iterator)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def test_keras_model_creation_experimental_run(self, distribution):
 | 
			
		||||
    dataset = self._get_dataset()
 | 
			
		||||
    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 | 
			
		||||
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      model = self._get_model()
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step(iterator):
 | 
			
		||||
      def step_fn(inputs):
 | 
			
		||||
        images, targets = inputs
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          outputs = model(images)
 | 
			
		||||
          loss = math_ops.reduce_sum(outputs - targets)
 | 
			
		||||
        grads = tape.gradient(loss, model.variables)
 | 
			
		||||
        return grads
 | 
			
		||||
 | 
			
		||||
      outputs = distribution.experimental_run_v2(
 | 
			
		||||
          step_fn, args=(next(iterator),))
 | 
			
		||||
      return nest.map_structure(distribution.experimental_local_results,
 | 
			
		||||
                                outputs)
 | 
			
		||||
 | 
			
		||||
    train_step(input_iterator)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def test_keras_model_optimizer_experimental_run(self, distribution):
 | 
			
		||||
    dataset = self._get_dataset()
 | 
			
		||||
    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 | 
			
		||||
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      model = self._get_model()
 | 
			
		||||
      optimizer = keras.optimizer_v2.rmsprop.RMSprop()
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step(iterator):
 | 
			
		||||
      def step_fn(inputs):
 | 
			
		||||
        images, targets = inputs
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          outputs = model(images)
 | 
			
		||||
          loss = math_ops.reduce_sum(outputs - targets)
 | 
			
		||||
        grads = tape.gradient(loss, model.variables)
 | 
			
		||||
        optimizer.apply_gradients(zip(grads, model.variables))
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
      outputs = distribution.experimental_run_v2(
 | 
			
		||||
          step_fn, args=(next(iterator),))
 | 
			
		||||
      return nest.map_structure(distribution.experimental_local_results,
 | 
			
		||||
                                outputs)
 | 
			
		||||
 | 
			
		||||
    train_step(input_iterator)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def test_keras_subclass_model_optimizer_experimental_run(self, distribution):
 | 
			
		||||
    def get_subclass_model():
 | 
			
		||||
 | 
			
		||||
      class KerasSubclassModel(keras.Model):
 | 
			
		||||
 | 
			
		||||
        def __init__(self):
 | 
			
		||||
          super(KerasSubclassModel, self).__init__()
 | 
			
		||||
          self.l = keras.layers.Dense(4, name="dense")
 | 
			
		||||
 | 
			
		||||
        def call(self, x):
 | 
			
		||||
          return self.l(x)
 | 
			
		||||
 | 
			
		||||
      return KerasSubclassModel()
 | 
			
		||||
    dataset = self._get_dataset()
 | 
			
		||||
    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 | 
			
		||||
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      model = get_subclass_model()
 | 
			
		||||
      optimizer = keras.optimizer_v2.rmsprop.RMSprop()
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step(iterator):
 | 
			
		||||
      def step_fn(inputs):
 | 
			
		||||
        images, targets = inputs
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          outputs = model(images)
 | 
			
		||||
          loss = math_ops.reduce_sum(outputs - targets)
 | 
			
		||||
        grads = tape.gradient(loss, model.variables)
 | 
			
		||||
        optimizer.apply_gradients(zip(grads, model.variables))
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
      outputs = distribution.experimental_run_v2(
 | 
			
		||||
          step_fn, args=(next(iterator),))
 | 
			
		||||
      return nest.map_structure(distribution.experimental_local_results,
 | 
			
		||||
                                outputs)
 | 
			
		||||
 | 
			
		||||
    train_step(input_iterator)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def test_keras_model_optimizer_experimental_run_loop(self, distribution):
 | 
			
		||||
    dataset = self._get_dataset()
 | 
			
		||||
    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 | 
			
		||||
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      model = self._get_model()
 | 
			
		||||
      optimizer = keras.optimizer_v2.rmsprop.RMSprop()
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step(iterator):
 | 
			
		||||
      def step_fn(inputs):
 | 
			
		||||
        images, targets = inputs
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          outputs = model(images)
 | 
			
		||||
          loss = math_ops.reduce_sum(outputs - targets)
 | 
			
		||||
        grads = tape.gradient(loss, model.variables)
 | 
			
		||||
        optimizer.apply_gradients(zip(grads, model.variables))
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
      for _ in range(5):
 | 
			
		||||
        distribution.experimental_run_v2(step_fn, args=(next(iterator),))
 | 
			
		||||
 | 
			
		||||
    train_step(input_iterator)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def test_lstm(self, distribution):
 | 
			
		||||
 | 
			
		||||
    batch_size = 32
 | 
			
		||||
 | 
			
		||||
    def create_lstm_model():
 | 
			
		||||
      model = keras.models.Sequential()
 | 
			
		||||
      # We only have LSTM variables so we can detect no gradient issues more
 | 
			
		||||
      # easily.
 | 
			
		||||
      model.add(
 | 
			
		||||
          keras.layers.LSTM(1, return_sequences=False, input_shape=(10, 1)))
 | 
			
		||||
      return model
 | 
			
		||||
 | 
			
		||||
    def create_lstm_data():
 | 
			
		||||
      seq_length = 10
 | 
			
		||||
 | 
			
		||||
      x_train = np.random.rand(batch_size, seq_length, 1).astype("float32")
 | 
			
		||||
      y_train = np.random.rand(batch_size, 1).astype("float32")
 | 
			
		||||
      return x_train, y_train
 | 
			
		||||
 | 
			
		||||
    x, y = create_lstm_data()
 | 
			
		||||
    dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
 | 
			
		||||
    dataset = dataset.batch(batch_size, drop_remainder=True)
 | 
			
		||||
    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 | 
			
		||||
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      model = create_lstm_model()
 | 
			
		||||
      optimizer = keras.optimizer_v2.gradient_descent.SGD()
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step(input_iterator):
 | 
			
		||||
 | 
			
		||||
      def step_fn(inputs):
 | 
			
		||||
        inps, targ = inputs
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          output = model(inps)
 | 
			
		||||
          loss = math_ops.reduce_mean(
 | 
			
		||||
              keras.losses.binary_crossentropy(
 | 
			
		||||
                  y_true=targ, y_pred=output, from_logits=False))
 | 
			
		||||
        grads = tape.gradient(loss, model.variables)
 | 
			
		||||
        optimizer.apply_gradients(zip(grads, model.variables))
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
      outputs = distribution.experimental_run_v2(
 | 
			
		||||
          step_fn, args=(next(input_iterator),))
 | 
			
		||||
      return distribution.experimental_local_results(outputs)
 | 
			
		||||
 | 
			
		||||
    train_step(input_iterator)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies, mode=["eager"]))
 | 
			
		||||
  def test_nested_tf_functions(self, distribution):
 | 
			
		||||
    # The test builds two computations with keras layers, one with nested
 | 
			
		||||
    # tf.function, and the other without nested tf.function. We run these
 | 
			
		||||
    # computations independently on the model with same weights, and make sure
 | 
			
		||||
    # the variables are still the same after one training step.
 | 
			
		||||
 | 
			
		||||
    inputs = np.random.random((10, 3)).astype(np.float32)
 | 
			
		||||
    targets = np.ones((10, 4), dtype=np.float32)
 | 
			
		||||
    dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).repeat()
 | 
			
		||||
    dataset = dataset.batch(10, drop_remainder=True)
 | 
			
		||||
    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 | 
			
		||||
 | 
			
		||||
    def get_model():
 | 
			
		||||
      x = keras.layers.Input(shape=(3,), name="input")
 | 
			
		||||
      y = keras.layers.Dense(4, name="dense")(x)
 | 
			
		||||
      model = keras.Model(x, y)
 | 
			
		||||
      return model
 | 
			
		||||
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      model = get_model()
 | 
			
		||||
      optimizer = keras.optimizer_v2.gradient_descent.SGD(0.1, momentum=0.01)
 | 
			
		||||
      weights_file = os.path.join(self.get_temp_dir(), ".h5")
 | 
			
		||||
      model.save_weights(weights_file)
 | 
			
		||||
      model2 = get_model()
 | 
			
		||||
      model2.load_weights(weights_file)
 | 
			
		||||
 | 
			
		||||
    # Make sure model and model2 variables are in sync when initialized.
 | 
			
		||||
    for model_v, model2_v in zip(model.variables, model2.variables):
 | 
			
		||||
      self.assertAllClose(model_v.numpy(), model2_v.numpy())
 | 
			
		||||
 | 
			
		||||
    def compute_loss(images, targets):
 | 
			
		||||
      outputs = model(images)
 | 
			
		||||
      return math_ops.reduce_sum(outputs - targets)
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step_without_nested_tf_function(inputs):
 | 
			
		||||
 | 
			
		||||
      def step_fn(inputs):
 | 
			
		||||
        images, targets = inputs
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          loss = compute_loss(images, targets)
 | 
			
		||||
        grads = tape.gradient(loss, model.variables)
 | 
			
		||||
        optimizer.apply_gradients(zip(grads, model.variables))
 | 
			
		||||
 | 
			
		||||
      distribution.experimental_run_v2(step_fn, args=(inputs,))
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def compute_loss2(images, targets):
 | 
			
		||||
      outputs = model2(images)
 | 
			
		||||
      return math_ops.reduce_sum(outputs - targets)
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step_with_nested_tf_function(inputs):
 | 
			
		||||
 | 
			
		||||
      def step_fn(inputs):
 | 
			
		||||
        images, targets = inputs
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          loss = compute_loss2(images, targets)
 | 
			
		||||
        grads = tape.gradient(loss, model2.variables)
 | 
			
		||||
        optimizer.apply_gradients(zip(grads, model2.variables))
 | 
			
		||||
 | 
			
		||||
      distribution.experimental_run_v2(step_fn, args=(inputs,))
 | 
			
		||||
 | 
			
		||||
    inputs = next(input_iterator)
 | 
			
		||||
 | 
			
		||||
    train_step_without_nested_tf_function(inputs)
 | 
			
		||||
    train_step_with_nested_tf_function(inputs)
 | 
			
		||||
 | 
			
		||||
    # Make sure model and model2 variables are still in sync.
 | 
			
		||||
    for model_v, model2_v in zip(model.variables, model2.variables):
 | 
			
		||||
      self.assertAllClose(model_v.numpy(), model2_v.numpy())
 | 
			
		||||
 | 
			
		||||
  def _get_dataset(self):
 | 
			
		||||
    inputs = np.zeros((10, 3), dtype=np.float32)
 | 
			
		||||
    targets = np.zeros((10, 4), dtype=np.float32)
 | 
			
		||||
    dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
 | 
			
		||||
    dataset = dataset.repeat(100)
 | 
			
		||||
    dataset = dataset.batch(10, drop_remainder=True)
 | 
			
		||||
    return dataset
 | 
			
		||||
 | 
			
		||||
  def _get_model(self):
 | 
			
		||||
    x = keras.layers.Input(shape=(3,), name="input")
 | 
			
		||||
    y = keras.layers.Dense(4, name="dense")(x)
 | 
			
		||||
    model = keras.Model(x, y)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KerasMetricsTest(test.TestCase, parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def test_multiple_keras_metrics_experimental_run(self, distribution):
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      loss_metric = keras.metrics.Mean("loss", dtype=np.float32)
 | 
			
		||||
      loss_metric_2 = keras.metrics.Mean("loss_2", dtype=np.float32)
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step():
 | 
			
		||||
      def step_fn():
 | 
			
		||||
        loss = constant_op.constant(5.0, dtype=np.float32)
 | 
			
		||||
        loss_metric.update_state(loss)
 | 
			
		||||
        loss_metric_2.update_state(loss)
 | 
			
		||||
 | 
			
		||||
      distribution.experimental_run_v2(step_fn)
 | 
			
		||||
 | 
			
		||||
    train_step()
 | 
			
		||||
    self.assertEqual(loss_metric.result().numpy(),
 | 
			
		||||
                     loss_metric_2.result().numpy())
 | 
			
		||||
    self.assertEqual(loss_metric.result().numpy(), 5.0)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def test_update_keras_metric_declared_in_strategy_scope(self, distribution):
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      metric = keras.metrics.Mean("test_metric", dtype=np.float32)
 | 
			
		||||
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(10).batch(2)
 | 
			
		||||
    dataset = distribution.experimental_distribute_dataset(dataset)
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def step_fn(i):
 | 
			
		||||
      metric.update_state(i)
 | 
			
		||||
 | 
			
		||||
    for i in dataset:
 | 
			
		||||
      distribution.experimental_run_v2(step_fn, args=(i,))
 | 
			
		||||
 | 
			
		||||
    # This should be the mean of integers 0-9 which has a sum of 45 and a count
 | 
			
		||||
    # of 10 resulting in mean of 4.5.
 | 
			
		||||
    self.assertEqual(metric.result().numpy(), 4.5)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
  test.main()
 | 
			
		||||
@ -0,0 +1,84 @@
 | 
			
		||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
# ==============================================================================
 | 
			
		||||
"""Tests for custom training loops."""
 | 
			
		||||
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from tensorflow.python import keras
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.distribute import combinations
 | 
			
		||||
from tensorflow.python.distribute import strategy_combinations
 | 
			
		||||
from tensorflow.python.eager import def_function
 | 
			
		||||
from tensorflow.python.eager import test
 | 
			
		||||
from tensorflow.python.framework import constant_op
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KerasMetricsTest(test.TestCase, parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def test_multiple_keras_metrics_experimental_run(self, distribution):
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      loss_metric = keras.metrics.Mean("loss", dtype=np.float32)
 | 
			
		||||
      loss_metric_2 = keras.metrics.Mean("loss_2", dtype=np.float32)
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step():
 | 
			
		||||
      def step_fn():
 | 
			
		||||
        loss = constant_op.constant(5.0, dtype=np.float32)
 | 
			
		||||
        loss_metric.update_state(loss)
 | 
			
		||||
        loss_metric_2.update_state(loss)
 | 
			
		||||
 | 
			
		||||
      distribution.experimental_run_v2(step_fn)
 | 
			
		||||
 | 
			
		||||
    train_step()
 | 
			
		||||
    self.assertEqual(loss_metric.result().numpy(),
 | 
			
		||||
                     loss_metric_2.result().numpy())
 | 
			
		||||
    self.assertEqual(loss_metric.result().numpy(), 5.0)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def test_update_keras_metric_declared_in_strategy_scope(self, distribution):
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      metric = keras.metrics.Mean("test_metric", dtype=np.float32)
 | 
			
		||||
 | 
			
		||||
    dataset = dataset_ops.Dataset.range(10).batch(2)
 | 
			
		||||
    dataset = distribution.experimental_distribute_dataset(dataset)
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def step_fn(i):
 | 
			
		||||
      metric.update_state(i)
 | 
			
		||||
 | 
			
		||||
    for i in dataset:
 | 
			
		||||
      distribution.experimental_run_v2(step_fn, args=(i,))
 | 
			
		||||
 | 
			
		||||
    # This should be the mean of integers 0-9 which has a sum of 45 and a count
 | 
			
		||||
    # of 10 resulting in mean of 4.5.
 | 
			
		||||
    self.assertEqual(metric.result().numpy(), 4.5)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
  test.main()
 | 
			
		||||
							
								
								
									
										344
									
								
								tensorflow/python/distribute/custom_training_loop_models_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										344
									
								
								tensorflow/python/distribute/custom_training_loop_models_test.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,344 @@
 | 
			
		||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
# ==============================================================================
 | 
			
		||||
"""Tests for custom training loops."""
 | 
			
		||||
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from tensorflow.python import keras
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.distribute import combinations
 | 
			
		||||
from tensorflow.python.distribute import strategy_combinations
 | 
			
		||||
from tensorflow.python.eager import backprop
 | 
			
		||||
from tensorflow.python.eager import def_function
 | 
			
		||||
from tensorflow.python.eager import test
 | 
			
		||||
from tensorflow.python.ops import math_ops
 | 
			
		||||
from tensorflow.python.util import nest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KerasModelsTest(test.TestCase, parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def test_single_keras_layer_experimental_run(self, distribution):
 | 
			
		||||
    dataset = self._get_dataset()
 | 
			
		||||
    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 | 
			
		||||
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      model = keras.layers.Dense(4, name="dense")
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step(iterator):
 | 
			
		||||
      def step_fn(inputs):
 | 
			
		||||
        images, targets = inputs
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          outputs = model(images)
 | 
			
		||||
          loss = math_ops.reduce_sum(outputs - targets)
 | 
			
		||||
        grads = tape.gradient(loss, model.variables)
 | 
			
		||||
        return grads
 | 
			
		||||
 | 
			
		||||
      outputs = distribution.experimental_run_v2(
 | 
			
		||||
          step_fn, args=(next(iterator),))
 | 
			
		||||
      return nest.map_structure(distribution.experimental_local_results,
 | 
			
		||||
                                outputs)
 | 
			
		||||
 | 
			
		||||
    train_step(input_iterator)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def test_keras_model_creation_experimental_run(self, distribution):
 | 
			
		||||
    dataset = self._get_dataset()
 | 
			
		||||
    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 | 
			
		||||
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      model = self._get_model()
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step(iterator):
 | 
			
		||||
      def step_fn(inputs):
 | 
			
		||||
        images, targets = inputs
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          outputs = model(images)
 | 
			
		||||
          loss = math_ops.reduce_sum(outputs - targets)
 | 
			
		||||
        grads = tape.gradient(loss, model.variables)
 | 
			
		||||
        return grads
 | 
			
		||||
 | 
			
		||||
      outputs = distribution.experimental_run_v2(
 | 
			
		||||
          step_fn, args=(next(iterator),))
 | 
			
		||||
      return nest.map_structure(distribution.experimental_local_results,
 | 
			
		||||
                                outputs)
 | 
			
		||||
 | 
			
		||||
    train_step(input_iterator)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def test_keras_model_optimizer_experimental_run(self, distribution):
 | 
			
		||||
    dataset = self._get_dataset()
 | 
			
		||||
    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 | 
			
		||||
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      model = self._get_model()
 | 
			
		||||
      optimizer = keras.optimizer_v2.rmsprop.RMSprop()
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step(iterator):
 | 
			
		||||
      def step_fn(inputs):
 | 
			
		||||
        images, targets = inputs
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          outputs = model(images)
 | 
			
		||||
          loss = math_ops.reduce_sum(outputs - targets)
 | 
			
		||||
        grads = tape.gradient(loss, model.variables)
 | 
			
		||||
        optimizer.apply_gradients(zip(grads, model.variables))
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
      outputs = distribution.experimental_run_v2(
 | 
			
		||||
          step_fn, args=(next(iterator),))
 | 
			
		||||
      return nest.map_structure(distribution.experimental_local_results,
 | 
			
		||||
                                outputs)
 | 
			
		||||
 | 
			
		||||
    train_step(input_iterator)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def test_keras_subclass_model_optimizer_experimental_run(self, distribution):
 | 
			
		||||
    def get_subclass_model():
 | 
			
		||||
 | 
			
		||||
      class KerasSubclassModel(keras.Model):
 | 
			
		||||
 | 
			
		||||
        def __init__(self):
 | 
			
		||||
          super(KerasSubclassModel, self).__init__()
 | 
			
		||||
          self.l = keras.layers.Dense(4, name="dense")
 | 
			
		||||
 | 
			
		||||
        def call(self, x):
 | 
			
		||||
          return self.l(x)
 | 
			
		||||
 | 
			
		||||
      return KerasSubclassModel()
 | 
			
		||||
    dataset = self._get_dataset()
 | 
			
		||||
    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 | 
			
		||||
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      model = get_subclass_model()
 | 
			
		||||
      optimizer = keras.optimizer_v2.rmsprop.RMSprop()
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step(iterator):
 | 
			
		||||
      def step_fn(inputs):
 | 
			
		||||
        images, targets = inputs
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          outputs = model(images)
 | 
			
		||||
          loss = math_ops.reduce_sum(outputs - targets)
 | 
			
		||||
        grads = tape.gradient(loss, model.variables)
 | 
			
		||||
        optimizer.apply_gradients(zip(grads, model.variables))
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
      outputs = distribution.experimental_run_v2(
 | 
			
		||||
          step_fn, args=(next(iterator),))
 | 
			
		||||
      return nest.map_structure(distribution.experimental_local_results,
 | 
			
		||||
                                outputs)
 | 
			
		||||
 | 
			
		||||
    train_step(input_iterator)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def test_keras_model_optimizer_experimental_run_loop(self, distribution):
 | 
			
		||||
    dataset = self._get_dataset()
 | 
			
		||||
    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 | 
			
		||||
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      model = self._get_model()
 | 
			
		||||
      optimizer = keras.optimizer_v2.rmsprop.RMSprop()
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step(iterator):
 | 
			
		||||
      def step_fn(inputs):
 | 
			
		||||
        images, targets = inputs
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          outputs = model(images)
 | 
			
		||||
          loss = math_ops.reduce_sum(outputs - targets)
 | 
			
		||||
        grads = tape.gradient(loss, model.variables)
 | 
			
		||||
        optimizer.apply_gradients(zip(grads, model.variables))
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
      for _ in range(5):
 | 
			
		||||
        distribution.experimental_run_v2(step_fn, args=(next(iterator),))
 | 
			
		||||
 | 
			
		||||
    train_step(input_iterator)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies,
 | 
			
		||||
          mode=["eager"]
 | 
			
		||||
      ))
 | 
			
		||||
  def test_lstm(self, distribution):
 | 
			
		||||
 | 
			
		||||
    batch_size = 32
 | 
			
		||||
 | 
			
		||||
    def create_lstm_model():
 | 
			
		||||
      model = keras.models.Sequential()
 | 
			
		||||
      # We only have LSTM variables so we can detect no gradient issues more
 | 
			
		||||
      # easily.
 | 
			
		||||
      model.add(
 | 
			
		||||
          keras.layers.LSTM(1, return_sequences=False, input_shape=(10, 1)))
 | 
			
		||||
      return model
 | 
			
		||||
 | 
			
		||||
    def create_lstm_data():
 | 
			
		||||
      seq_length = 10
 | 
			
		||||
 | 
			
		||||
      x_train = np.random.rand(batch_size, seq_length, 1).astype("float32")
 | 
			
		||||
      y_train = np.random.rand(batch_size, 1).astype("float32")
 | 
			
		||||
      return x_train, y_train
 | 
			
		||||
 | 
			
		||||
    x, y = create_lstm_data()
 | 
			
		||||
    dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
 | 
			
		||||
    dataset = dataset.batch(batch_size, drop_remainder=True)
 | 
			
		||||
    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 | 
			
		||||
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      model = create_lstm_model()
 | 
			
		||||
      optimizer = keras.optimizer_v2.gradient_descent.SGD()
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step(input_iterator):
 | 
			
		||||
 | 
			
		||||
      def step_fn(inputs):
 | 
			
		||||
        inps, targ = inputs
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          output = model(inps)
 | 
			
		||||
          loss = math_ops.reduce_mean(
 | 
			
		||||
              keras.losses.binary_crossentropy(
 | 
			
		||||
                  y_true=targ, y_pred=output, from_logits=False))
 | 
			
		||||
        grads = tape.gradient(loss, model.variables)
 | 
			
		||||
        optimizer.apply_gradients(zip(grads, model.variables))
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
      outputs = distribution.experimental_run_v2(
 | 
			
		||||
          step_fn, args=(next(input_iterator),))
 | 
			
		||||
      return distribution.experimental_local_results(outputs)
 | 
			
		||||
 | 
			
		||||
    train_step(input_iterator)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          distribution=strategy_combinations.all_strategies, mode=["eager"]))
 | 
			
		||||
  def test_nested_tf_functions(self, distribution):
 | 
			
		||||
    # The test builds two computations with keras layers, one with nested
 | 
			
		||||
    # tf.function, and the other without nested tf.function. We run these
 | 
			
		||||
    # computations independently on the model with same weights, and make sure
 | 
			
		||||
    # the variables are still the same after one training step.
 | 
			
		||||
 | 
			
		||||
    inputs = np.random.random((10, 3)).astype(np.float32)
 | 
			
		||||
    targets = np.ones((10, 4), dtype=np.float32)
 | 
			
		||||
    dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).repeat()
 | 
			
		||||
    dataset = dataset.batch(10, drop_remainder=True)
 | 
			
		||||
    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
 | 
			
		||||
 | 
			
		||||
    def get_model():
 | 
			
		||||
      x = keras.layers.Input(shape=(3,), name="input")
 | 
			
		||||
      y = keras.layers.Dense(4, name="dense")(x)
 | 
			
		||||
      model = keras.Model(x, y)
 | 
			
		||||
      return model
 | 
			
		||||
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      model = get_model()
 | 
			
		||||
      optimizer = keras.optimizer_v2.gradient_descent.SGD(0.1, momentum=0.01)
 | 
			
		||||
      weights_file = os.path.join(self.get_temp_dir(), ".h5")
 | 
			
		||||
      model.save_weights(weights_file)
 | 
			
		||||
      model2 = get_model()
 | 
			
		||||
      model2.load_weights(weights_file)
 | 
			
		||||
 | 
			
		||||
    # Make sure model and model2 variables are in sync when initialized.
 | 
			
		||||
    for model_v, model2_v in zip(model.variables, model2.variables):
 | 
			
		||||
      self.assertAllClose(model_v.numpy(), model2_v.numpy())
 | 
			
		||||
 | 
			
		||||
    def compute_loss(images, targets):
 | 
			
		||||
      outputs = model(images)
 | 
			
		||||
      return math_ops.reduce_sum(outputs - targets)
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step_without_nested_tf_function(inputs):
 | 
			
		||||
 | 
			
		||||
      def step_fn(inputs):
 | 
			
		||||
        images, targets = inputs
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          loss = compute_loss(images, targets)
 | 
			
		||||
        grads = tape.gradient(loss, model.variables)
 | 
			
		||||
        optimizer.apply_gradients(zip(grads, model.variables))
 | 
			
		||||
 | 
			
		||||
      distribution.experimental_run_v2(step_fn, args=(inputs,))
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def compute_loss2(images, targets):
 | 
			
		||||
      outputs = model2(images)
 | 
			
		||||
      return math_ops.reduce_sum(outputs - targets)
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def train_step_with_nested_tf_function(inputs):
 | 
			
		||||
 | 
			
		||||
      def step_fn(inputs):
 | 
			
		||||
        images, targets = inputs
 | 
			
		||||
        with backprop.GradientTape() as tape:
 | 
			
		||||
          loss = compute_loss2(images, targets)
 | 
			
		||||
        grads = tape.gradient(loss, model2.variables)
 | 
			
		||||
        optimizer.apply_gradients(zip(grads, model2.variables))
 | 
			
		||||
 | 
			
		||||
      distribution.experimental_run_v2(step_fn, args=(inputs,))
 | 
			
		||||
 | 
			
		||||
    inputs = next(input_iterator)
 | 
			
		||||
 | 
			
		||||
    train_step_without_nested_tf_function(inputs)
 | 
			
		||||
    train_step_with_nested_tf_function(inputs)
 | 
			
		||||
 | 
			
		||||
    # Make sure model and model2 variables are still in sync.
 | 
			
		||||
    for model_v, model2_v in zip(model.variables, model2.variables):
 | 
			
		||||
      self.assertAllClose(model_v.numpy(), model2_v.numpy())
 | 
			
		||||
 | 
			
		||||
  def _get_dataset(self):
 | 
			
		||||
    inputs = np.zeros((10, 3), dtype=np.float32)
 | 
			
		||||
    targets = np.zeros((10, 4), dtype=np.float32)
 | 
			
		||||
    dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
 | 
			
		||||
    dataset = dataset.repeat(100)
 | 
			
		||||
    dataset = dataset.batch(10, drop_remainder=True)
 | 
			
		||||
    return dataset
 | 
			
		||||
 | 
			
		||||
  def _get_model(self):
 | 
			
		||||
    x = keras.layers.Input(shape=(3,), name="input")
 | 
			
		||||
    y = keras.layers.Dense(4, name="dense")(x)
 | 
			
		||||
    model = keras.Model(x, y)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
  test.main()
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user