Create Optimizer utilities file and move gradient aggregation and filtering
functions to this file. PiperOrigin-RevId: 323831904 Change-Id: I59ff3f42771eba8c7fb27ea68c53296bd2425011
This commit is contained in:
parent
24b36f123d
commit
aa462dd250
@ -28,6 +28,7 @@ py_library(
|
||||
"nadam.py",
|
||||
"optimizer_v2.py",
|
||||
"rmsprop.py",
|
||||
"utils.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
|
@ -27,7 +27,6 @@ import six
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||
from tensorflow.python.distribute import values as ds_values
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
@ -38,6 +37,7 @@ from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
|
||||
from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -47,7 +47,6 @@ from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import gradients
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables as tf_variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import revived_types
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
@ -510,7 +509,7 @@ class OptimizerV2(trackable.Trackable):
|
||||
TypeError: If `grads_and_vars` is malformed.
|
||||
ValueError: If none of the variables have gradients.
|
||||
"""
|
||||
grads_and_vars = _filter_grads(grads_and_vars)
|
||||
grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
|
||||
var_list = [v for (_, v) in grads_and_vars]
|
||||
|
||||
with backend.name_scope(self._name):
|
||||
@ -550,7 +549,10 @@ class OptimizerV2(trackable.Trackable):
|
||||
})
|
||||
|
||||
def _aggregate_gradients(self, grads_and_vars):
|
||||
"""Returns all-reduced gradients.
|
||||
"""Returns aggregated gradients.
|
||||
|
||||
This method must be preserved to maintain backwards compatibility with
|
||||
Horovod aggregation.
|
||||
|
||||
Args:
|
||||
grads_and_vars: List of (gradient, variable) pairs.
|
||||
@ -558,32 +560,7 @@ class OptimizerV2(trackable.Trackable):
|
||||
Returns:
|
||||
A list of all-reduced gradients.
|
||||
"""
|
||||
grads_and_vars = list(grads_and_vars)
|
||||
filtered_grads_and_vars = _filter_grads(grads_and_vars)
|
||||
def all_reduce_fn(distribution, grads_and_vars):
|
||||
return distribution.extended.batch_reduce_to(
|
||||
ds_reduce_util.ReduceOp.SUM, grads_and_vars)
|
||||
# We switch to a cross-replica context since there is a bug which causes
|
||||
# IndexedSlices to be converted to dense tensors when all-reduced in a
|
||||
# replica context.
|
||||
# TODO(b/150507409): Do not switch to a cross-replica context once the bug
|
||||
# is fixed.
|
||||
if filtered_grads_and_vars:
|
||||
reduced = distribute_ctx.get_replica_context().merge_call(
|
||||
all_reduce_fn, args=(filtered_grads_and_vars,))
|
||||
else:
|
||||
reduced = []
|
||||
# Copy 'reduced' but add None gradients back in
|
||||
reduced_with_nones = []
|
||||
reduced_pos = 0
|
||||
for g, _ in grads_and_vars:
|
||||
if g is None:
|
||||
reduced_with_nones.append(None)
|
||||
else:
|
||||
reduced_with_nones.append(reduced[reduced_pos])
|
||||
reduced_pos += 1
|
||||
assert reduced_pos == len(reduced), "Failed to add all gradients"
|
||||
return reduced_with_nones
|
||||
return optimizer_utils.all_reduce_sum_gradients(grads_and_vars)
|
||||
|
||||
def _distributed_apply(self, distribution, grads_and_vars, name, apply_state):
|
||||
"""`apply_gradients` using a `DistributionStrategy`."""
|
||||
@ -1259,29 +1236,6 @@ class OptimizerV2(trackable.Trackable):
|
||||
yield
|
||||
|
||||
|
||||
def _filter_grads(grads_and_vars):
|
||||
"""Filter out iterable with grad equal to None."""
|
||||
grads_and_vars = tuple(grads_and_vars)
|
||||
if not grads_and_vars:
|
||||
return grads_and_vars
|
||||
filtered = []
|
||||
vars_with_empty_grads = []
|
||||
for grad, var in grads_and_vars:
|
||||
if grad is None:
|
||||
vars_with_empty_grads.append(var)
|
||||
else:
|
||||
filtered.append((grad, var))
|
||||
filtered = tuple(filtered)
|
||||
if not filtered:
|
||||
raise ValueError("No gradients provided for any variable: %s." %
|
||||
([v.name for _, v in grads_and_vars],))
|
||||
if vars_with_empty_grads:
|
||||
logging.warning(
|
||||
("Gradients do not exist for variables %s when minimizing the loss."),
|
||||
([v.name for v in vars_with_empty_grads]))
|
||||
return filtered
|
||||
|
||||
|
||||
def _var_key(var):
|
||||
"""Key for representing a primary variable, for looking up slots.
|
||||
|
||||
|
87
tensorflow/python/keras/optimizer_v2/utils.py
Normal file
87
tensorflow/python/keras/optimizer_v2/utils.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Optimizer utilities."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
|
||||
from tensorflow.python.distribute import reduce_util as ds_reduce_util
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
|
||||
def all_reduce_sum_gradients(grads_and_vars):
|
||||
"""Returns all-reduced gradients aggregated via summation.
|
||||
|
||||
Args:
|
||||
grads_and_vars: List of (gradient, variable) pairs.
|
||||
|
||||
Returns:
|
||||
A list of all-reduced gradients.
|
||||
"""
|
||||
grads_and_vars = list(grads_and_vars)
|
||||
filtered_grads_and_vars = filter_empty_gradients(grads_and_vars)
|
||||
# We switch to a cross-replica context since there is a bug which causes
|
||||
# IndexedSlices to be converted to dense tensors when all-reduced in a
|
||||
# replica context.
|
||||
# TODO(b/150507409): Do not switch to a cross-replica context once the bug
|
||||
# is fixed.
|
||||
if filtered_grads_and_vars:
|
||||
reduced = distribute_ctx.get_replica_context().merge_call(
|
||||
_all_reduce_sum_fn, args=(filtered_grads_and_vars,))
|
||||
else:
|
||||
reduced = []
|
||||
# Copy 'reduced' but add None gradients back in
|
||||
reduced_with_nones = []
|
||||
reduced_pos = 0
|
||||
for g, _ in grads_and_vars:
|
||||
if g is None:
|
||||
reduced_with_nones.append(None)
|
||||
else:
|
||||
reduced_with_nones.append(reduced[reduced_pos])
|
||||
reduced_pos += 1
|
||||
assert reduced_pos == len(reduced), "Failed to add all gradients"
|
||||
return reduced_with_nones
|
||||
|
||||
|
||||
def filter_empty_gradients(grads_and_vars):
|
||||
"""Filter out `(grad, var)` pairs that have a gradient equal to `None`."""
|
||||
grads_and_vars = tuple(grads_and_vars)
|
||||
if not grads_and_vars:
|
||||
return grads_and_vars
|
||||
|
||||
filtered = []
|
||||
vars_with_empty_grads = []
|
||||
for grad, var in grads_and_vars:
|
||||
if grad is None:
|
||||
vars_with_empty_grads.append(var)
|
||||
else:
|
||||
filtered.append((grad, var))
|
||||
filtered = tuple(filtered)
|
||||
|
||||
if not filtered:
|
||||
raise ValueError("No gradients provided for any variable: %s." %
|
||||
([v.name for _, v in grads_and_vars],))
|
||||
if vars_with_empty_grads:
|
||||
logging.warning(
|
||||
("Gradients do not exist for variables %s when minimizing the loss."),
|
||||
([v.name for v in vars_with_empty_grads]))
|
||||
return filtered
|
||||
|
||||
|
||||
def _all_reduce_sum_fn(distribution, grads_and_vars):
|
||||
return distribution.extended.batch_reduce_to(ds_reduce_util.ReduceOp.SUM,
|
||||
grads_and_vars)
|
Loading…
x
Reference in New Issue
Block a user