Merge pull request #32121 from ilhamfp:autograph-override-map
PiperOrigin-RevId: 267193122
This commit is contained in:
commit
059fe88ded
@ -349,14 +349,28 @@ def zip_(*iterables):
|
|||||||
|
|
||||||
|
|
||||||
def _tf_dataset_zip(*iterables):
|
def _tf_dataset_zip(*iterables):
|
||||||
return dataset_ops.DatasetV2.zip(tuple(iterables))
|
return dataset_ops.DatasetV2.zip(iterables)
|
||||||
|
|
||||||
|
|
||||||
def _py_zip(*iterables):
|
def _py_zip(*iterables):
|
||||||
return zip(*iterables)
|
return zip(*iterables)
|
||||||
|
|
||||||
|
|
||||||
SUPPORTED_BUILTINS = (abs, float, int, len, print, range, enumerate, zip)
|
def map_(fn, *iterables):
|
||||||
|
if all(isinstance(x, dataset_ops.DatasetV2) for x in iterables):
|
||||||
|
return _tf_dataset_map(fn, *iterables)
|
||||||
|
return _py_map(fn, *iterables)
|
||||||
|
|
||||||
|
|
||||||
|
def _tf_dataset_map(fn, *iterables):
|
||||||
|
return dataset_ops.DatasetV2.zip(iterables).map(fn)
|
||||||
|
|
||||||
|
|
||||||
|
def _py_map(fn, *iterables):
|
||||||
|
return map(fn, *iterables)
|
||||||
|
|
||||||
|
|
||||||
|
SUPPORTED_BUILTINS = (abs, float, int, len, print, range, enumerate, zip, map)
|
||||||
|
|
||||||
if six.PY2:
|
if six.PY2:
|
||||||
SUPPORTED_BUILTINS += (xrange,)
|
SUPPORTED_BUILTINS += (xrange,)
|
||||||
@ -372,4 +386,5 @@ BUILTIN_FUINCTIONS_MAP = {
|
|||||||
'xrange': range_,
|
'xrange': range_,
|
||||||
'enumerate': enumerate_,
|
'enumerate': enumerate_,
|
||||||
'zip': zip_,
|
'zip': zip_,
|
||||||
|
'map': map_,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -178,6 +178,41 @@ class PyBuiltinsTest(test.TestCase):
|
|||||||
self.assertAllEqual(self.evaluate(iterator.get_next()), (-12, -22))
|
self.assertAllEqual(self.evaluate(iterator.get_next()), (-12, -22))
|
||||||
self.assertAllEqual(self.evaluate(iterator.get_next()), (4, 5))
|
self.assertAllEqual(self.evaluate(iterator.get_next()), (4, 5))
|
||||||
|
|
||||||
|
def test_map(self):
|
||||||
|
|
||||||
|
def increment(x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
add_list = lambda x, y: x + y
|
||||||
|
self.assertListEqual(
|
||||||
|
list(py_builtins.map_(increment, [4, 5, 6])), [5, 6, 7])
|
||||||
|
self.assertListEqual(
|
||||||
|
list(py_builtins.map_(add_list, [3, 2, 1], [-1, -2, -3])), [2, 0, -2])
|
||||||
|
|
||||||
|
def test_map_dataset(self):
|
||||||
|
|
||||||
|
def increment(x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
ds1 = dataset_ops.DatasetV2.from_tensor_slices([4, 5, 6])
|
||||||
|
ds2 = py_builtins.map_(increment, ds1)
|
||||||
|
iterator = dataset_ops.make_one_shot_iterator(ds2)
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
self.assertAllEqual(self.evaluate(iterator.get_next()), 5)
|
||||||
|
self.assertAllEqual(self.evaluate(iterator.get_next()), 6)
|
||||||
|
self.assertAllEqual(self.evaluate(iterator.get_next()), 7)
|
||||||
|
|
||||||
|
def test_map_multiple_datasets(self):
|
||||||
|
add_list = lambda x, y: x + y
|
||||||
|
ds1 = dataset_ops.DatasetV2.from_tensor_slices([-11, -12, 4])
|
||||||
|
ds2 = dataset_ops.DatasetV2.from_tensor_slices([-21, -22, 5])
|
||||||
|
ds3 = py_builtins.map_(add_list, ds1, ds2)
|
||||||
|
iterator = dataset_ops.make_one_shot_iterator(ds3)
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
self.assertAllEqual(self.evaluate(iterator.get_next()), -32)
|
||||||
|
self.assertAllEqual(self.evaluate(iterator.get_next()), -34)
|
||||||
|
self.assertAllEqual(self.evaluate(iterator.get_next()), 9)
|
||||||
|
|
||||||
def _basic_function_scope(self):
|
def _basic_function_scope(self):
|
||||||
return function_wrappers.FunctionScope(
|
return function_wrappers.FunctionScope(
|
||||||
'test_function_name',
|
'test_function_name',
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user