Special case wrapping of ndarrays in the gradient tape code.
PiperOrigin-RevId: 317762474 Change-Id: Ie848ad90a88aff5b2faef4069c3f05887038c367
This commit is contained in:
parent
2d8d440dbb
commit
5d4a29eaf5
|
@ -588,6 +588,7 @@ py_library(
|
|||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:unconnected_gradients",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/ops/numpy_ops:numpy",
|
||||
"//tensorflow/python/ops/parallel_for:control_flow_ops",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
|
|
|
@ -62,6 +62,9 @@ from tensorflow.python.util.tf_export import tf_export
|
|||
pfor_ops = LazyLoader(
|
||||
"pfor_ops", globals(),
|
||||
"tensorflow.python.ops.parallel_for.control_flow_ops")
|
||||
np_arrays = LazyLoader(
|
||||
"np_arrays", globals(),
|
||||
"tensorflow.python.ops.numpy_ops.np_arrays")
|
||||
|
||||
function = LazyLoader("function", globals(),
|
||||
"tensorflow.python.eager.function")
|
||||
|
@ -721,9 +724,11 @@ pywrap_tfe.TFE_Py_RegisterVSpace(_default_vspace)
|
|||
|
||||
|
||||
def _handle_or_self(x):
|
||||
"""If x is ResourceVariable, return its handle, else x."""
|
||||
"""Unwrap resource variable/ndarray to return tensors."""
|
||||
if resource_variable_ops.is_resource_variable(x):
|
||||
x = x.handle
|
||||
return x.handle
|
||||
if isinstance(x, np_arrays.ndarray):
|
||||
return x.data
|
||||
return x
|
||||
|
||||
|
||||
|
@ -1023,6 +1028,7 @@ class GradientTape(object):
|
|||
"gradient in order to compute higher order "
|
||||
"derivatives.", 1)
|
||||
|
||||
num_ndarrays = 0
|
||||
flat_targets = []
|
||||
for t in nest.flatten(target):
|
||||
if not backprop_util.IsTrainable(t):
|
||||
|
@ -1033,7 +1039,12 @@ class GradientTape(object):
|
|||
if resource_variable_ops.is_resource_variable(t):
|
||||
with self:
|
||||
t = ops.convert_to_tensor(t)
|
||||
elif isinstance(t, np_arrays.ndarray):
|
||||
t = t.data
|
||||
num_ndarrays += 1
|
||||
flat_targets.append(t)
|
||||
# Only rewrap if all targets are ndarray. If not, prefer tensors.
|
||||
rewrap_as_ndarray = num_ndarrays == len(flat_targets)
|
||||
|
||||
flat_sources = nest.flatten(sources)
|
||||
flat_sources_raw = flat_sources
|
||||
|
@ -1066,6 +1077,9 @@ class GradientTape(object):
|
|||
self._watched_variables = self._tape.watched_variables()
|
||||
self._tape = None
|
||||
|
||||
if rewrap_as_ndarray:
|
||||
flat_grad = nest.map_structure(np_arrays.tensor_to_ndarray, flat_grad)
|
||||
|
||||
grad = nest.pack_sequence_as(sources, flat_grad)
|
||||
return grad
|
||||
|
||||
|
@ -1120,6 +1134,10 @@ class GradientTape(object):
|
|||
ValueError: If vectorization of jacobian computation fails.
|
||||
"""
|
||||
flat_sources = nest.flatten(sources)
|
||||
rewrap_as_ndarray = False
|
||||
if isinstance(target, np_arrays.ndarray):
|
||||
target = target.data
|
||||
rewrap_as_ndarray = True
|
||||
target_static_shape = target.shape
|
||||
target_shape = array_ops.shape(target)
|
||||
# Note that we push and pop the tape here and below. This is needed since we
|
||||
|
@ -1169,6 +1187,8 @@ class GradientTape(object):
|
|||
out = array_ops.reshape(out, new_shape)
|
||||
if context.executing_eagerly():
|
||||
out.set_shape(target_static_shape.concatenate(flat_sources[i].shape))
|
||||
if rewrap_as_ndarray:
|
||||
out = np_arrays.tensor_to_ndarray(out)
|
||||
output[i] = out
|
||||
|
||||
return nest.pack_sequence_as(sources, output)
|
||||
|
|
|
@ -82,10 +82,10 @@ class NdarraySpec(type_spec.BatchableTypeSpec):
|
|||
return (self._data_spec,)
|
||||
|
||||
def _batch(self, batch_size):
|
||||
return NdarraySpec(self._data_spec.batch(batch_size))
|
||||
return NdarraySpec(self._data_spec._batch(batch_size)) # pylint: disable=protected-access
|
||||
|
||||
def _unbatch(self):
|
||||
return NdarraySpec(self._data_spec.unbatch())
|
||||
return NdarraySpec(self._data_spec._unbatch()) # pylint: disable=protected-access
|
||||
|
||||
|
||||
class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
|
||||
|
@ -306,10 +306,6 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
|
|||
def __repr__(self):
|
||||
return 'ndarray<{}>'.format(self.data.__repr__())
|
||||
|
||||
@property
|
||||
def _id(self):
|
||||
return self.data._id # pylint: disable=protected-access
|
||||
|
||||
|
||||
def tensor_to_ndarray(tensor):
|
||||
return ndarray.from_tensor(tensor)
|
||||
|
|
Loading…
Reference in New Issue