Wrap/rewrap ndarrays in tf.vectorized_map
PiperOrigin-RevId: 322714372 Change-Id: I59a66d8d60674800df36712ab53902037642bf2f
This commit is contained in:
parent
ee74b70ee5
commit
488448c742
@ -281,8 +281,7 @@ class InteropTest(tf.test.TestCase):
|
||||
a = np.ones((batch_size, 32, 32))
|
||||
c = tf.vectorized_map(outer_product, a)
|
||||
|
||||
# # TODO(nareshmodi): vectorized_map doesn't rewrap tensors in ndarray.
|
||||
# self.assertIsInstance(c, np.ndarray)
|
||||
self.assertIsInstance(c, np.ndarray)
|
||||
self.assertEqual(c.shape, (batch_size, 32, 32, 32, 32))
|
||||
|
||||
def testJacobian(self):
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops.numpy_ops import np_arrays
|
||||
from tensorflow.python.ops.parallel_for.pfor import PFor
|
||||
from tensorflow.python.ops.parallel_for.pfor import PForConfig
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
@ -246,6 +247,7 @@ def _pfor_impl(loop_fn,
|
||||
loop_fn_outputs = loop_fn(loop_var)
|
||||
|
||||
# Convert outputs to Tensor if needed.
|
||||
rewrap_as_ndarray = False
|
||||
tmp_loop_fn_outputs = []
|
||||
for loop_fn_output in nest.flatten(loop_fn_outputs):
|
||||
if (loop_fn_output is not None and not isinstance(
|
||||
@ -256,7 +258,12 @@ def _pfor_impl(loop_fn,
|
||||
" Alternatively, output the indices and values of the"
|
||||
" IndexedSlices separately, and handle the vectorized"
|
||||
" outputs directly." % loop_fn_output)
|
||||
loop_fn_output = ops.convert_to_tensor(loop_fn_output)
|
||||
loop_fn_output = ops.convert_to_tensor(loop_fn_output)
|
||||
elif isinstance(loop_fn_output, np_arrays.ndarray):
|
||||
loop_fn_output = loop_fn_output.data
|
||||
rewrap_as_ndarray = True
|
||||
else:
|
||||
loop_fn_output = ops.convert_to_tensor(loop_fn_output)
|
||||
tmp_loop_fn_outputs.append(loop_fn_output)
|
||||
loop_fn_outputs = nest.pack_sequence_as(loop_fn_outputs, tmp_loop_fn_outputs)
|
||||
|
||||
@ -277,7 +284,10 @@ def _pfor_impl(loop_fn,
|
||||
pfor_config=pfor_config)
|
||||
outputs = []
|
||||
for loop_fn_output in nest.flatten(loop_fn_outputs):
|
||||
outputs.append(converter.convert(loop_fn_output))
|
||||
output = converter.convert(loop_fn_output)
|
||||
if rewrap_as_ndarray:
|
||||
output = np_arrays.tensor_to_ndarray(output)
|
||||
outputs.append(output)
|
||||
return nest.pack_sequence_as(loop_fn_outputs, outputs)
|
||||
else:
|
||||
if pfor_config is not None and pfor_config._has_reductions(): # pylint: disable=protected-access
|
||||
@ -294,7 +304,10 @@ def _pfor_impl(loop_fn,
|
||||
remaining_outputs = []
|
||||
flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs)
|
||||
for loop_fn_output in flattened_loop_fn_outputs:
|
||||
remaining_outputs.append(converter.convert(loop_fn_output))
|
||||
output = converter.convert(loop_fn_output)
|
||||
if rewrap_as_ndarray:
|
||||
output = np_arrays.tensor_to_ndarray(output)
|
||||
remaining_outputs.append(output)
|
||||
|
||||
with ops.name_scope("pfor_tiled"):
|
||||
loop_fn_dtypes = [ops.convert_to_tensor(x).dtype
|
||||
@ -329,6 +342,10 @@ def _pfor_impl(loop_fn,
|
||||
for x, y in zip(remaining_outputs, tiled_outputs)])
|
||||
else:
|
||||
outputs = tiled_outputs
|
||||
flattened_outputs = nest.flatten(outputs)
|
||||
if rewrap_as_ndarray:
|
||||
flattened_outputs = [
|
||||
np_arrays.tensor_to_ndarray(x) for x in flattened_outputs]
|
||||
return nest.pack_sequence_as(loop_fn_outputs, nest.flatten(outputs))
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user