STT-tensorflow/tensorflow/python/keras/optimizer_v2/utils.py
Ran Chen 1435e86cb6 Create different strategy based on TF1/2 in strategy_combinations
TF1 and TF2 strategy APIs have diverged, and are going to diverge further. We should create corresponding strategies in tests, so that TF1 tests can be left untouched while we change TF2 APIs.

PiperOrigin-RevId: 336391144
Change-Id: I242e03341ec24442b82705a86ec8b5e3dff2ddbb
2020-10-09 17:02:10 -07:00

153 lines
5.5 KiB
Python

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Optimizer utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import central_storage_strategy
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.ops import clip_ops
from tensorflow.python.platform import tf_logging as logging
def all_reduce_sum_gradients(grads_and_vars):
"""Returns all-reduced gradients aggregated via summation.
Args:
grads_and_vars: List of (gradient, variable) pairs.
Returns:
List of (gradient, variable) pairs where gradients have been all-reduced.
"""
grads_and_vars = list(grads_and_vars)
filtered_grads_and_vars = filter_empty_gradients(grads_and_vars)
# We switch to a cross-replica context since there is a bug which causes
# IndexedSlices to be converted to dense tensors when all-reduced in a
# replica context.
# TODO(b/150507409): Do not switch to a cross-replica context once the bug
# is fixed.
if filtered_grads_and_vars:
reduced = distribute_ctx.get_replica_context().merge_call(
_all_reduce_sum_fn, args=(filtered_grads_and_vars,))
else:
reduced = []
# Copy 'reduced' but add None gradients back in
reduced_with_nones = []
reduced_pos = 0
for g, v in grads_and_vars:
if g is None:
reduced_with_nones.append((None, v))
else:
reduced_with_nones.append((reduced[reduced_pos], v))
reduced_pos += 1
assert reduced_pos == len(reduced), "Failed to add all gradients"
return reduced_with_nones
def filter_empty_gradients(grads_and_vars):
"""Filter out `(grad, var)` pairs that have a gradient equal to `None`."""
grads_and_vars = tuple(grads_and_vars)
if not grads_and_vars:
return grads_and_vars
filtered = []
vars_with_empty_grads = []
for grad, var in grads_and_vars:
if grad is None:
vars_with_empty_grads.append(var)
else:
filtered.append((grad, var))
filtered = tuple(filtered)
if not filtered:
raise ValueError("No gradients provided for any variable: %s." %
([v.name for _, v in grads_and_vars],))
if vars_with_empty_grads:
logging.warning(
("Gradients do not exist for variables %s when minimizing the loss."),
([v.name for v in vars_with_empty_grads]))
return filtered
def make_gradient_clipnorm_fn(clipnorm):
"""Creates a gradient transformation function for clipping by norm."""
if clipnorm is None:
return lambda grads_and_vars: grads_and_vars
def gradient_clipnorm_fn(grads_and_vars):
if isinstance(distribute_ctx.get_strategy(),
(central_storage_strategy.CentralStorageStrategy,
central_storage_strategy.CentralStorageStrategyV1)):
raise ValueError(
"`clipnorm` is not supported with `CenteralStorageStrategy`")
clipped_grads_and_vars = [
(clip_ops.clip_by_norm(g, clipnorm), v) for g, v in grads_and_vars
]
return clipped_grads_and_vars
return gradient_clipnorm_fn
def make_global_gradient_clipnorm_fn(clipnorm):
"""Creates a gradient transformation function for clipping by norm."""
if clipnorm is None:
return lambda grads_and_vars: grads_and_vars
def gradient_clipnorm_fn(grads_and_vars):
if isinstance(distribute_ctx.get_strategy(),
(central_storage_strategy.CentralStorageStrategy,
central_storage_strategy.CentralStorageStrategyV1)):
raise ValueError(
"`global_clipnorm` is not supported with `CenteralStorageStrategy`")
grads, variables = zip(*grads_and_vars)
clipped_grads, _ = clip_ops.clip_by_global_norm(grads, clipnorm)
clipped_grads_and_vars = list(zip(clipped_grads, variables))
return clipped_grads_and_vars
return gradient_clipnorm_fn
def make_gradient_clipvalue_fn(clipvalue):
"""Creates a gradient transformation function for clipping by value."""
if clipvalue is None:
return lambda grads_and_vars: grads_and_vars
def gradient_clipvalue_fn(grads_and_vars):
if isinstance(distribute_ctx.get_strategy(),
(central_storage_strategy.CentralStorageStrategy,
central_storage_strategy.CentralStorageStrategyV1)):
raise ValueError(
"`clipvalue` is not supported with `CenteralStorageStrategy`")
clipped_grads_and_vars = [(clip_ops.clip_by_value(g, -clipvalue,
clipvalue), v)
for g, v in grads_and_vars]
return clipped_grads_and_vars
return gradient_clipvalue_fn
def _all_reduce_sum_fn(distribution, grads_and_vars):
return distribution.extended.batch_reduce_to(ds_reduce_util.ReduceOp.SUM,
grads_and_vars)