Special case wrapping of ndarrays in the gradient tape code.
PiperOrigin-RevId: 317762474 Change-Id: Ie848ad90a88aff5b2faef4069c3f05887038c367
This commit is contained in:
parent
2d8d440dbb
commit
5d4a29eaf5
|
@ -588,6 +588,7 @@ py_library(
|
||||||
"//tensorflow/python:tensor_shape",
|
"//tensorflow/python:tensor_shape",
|
||||||
"//tensorflow/python:unconnected_gradients",
|
"//tensorflow/python:unconnected_gradients",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
|
"//tensorflow/python/ops/numpy_ops:numpy",
|
||||||
"//tensorflow/python/ops/parallel_for:control_flow_ops",
|
"//tensorflow/python/ops/parallel_for:control_flow_ops",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
|
|
|
@ -62,6 +62,9 @@ from tensorflow.python.util.tf_export import tf_export
|
||||||
pfor_ops = LazyLoader(
|
pfor_ops = LazyLoader(
|
||||||
"pfor_ops", globals(),
|
"pfor_ops", globals(),
|
||||||
"tensorflow.python.ops.parallel_for.control_flow_ops")
|
"tensorflow.python.ops.parallel_for.control_flow_ops")
|
||||||
|
np_arrays = LazyLoader(
|
||||||
|
"np_arrays", globals(),
|
||||||
|
"tensorflow.python.ops.numpy_ops.np_arrays")
|
||||||
|
|
||||||
function = LazyLoader("function", globals(),
|
function = LazyLoader("function", globals(),
|
||||||
"tensorflow.python.eager.function")
|
"tensorflow.python.eager.function")
|
||||||
|
@ -721,9 +724,11 @@ pywrap_tfe.TFE_Py_RegisterVSpace(_default_vspace)
|
||||||
|
|
||||||
|
|
||||||
def _handle_or_self(x):
|
def _handle_or_self(x):
|
||||||
"""If x is ResourceVariable, return its handle, else x."""
|
"""Unwrap resource variable/ndarray to return tensors."""
|
||||||
if resource_variable_ops.is_resource_variable(x):
|
if resource_variable_ops.is_resource_variable(x):
|
||||||
x = x.handle
|
return x.handle
|
||||||
|
if isinstance(x, np_arrays.ndarray):
|
||||||
|
return x.data
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -1023,6 +1028,7 @@ class GradientTape(object):
|
||||||
"gradient in order to compute higher order "
|
"gradient in order to compute higher order "
|
||||||
"derivatives.", 1)
|
"derivatives.", 1)
|
||||||
|
|
||||||
|
num_ndarrays = 0
|
||||||
flat_targets = []
|
flat_targets = []
|
||||||
for t in nest.flatten(target):
|
for t in nest.flatten(target):
|
||||||
if not backprop_util.IsTrainable(t):
|
if not backprop_util.IsTrainable(t):
|
||||||
|
@ -1033,7 +1039,12 @@ class GradientTape(object):
|
||||||
if resource_variable_ops.is_resource_variable(t):
|
if resource_variable_ops.is_resource_variable(t):
|
||||||
with self:
|
with self:
|
||||||
t = ops.convert_to_tensor(t)
|
t = ops.convert_to_tensor(t)
|
||||||
|
elif isinstance(t, np_arrays.ndarray):
|
||||||
|
t = t.data
|
||||||
|
num_ndarrays += 1
|
||||||
flat_targets.append(t)
|
flat_targets.append(t)
|
||||||
|
# Only rewrap if all targets are ndarray. If not, prefer tensors.
|
||||||
|
rewrap_as_ndarray = num_ndarrays == len(flat_targets)
|
||||||
|
|
||||||
flat_sources = nest.flatten(sources)
|
flat_sources = nest.flatten(sources)
|
||||||
flat_sources_raw = flat_sources
|
flat_sources_raw = flat_sources
|
||||||
|
@ -1066,6 +1077,9 @@ class GradientTape(object):
|
||||||
self._watched_variables = self._tape.watched_variables()
|
self._watched_variables = self._tape.watched_variables()
|
||||||
self._tape = None
|
self._tape = None
|
||||||
|
|
||||||
|
if rewrap_as_ndarray:
|
||||||
|
flat_grad = nest.map_structure(np_arrays.tensor_to_ndarray, flat_grad)
|
||||||
|
|
||||||
grad = nest.pack_sequence_as(sources, flat_grad)
|
grad = nest.pack_sequence_as(sources, flat_grad)
|
||||||
return grad
|
return grad
|
||||||
|
|
||||||
|
@ -1120,6 +1134,10 @@ class GradientTape(object):
|
||||||
ValueError: If vectorization of jacobian computation fails.
|
ValueError: If vectorization of jacobian computation fails.
|
||||||
"""
|
"""
|
||||||
flat_sources = nest.flatten(sources)
|
flat_sources = nest.flatten(sources)
|
||||||
|
rewrap_as_ndarray = False
|
||||||
|
if isinstance(target, np_arrays.ndarray):
|
||||||
|
target = target.data
|
||||||
|
rewrap_as_ndarray = True
|
||||||
target_static_shape = target.shape
|
target_static_shape = target.shape
|
||||||
target_shape = array_ops.shape(target)
|
target_shape = array_ops.shape(target)
|
||||||
# Note that we push and pop the tape here and below. This is needed since we
|
# Note that we push and pop the tape here and below. This is needed since we
|
||||||
|
@ -1169,6 +1187,8 @@ class GradientTape(object):
|
||||||
out = array_ops.reshape(out, new_shape)
|
out = array_ops.reshape(out, new_shape)
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
out.set_shape(target_static_shape.concatenate(flat_sources[i].shape))
|
out.set_shape(target_static_shape.concatenate(flat_sources[i].shape))
|
||||||
|
if rewrap_as_ndarray:
|
||||||
|
out = np_arrays.tensor_to_ndarray(out)
|
||||||
output[i] = out
|
output[i] = out
|
||||||
|
|
||||||
return nest.pack_sequence_as(sources, output)
|
return nest.pack_sequence_as(sources, output)
|
||||||
|
|
|
@ -82,10 +82,10 @@ class NdarraySpec(type_spec.BatchableTypeSpec):
|
||||||
return (self._data_spec,)
|
return (self._data_spec,)
|
||||||
|
|
||||||
def _batch(self, batch_size):
|
def _batch(self, batch_size):
|
||||||
return NdarraySpec(self._data_spec.batch(batch_size))
|
return NdarraySpec(self._data_spec._batch(batch_size)) # pylint: disable=protected-access
|
||||||
|
|
||||||
def _unbatch(self):
|
def _unbatch(self):
|
||||||
return NdarraySpec(self._data_spec.unbatch())
|
return NdarraySpec(self._data_spec._unbatch()) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
|
class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
|
||||||
|
@ -306,10 +306,6 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return 'ndarray<{}>'.format(self.data.__repr__())
|
return 'ndarray<{}>'.format(self.data.__repr__())
|
||||||
|
|
||||||
@property
|
|
||||||
def _id(self):
|
|
||||||
return self.data._id # pylint: disable=protected-access
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_to_ndarray(tensor):
|
def tensor_to_ndarray(tensor):
|
||||||
return ndarray.from_tensor(tensor)
|
return ndarray.from_tensor(tensor)
|
||||||
|
|
Loading…
Reference in New Issue