Fix tf.map_fn interop with np ndarray
PiperOrigin-RevId: 323628461 Change-Id: Ib56efd25a2cd200cb1f2cb2469940f24b93cfcd2
This commit is contained in:
parent
f72b707dfc
commit
cdac4cf5c5
@ -38,10 +38,16 @@ from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import lazy_loader
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
np_arrays = lazy_loader.LazyLoader(
|
||||
"np_arrays", globals(),
|
||||
"tensorflow.python.ops.numpy_ops.np_arrays")
|
||||
|
||||
|
||||
@tf_export(v1=["map_fn"])
|
||||
@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype")
|
||||
def map_fn(fn,
|
||||
@ -419,7 +425,10 @@ def map_fn(fn,
|
||||
]
|
||||
|
||||
# Check that inputs are not scalars.
|
||||
elems_static_shape = elems_flat[0].shape
|
||||
first_elem = elems_flat[0]
|
||||
if isinstance(first_elem, np_arrays.ndarray):
|
||||
first_elem = first_elem.data
|
||||
elems_static_shape = first_elem.shape
|
||||
if elems_static_shape.ndims is not None and elems_static_shape.ndims < 1:
|
||||
if len(elems_flat) == 1:
|
||||
raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar")
|
||||
|
@ -315,6 +315,15 @@ class InteropTest(tf.test.TestCase):
|
||||
self.assertIsInstance(batch_jacobian, np.ndarray)
|
||||
self.assertAllClose(batch_jacobian, answer)
|
||||
|
||||
def testMapFn(self):
|
||||
x = np.asarray([1., 2.])
|
||||
mapped_x = tf.map_fn(lambda x: (x[0]+1, x[1]+1), (x, x))
|
||||
|
||||
self.assertIsInstance(mapped_x[0], np.ndarray)
|
||||
self.assertIsInstance(mapped_x[1], np.ndarray)
|
||||
self.assertAllClose(mapped_x[0], [2., 3.])
|
||||
self.assertAllClose(mapped_x[1], [2., 3.])
|
||||
|
||||
|
||||
class FunctionTest(InteropTest):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user