Update minimize_loss_test to not rely on Keras.
Also update the other related tests to use the extracted utils in strategy_test_lib.py PiperOrigin-RevId: 320122161 Change-Id: I0a8f66d19b8f6cf32978f3386d7cbbcbe3dc4b84
This commit is contained in:
parent
3dda4182aa
commit
d93971b09e
@ -1333,6 +1333,7 @@ distribute_py_test(
|
|||||||
":mirrored_strategy",
|
":mirrored_strategy",
|
||||||
":single_loss_example",
|
":single_loss_example",
|
||||||
":strategy_combinations",
|
":strategy_combinations",
|
||||||
|
":strategy_test_lib",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
"//tensorflow/python:control_flow_v2_toggles",
|
"//tensorflow/python:control_flow_v2_toggles",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
@ -1342,8 +1343,6 @@ distribute_py_test(
|
|||||||
"//tensorflow/python/data/ops:dataset_ops",
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
"//tensorflow/python/keras/layers",
|
|
||||||
"//tensorflow/python/keras/optimizer_v2",
|
|
||||||
"//tensorflow/python/ops/losses",
|
"//tensorflow/python/ops/losses",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
@ -1355,6 +1354,7 @@ py_library(
|
|||||||
srcs = ["single_loss_example.py"],
|
srcs = ["single_loss_example.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":step_fn",
|
":step_fn",
|
||||||
|
":strategy_test_lib",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:constant_op",
|
"//tensorflow/python:constant_op",
|
||||||
"//tensorflow/python:layers",
|
"//tensorflow/python:layers",
|
||||||
|
@ -20,22 +20,24 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import combinations
|
from tensorflow.python.distribute import combinations
|
||||||
from tensorflow.python.distribute import reduce_util
|
from tensorflow.python.distribute import reduce_util
|
||||||
from tensorflow.python.distribute import strategy_combinations
|
from tensorflow.python.distribute import strategy_combinations
|
||||||
|
from tensorflow.python.distribute import strategy_test_lib
|
||||||
from tensorflow.python.distribute.single_loss_example import batchnorm_example
|
from tensorflow.python.distribute.single_loss_example import batchnorm_example
|
||||||
from tensorflow.python.distribute.single_loss_example import minimize_loss_example
|
from tensorflow.python.distribute.single_loss_example import minimize_loss_example
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.keras.layers import core
|
|
||||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import control_flow_v2_toggles
|
from tensorflow.python.ops import control_flow_v2_toggles
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import nn_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables as variables_lib
|
from tensorflow.python.ops import variables as variables_lib
|
||||||
from tensorflow.python.ops.losses import losses_impl
|
from tensorflow.python.ops.losses import losses_impl
|
||||||
@ -208,7 +210,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
def get_expected_variables(num_parameter_devices):
|
def get_expected_variables(num_parameter_devices):
|
||||||
name = optimizer._name
|
name = optimizer._name
|
||||||
|
|
||||||
if isinstance(optimizer, optimizer_v2.OptimizerV2):
|
if strategy_test_lib.is_optimizer_v2_instance(optimizer):
|
||||||
variables = VAR_MAP_V2[name]
|
variables = VAR_MAP_V2[name]
|
||||||
else:
|
else:
|
||||||
variables = VAR_MAP_V1[name]
|
variables = VAR_MAP_V1[name]
|
||||||
@ -349,7 +351,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
optimizer = optimizer_fn() # GradientDescent with 0.2 learning rate
|
optimizer = optimizer_fn() # GradientDescent with 0.2 learning rate
|
||||||
|
|
||||||
if isinstance(optimizer, optimizer_v2.OptimizerV2):
|
if strategy_test_lib.is_optimizer_v2_instance(optimizer):
|
||||||
return optimizer.minimize(loss_fn, [w])
|
return optimizer.minimize(loss_fn, [w])
|
||||||
else:
|
else:
|
||||||
if use_callable_loss:
|
if use_callable_loss:
|
||||||
@ -426,7 +428,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
return dataset.batch(batch_size=1, drop_remainder=True)
|
return dataset.batch(batch_size=1, drop_remainder=True)
|
||||||
|
|
||||||
optimizer = optimizer_fn()
|
optimizer = optimizer_fn()
|
||||||
layer = core.Dense(1, use_bias=True)
|
kernel = strategy_test_lib.create_variable_like_keras_layer(
|
||||||
|
"kernel", (1, 1), dtypes.float32)
|
||||||
|
bias = strategy_test_lib.create_variable_like_keras_layer(
|
||||||
|
"bias", (1,), dtypes.float32)
|
||||||
|
# layer = core.Dense(1, use_bias=True)
|
||||||
|
|
||||||
key1 = "foo"
|
key1 = "foo"
|
||||||
value1 = "bar"
|
value1 = "bar"
|
||||||
@ -434,12 +440,13 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
def model_fn(output_context, x):
|
def model_fn(output_context, x):
|
||||||
"""A very simple model written by the user."""
|
"""A very simple model written by the user."""
|
||||||
def loss_fn():
|
def loss_fn():
|
||||||
y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
|
y = array_ops.reshape(nn_ops.bias_add(
|
||||||
|
math_ops.matmul(x, kernel), bias), []) - constant_op.constant(1.)
|
||||||
return y * y
|
return y * y
|
||||||
|
|
||||||
if isinstance(optimizer, optimizer_v2.OptimizerV2):
|
if strategy_test_lib.is_optimizer_v2_instance(optimizer):
|
||||||
train_op = optimizer.minimize(
|
train_op = optimizer.minimize(
|
||||||
loss_fn, lambda: layer.trainable_variables)
|
loss_fn, lambda: [kernel, bias])
|
||||||
else:
|
else:
|
||||||
train_op = optimizer.minimize(loss_fn)
|
train_op = optimizer.minimize(loss_fn)
|
||||||
loss = loss_fn()
|
loss = loss_fn()
|
||||||
@ -508,8 +515,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
_, loss = run_step()
|
_, loss = run_step()
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
weights.append(self.evaluate(layer.kernel))
|
weights.append(self.evaluate(kernel))
|
||||||
biases.append(self.evaluate(layer.bias))
|
biases.append(self.evaluate(bias))
|
||||||
|
|
||||||
loss_is_not_increasing = all(y <= x for x, y in zip(losses, losses[1:]))
|
loss_is_not_increasing = all(y <= x for x, y in zip(losses, losses[1:]))
|
||||||
self.assertTrue(loss_is_not_increasing)
|
self.assertTrue(loss_is_not_increasing)
|
||||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import functools
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
@ -51,7 +50,6 @@ from tensorflow.python.framework import tensor_util
|
|||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import gradients
|
from tensorflow.python.ops import gradients
|
||||||
from tensorflow.python.ops import init_ops_v2
|
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import partitioned_variables
|
from tensorflow.python.ops import partitioned_variables
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
@ -452,10 +450,8 @@ class ParameterServerStrategyTestBase(
|
|||||||
self.cached_session(target=master_target,
|
self.cached_session(target=master_target,
|
||||||
config=sess_config) as sess, \
|
config=sess_config) as sess, \
|
||||||
d.scope():
|
d.scope():
|
||||||
initializer = functools.partial(
|
kernel = strategy_test_lib.create_variable_like_keras_layer(
|
||||||
init_ops_v2.GlorotUniform(), (1, 1), dtype=dtypes.float32)
|
'kernel', (1, 1), dtypes.float32,)
|
||||||
kernel = variables.Variable(
|
|
||||||
initial_value=initializer, name='kernel', trainable=True)
|
|
||||||
|
|
||||||
def loss_fn(x):
|
def loss_fn(x):
|
||||||
y = array_ops.reshape(
|
y = array_ops.reshape(
|
||||||
|
@ -20,13 +20,13 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.distribute import step_fn
|
from tensorflow.python.distribute import step_fn
|
||||||
|
from tensorflow.python.distribute import strategy_test_lib
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.layers import core
|
from tensorflow.python.layers import core
|
||||||
from tensorflow.python.layers import normalization
|
from tensorflow.python.layers import normalization
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.util import tf_inspect
|
|
||||||
|
|
||||||
|
|
||||||
def single_loss_example(optimizer_fn, distribution, use_bias=False,
|
def single_loss_example(optimizer_fn, distribution, use_bias=False,
|
||||||
@ -69,7 +69,7 @@ def minimize_loss_example(optimizer, use_bias=False, use_callable_loss=True):
|
|||||||
y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
|
y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
|
||||||
return y * y
|
return y * y
|
||||||
|
|
||||||
if _is_optimizer_v2_instance(optimizer):
|
if strategy_test_lib.is_optimizer_v2_instance(optimizer):
|
||||||
return optimizer.minimize(loss_fn, lambda: layer.trainable_variables)
|
return optimizer.minimize(loss_fn, lambda: layer.trainable_variables)
|
||||||
elif use_callable_loss:
|
elif use_callable_loss:
|
||||||
return optimizer.minimize(loss_fn)
|
return optimizer.minimize(loss_fn)
|
||||||
@ -112,17 +112,10 @@ def batchnorm_example(optimizer_fn,
|
|||||||
# `x` and `y` will be fetched by the gradient computation, but not `loss`.
|
# `x` and `y` will be fetched by the gradient computation, but not `loss`.
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
if _is_optimizer_v2_instance(optimizer):
|
if strategy_test_lib.is_optimizer_v2_instance(optimizer):
|
||||||
return optimizer.minimize(loss_fn, lambda: layer.trainable_variables)
|
return optimizer.minimize(loss_fn, lambda: layer.trainable_variables)
|
||||||
|
|
||||||
# Callable loss.
|
# Callable loss.
|
||||||
return optimizer.minimize(loss_fn)
|
return optimizer.minimize(loss_fn)
|
||||||
|
|
||||||
return model_fn, dataset_fn, batchnorm
|
return model_fn, dataset_fn, batchnorm
|
||||||
|
|
||||||
|
|
||||||
def _is_optimizer_v2_instance(optimizer):
|
|
||||||
# For a optimizer instance, the v2 implementation has var_list as a required
|
|
||||||
# argument.
|
|
||||||
arg_spec = tf_inspect.getfullargspec(optimizer.minimize)
|
|
||||||
return 'var_list' in arg_spec.args[:-len(arg_spec.defaults)]
|
|
||||||
|
@ -52,6 +52,7 @@ from tensorflow.python.platform import gfile
|
|||||||
from tensorflow.python.training import optimizer
|
from tensorflow.python.training import optimizer
|
||||||
from tensorflow.python.training import training_util
|
from tensorflow.python.training import training_util
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
from tensorflow.python.util import tf_inspect
|
||||||
|
|
||||||
|
|
||||||
class _TestException(Exception):
|
class _TestException(Exception):
|
||||||
@ -113,18 +114,27 @@ def _events_from_logdir(test_case, logdir):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class DistributionTestBase(test.TestCase):
|
def create_variable_like_keras_layer(name, shape, dtype):
|
||||||
"""Some tests that should work with any DistributionStrategy."""
|
"""Utitlity for create variables that works like variable in keras layer."""
|
||||||
|
|
||||||
def _create_variable_like_keras_dense_layer(self, name, shape, dtype):
|
|
||||||
initializer = functools.partial(
|
initializer = functools.partial(
|
||||||
init_ops_v2.GlorotUniform(), shape, dtype=dtype)
|
init_ops_v2.GlorotUniform(), shape, dtype=dtype)
|
||||||
return variables.Variable(
|
return variables.Variable(
|
||||||
initial_value=initializer, name=name, trainable=True)
|
initial_value=initializer, name=name, trainable=True)
|
||||||
|
|
||||||
|
|
||||||
|
def is_optimizer_v2_instance(optimizer_obj):
|
||||||
|
# For a optimizer instance, the v2 implementation has var_list as a required
|
||||||
|
# argument.
|
||||||
|
arg_spec = tf_inspect.getfullargspec(optimizer_obj.minimize)
|
||||||
|
return "var_list" in arg_spec.args[:-len(arg_spec.defaults)]
|
||||||
|
|
||||||
|
|
||||||
|
class DistributionTestBase(test.TestCase):
|
||||||
|
"""Some tests that should work with any DistributionStrategy."""
|
||||||
|
|
||||||
def _test_minimize_loss_eager(self, d):
|
def _test_minimize_loss_eager(self, d):
|
||||||
with d.scope():
|
with d.scope():
|
||||||
kernel = self._create_variable_like_keras_dense_layer(
|
kernel = create_variable_like_keras_layer(
|
||||||
name="kernel", shape=(1, 1), dtype=dtypes.float32)
|
name="kernel", shape=(1, 1), dtype=dtypes.float32)
|
||||||
def loss(x):
|
def loss(x):
|
||||||
y = array_ops.reshape(
|
y = array_ops.reshape(
|
||||||
@ -182,7 +192,7 @@ class DistributionTestBase(test.TestCase):
|
|||||||
ops.Graph().as_default(), \
|
ops.Graph().as_default(), \
|
||||||
self.cached_session(config=config) as sess, \
|
self.cached_session(config=config) as sess, \
|
||||||
d.scope():
|
d.scope():
|
||||||
kernel = self._create_variable_like_keras_dense_layer(
|
kernel = create_variable_like_keras_layer(
|
||||||
name="kernel", shape=(1, 1), dtype=dtypes.float32)
|
name="kernel", shape=(1, 1), dtype=dtypes.float32)
|
||||||
|
|
||||||
def loss(x):
|
def loss(x):
|
||||||
|
Loading…
Reference in New Issue
Block a user