[tf.data] Reduce noise in the Dataset.from_tensor_slices(tf.SparseTensor).

This change uses the single-threaded executor in FlatMapDataset to avoid thread-scheduling noise from function invocation.

PiperOrigin-RevId: 288720757
Change-Id: I7d0e7982d90274201449820c2c949a98ad612335
This commit is contained in:
Derek Murray 2020-01-08 10:07:50 -08:00 committed by TensorFlower Gardener
parent 2ad9dd652f
commit 5da98db80c

View File

@ -22,7 +22,40 @@ import numpy as np
from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.experimental.ops import get_single_element
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import def_function
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_dataset_ops
class SingleThreadedFlatMapDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that maps a function over its input and flattens the result."""
def __init__(self, input_dataset, map_func):
"""See `Dataset.flat_map()` for details."""
self._input_dataset = input_dataset
self._map_func = dataset_ops.StructuredFunctionWrapper(
map_func,
self._transformation_name(),
dataset=input_dataset,
defun_kwargs={"_executor": "SINGLE_THREADED_EXECUTOR"})
self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.flat_map_dataset(
input_dataset._variant_tensor, # pylint: disable=protected-access
self._map_func.function.captured_inputs,
f=self._map_func.function,
**self._flat_structure)
super(SingleThreadedFlatMapDataset, self).__init__(input_dataset,
variant_tensor)
def _functions(self):
return [self._map_func]
@property
def element_spec(self):
return self._structure
def _transformation_name(self):
return "SingleThreadedFlatMapDataset"
# TODO(b/119837791): Add eager benchmarks.
@ -76,14 +109,21 @@ class FromTensorSlicesBenchmark(benchmark_base.DatasetBenchmarkBase):
dense_shape=[1000])
for num_rows in num_rows_values:
# TODO(b/147153744): Function-valued attributes with their own
# attributes are currently only supported in graph mode.
@def_function.function
def make_dataset():
batched = dataset_ops.Dataset.from_tensors(
tensor).repeat(num_rows).batch(num_rows)
tensor).repeat(num_rows).batch(num_rows) # pylint: disable=cell-var-from-loop
batched_tensor = get_single_element.get_single_element(batched)
dataset = dataset_ops.Dataset.from_tensors(batched_tensor).flat_map(
dataset_ops.Dataset.from_tensor_slices).repeat()
dataset = dataset_ops.Dataset.from_tensors(batched_tensor).repeat()
return SingleThreadedFlatMapDataset(
dataset, dataset_ops.Dataset.from_tensor_slices)
self.run_and_report_benchmark(
dataset,
make_dataset(),
num_elements=100000,
iters=5,
name="slice_repeat_sparse_elements_per_row_%d_num_rows_%d" % (