Wrap/rewrap ndarrays in tf.vectorized_map

PiperOrigin-RevId: 322714372
Change-Id: I59a66d8d60674800df36712ab53902037642bf2f
This commit is contained in:
Akshay Modi 2020-07-22 20:56:00 -07:00 committed by TensorFlower Gardener
parent ee74b70ee5
commit 488448c742
2 changed files with 21 additions and 5 deletions

View File

@ -281,8 +281,7 @@ class InteropTest(tf.test.TestCase):
a = np.ones((batch_size, 32, 32))
c = tf.vectorized_map(outer_product, a)
# # TODO(nareshmodi): vectorized_map doesn't rewrap tensors in ndarray.
# self.assertIsInstance(c, np.ndarray)
self.assertIsInstance(c, np.ndarray)
self.assertEqual(c.shape, (batch_size, 32, 32, 32, 32))
def testJacobian(self):

View File

@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.parallel_for.pfor import PFor
from tensorflow.python.ops.parallel_for.pfor import PForConfig
from tensorflow.python.platform import tf_logging as logging
@ -246,6 +247,7 @@ def _pfor_impl(loop_fn,
loop_fn_outputs = loop_fn(loop_var)
# Convert outputs to Tensor if needed.
rewrap_as_ndarray = False
tmp_loop_fn_outputs = []
for loop_fn_output in nest.flatten(loop_fn_outputs):
if (loop_fn_output is not None and not isinstance(
@ -256,7 +258,12 @@ def _pfor_impl(loop_fn,
" Alternatively, output the indices and values of the"
" IndexedSlices separately, and handle the vectorized"
" outputs directly." % loop_fn_output)
loop_fn_output = ops.convert_to_tensor(loop_fn_output)
loop_fn_output = ops.convert_to_tensor(loop_fn_output)
elif isinstance(loop_fn_output, np_arrays.ndarray):
loop_fn_output = loop_fn_output.data
rewrap_as_ndarray = True
else:
loop_fn_output = ops.convert_to_tensor(loop_fn_output)
tmp_loop_fn_outputs.append(loop_fn_output)
loop_fn_outputs = nest.pack_sequence_as(loop_fn_outputs, tmp_loop_fn_outputs)
@ -277,7 +284,10 @@ def _pfor_impl(loop_fn,
pfor_config=pfor_config)
outputs = []
for loop_fn_output in nest.flatten(loop_fn_outputs):
outputs.append(converter.convert(loop_fn_output))
output = converter.convert(loop_fn_output)
if rewrap_as_ndarray:
output = np_arrays.tensor_to_ndarray(output)
outputs.append(output)
return nest.pack_sequence_as(loop_fn_outputs, outputs)
else:
if pfor_config is not None and pfor_config._has_reductions(): # pylint: disable=protected-access
@ -294,7 +304,10 @@ def _pfor_impl(loop_fn,
remaining_outputs = []
flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs)
for loop_fn_output in flattened_loop_fn_outputs:
remaining_outputs.append(converter.convert(loop_fn_output))
output = converter.convert(loop_fn_output)
if rewrap_as_ndarray:
output = np_arrays.tensor_to_ndarray(output)
remaining_outputs.append(output)
with ops.name_scope("pfor_tiled"):
loop_fn_dtypes = [ops.convert_to_tensor(x).dtype
@ -329,6 +342,10 @@ def _pfor_impl(loop_fn,
for x, y in zip(remaining_outputs, tiled_outputs)])
else:
outputs = tiled_outputs
flattened_outputs = nest.flatten(outputs)
if rewrap_as_ndarray:
flattened_outputs = [
np_arrays.tensor_to_ndarray(x) for x in flattened_outputs]
return nest.pack_sequence_as(loop_fn_outputs, nest.flatten(outputs))