Create Optimizer utilities file and move gradient aggregation and filtering

functions to this file.

PiperOrigin-RevId: 323831904
Change-Id: I59ff3f42771eba8c7fb27ea68c53296bd2425011
This commit is contained in:
Thomas O'Malley 2020-07-29 11:36:14 -07:00 committed by TensorFlower Gardener
parent 24b36f123d
commit aa462dd250
3 changed files with 95 additions and 53 deletions

View File

@ -28,6 +28,7 @@ py_library(
"nadam.py",
"optimizer_v2.py",
"rmsprop.py",
"utils.py",
],
srcs_version = "PY2AND3",
deps = [

View File

@ -27,7 +27,6 @@ import six
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.distribute import values as ds_values
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@ -38,6 +37,7 @@ from tensorflow.python.keras import backend
from tensorflow.python.keras import initializers
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
from tensorflow.python.keras.optimizer_v2 import utils as optimizer_utils
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
@ -47,7 +47,6 @@ from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import revived_types
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import tracking
@ -510,7 +509,7 @@ class OptimizerV2(trackable.Trackable):
TypeError: If `grads_and_vars` is malformed.
ValueError: If none of the variables have gradients.
"""
grads_and_vars = _filter_grads(grads_and_vars)
grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
var_list = [v for (_, v) in grads_and_vars]
with backend.name_scope(self._name):
@ -550,7 +549,10 @@ class OptimizerV2(trackable.Trackable):
})
def _aggregate_gradients(self, grads_and_vars):
"""Returns all-reduced gradients.
"""Returns aggregated gradients.
This method must be preserved to maintain backwards compatibility with
Horovod aggregation.
Args:
grads_and_vars: List of (gradient, variable) pairs.
@ -558,32 +560,7 @@ class OptimizerV2(trackable.Trackable):
Returns:
A list of all-reduced gradients.
"""
grads_and_vars = list(grads_and_vars)
filtered_grads_and_vars = _filter_grads(grads_and_vars)
def all_reduce_fn(distribution, grads_and_vars):
return distribution.extended.batch_reduce_to(
ds_reduce_util.ReduceOp.SUM, grads_and_vars)
# We switch to a cross-replica context since there is a bug which causes
# IndexedSlices to be converted to dense tensors when all-reduced in a
# replica context.
# TODO(b/150507409): Do not switch to a cross-replica context once the bug
# is fixed.
if filtered_grads_and_vars:
reduced = distribute_ctx.get_replica_context().merge_call(
all_reduce_fn, args=(filtered_grads_and_vars,))
else:
reduced = []
# Copy 'reduced' but add None gradients back in
reduced_with_nones = []
reduced_pos = 0
for g, _ in grads_and_vars:
if g is None:
reduced_with_nones.append(None)
else:
reduced_with_nones.append(reduced[reduced_pos])
reduced_pos += 1
assert reduced_pos == len(reduced), "Failed to add all gradients"
return reduced_with_nones
return optimizer_utils.all_reduce_sum_gradients(grads_and_vars)
def _distributed_apply(self, distribution, grads_and_vars, name, apply_state):
"""`apply_gradients` using a `DistributionStrategy`."""
@ -1259,29 +1236,6 @@ class OptimizerV2(trackable.Trackable):
yield
def _filter_grads(grads_and_vars):
"""Filter out iterable with grad equal to None."""
grads_and_vars = tuple(grads_and_vars)
if not grads_and_vars:
return grads_and_vars
filtered = []
vars_with_empty_grads = []
for grad, var in grads_and_vars:
if grad is None:
vars_with_empty_grads.append(var)
else:
filtered.append((grad, var))
filtered = tuple(filtered)
if not filtered:
raise ValueError("No gradients provided for any variable: %s." %
([v.name for _, v in grads_and_vars],))
if vars_with_empty_grads:
logging.warning(
("Gradients do not exist for variables %s when minimizing the loss."),
([v.name for v in vars_with_empty_grads]))
return filtered
def _var_key(var):
"""Key for representing a primary variable, for looking up slots.

View File

@ -0,0 +1,87 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optimizer utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.platform import tf_logging as logging
def all_reduce_sum_gradients(grads_and_vars):
"""Returns all-reduced gradients aggregated via summation.
Args:
grads_and_vars: List of (gradient, variable) pairs.
Returns:
A list of all-reduced gradients.
"""
grads_and_vars = list(grads_and_vars)
filtered_grads_and_vars = filter_empty_gradients(grads_and_vars)
# We switch to a cross-replica context since there is a bug which causes
# IndexedSlices to be converted to dense tensors when all-reduced in a
# replica context.
# TODO(b/150507409): Do not switch to a cross-replica context once the bug
# is fixed.
if filtered_grads_and_vars:
reduced = distribute_ctx.get_replica_context().merge_call(
_all_reduce_sum_fn, args=(filtered_grads_and_vars,))
else:
reduced = []
# Copy 'reduced' but add None gradients back in
reduced_with_nones = []
reduced_pos = 0
for g, _ in grads_and_vars:
if g is None:
reduced_with_nones.append(None)
else:
reduced_with_nones.append(reduced[reduced_pos])
reduced_pos += 1
assert reduced_pos == len(reduced), "Failed to add all gradients"
return reduced_with_nones
def filter_empty_gradients(grads_and_vars):
"""Filter out `(grad, var)` pairs that have a gradient equal to `None`."""
grads_and_vars = tuple(grads_and_vars)
if not grads_and_vars:
return grads_and_vars
filtered = []
vars_with_empty_grads = []
for grad, var in grads_and_vars:
if grad is None:
vars_with_empty_grads.append(var)
else:
filtered.append((grad, var))
filtered = tuple(filtered)
if not filtered:
raise ValueError("No gradients provided for any variable: %s." %
([v.name for _, v in grads_and_vars],))
if vars_with_empty_grads:
logging.warning(
("Gradients do not exist for variables %s when minimizing the loss."),
([v.name for v in vars_with_empty_grads]))
return filtered
def _all_reduce_sum_fn(distribution, grads_and_vars):
return distribution.extended.batch_reduce_to(ds_reduce_util.ReduceOp.SUM,
grads_and_vars)