Move the optimizer name scope from model.training to optimizer

PiperOrigin-RevId: 249328139
This commit is contained in:
Zhenyu Tan 2019-05-21 14:39:22 -07:00 committed by TensorFlower Gardener
parent 438ff85035
commit 611db3490b
6 changed files with 74 additions and 75 deletions

View File

@ -72,8 +72,9 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
def testTrainNetwork(self, distribution, optimizer_fn,
use_callable_loss=True):
with distribution.scope():
optimizer = optimizer_fn()
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
optimizer, use_bias=True, use_callable_loss=use_callable_loss)
iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())
def run_step():

View File

@ -47,10 +47,12 @@ VAR_MAP_V1 = {
}
VAR_MAP_V2 = {
"SGD": ("dense/bias", "learning_rate", "decay", "iter", "dense/kernel",
"momentum"),
"Adagrad": ("iter", "dense/bias", "dense/kernel", "learning_rate", "decay",
"dense/kernel/accumulator", "dense/bias/accumulator")
"SGD": ("dense/bias", "SGD/learning_rate", "SGD/decay", "SGD/iter",
"dense/kernel", "SGD/momentum"),
"Adagrad":
("Adagrad/iter", "dense/bias", "dense/kernel", "Adagrad/learning_rate",
"Adagrad/decay", "Adagrad/dense/kernel/accumulator",
"Adagrad/dense/bias/accumulator")
}
@ -81,8 +83,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
use_callable_loss=[True, False]))
def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss):
with distribution.scope():
optimizer = optimizer_fn()
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
optimizer, use_bias=True, use_callable_loss=use_callable_loss)
def step_fn(ctx, inputs):
del ctx # Unused
@ -123,8 +126,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def testTrainNetworkByCallForEachReplica(self, distribution, optimizer_fn,
use_callable_loss):
with distribution.scope():
optimizer = optimizer_fn()
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
optimizer, use_bias=True, use_callable_loss=use_callable_loss)
iterator = self._get_iterator(distribution, dataset_fn)
@ -171,11 +175,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
# `distribution.scope`.
with variable_scope.variable_creator_scope(
appending_creator), distribution.scope():
optimizer = optimizer_fn()
model_fn, dataset_fn, _ = minimize_loss_example(
optimizer_fn,
use_bias=True,
use_callable_loss=True,
create_optimizer_inside_model_fn=True)
optimizer, use_bias=True, use_callable_loss=True)
def step_fn(ctx, inputs):
del ctx # Unused
@ -195,8 +197,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.evaluate(variables_lib.global_variables_initializer())
run_step()
def get_expected_variables(optimizer_fn, num_parameter_devices):
optimizer = optimizer_fn()
def get_expected_variables(num_parameter_devices):
name = optimizer._name
if isinstance(optimizer, optimizer_v2.OptimizerV2):
@ -213,8 +214,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
return set([v + ":0" for v in variables])
self.assertEqual(
get_expected_variables(optimizer_fn,
len(distribution.extended.parameter_devices)),
get_expected_variables(len(distribution.extended.parameter_devices)),
set(created_variables))
@combinations.generate(

View File

@ -50,27 +50,15 @@ def single_loss_example(optimizer_fn, distribution, use_bias=False,
return single_loss_step, layer
def minimize_loss_example(optimizer_fn,
use_bias=False,
use_callable_loss=True,
create_optimizer_inside_model_fn=False):
def minimize_loss_example(optimizer, use_bias=False, use_callable_loss=True):
"""Example of non-distribution-aware legacy code."""
if isinstance(optimizer_fn(), optimizer_v2.OptimizerV2):
# Keras optimizer v2 always uses callable loss
assert use_callable_loss
def dataset_fn():
dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
# TODO(isaprykin): batch with drop_remainder causes shapes to be
# fully defined for TPU. Remove this when XLA supports dynamic shapes.
return dataset.batch(1, drop_remainder=True)
# An Optimizer instance is created either outside or inside model_fn.
outer_optimizer = None
if not create_optimizer_inside_model_fn:
outer_optimizer = optimizer_fn()
layer = core.Dense(1, use_bias=use_bias)
def model_fn(x):
@ -80,12 +68,9 @@ def minimize_loss_example(optimizer_fn,
y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
return y * y
optimizer = outer_optimizer or optimizer_fn()
if isinstance(optimizer, optimizer_v2.OptimizerV2):
return optimizer.minimize(loss_fn, lambda: layer.trainable_variables)
if use_callable_loss:
elif use_callable_loss:
return optimizer.minimize(loss_fn)
else:
return optimizer.minimize(loss_fn())

View File

@ -2214,10 +2214,9 @@ class Model(network.Network):
with K.get_graph().as_default():
with K.name_scope('training'):
with K.name_scope(self.optimizer.__class__.__name__):
# Training updates
updates = self.optimizer.get_updates(
params=self._collected_trainable_weights, loss=self.total_loss)
# Training updates
updates = self.optimizer.get_updates(
params=self._collected_trainable_weights, loss=self.total_loss)
# Unconditional updates
updates += self.get_updates_for(None)
# Conditional updates relevant to this model

View File

@ -36,6 +36,7 @@ from tensorflow.python.keras import backend
from tensorflow.python.keras import initializers
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@ -257,7 +258,10 @@ class OptimizerV2(trackable.Trackable):
raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k]))
self._use_locking = True
self._name = name
self._init_set_name(name)
# in graph mode, name_scope performs uniquification, so keep scope_context.
with backend.name_scope(self._name) as name_scope:
self._scope_ctx = name_scope
self._hyper = {}
# dict: {variable name : {slot name : variable}}
self._slots = {}
@ -349,15 +353,16 @@ class OptimizerV2(trackable.Trackable):
if callable(var_list):
var_list = var_list()
var_list = nest.flatten(var_list)
grads = tape.gradient(loss_value, var_list, grad_loss)
with backend.name_scope(self._scope_ctx):
grads = tape.gradient(loss_value, var_list, grad_loss)
if hasattr(self, "clipnorm"):
grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
if hasattr(self, "clipvalue"):
grads = [
clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
for g in grads
]
if hasattr(self, "clipnorm"):
grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
if hasattr(self, "clipvalue"):
grads = [
clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
for g in grads
]
grads_and_vars = list(zip(grads, var_list))
self._assert_valid_dtypes([
@ -382,22 +387,22 @@ class OptimizerV2(trackable.Trackable):
function not implemented).
"""
params = nest.flatten(params)
with backend.get_graph().as_default():
with backend.get_graph().as_default(), backend.name_scope(self._scope_ctx):
grads = gradients.gradients(loss, params)
for grad, param in zip(grads, params):
if grad is None:
raise ValueError("Variable {} has `None` for gradient. "
"Please make sure that all of your ops have a "
"gradient defined (i.e. are differentiable). "
"Common ops without gradient: "
"K.argmax, K.round, K.eval.".format(param))
if hasattr(self, "clipnorm"):
grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
if hasattr(self, "clipvalue"):
grads = [
clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
for g in grads
]
for grad, param in zip(grads, params):
if grad is None:
raise ValueError("Variable {} has `None` for gradient. "
"Please make sure that all of your ops have a "
"gradient defined (i.e. are differentiable). "
"Common ops without gradient: "
"K.argmax, K.round, K.eval.".format(param))
if hasattr(self, "clipnorm"):
grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
if hasattr(self, "clipvalue"):
grads = [
clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
for g in grads
]
return grads
def apply_gradients(self, grads_and_vars, name=None):
@ -422,16 +427,19 @@ class OptimizerV2(trackable.Trackable):
grads_and_vars = _filter_grads(grads_and_vars)
var_list = [v for (_, v) in grads_and_vars]
# Create iteration if necessary.
with ops.init_scope():
_ = self.iterations
self._create_hypers()
self._create_slots(var_list)
with backend.name_scope(self._scope_ctx):
# Create iteration if necessary.
with ops.init_scope():
_ = self.iterations
self._create_hypers()
self._create_slots(var_list)
self._prepare(var_list)
self._prepare(var_list)
return distribute_ctx.get_replica_context().merge_call(
self._distributed_apply, args=(grads_and_vars,), kwargs={"name": name})
return distribute_ctx.get_replica_context().merge_call(
self._distributed_apply,
args=(grads_and_vars,),
kwargs={"name": name})
def _distributed_apply(self, distribution, grads_and_vars, name):
"""`apply_gradients` using a `DistributionStrategy`."""
@ -764,6 +772,14 @@ class OptimizerV2(trackable.Trackable):
return variable
def _init_set_name(self, name, zero_based=True):
if not name:
self._name = backend.unique_object_name(
generic_utils.to_snake_case(self.__class__.__name__),
zero_based=zero_based)
else:
self._name = name
def _assert_valid_dtypes(self, tensors):
"""Asserts tensors are all valid types (see `_valid_dtypes`).

View File

@ -336,12 +336,10 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
self.assertEqual(
"my_model/dense/kernel",
named_variables["model/_named_dense/kernel" + suffix].full_name)
self.assertEqual(
"beta_1",
named_variables["optimizer/beta_1" + suffix].full_name)
self.assertEqual(
"beta_2",
named_variables["optimizer/beta_2" + suffix].full_name)
self.assertEqual("Adam/beta_1",
named_variables["optimizer/beta_1" + suffix].full_name)
self.assertEqual("Adam/beta_2",
named_variables["optimizer/beta_2" + suffix].full_name)
# Spot check the generated protocol buffers.
self.assertEqual("optimizer",
serialized_graph.nodes[0].children[1].local_name)
@ -350,7 +348,7 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
children = [node.local_name for node in optimizer_node.children]
six.assertCountEqual(
self,
# Non-slot dependencies
# hyper variable dependencies
["beta_1", "beta_2", "iter", "decay", "learning_rate"],
children)
serialized_slot_keys = []