Merge pull request #32121 from ilhamfp:autograph-override-map

PiperOrigin-RevId: 267193122
This commit is contained in:
TensorFlower Gardener 2019-09-04 11:50:11 -07:00
commit 059fe88ded
2 changed files with 52 additions and 2 deletions

View File

@ -349,14 +349,28 @@ def zip_(*iterables):
def _tf_dataset_zip(*iterables):
return dataset_ops.DatasetV2.zip(tuple(iterables))
return dataset_ops.DatasetV2.zip(iterables)
def _py_zip(*iterables):
return zip(*iterables)
SUPPORTED_BUILTINS = (abs, float, int, len, print, range, enumerate, zip)
def map_(fn, *iterables):
if all(isinstance(x, dataset_ops.DatasetV2) for x in iterables):
return _tf_dataset_map(fn, *iterables)
return _py_map(fn, *iterables)
def _tf_dataset_map(fn, *iterables):
return dataset_ops.DatasetV2.zip(iterables).map(fn)
def _py_map(fn, *iterables):
return map(fn, *iterables)
SUPPORTED_BUILTINS = (abs, float, int, len, print, range, enumerate, zip, map)
if six.PY2:
SUPPORTED_BUILTINS += (xrange,)
@ -372,4 +386,5 @@ BUILTIN_FUINCTIONS_MAP = {
'xrange': range_,
'enumerate': enumerate_,
'zip': zip_,
'map': map_,
}

View File

@ -178,6 +178,41 @@ class PyBuiltinsTest(test.TestCase):
self.assertAllEqual(self.evaluate(iterator.get_next()), (-12, -22))
self.assertAllEqual(self.evaluate(iterator.get_next()), (4, 5))
def test_map(self):
def increment(x):
return x + 1
add_list = lambda x, y: x + y
self.assertListEqual(
list(py_builtins.map_(increment, [4, 5, 6])), [5, 6, 7])
self.assertListEqual(
list(py_builtins.map_(add_list, [3, 2, 1], [-1, -2, -3])), [2, 0, -2])
def test_map_dataset(self):
def increment(x):
return x + 1
ds1 = dataset_ops.DatasetV2.from_tensor_slices([4, 5, 6])
ds2 = py_builtins.map_(increment, ds1)
iterator = dataset_ops.make_one_shot_iterator(ds2)
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(iterator.get_next()), 5)
self.assertAllEqual(self.evaluate(iterator.get_next()), 6)
self.assertAllEqual(self.evaluate(iterator.get_next()), 7)
def test_map_multiple_datasets(self):
add_list = lambda x, y: x + y
ds1 = dataset_ops.DatasetV2.from_tensor_slices([-11, -12, 4])
ds2 = dataset_ops.DatasetV2.from_tensor_slices([-21, -22, 5])
ds3 = py_builtins.map_(add_list, ds1, ds2)
iterator = dataset_ops.make_one_shot_iterator(ds3)
with self.cached_session() as sess:
self.assertAllEqual(self.evaluate(iterator.get_next()), -32)
self.assertAllEqual(self.evaluate(iterator.get_next()), -34)
self.assertAllEqual(self.evaluate(iterator.get_next()), 9)
def _basic_function_scope(self):
return function_wrappers.FunctionScope(
'test_function_name',