diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index f51bd97e488..408d784ae82 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -588,6 +588,7 @@ py_library( "//tensorflow/python:tensor_shape", "//tensorflow/python:unconnected_gradients", "//tensorflow/python:util", + "//tensorflow/python/ops/numpy_ops:numpy", "//tensorflow/python/ops/parallel_for:control_flow_ops", "@six_archive//:six", ], diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 8da3f71360a..5800a51f89a 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -62,6 +62,9 @@ from tensorflow.python.util.tf_export import tf_export pfor_ops = LazyLoader( "pfor_ops", globals(), "tensorflow.python.ops.parallel_for.control_flow_ops") +np_arrays = LazyLoader( + "np_arrays", globals(), + "tensorflow.python.ops.numpy_ops.np_arrays") function = LazyLoader("function", globals(), "tensorflow.python.eager.function") @@ -721,9 +724,11 @@ pywrap_tfe.TFE_Py_RegisterVSpace(_default_vspace) def _handle_or_self(x): - """If x is ResourceVariable, return its handle, else x.""" + """Unwrap resource variable/ndarray to return tensors.""" if resource_variable_ops.is_resource_variable(x): - x = x.handle + return x.handle + if isinstance(x, np_arrays.ndarray): + return x.data return x @@ -1023,6 +1028,7 @@ class GradientTape(object): "gradient in order to compute higher order " "derivatives.", 1) + num_ndarrays = 0 flat_targets = [] for t in nest.flatten(target): if not backprop_util.IsTrainable(t): @@ -1033,7 +1039,12 @@ class GradientTape(object): if resource_variable_ops.is_resource_variable(t): with self: t = ops.convert_to_tensor(t) + elif isinstance(t, np_arrays.ndarray): + t = t.data + num_ndarrays += 1 flat_targets.append(t) + # Only rewrap if all targets are ndarray. If not, prefer tensors. + rewrap_as_ndarray = num_ndarrays == len(flat_targets) flat_sources = nest.flatten(sources) flat_sources_raw = flat_sources @@ -1066,6 +1077,9 @@ class GradientTape(object): self._watched_variables = self._tape.watched_variables() self._tape = None + if rewrap_as_ndarray: + flat_grad = nest.map_structure(np_arrays.tensor_to_ndarray, flat_grad) + grad = nest.pack_sequence_as(sources, flat_grad) return grad @@ -1120,6 +1134,10 @@ class GradientTape(object): ValueError: If vectorization of jacobian computation fails. """ flat_sources = nest.flatten(sources) + rewrap_as_ndarray = False + if isinstance(target, np_arrays.ndarray): + target = target.data + rewrap_as_ndarray = True target_static_shape = target.shape target_shape = array_ops.shape(target) # Note that we push and pop the tape here and below. This is needed since we @@ -1169,6 +1187,8 @@ class GradientTape(object): out = array_ops.reshape(out, new_shape) if context.executing_eagerly(): out.set_shape(target_static_shape.concatenate(flat_sources[i].shape)) + if rewrap_as_ndarray: + out = np_arrays.tensor_to_ndarray(out) output[i] = out return nest.pack_sequence_as(sources, output) diff --git a/tensorflow/python/ops/numpy_ops/np_arrays.py b/tensorflow/python/ops/numpy_ops/np_arrays.py index fd26318bea9..eca84421d1b 100644 --- a/tensorflow/python/ops/numpy_ops/np_arrays.py +++ b/tensorflow/python/ops/numpy_ops/np_arrays.py @@ -82,10 +82,10 @@ class NdarraySpec(type_spec.BatchableTypeSpec): return (self._data_spec,) def _batch(self, batch_size): - return NdarraySpec(self._data_spec.batch(batch_size)) + return NdarraySpec(self._data_spec._batch(batch_size)) # pylint: disable=protected-access def _unbatch(self): - return NdarraySpec(self._data_spec.unbatch()) + return NdarraySpec(self._data_spec._unbatch()) # pylint: disable=protected-access class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name @@ -306,10 +306,6 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name def __repr__(self): return 'ndarray<{}>'.format(self.data.__repr__()) - @property - def _id(self): - return self.data._id # pylint: disable=protected-access - def tensor_to_ndarray(tensor): return ndarray.from_tensor(tensor)