Fix tf.map_fn interop with np ndarray

PiperOrigin-RevId: 323628461
Change-Id: Ib56efd25a2cd200cb1f2cb2469940f24b93cfcd2
This commit is contained in:
Akshay Modi 2020-07-28 12:27:16 -07:00 committed by TensorFlower Gardener
parent f72b707dfc
commit cdac4cf5c5
2 changed files with 19 additions and 1 deletions

View File

@ -38,10 +38,16 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
from tensorflow.python.util import lazy_loader
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
np_arrays = lazy_loader.LazyLoader(
"np_arrays", globals(),
"tensorflow.python.ops.numpy_ops.np_arrays")
@tf_export(v1=["map_fn"])
@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype")
def map_fn(fn,
@ -419,7 +425,10 @@ def map_fn(fn,
]
# Check that inputs are not scalars.
elems_static_shape = elems_flat[0].shape
first_elem = elems_flat[0]
if isinstance(first_elem, np_arrays.ndarray):
first_elem = first_elem.data
elems_static_shape = first_elem.shape
if elems_static_shape.ndims is not None and elems_static_shape.ndims < 1:
if len(elems_flat) == 1:
raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar")

View File

@ -315,6 +315,15 @@ class InteropTest(tf.test.TestCase):
self.assertIsInstance(batch_jacobian, np.ndarray)
self.assertAllClose(batch_jacobian, answer)
def testMapFn(self):
x = np.asarray([1., 2.])
mapped_x = tf.map_fn(lambda x: (x[0]+1, x[1]+1), (x, x))
self.assertIsInstance(mapped_x[0], np.ndarray)
self.assertIsInstance(mapped_x[1], np.ndarray)
self.assertAllClose(mapped_x[0], [2., 3.])
self.assertAllClose(mapped_x[1], [2., 3.])
class FunctionTest(InteropTest):