Fork the keras related mirrored_strategy_test to keras/distribute.
PiperOrigin-RevId: 317329998 Change-Id: I7dc55499cb0409129729696c286862a6b6d574aa
This commit is contained in:
parent
4d751f9da4
commit
d00691f7aa
@ -1486,7 +1486,6 @@ cuda_py_test(
|
||||
"//tensorflow/python/autograph/core:test_lib",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/keras/layers",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -22,7 +22,6 @@ import json
|
||||
import sys
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import tf2
|
||||
@ -50,16 +49,12 @@ from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras.engine import training as keras_training
|
||||
from tensorflow.python.keras.layers import core as keras_core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradients
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import optimizer as optimizer_lib
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
|
||||
@ -988,22 +983,6 @@ class MockModel(object):
|
||||
return x
|
||||
|
||||
|
||||
class MiniModel(keras_training.Model):
|
||||
"""Minimal model for mnist.
|
||||
|
||||
Useful for testing and debugging on slow TPU simulators.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(MiniModel, self).__init__(name="")
|
||||
self.fc = keras_core.Dense(1, name="fc", kernel_initializer="ones",
|
||||
bias_initializer="ones")
|
||||
|
||||
def call(self, inputs, training=True):
|
||||
inputs = array_ops.ones([1, 10])
|
||||
return self.fc(inputs)
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
@ -1116,32 +1095,6 @@ class MirroredStrategyDefunTest(test.TestCase):
|
||||
expected_result = values.PerReplica((5.0 * 1.25, 3.0 * 1.25))
|
||||
self._call_and_check(distribution, fn1, [factors], expected_result, [fn1])
|
||||
|
||||
def testTrain(self, distribution):
|
||||
with distribution.scope():
|
||||
mock_model = MiniModel()
|
||||
mock_model.call = function.defun(mock_model.call)
|
||||
|
||||
def loss_fn(ctx):
|
||||
del ctx
|
||||
return mock_model(array_ops.ones([1, 10]))
|
||||
|
||||
gradients_fn = backprop.implicit_grad(loss_fn)
|
||||
gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn)
|
||||
grads_and_vars = distribution.extended.call_for_each_replica(
|
||||
gradients_fn, args=(None,))
|
||||
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(0.25)
|
||||
update_ops = optimizer._distributed_apply(distribution, grads_and_vars) # pylint: disable=protected-access
|
||||
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(update_ops)
|
||||
|
||||
updated_var_values = self.evaluate(mock_model.variables)
|
||||
# All variables start at 1.0 and get two updates of 0.25.
|
||||
self.assertAllEqual(0.5 * np.ones([10, 1]), updated_var_values[0])
|
||||
self.assertAllEqual([0.5], updated_var_values[1])
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
|
@ -324,6 +324,29 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "mirrored_strategy_test",
|
||||
srcs = ["mirrored_strategy_test.py"],
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_windows_gpu", # TODO(b/130551176)
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:training_lib",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:function",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/keras/engine",
|
||||
"//tensorflow/python/keras/layers:core",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "multi_worker_test",
|
||||
srcs = ["multi_worker_test.py"],
|
||||
|
89
tensorflow/python/keras/distribute/mirrored_strategy_test.py
Normal file
89
tensorflow/python/keras/distribute/mirrored_strategy_test.py
Normal file
@ -0,0 +1,89 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for MirroredStrategy."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.keras.engine import training as keras_training
|
||||
from tensorflow.python.keras.layers import core as keras_core
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import optimizer as optimizer_lib
|
||||
|
||||
|
||||
class MiniModel(keras_training.Model):
|
||||
"""Minimal model for mnist.
|
||||
|
||||
Useful for testing and debugging on slow TPU simulators.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(MiniModel, self).__init__(name="")
|
||||
self.fc = keras_core.Dense(1, name="fc", kernel_initializer="ones",
|
||||
bias_initializer="ones")
|
||||
|
||||
def call(self, inputs, training=True):
|
||||
inputs = array_ops.ones([1, 10])
|
||||
return self.fc(inputs)
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
],
|
||||
mode=["graph", "eager"]))
|
||||
class MirroredStrategyDefunTest(test.TestCase):
|
||||
|
||||
def testTrain(self, distribution):
|
||||
with distribution.scope():
|
||||
mock_model = MiniModel()
|
||||
mock_model.call = function.defun(mock_model.call)
|
||||
|
||||
def loss_fn(ctx):
|
||||
del ctx
|
||||
return mock_model(array_ops.ones([1, 10]))
|
||||
|
||||
gradients_fn = backprop.implicit_grad(loss_fn)
|
||||
gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn)
|
||||
grads_and_vars = distribution.extended.call_for_each_replica(
|
||||
gradients_fn, args=(None,))
|
||||
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(0.25)
|
||||
update_ops = optimizer._distributed_apply(distribution, grads_and_vars) # pylint: disable=protected-access
|
||||
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(update_ops)
|
||||
|
||||
updated_var_values = self.evaluate(mock_model.variables)
|
||||
# All variables start at 1.0 and get two updates of 0.25.
|
||||
self.assertAllEqual(0.5 * np.ones([10, 1]), updated_var_values[0])
|
||||
self.assertAllEqual([0.5], updated_var_values[1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user