Undo some of the composite tensor changes in gradient tape code
Its not clear that we should always take gradients with respect to component tensors. PiperOrigin-RevId: 317708164 Change-Id: I8a0cdddd705497e5539857afcbb60aaa38821e0c
This commit is contained in:
parent
af8f596d21
commit
dfd21eaec6
|
@ -1024,7 +1024,7 @@ class GradientTape(object):
|
|||
"derivatives.", 1)
|
||||
|
||||
flat_targets = []
|
||||
for t in nest.flatten(target, expand_composites=True):
|
||||
for t in nest.flatten(target):
|
||||
if not backprop_util.IsTrainable(t):
|
||||
logging.vlog(
|
||||
logging.WARN, "The dtype of the target tensor must be "
|
||||
|
@ -1035,7 +1035,7 @@ class GradientTape(object):
|
|||
t = ops.convert_to_tensor(t)
|
||||
flat_targets.append(t)
|
||||
|
||||
flat_sources = nest.flatten(sources, expand_composites=True)
|
||||
flat_sources = nest.flatten(sources)
|
||||
flat_sources_raw = flat_sources
|
||||
flat_sources = [_handle_or_self(x) for x in flat_sources]
|
||||
for t in flat_sources_raw:
|
||||
|
@ -1051,8 +1051,7 @@ class GradientTape(object):
|
|||
|
||||
if output_gradients is not None:
|
||||
output_gradients = [None if x is None else ops.convert_to_tensor(x)
|
||||
for x in nest.flatten(
|
||||
output_gradients, expand_composites=True)]
|
||||
for x in nest.flatten(output_gradients)]
|
||||
|
||||
flat_grad = imperative_grad.imperative_grad(
|
||||
self._tape,
|
||||
|
@ -1067,7 +1066,7 @@ class GradientTape(object):
|
|||
self._watched_variables = self._tape.watched_variables()
|
||||
self._tape = None
|
||||
|
||||
grad = nest.pack_sequence_as(sources, flat_grad, expand_composites=True)
|
||||
grad = nest.pack_sequence_as(sources, flat_grad)
|
||||
return grad
|
||||
|
||||
def jacobian(self,
|
||||
|
|
|
@ -28,7 +28,6 @@ from tensorflow.python.eager import def_function
|
|||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import tape as tape_lib
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
|
@ -37,7 +36,6 @@ from tensorflow.python.framework import sparse_tensor
|
|||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.framework.memory_checker import MemoryChecker
|
||||
from tensorflow.python.layers.pooling import max_pooling3d
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
@ -54,44 +52,6 @@ from tensorflow.python.ops import resource_variable_ops
|
|||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import training
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
# TODO(nareshmodi): This is copied from composite_tensor_test.py. Extract it out
|
||||
# to a common library to avoid duplication.
|
||||
class CTSpec(type_spec.TypeSpec):
|
||||
"""A generic CompositeTensor TypeSpec, used for constructing tests."""
|
||||
|
||||
def __init__(self, component_specs):
|
||||
self.component_specs = component_specs
|
||||
|
||||
value_type = property(lambda self: CT)
|
||||
_component_specs = property(lambda self: self.component_specs)
|
||||
|
||||
def _serialize(self):
|
||||
return (self.component_specs,)
|
||||
|
||||
def _to_components(self, value):
|
||||
return value.components
|
||||
|
||||
def _from_components(self, tensor_list):
|
||||
return CT(tensor_list)
|
||||
|
||||
|
||||
class CT(composite_tensor.CompositeTensor):
|
||||
"""A generic CompositeTensor, used for constructing tests."""
|
||||
_type_spec_class = CTSpec
|
||||
|
||||
def __init__(self, components):
|
||||
if isinstance(components, list):
|
||||
components = tuple(components)
|
||||
self.components = components
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
component_specs = nest.map_structure(type_spec.type_spec_from_value,
|
||||
self.components)
|
||||
return self._type_spec_class(component_specs)
|
||||
|
||||
|
||||
class BackpropTest(test.TestCase, parameterized.TestCase):
|
||||
|
@ -1621,35 +1581,6 @@ class BackpropTest(test.TestCase, parameterized.TestCase):
|
|||
memory_checker.report()
|
||||
memory_checker.assert_no_leak_if_all_possibly_except_one()
|
||||
|
||||
def testCompositeTensorAsSource(self):
|
||||
t = CT([constant_op.constant(3.), constant_op.constant(2.)])
|
||||
with backprop.GradientTape() as gt:
|
||||
gt.watch(t)
|
||||
y = CT([t.components[0] * 2, t.components[1] * 3])
|
||||
|
||||
grad = gt.gradient(y, t)
|
||||
expected_grad = CT([constant_op.constant(2.), constant_op.constant(3.)])
|
||||
|
||||
flat_grads = nest.flatten(grad, expand_composites=True)
|
||||
flat_expected_grads = nest.flatten(expected_grad, expand_composites=True)
|
||||
|
||||
self.assertAllClose(flat_grads, flat_expected_grads)
|
||||
|
||||
def testCompositeTensorAsOutputGradients(self):
|
||||
t = CT([constant_op.constant(3.), constant_op.constant(2.)])
|
||||
with backprop.GradientTape() as gt:
|
||||
gt.watch(t)
|
||||
y = CT([t.components[0] * 2, t.components[1] * 3])
|
||||
|
||||
output_gradients = CT([constant_op.constant(5.), constant_op.constant(10.)])
|
||||
grad = gt.gradient(y, t, output_gradients=output_gradients)
|
||||
expected_grad = CT([constant_op.constant(10.), constant_op.constant(30.)])
|
||||
|
||||
flat_grads = nest.flatten(grad, expand_composites=True)
|
||||
flat_expected_grads = nest.flatten(expected_grad, expand_composites=True)
|
||||
|
||||
self.assertAllClose(flat_grads, flat_expected_grads)
|
||||
|
||||
|
||||
class JacobianTest(test.TestCase):
|
||||
|
||||
|
|
|
@ -306,6 +306,10 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
|
|||
def __repr__(self):
|
||||
return 'ndarray<{}>'.format(self.data.__repr__())
|
||||
|
||||
@property
|
||||
def _id(self):
|
||||
return self.data._id # pylint: disable=protected-access
|
||||
|
||||
|
||||
def tensor_to_ndarray(tensor):
|
||||
return ndarray.from_tensor(tensor)
|
||||
|
|
|
@ -64,8 +64,9 @@ class InteropTest(test.TestCase):
|
|||
|
||||
dx, dy = t.gradient([xx, yy], [x, y])
|
||||
|
||||
self.assertIsInstance(dx, np_arrays.ndarray)
|
||||
self.assertIsInstance(dy, np_arrays.ndarray)
|
||||
# # TODO(nareshmodi): Figure out a way to rewrap ndarray as tensors.
|
||||
# self.assertIsInstance(dx, np_arrays.ndarray)
|
||||
# self.assertIsInstance(dy, np_arrays.ndarray)
|
||||
self.assertAllClose(dx, 2.0)
|
||||
self.assertAllClose(dy, 3.0)
|
||||
|
||||
|
|
Loading…
Reference in New Issue