diff --git a/tensorflow/python/autograph/operators/py_builtins.py b/tensorflow/python/autograph/operators/py_builtins.py index cd5f69bbce2..93f7e8fed5b 100644 --- a/tensorflow/python/autograph/operators/py_builtins.py +++ b/tensorflow/python/autograph/operators/py_builtins.py @@ -349,14 +349,28 @@ def zip_(*iterables): def _tf_dataset_zip(*iterables): - return dataset_ops.DatasetV2.zip(tuple(iterables)) + return dataset_ops.DatasetV2.zip(iterables) def _py_zip(*iterables): return zip(*iterables) -SUPPORTED_BUILTINS = (abs, float, int, len, print, range, enumerate, zip) +def map_(fn, *iterables): + if all(isinstance(x, dataset_ops.DatasetV2) for x in iterables): + return _tf_dataset_map(fn, *iterables) + return _py_map(fn, *iterables) + + +def _tf_dataset_map(fn, *iterables): + return dataset_ops.DatasetV2.zip(iterables).map(fn) + + +def _py_map(fn, *iterables): + return map(fn, *iterables) + + +SUPPORTED_BUILTINS = (abs, float, int, len, print, range, enumerate, zip, map) if six.PY2: SUPPORTED_BUILTINS += (xrange,) @@ -372,4 +386,5 @@ BUILTIN_FUINCTIONS_MAP = { 'xrange': range_, 'enumerate': enumerate_, 'zip': zip_, + 'map': map_, } diff --git a/tensorflow/python/autograph/operators/py_builtins_test.py b/tensorflow/python/autograph/operators/py_builtins_test.py index be77495daed..4257510f8d1 100644 --- a/tensorflow/python/autograph/operators/py_builtins_test.py +++ b/tensorflow/python/autograph/operators/py_builtins_test.py @@ -178,6 +178,41 @@ class PyBuiltinsTest(test.TestCase): self.assertAllEqual(self.evaluate(iterator.get_next()), (-12, -22)) self.assertAllEqual(self.evaluate(iterator.get_next()), (4, 5)) + def test_map(self): + + def increment(x): + return x + 1 + + add_list = lambda x, y: x + y + self.assertListEqual( + list(py_builtins.map_(increment, [4, 5, 6])), [5, 6, 7]) + self.assertListEqual( + list(py_builtins.map_(add_list, [3, 2, 1], [-1, -2, -3])), [2, 0, -2]) + + def test_map_dataset(self): + + def increment(x): + return x + 1 + + ds1 = dataset_ops.DatasetV2.from_tensor_slices([4, 5, 6]) + ds2 = py_builtins.map_(increment, ds1) + iterator = dataset_ops.make_one_shot_iterator(ds2) + with self.cached_session() as sess: + self.assertAllEqual(self.evaluate(iterator.get_next()), 5) + self.assertAllEqual(self.evaluate(iterator.get_next()), 6) + self.assertAllEqual(self.evaluate(iterator.get_next()), 7) + + def test_map_multiple_datasets(self): + add_list = lambda x, y: x + y + ds1 = dataset_ops.DatasetV2.from_tensor_slices([-11, -12, 4]) + ds2 = dataset_ops.DatasetV2.from_tensor_slices([-21, -22, 5]) + ds3 = py_builtins.map_(add_list, ds1, ds2) + iterator = dataset_ops.make_one_shot_iterator(ds3) + with self.cached_session() as sess: + self.assertAllEqual(self.evaluate(iterator.get_next()), -32) + self.assertAllEqual(self.evaluate(iterator.get_next()), -34) + self.assertAllEqual(self.evaluate(iterator.get_next()), 9) + def _basic_function_scope(self): return function_wrappers.FunctionScope( 'test_function_name',