Merge pull request #32121 from ilhamfp:autograph-override-map
PiperOrigin-RevId: 267193122
This commit is contained in:
commit
059fe88ded
@ -349,14 +349,28 @@ def zip_(*iterables):
|
||||
|
||||
|
||||
def _tf_dataset_zip(*iterables):
|
||||
return dataset_ops.DatasetV2.zip(tuple(iterables))
|
||||
return dataset_ops.DatasetV2.zip(iterables)
|
||||
|
||||
|
||||
def _py_zip(*iterables):
|
||||
return zip(*iterables)
|
||||
|
||||
|
||||
SUPPORTED_BUILTINS = (abs, float, int, len, print, range, enumerate, zip)
|
||||
def map_(fn, *iterables):
|
||||
if all(isinstance(x, dataset_ops.DatasetV2) for x in iterables):
|
||||
return _tf_dataset_map(fn, *iterables)
|
||||
return _py_map(fn, *iterables)
|
||||
|
||||
|
||||
def _tf_dataset_map(fn, *iterables):
|
||||
return dataset_ops.DatasetV2.zip(iterables).map(fn)
|
||||
|
||||
|
||||
def _py_map(fn, *iterables):
|
||||
return map(fn, *iterables)
|
||||
|
||||
|
||||
SUPPORTED_BUILTINS = (abs, float, int, len, print, range, enumerate, zip, map)
|
||||
|
||||
if six.PY2:
|
||||
SUPPORTED_BUILTINS += (xrange,)
|
||||
@ -372,4 +386,5 @@ BUILTIN_FUINCTIONS_MAP = {
|
||||
'xrange': range_,
|
||||
'enumerate': enumerate_,
|
||||
'zip': zip_,
|
||||
'map': map_,
|
||||
}
|
||||
|
||||
@ -178,6 +178,41 @@ class PyBuiltinsTest(test.TestCase):
|
||||
self.assertAllEqual(self.evaluate(iterator.get_next()), (-12, -22))
|
||||
self.assertAllEqual(self.evaluate(iterator.get_next()), (4, 5))
|
||||
|
||||
def test_map(self):
|
||||
|
||||
def increment(x):
|
||||
return x + 1
|
||||
|
||||
add_list = lambda x, y: x + y
|
||||
self.assertListEqual(
|
||||
list(py_builtins.map_(increment, [4, 5, 6])), [5, 6, 7])
|
||||
self.assertListEqual(
|
||||
list(py_builtins.map_(add_list, [3, 2, 1], [-1, -2, -3])), [2, 0, -2])
|
||||
|
||||
def test_map_dataset(self):
|
||||
|
||||
def increment(x):
|
||||
return x + 1
|
||||
|
||||
ds1 = dataset_ops.DatasetV2.from_tensor_slices([4, 5, 6])
|
||||
ds2 = py_builtins.map_(increment, ds1)
|
||||
iterator = dataset_ops.make_one_shot_iterator(ds2)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(self.evaluate(iterator.get_next()), 5)
|
||||
self.assertAllEqual(self.evaluate(iterator.get_next()), 6)
|
||||
self.assertAllEqual(self.evaluate(iterator.get_next()), 7)
|
||||
|
||||
def test_map_multiple_datasets(self):
|
||||
add_list = lambda x, y: x + y
|
||||
ds1 = dataset_ops.DatasetV2.from_tensor_slices([-11, -12, 4])
|
||||
ds2 = dataset_ops.DatasetV2.from_tensor_slices([-21, -22, 5])
|
||||
ds3 = py_builtins.map_(add_list, ds1, ds2)
|
||||
iterator = dataset_ops.make_one_shot_iterator(ds3)
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(self.evaluate(iterator.get_next()), -32)
|
||||
self.assertAllEqual(self.evaluate(iterator.get_next()), -34)
|
||||
self.assertAllEqual(self.evaluate(iterator.get_next()), 9)
|
||||
|
||||
def _basic_function_scope(self):
|
||||
return function_wrappers.FunctionScope(
|
||||
'test_function_name',
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user