Move the optimizer name scope from model.training to optimizer
PiperOrigin-RevId: 249328139
This commit is contained in:
parent
438ff85035
commit
611db3490b
@ -72,8 +72,9 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
|
|||||||
def testTrainNetwork(self, distribution, optimizer_fn,
|
def testTrainNetwork(self, distribution, optimizer_fn,
|
||||||
use_callable_loss=True):
|
use_callable_loss=True):
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
|
optimizer = optimizer_fn()
|
||||||
model_fn, dataset_fn, layer = minimize_loss_example(
|
model_fn, dataset_fn, layer = minimize_loss_example(
|
||||||
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
|
optimizer, use_bias=True, use_callable_loss=use_callable_loss)
|
||||||
iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())
|
iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())
|
||||||
|
|
||||||
def run_step():
|
def run_step():
|
||||||
|
@ -47,10 +47,12 @@ VAR_MAP_V1 = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
VAR_MAP_V2 = {
|
VAR_MAP_V2 = {
|
||||||
"SGD": ("dense/bias", "learning_rate", "decay", "iter", "dense/kernel",
|
"SGD": ("dense/bias", "SGD/learning_rate", "SGD/decay", "SGD/iter",
|
||||||
"momentum"),
|
"dense/kernel", "SGD/momentum"),
|
||||||
"Adagrad": ("iter", "dense/bias", "dense/kernel", "learning_rate", "decay",
|
"Adagrad":
|
||||||
"dense/kernel/accumulator", "dense/bias/accumulator")
|
("Adagrad/iter", "dense/bias", "dense/kernel", "Adagrad/learning_rate",
|
||||||
|
"Adagrad/decay", "Adagrad/dense/kernel/accumulator",
|
||||||
|
"Adagrad/dense/bias/accumulator")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -81,8 +83,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
use_callable_loss=[True, False]))
|
use_callable_loss=[True, False]))
|
||||||
def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss):
|
def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss):
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
|
optimizer = optimizer_fn()
|
||||||
model_fn, dataset_fn, layer = minimize_loss_example(
|
model_fn, dataset_fn, layer = minimize_loss_example(
|
||||||
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
|
optimizer, use_bias=True, use_callable_loss=use_callable_loss)
|
||||||
|
|
||||||
def step_fn(ctx, inputs):
|
def step_fn(ctx, inputs):
|
||||||
del ctx # Unused
|
del ctx # Unused
|
||||||
@ -123,8 +126,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
def testTrainNetworkByCallForEachReplica(self, distribution, optimizer_fn,
|
def testTrainNetworkByCallForEachReplica(self, distribution, optimizer_fn,
|
||||||
use_callable_loss):
|
use_callable_loss):
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
|
optimizer = optimizer_fn()
|
||||||
model_fn, dataset_fn, layer = minimize_loss_example(
|
model_fn, dataset_fn, layer = minimize_loss_example(
|
||||||
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
|
optimizer, use_bias=True, use_callable_loss=use_callable_loss)
|
||||||
|
|
||||||
iterator = self._get_iterator(distribution, dataset_fn)
|
iterator = self._get_iterator(distribution, dataset_fn)
|
||||||
|
|
||||||
@ -171,11 +175,9 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
# `distribution.scope`.
|
# `distribution.scope`.
|
||||||
with variable_scope.variable_creator_scope(
|
with variable_scope.variable_creator_scope(
|
||||||
appending_creator), distribution.scope():
|
appending_creator), distribution.scope():
|
||||||
|
optimizer = optimizer_fn()
|
||||||
model_fn, dataset_fn, _ = minimize_loss_example(
|
model_fn, dataset_fn, _ = minimize_loss_example(
|
||||||
optimizer_fn,
|
optimizer, use_bias=True, use_callable_loss=True)
|
||||||
use_bias=True,
|
|
||||||
use_callable_loss=True,
|
|
||||||
create_optimizer_inside_model_fn=True)
|
|
||||||
|
|
||||||
def step_fn(ctx, inputs):
|
def step_fn(ctx, inputs):
|
||||||
del ctx # Unused
|
del ctx # Unused
|
||||||
@ -195,8 +197,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.evaluate(variables_lib.global_variables_initializer())
|
self.evaluate(variables_lib.global_variables_initializer())
|
||||||
run_step()
|
run_step()
|
||||||
|
|
||||||
def get_expected_variables(optimizer_fn, num_parameter_devices):
|
def get_expected_variables(num_parameter_devices):
|
||||||
optimizer = optimizer_fn()
|
|
||||||
name = optimizer._name
|
name = optimizer._name
|
||||||
|
|
||||||
if isinstance(optimizer, optimizer_v2.OptimizerV2):
|
if isinstance(optimizer, optimizer_v2.OptimizerV2):
|
||||||
@ -213,8 +214,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
|||||||
return set([v + ":0" for v in variables])
|
return set([v + ":0" for v in variables])
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
get_expected_variables(optimizer_fn,
|
get_expected_variables(len(distribution.extended.parameter_devices)),
|
||||||
len(distribution.extended.parameter_devices)),
|
|
||||||
set(created_variables))
|
set(created_variables))
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
|
@ -50,27 +50,15 @@ def single_loss_example(optimizer_fn, distribution, use_bias=False,
|
|||||||
return single_loss_step, layer
|
return single_loss_step, layer
|
||||||
|
|
||||||
|
|
||||||
def minimize_loss_example(optimizer_fn,
|
def minimize_loss_example(optimizer, use_bias=False, use_callable_loss=True):
|
||||||
use_bias=False,
|
|
||||||
use_callable_loss=True,
|
|
||||||
create_optimizer_inside_model_fn=False):
|
|
||||||
"""Example of non-distribution-aware legacy code."""
|
"""Example of non-distribution-aware legacy code."""
|
||||||
|
|
||||||
if isinstance(optimizer_fn(), optimizer_v2.OptimizerV2):
|
|
||||||
# Keras optimizer v2 always uses callable loss
|
|
||||||
assert use_callable_loss
|
|
||||||
|
|
||||||
def dataset_fn():
|
def dataset_fn():
|
||||||
dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
|
dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
|
||||||
# TODO(isaprykin): batch with drop_remainder causes shapes to be
|
# TODO(isaprykin): batch with drop_remainder causes shapes to be
|
||||||
# fully defined for TPU. Remove this when XLA supports dynamic shapes.
|
# fully defined for TPU. Remove this when XLA supports dynamic shapes.
|
||||||
return dataset.batch(1, drop_remainder=True)
|
return dataset.batch(1, drop_remainder=True)
|
||||||
|
|
||||||
# An Optimizer instance is created either outside or inside model_fn.
|
|
||||||
outer_optimizer = None
|
|
||||||
if not create_optimizer_inside_model_fn:
|
|
||||||
outer_optimizer = optimizer_fn()
|
|
||||||
|
|
||||||
layer = core.Dense(1, use_bias=use_bias)
|
layer = core.Dense(1, use_bias=use_bias)
|
||||||
|
|
||||||
def model_fn(x):
|
def model_fn(x):
|
||||||
@ -80,12 +68,9 @@ def minimize_loss_example(optimizer_fn,
|
|||||||
y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
|
y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
|
||||||
return y * y
|
return y * y
|
||||||
|
|
||||||
optimizer = outer_optimizer or optimizer_fn()
|
|
||||||
|
|
||||||
if isinstance(optimizer, optimizer_v2.OptimizerV2):
|
if isinstance(optimizer, optimizer_v2.OptimizerV2):
|
||||||
return optimizer.minimize(loss_fn, lambda: layer.trainable_variables)
|
return optimizer.minimize(loss_fn, lambda: layer.trainable_variables)
|
||||||
|
elif use_callable_loss:
|
||||||
if use_callable_loss:
|
|
||||||
return optimizer.minimize(loss_fn)
|
return optimizer.minimize(loss_fn)
|
||||||
else:
|
else:
|
||||||
return optimizer.minimize(loss_fn())
|
return optimizer.minimize(loss_fn())
|
||||||
|
@ -2214,10 +2214,9 @@ class Model(network.Network):
|
|||||||
|
|
||||||
with K.get_graph().as_default():
|
with K.get_graph().as_default():
|
||||||
with K.name_scope('training'):
|
with K.name_scope('training'):
|
||||||
with K.name_scope(self.optimizer.__class__.__name__):
|
# Training updates
|
||||||
# Training updates
|
updates = self.optimizer.get_updates(
|
||||||
updates = self.optimizer.get_updates(
|
params=self._collected_trainable_weights, loss=self.total_loss)
|
||||||
params=self._collected_trainable_weights, loss=self.total_loss)
|
|
||||||
# Unconditional updates
|
# Unconditional updates
|
||||||
updates += self.get_updates_for(None)
|
updates += self.get_updates_for(None)
|
||||||
# Conditional updates relevant to this model
|
# Conditional updates relevant to this model
|
||||||
|
@ -36,6 +36,7 @@ from tensorflow.python.keras import backend
|
|||||||
from tensorflow.python.keras import initializers
|
from tensorflow.python.keras import initializers
|
||||||
from tensorflow.python.keras.engine import base_layer_utils
|
from tensorflow.python.keras.engine import base_layer_utils
|
||||||
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
|
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
|
||||||
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import clip_ops
|
from tensorflow.python.ops import clip_ops
|
||||||
@ -257,7 +258,10 @@ class OptimizerV2(trackable.Trackable):
|
|||||||
raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k]))
|
raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k]))
|
||||||
|
|
||||||
self._use_locking = True
|
self._use_locking = True
|
||||||
self._name = name
|
self._init_set_name(name)
|
||||||
|
# in graph mode, name_scope performs uniquification, so keep scope_context.
|
||||||
|
with backend.name_scope(self._name) as name_scope:
|
||||||
|
self._scope_ctx = name_scope
|
||||||
self._hyper = {}
|
self._hyper = {}
|
||||||
# dict: {variable name : {slot name : variable}}
|
# dict: {variable name : {slot name : variable}}
|
||||||
self._slots = {}
|
self._slots = {}
|
||||||
@ -349,15 +353,16 @@ class OptimizerV2(trackable.Trackable):
|
|||||||
if callable(var_list):
|
if callable(var_list):
|
||||||
var_list = var_list()
|
var_list = var_list()
|
||||||
var_list = nest.flatten(var_list)
|
var_list = nest.flatten(var_list)
|
||||||
grads = tape.gradient(loss_value, var_list, grad_loss)
|
with backend.name_scope(self._scope_ctx):
|
||||||
|
grads = tape.gradient(loss_value, var_list, grad_loss)
|
||||||
|
|
||||||
if hasattr(self, "clipnorm"):
|
if hasattr(self, "clipnorm"):
|
||||||
grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
|
grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
|
||||||
if hasattr(self, "clipvalue"):
|
if hasattr(self, "clipvalue"):
|
||||||
grads = [
|
grads = [
|
||||||
clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
|
clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
|
||||||
for g in grads
|
for g in grads
|
||||||
]
|
]
|
||||||
|
|
||||||
grads_and_vars = list(zip(grads, var_list))
|
grads_and_vars = list(zip(grads, var_list))
|
||||||
self._assert_valid_dtypes([
|
self._assert_valid_dtypes([
|
||||||
@ -382,22 +387,22 @@ class OptimizerV2(trackable.Trackable):
|
|||||||
function not implemented).
|
function not implemented).
|
||||||
"""
|
"""
|
||||||
params = nest.flatten(params)
|
params = nest.flatten(params)
|
||||||
with backend.get_graph().as_default():
|
with backend.get_graph().as_default(), backend.name_scope(self._scope_ctx):
|
||||||
grads = gradients.gradients(loss, params)
|
grads = gradients.gradients(loss, params)
|
||||||
for grad, param in zip(grads, params):
|
for grad, param in zip(grads, params):
|
||||||
if grad is None:
|
if grad is None:
|
||||||
raise ValueError("Variable {} has `None` for gradient. "
|
raise ValueError("Variable {} has `None` for gradient. "
|
||||||
"Please make sure that all of your ops have a "
|
"Please make sure that all of your ops have a "
|
||||||
"gradient defined (i.e. are differentiable). "
|
"gradient defined (i.e. are differentiable). "
|
||||||
"Common ops without gradient: "
|
"Common ops without gradient: "
|
||||||
"K.argmax, K.round, K.eval.".format(param))
|
"K.argmax, K.round, K.eval.".format(param))
|
||||||
if hasattr(self, "clipnorm"):
|
if hasattr(self, "clipnorm"):
|
||||||
grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
|
grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
|
||||||
if hasattr(self, "clipvalue"):
|
if hasattr(self, "clipvalue"):
|
||||||
grads = [
|
grads = [
|
||||||
clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
|
clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
|
||||||
for g in grads
|
for g in grads
|
||||||
]
|
]
|
||||||
return grads
|
return grads
|
||||||
|
|
||||||
def apply_gradients(self, grads_and_vars, name=None):
|
def apply_gradients(self, grads_and_vars, name=None):
|
||||||
@ -422,16 +427,19 @@ class OptimizerV2(trackable.Trackable):
|
|||||||
grads_and_vars = _filter_grads(grads_and_vars)
|
grads_and_vars = _filter_grads(grads_and_vars)
|
||||||
var_list = [v for (_, v) in grads_and_vars]
|
var_list = [v for (_, v) in grads_and_vars]
|
||||||
|
|
||||||
# Create iteration if necessary.
|
with backend.name_scope(self._scope_ctx):
|
||||||
with ops.init_scope():
|
# Create iteration if necessary.
|
||||||
_ = self.iterations
|
with ops.init_scope():
|
||||||
self._create_hypers()
|
_ = self.iterations
|
||||||
self._create_slots(var_list)
|
self._create_hypers()
|
||||||
|
self._create_slots(var_list)
|
||||||
|
|
||||||
self._prepare(var_list)
|
self._prepare(var_list)
|
||||||
|
|
||||||
return distribute_ctx.get_replica_context().merge_call(
|
return distribute_ctx.get_replica_context().merge_call(
|
||||||
self._distributed_apply, args=(grads_and_vars,), kwargs={"name": name})
|
self._distributed_apply,
|
||||||
|
args=(grads_and_vars,),
|
||||||
|
kwargs={"name": name})
|
||||||
|
|
||||||
def _distributed_apply(self, distribution, grads_and_vars, name):
|
def _distributed_apply(self, distribution, grads_and_vars, name):
|
||||||
"""`apply_gradients` using a `DistributionStrategy`."""
|
"""`apply_gradients` using a `DistributionStrategy`."""
|
||||||
@ -764,6 +772,14 @@ class OptimizerV2(trackable.Trackable):
|
|||||||
|
|
||||||
return variable
|
return variable
|
||||||
|
|
||||||
|
def _init_set_name(self, name, zero_based=True):
|
||||||
|
if not name:
|
||||||
|
self._name = backend.unique_object_name(
|
||||||
|
generic_utils.to_snake_case(self.__class__.__name__),
|
||||||
|
zero_based=zero_based)
|
||||||
|
else:
|
||||||
|
self._name = name
|
||||||
|
|
||||||
def _assert_valid_dtypes(self, tensors):
|
def _assert_valid_dtypes(self, tensors):
|
||||||
"""Asserts tensors are all valid types (see `_valid_dtypes`).
|
"""Asserts tensors are all valid types (see `_valid_dtypes`).
|
||||||
|
|
||||||
|
@ -336,12 +336,10 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"my_model/dense/kernel",
|
"my_model/dense/kernel",
|
||||||
named_variables["model/_named_dense/kernel" + suffix].full_name)
|
named_variables["model/_named_dense/kernel" + suffix].full_name)
|
||||||
self.assertEqual(
|
self.assertEqual("Adam/beta_1",
|
||||||
"beta_1",
|
named_variables["optimizer/beta_1" + suffix].full_name)
|
||||||
named_variables["optimizer/beta_1" + suffix].full_name)
|
self.assertEqual("Adam/beta_2",
|
||||||
self.assertEqual(
|
named_variables["optimizer/beta_2" + suffix].full_name)
|
||||||
"beta_2",
|
|
||||||
named_variables["optimizer/beta_2" + suffix].full_name)
|
|
||||||
# Spot check the generated protocol buffers.
|
# Spot check the generated protocol buffers.
|
||||||
self.assertEqual("optimizer",
|
self.assertEqual("optimizer",
|
||||||
serialized_graph.nodes[0].children[1].local_name)
|
serialized_graph.nodes[0].children[1].local_name)
|
||||||
@ -350,7 +348,7 @@ class CheckpointingTests(parameterized.TestCase, test.TestCase):
|
|||||||
children = [node.local_name for node in optimizer_node.children]
|
children = [node.local_name for node in optimizer_node.children]
|
||||||
six.assertCountEqual(
|
six.assertCountEqual(
|
||||||
self,
|
self,
|
||||||
# Non-slot dependencies
|
# hyper variable dependencies
|
||||||
["beta_1", "beta_2", "iter", "decay", "learning_rate"],
|
["beta_1", "beta_2", "iter", "decay", "learning_rate"],
|
||||||
children)
|
children)
|
||||||
serialized_slot_keys = []
|
serialized_slot_keys = []
|
||||||
|
Loading…
Reference in New Issue
Block a user