Move the optimizer name scope from model.training to optimizer
PiperOrigin-RevId: 249328139
This commit is contained in:
parent
438ff85035
commit
611db3490b
@ -72,8 +72,9 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
|
||||
def testTrainNetwork(self, distribution, optimizer_fn,
|
||||
use_callable_loss=True):
|
||||
with distribution.scope():
|
||||
optimizer = optimizer_fn()
|
||||
model_fn, dataset_fn, layer = minimize_loss_example(
|
||||
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
|
||||
optimizer, use_bias=True, use_callable_loss=use_callable_loss)
|
||||
iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())
|
||||
|
||||
def run_step():
|
||||
|
@ -47,10 +47,12 @@ VAR_MAP_V1 = {
|
||||
}
|
||||
|
||||
VAR_MAP_V2 = {
|
||||
"SGD": ("dense/bias", "learning_rate", "decay", "iter", "dense/kernel",
|
||||
"momentum"),
|
||||
"Adagrad": ("iter", "dense/bias", "dense/kernel", "learning_rate", "decay",
|
||||
"dense/kernel/accumulator", "dense/bias/accumulator")
|
||||
"SGD": ("dense/bias", "SGD/learning_rate", "SGD/decay", "SGD/iter",
|
||||
"dense/kernel", "SGD/momentum"),
|
||||
"Adagrad":
|
||||
("Adagrad/iter", "dense/bias", "dense/kernel", "Adagrad/learning_rate",
|
||||
"Adagrad/decay", "Adagrad/dense/kernel/accumulator",
|
||||
"Adagrad/dense/bias/accumulator")
|
||||
}
|
||||
|
||||
|
||||
@ -81,8 +83,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
use_callable_loss=[True, False]))
|
||||
def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss):
|
||||
with distribution.scope():
|
||||
optimizer = optimizer_fn()
|
||||
model_fn, dataset_fn, layer = minimize_loss_example(
|
||||
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
|
||||
optimizer, use_bias=True, use_callable_loss=use_callable_loss)
|
||||
|
||||
def step_fn(ctx, inputs):
|
||||
del ctx # Unused
|
||||
@ -123,8 +126,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
def testTrainNetworkByCallForEachReplica(self, distribution, optimizer_fn,
|
||||
use_callable_loss):
|
||||
with distribution.scope():
|
||||
optimizer = optimizer_fn()
|
||||
model_fn, dataset_fn, layer = minimize_loss_example(
|
||||
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
|
||||
optimizer, use_bias=True, use_callable_loss=use_callable_loss)
|
||||
|
||||
iterator = self._get_iterator(distribution, dataset_fn)
|
||||
|
||||
@ -171,11 +175,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
# `distribution.scope`.
|
||||
with variable_scope.variable_creator_scope(
|
||||
appending_creator), distribution.scope():
|
||||
optimizer = optimizer_fn()
|
||||
model_fn, dataset_fn, _ = minimize_loss_example(
|
||||
optimizer_fn,
|
||||
use_bias=True,
|
||||
use_callable_loss=True,
|
||||
create_optimizer_inside_model_fn=True)
|
||||
optimizer, use_bias=True, use_callable_loss=True)
|
||||
|
||||
def step_fn(ctx, inputs):
|
||||
del ctx # Unused
|
||||
@ -195,8 +197,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
self.evaluate(variables_lib.global_variables_initializer())
|
||||
run_step()
|
||||
|
||||
def get_expected_variables(optimizer_fn, num_parameter_devices):
|
||||
optimizer = optimizer_fn()
|
||||
def get_expected_variables(num_parameter_devices):
|
||||
name = optimizer._name
|
||||
|
||||
if isinstance(optimizer, optimizer_v2.OptimizerV2):
|
||||
@ -213,8 +214,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
return set([v + ":0" for v in variables])
|
||||
|
||||
self.assertEqual(
|
||||
get_expected_variables(optimizer_fn,
|
||||
len(distribution.extended.parameter_devices)),
|
||||
get_expected_variables(len(distribution.extended.parameter_devices)),
|
||||
set(created_variables))
|
||||
|
||||
@combinations.generate(
|
||||
|
@ -50,27 +50,15 @@ def single_loss_example(optimizer_fn, distribution, use_bias=False,
|
||||
return single_loss_step, layer
|
||||
|
||||
|
||||
def minimize_loss_example(optimizer_fn,
|
||||
use_bias=False,
|
||||
use_callable_loss=True,
|
||||
create_optimizer_inside_model_fn=False):
|
||||
def minimize_loss_example(optimizer, use_bias=False, use_callable_loss=True):
|
||||
"""Example of non-distribution-aware legacy code."""
|
||||
|
||||
if isinstance(optimizer_fn(), optimizer_v2.OptimizerV2):
|
||||
# Keras optimizer v2 always uses callable loss
|
||||
assert use_callable_loss
|
||||
|
||||
def dataset_fn():
|
||||
dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
|
||||
# TODO(isaprykin): batch with drop_remainder causes shapes to be
|
||||
# fully defined for TPU. Remove this when XLA supports dynamic shapes.
|
||||
return dataset.batch(1, drop_remainder=True)
|
||||
|
||||
# An Optimizer instance is created either outside or inside model_fn.
|
||||
outer_optimizer = None
|
||||
if not create_optimizer_inside_model_fn:
|
||||
outer_optimizer = optimizer_fn()
|
||||
|
||||
layer = core.Dense(1, use_bias=use_bias)
|
||||
|
||||
def model_fn(x):
|
||||
@ -80,12 +68,9 @@ def minimize_loss_example(optimizer_fn,
|
||||
y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
|
||||
return y * y
|
||||
|
||||
optimizer = outer_optimizer or optimizer_fn()
|
||||
|
||||
if isinstance(optimizer, optimizer_v2.OptimizerV2):
|
||||
return optimizer.minimize(loss_fn, lambda: layer.trainable_variables)
|
||||
|
||||
if use_callable_loss:
|
||||
elif use_callable_loss:
|
||||
return optimizer.minimize(loss_fn)
|
||||
else:
|
||||
return optimizer.minimize(loss_fn())
|
||||
|
@ -2214,10 +2214,9 @@ class Model(network.Network):
|
||||
|
||||
with K.get_graph().as_default():
|
||||
with K.name_scope('training'):
|
||||
with K.name_scope(self.optimizer.__class__.__name__):
|
||||
# Training updates
|
||||
updates = self.optimizer.get_updates(
|
||||
params=self._collected_trainable_weights, loss=self.total_loss)
|
||||
# Training updates
|
||||
updates = self.optimizer.get_updates(
|
||||
params=self._collected_trainable_weights, loss=self.total_loss)
|
||||
# Unconditional updates
|
||||
updates += self.get_updates_for(None)
|
||||
# Conditional updates relevant to this model
|
||||
|
@ -36,6 +36,7 @@ from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
@ -257,7 +258,10 @@ class OptimizerV2(trackable.Trackable):
|
||||
raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k]))
|
||||
|
||||
self._use_locking = True
|
||||
self._name = name
|
||||
self._init_set_name(name)
|
||||
# in graph mode, name_scope performs uniquification, so keep scope_context.
|
||||
with backend.name_scope(self._name) as name_scope:
|
||||
self._scope_ctx = name_scope
|
||||
self._hyper = {}
|
||||
# dict: {variable name : {slot name : variable}}
|
||||
self._slots = {}
|
||||
@ -349,15 +353,16 @@ class OptimizerV2(trackable.Trackable):
|
||||
if callable(var_list):
|
||||
var_list = var_list()
|
||||
var_list = nest.flatten(var_list)
|
||||
grads = tape.gradient(loss_value, var_list, grad_loss)
|
||||
with backend.name_scope(self._scope_ctx):
|
||||
grads = tape.gradient(loss_value, var_list, grad_loss)
|
||||
|
||||
if hasattr(self, "clipnorm"):
|
||||
grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
|
||||
if hasattr(self, "clipvalue"):
|
||||
grads = [
|
||||
clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
|
||||
for g in grads
|
||||
]
|
||||
if hasattr(self, "clipnorm"):
|
||||
grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
|
||||
if hasattr(self, "clipvalue"):
|
||||
grads = [
|
||||
clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
|
||||
for g in grads
|
||||
]
|
||||
|
||||
grads_and_vars = list(zip(grads, var_list))
|
||||
self._assert_valid_dtypes([
|
||||
@ -382,22 +387,22 @@ class OptimizerV2(trackable.Trackable):
|
||||
function not implemented).
|
||||
"""
|
||||
params = nest.flatten(params)
|
||||
with backend.get_graph().as_default():
|
||||
with backend.get_graph().as_default(), backend.name_scope(self._scope_ctx):
|
||||
grads = gradients.gradients(loss, params)
|
||||
for grad, param in zip(grads, params):
|
||||
if grad is None:
|
||||
raise ValueError("Variable {} has `None` for gradient. "
|
||||
"Please make sure that all of your ops have a "
|
||||
"gradient defined (i.e. are differentiable). "
|
||||
"Common ops without gradient: "
|
||||
"K.argmax, K.round, K.eval.".format(param))
|
||||
if hasattr(self, "clipnorm"):
|
||||
grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
|
||||
if hasattr(self, "clipvalue"):
|
||||
grads = [
|
||||
clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
|
||||
for g in grads
|
||||
]
|
||||
for grad, param in zip(grads, params):
|
||||
if grad is None:
|
||||
raise ValueError("Variable {} has `None` for gradient. "
|
||||
"Please make sure that all of your ops have a "
|
||||
"gradient defined (i.e. are differentiable). "
|
||||
"Common ops without gradient: "
|
||||
"K.argmax, K.round, K.eval.".format(param))
|
||||
if hasattr(self, "clipnorm"):
|
||||
grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
|
||||
if hasattr(self, "clipvalue"):
|
||||
grads = [
|
||||
clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
|
||||
for g in grads
|
||||
]
|
||||
return grads
|
||||
|
||||
def apply_gradients(self, grads_and_vars, name=None):
|
||||
@ -422,16 +427,19 @@ class OptimizerV2(trackable.Trackable):
|
||||
grads_and_vars = _filter_grads(grads_and_vars)
|
||||
var_list = [v for (_, v) in grads_and_vars]
|
||||
|
||||
# Create iteration if necessary.
|
||||
with ops.init_scope():
|
||||
_ = self.iterations
|
||||
self._create_hypers()
|
||||
self._create_slots(var_list)
|
||||
with backend.name_scope(self._scope_ctx):
|
||||
# Create iteration if necessary.
|
||||
with ops.init_scope():
|
||||
_ = self.iterations
|
||||
self._create_hypers()
|
||||
self._create_slots(var_list)
|
||||
|
||||
self._prepare(var_list)
|
||||
self._prepare(var_list)
|
||||
|
||||
return distribute_ctx.get_replica_context().merge_call(
|
||||
self._distributed_apply, args=(grads_and_vars,), kwargs={"name": name})
|
||||
return distribute_ctx.get_replica_context().merge_call(
|
||||
self._distributed_apply,
|
||||
args=(grads_and_vars,),
|
||||
kwargs={"name": name})
|
||||
|
||||
def _distributed_apply(self, distribution, grads_and_vars, name):
|
||||
"""`apply_gradients` using a `DistributionStrategy`."""
|
||||
@ -764,6 +772,14 @@ class OptimizerV2(trackable.Trackable):
|
||||
|
||||
return variable
|
||||
|
||||
def _init_set_name(self, name, zero_based=True):
|
||||
if not name:
|
||||
self._name = backend.unique_object_name(
|
||||
generic_utils.to_snake_case(self.__class__.__name__),
|
||||
zero_based=zero_based)
|
||||
else:
|
||||
self._name = name
|
||||
|
||||
def _assert_valid_dtypes(self, tensors):
|
||||
"""Asserts tensors are all valid types (see `_valid_dtypes`).
|
||||
|
||||
|
@ -336,12 +336,10 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
||||
self.assertEqual(
|
||||
"my_model/dense/kernel",
|
||||
named_variables["model/_named_dense/kernel" + suffix].full_name)
|
||||
self.assertEqual(
|
||||
"beta_1",
|
||||
named_variables["optimizer/beta_1" + suffix].full_name)
|
||||
self.assertEqual(
|
||||
"beta_2",
|
||||
named_variables["optimizer/beta_2" + suffix].full_name)
|
||||
self.assertEqual("Adam/beta_1",
|
||||
named_variables["optimizer/beta_1" + suffix].full_name)
|
||||
self.assertEqual("Adam/beta_2",
|
||||
named_variables["optimizer/beta_2" + suffix].full_name)
|
||||
# Spot check the generated protocol buffers.
|
||||
self.assertEqual("optimizer",
|
||||
serialized_graph.nodes[0].children[1].local_name)
|
||||
@ -350,7 +348,7 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
||||
children = [node.local_name for node in optimizer_node.children]
|
||||
six.assertCountEqual(
|
||||
self,
|
||||
# Non-slot dependencies
|
||||
# hyper variable dependencies
|
||||
["beta_1", "beta_2", "iter", "decay", "learning_rate"],
|
||||
children)
|
||||
serialized_slot_keys = []
|
||||
|
Loading…
Reference in New Issue
Block a user