Internal change
PiperOrigin-RevId: 312694720 Change-Id: I04439efff13aabe38d18c98c025d45ae33d33f46
This commit is contained in:
parent
3d9ec6298a
commit
bfe0b28c37
@ -199,6 +199,7 @@ def _test_gradients(testcase,
|
||||
# And the symbolic computations should be much closer.
|
||||
testcase.assertAllClose(sym_jac_back, sym_jac_fwd)
|
||||
|
||||
|
||||
class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testJVPFunction(self):
|
||||
@ -360,17 +361,14 @@ class ForwardpropTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
_test_gradients(self, f, [constant_op.constant([1., 2.])], order=3)
|
||||
|
||||
# TODO(allenl): investigate why assert_no_new_pyobjects_executing_eagerly fails around this test?
|
||||
def testExceptionCustomGradientRecomputeGradForward(self):
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testCustomGradientRecomputeGrad(self):
|
||||
|
||||
@custom_gradient.recompute_grad
|
||||
def f(x):
|
||||
return math_ops.reduce_prod(math_ops.tanh(x)**2)
|
||||
|
||||
with self.assertRaisesRegexp(NotImplementedError,
|
||||
"recompute_grad tried to transpose"):
|
||||
primals = [constant_op.constant([1.])]
|
||||
sym_jac_fwd = _jacfwd(f, primals)
|
||||
_test_gradients(self, f, [constant_op.constant([1.])], order=3)
|
||||
|
||||
def testExceptionInCustomGradientNotSwallowed(self):
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Description:
|
||||
# Contains Keras integration tests that verify with other TF high level APIs.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test", "tf_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -70,13 +70,3 @@ tf_py_test(
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "gradient_checkpoint_test",
|
||||
srcs = ["gradient_checkpoint_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
],
|
||||
)
|
||||
|
@ -1,158 +0,0 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
layers = tf.keras.layers
|
||||
optimizers = tf.keras.optimizers
|
||||
|
||||
|
||||
def _get_big_cnn_model(img_dim, n_channels, num_partitions,
|
||||
blocks_per_partition):
|
||||
"""Creates a test model whose activations are significantly larger than model size."""
|
||||
model = tf.keras.Sequential()
|
||||
model.add(layers.Input(shape=(img_dim, img_dim, n_channels)))
|
||||
for _ in range(num_partitions):
|
||||
for _ in range(blocks_per_partition):
|
||||
model.add(layers.Conv2D(10, 5, padding='same', activation=tf.nn.relu))
|
||||
model.add(layers.MaxPooling2D((1, 1), padding='same'))
|
||||
model.add(layers.Conv2D(40, 5, padding='same', activation=tf.nn.relu))
|
||||
model.add(layers.MaxPooling2D((1, 1), padding='same'))
|
||||
model.add(layers.Conv2D(20, 5, padding='same', activation=tf.nn.relu))
|
||||
model.add(layers.MaxPooling2D((1, 1), padding='same'))
|
||||
model.add(layers.Flatten())
|
||||
model.add(layers.Dense(32, activation=tf.nn.relu))
|
||||
model.add(layers.Dense(10))
|
||||
return model
|
||||
|
||||
|
||||
def _get_split_cnn_model(img_dim, n_channels, num_partitions,
|
||||
blocks_per_partition):
|
||||
"""Creates a test model that is split into `num_partitions` smaller models"""
|
||||
models = [tf.keras.Sequential() for _ in range(num_partitions)]
|
||||
models[0].add(layers.Input(shape=(img_dim, img_dim, n_channels)))
|
||||
for i in range(num_partitions):
|
||||
model = models[i]
|
||||
if i > 0:
|
||||
last_shape = models[i - 1].layers[-1].output_shape
|
||||
model.add(layers.Input(shape=last_shape[1:]))
|
||||
for _ in range(blocks_per_partition):
|
||||
model.add(layers.Conv2D(10, 5, padding='same', activation=tf.nn.relu))
|
||||
model.add(layers.MaxPooling2D((1, 1), padding='same'))
|
||||
model.add(layers.Conv2D(40, 5, padding='same', activation=tf.nn.relu))
|
||||
model.add(layers.MaxPooling2D((1, 1), padding='same'))
|
||||
model.add(layers.Conv2D(20, 5, padding='same', activation=tf.nn.relu))
|
||||
model.add(layers.MaxPooling2D((1, 1), padding='same'))
|
||||
models[-1].add(layers.Flatten())
|
||||
models[-1].add(layers.Dense(32, activation=tf.nn.relu))
|
||||
models[-1].add(layers.Dense(10))
|
||||
return models
|
||||
|
||||
|
||||
def _compute_loss(logits, labels):
|
||||
return tf.reduce_mean(
|
||||
tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
logits=logits, labels=labels))
|
||||
|
||||
|
||||
def _limit_gpu_memory():
|
||||
"""Helper function to limit GPU memory for testing """
|
||||
gpus = tf.config.experimental.list_physical_devices('GPU')
|
||||
if gpus:
|
||||
tf.config.experimental.set_virtual_device_configuration(
|
||||
gpus[0],
|
||||
[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)])
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_dummy_data(img_dim, n_channels, batch_size):
|
||||
inputs = tf.ones([batch_size, img_dim, img_dim, n_channels])
|
||||
labels = tf.ones([batch_size], dtype=tf.int64)
|
||||
return inputs, labels
|
||||
|
||||
|
||||
def _train_no_recompute(n_steps):
|
||||
"""Trains a single large model without gradient checkpointing."""
|
||||
img_dim, n_channels, batch_size = 256, 1, 4
|
||||
x, y = _get_dummy_data(img_dim, n_channels, batch_size)
|
||||
model = _get_big_cnn_model(
|
||||
img_dim, n_channels, num_partitions=3, blocks_per_partition=2)
|
||||
optimizer = optimizers.SGD()
|
||||
losses = []
|
||||
tr_vars = model.trainable_variables
|
||||
for _ in range(n_steps):
|
||||
with tf.GradientTape() as tape:
|
||||
logits = model(x)
|
||||
loss = _compute_loss(logits, y)
|
||||
losses.append(loss)
|
||||
grads = tape.gradient(loss, tr_vars) # tr_vars
|
||||
optimizer.apply_gradients(zip(grads, tr_vars))
|
||||
del grads
|
||||
return losses
|
||||
|
||||
|
||||
def _train_with_recompute(n_steps):
|
||||
"""Trains a single large model with gradient checkpointing using tf.recompute_grad."""
|
||||
img_dim, n_channels, batch_size = 256, 1, 4
|
||||
x, y = _get_dummy_data(img_dim, n_channels, batch_size)
|
||||
# This model is the same model as _get_big_cnn_model but split into 3 parts.
|
||||
models = _get_split_cnn_model(
|
||||
img_dim, n_channels, num_partitions=3, blocks_per_partition=2)
|
||||
model1, model2, model3 = models
|
||||
# Apply gradient checkpointing to the submodels using tf.recompute_grad.
|
||||
model1_re = tf.recompute_grad(model1)
|
||||
model2_re = tf.recompute_grad(model2)
|
||||
model3_re = tf.recompute_grad(model3)
|
||||
optimizer = optimizers.SGD()
|
||||
tr_vars = (
|
||||
model1.trainable_variables + model2.trainable_variables +
|
||||
model3.trainable_variables)
|
||||
losses = []
|
||||
for _ in range(n_steps):
|
||||
with tf.GradientTape() as tape:
|
||||
logits1 = model1_re(x)
|
||||
logits2 = model2_re(logits1)
|
||||
logits3 = model3_re(logits2)
|
||||
loss = _compute_loss(logits3, y)
|
||||
losses.append(loss)
|
||||
grads = tape.gradient(loss, tr_vars) # tr_vars
|
||||
optimizer.apply_gradients(zip(grads, tr_vars))
|
||||
del grads
|
||||
return losses
|
||||
|
||||
|
||||
class GradientCheckpointTest(tf.test.TestCase):
|
||||
|
||||
def test_raises_oom_exception(self):
|
||||
if not _limit_gpu_memory():
|
||||
self.skipTest('No virtual GPUs found')
|
||||
with self.assertRaises(Exception) as context:
|
||||
_train_no_recompute(1)
|
||||
self.assertTrue(
|
||||
context.exception.__class__.__name__ == 'ResourceExhaustedError')
|
||||
|
||||
def test_does_not_raise_oom_exception(self):
|
||||
if not _limit_gpu_memory():
|
||||
self.skipTest('No virtual GPUs found')
|
||||
n_step = 2
|
||||
losses = _train_with_recompute(n_step)
|
||||
self.assertTrue(len(losses) == n_step)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -28,7 +28,6 @@ from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import op_selector
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
@ -483,47 +482,28 @@ def recompute_grad(f):
|
||||
def inner(*args, **kwargs):
|
||||
"""Inner function closure for calculating gradients."""
|
||||
current_var_scope = variable_scope.get_variable_scope()
|
||||
with tape_lib.stop_recording():
|
||||
result = f(*args, **kwargs)
|
||||
|
||||
def grad_wrapper(*wrapper_args, **grad_kwargs):
|
||||
"""Wrapper function to accomodate lack of kwargs in graph mode decorator."""
|
||||
result = f(*args, **kwargs)
|
||||
|
||||
@custom_gradient
|
||||
def inner_recompute_grad(*dresult):
|
||||
"""Nested custom gradient function for computing grads in reverse and forward mode autodiff."""
|
||||
# Gradient calculation for reverse mode autodiff.
|
||||
variables = grad_kwargs.get("variables")
|
||||
with backprop.GradientTape() as t:
|
||||
id_args = [gen_array_ops.identity(x) for x in args]
|
||||
t.watch(id_args)
|
||||
if variables is not None:
|
||||
t.watch(variables)
|
||||
with ops.control_dependencies(dresult):
|
||||
with variable_scope.variable_scope(current_var_scope):
|
||||
result = f(*id_args, **kwargs)
|
||||
kw_vars = []
|
||||
def grad(*dresult, **grad_kwargs):
|
||||
"""Gradient function calculation for inner function."""
|
||||
variables = grad_kwargs.get("variables")
|
||||
with backprop.GradientTape() as t:
|
||||
id_args = [gen_array_ops.identity(x) for x in args]
|
||||
t.watch(id_args)
|
||||
if variables is not None:
|
||||
kw_vars = list(variables)
|
||||
grads = t.gradient(
|
||||
result,
|
||||
list(id_args) + kw_vars,
|
||||
output_gradients=dresult,
|
||||
unconnected_gradients=UnconnectedGradients.ZERO)
|
||||
t.watch(variables)
|
||||
with ops.control_dependencies(dresult):
|
||||
with variable_scope.variable_scope(current_var_scope):
|
||||
result = f(*id_args, **kwargs)
|
||||
kw_vars = []
|
||||
if variables is not None:
|
||||
kw_vars = list(variables)
|
||||
grads = t.gradient(
|
||||
result, list(id_args) + kw_vars, output_gradients=dresult)
|
||||
return grads[:len(id_args)], grads[len(id_args):]
|
||||
|
||||
def transpose(*t_args, **t_kwargs):
|
||||
"""Gradient function calculation for forward mode autodiff."""
|
||||
# Just throw an error since gradients / activations are not stored on tape for recompute.
|
||||
raise NotImplementedError(
|
||||
"recompute_grad tried to transpose grad of {}. "
|
||||
"Consider not using recompute_grad in forward mode"
|
||||
"autodiff".format(f.__name__))
|
||||
|
||||
return (grads[:len(id_args)], grads[len(id_args):]), transpose
|
||||
|
||||
return inner_recompute_grad(*wrapper_args)
|
||||
|
||||
return result, grad_wrapper
|
||||
return result, grad
|
||||
|
||||
return inner
|
||||
|
||||
|
@ -59,7 +59,6 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.nn_ops import bias_add
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
|
||||
|
||||
class GradientsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
@ -1341,46 +1340,6 @@ class VariablesGradientTest(test_util.TensorFlowTestCase):
|
||||
|
||||
return grads_re, grads
|
||||
|
||||
def _grad(self, f, argnums=0):
|
||||
"""Return a function which computes the gradient of `f`."""
|
||||
|
||||
def _f(*params):
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(params)
|
||||
outputs = f(*params)
|
||||
return tape.gradient(
|
||||
outputs,
|
||||
params[argnums],
|
||||
unconnected_gradients=unconnected_gradients.UnconnectedGradients.ZERO)
|
||||
|
||||
return _f
|
||||
|
||||
def _test_gradients(self, f, inputs, order, delta=1e-3, rtol=1e-2, atol=1e-6):
|
||||
"""Tests backward jacobians of `f`'s [0, `order`)-order gradients."""
|
||||
if order < 1:
|
||||
raise ValueError(
|
||||
"`order` should be a positive integer, got '{}'.".format(order))
|
||||
if order > 1:
|
||||
self._test_gradients(
|
||||
f=self._grad(f),
|
||||
inputs=inputs,
|
||||
order=order - 1,
|
||||
delta=delta,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
sym_jac_back, num_jac = gradient_checker_v2.compute_gradient(
|
||||
f, inputs, delta=delta)
|
||||
self.assertAllClose(num_jac, sym_jac_back, rtol=rtol, atol=atol)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testCustomGradientRecomputeGradHigherOrder(self):
|
||||
|
||||
@custom_gradient.recompute_grad
|
||||
def f(x):
|
||||
return math_ops.reduce_prod(math_ops.tanh(x)**2)
|
||||
|
||||
self._test_gradients(f, [constant_op.constant([1.])], order=3)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testFnRecompute(self):
|
||||
"""Checks that recompute_grad works grads of function args."""
|
||||
@ -1397,8 +1356,8 @@ class VariablesGradientTest(test_util.TensorFlowTestCase):
|
||||
shape=10,
|
||||
trainable=True,
|
||||
)
|
||||
self.evaluate(test_var.assign(np.ones([10])))
|
||||
test_input = constant(np.ones((10, 10), dtype=np.float32))
|
||||
|
||||
test_input = constant(np.zeros((10, 10), dtype=np.float32))
|
||||
|
||||
grads_re, grads = self._TestFnVariablesGradient(test_input, TestFn,
|
||||
test_input)
|
||||
@ -1441,7 +1400,6 @@ class VariablesGradientTest(test_util.TensorFlowTestCase):
|
||||
shape=10,
|
||||
trainable=True,
|
||||
)
|
||||
self.evaluate(test_var.assign(np.ones([10])))
|
||||
return input_t * test_var
|
||||
|
||||
test_input_t = constant(np.zeros((10, 10), dtype=np.float32))
|
||||
@ -1484,8 +1442,6 @@ class VariablesGradientTest(test_util.TensorFlowTestCase):
|
||||
out_re = test_fn_re(test_input_t)
|
||||
out = TestFn(test_input_t)
|
||||
|
||||
init = variables.global_variables_initializer()
|
||||
self.evaluate(init)
|
||||
grads_re = gradients.gradients(out_re, variables.trainable_variables())
|
||||
grads = gradients.gradients(out, variables.trainable_variables())
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user