[tf.data] Reduce noise in the Dataset.from_tensor_slices(tf.SparseTensor)
.
This change uses the single-threaded executor in FlatMapDataset to avoid thread-scheduling noise from function invocation. PiperOrigin-RevId: 288720757 Change-Id: I7d0e7982d90274201449820c2c949a98ad612335
This commit is contained in:
parent
2ad9dd652f
commit
5da98db80c
@ -22,7 +22,40 @@ import numpy as np
|
||||
from tensorflow.python.data.benchmarks import benchmark_base
|
||||
from tensorflow.python.data.experimental.ops import get_single_element
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
|
||||
|
||||
class SingleThreadedFlatMapDataset(dataset_ops.UnaryDataset):
|
||||
"""A `Dataset` that maps a function over its input and flattens the result."""
|
||||
|
||||
def __init__(self, input_dataset, map_func):
|
||||
"""See `Dataset.flat_map()` for details."""
|
||||
self._input_dataset = input_dataset
|
||||
self._map_func = dataset_ops.StructuredFunctionWrapper(
|
||||
map_func,
|
||||
self._transformation_name(),
|
||||
dataset=input_dataset,
|
||||
defun_kwargs={"_executor": "SINGLE_THREADED_EXECUTOR"})
|
||||
self._structure = self._map_func.output_structure._element_spec # pylint: disable=protected-access
|
||||
variant_tensor = gen_dataset_ops.flat_map_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._map_func.function.captured_inputs,
|
||||
f=self._map_func.function,
|
||||
**self._flat_structure)
|
||||
super(SingleThreadedFlatMapDataset, self).__init__(input_dataset,
|
||||
variant_tensor)
|
||||
|
||||
def _functions(self):
|
||||
return [self._map_func]
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
return self._structure
|
||||
|
||||
def _transformation_name(self):
|
||||
return "SingleThreadedFlatMapDataset"
|
||||
|
||||
|
||||
# TODO(b/119837791): Add eager benchmarks.
|
||||
@ -76,14 +109,21 @@ class FromTensorSlicesBenchmark(benchmark_base.DatasetBenchmarkBase):
|
||||
dense_shape=[1000])
|
||||
|
||||
for num_rows in num_rows_values:
|
||||
|
||||
# TODO(b/147153744): Function-valued attributes with their own
|
||||
# attributes are currently only supported in graph mode.
|
||||
@def_function.function
|
||||
def make_dataset():
|
||||
batched = dataset_ops.Dataset.from_tensors(
|
||||
tensor).repeat(num_rows).batch(num_rows)
|
||||
tensor).repeat(num_rows).batch(num_rows) # pylint: disable=cell-var-from-loop
|
||||
batched_tensor = get_single_element.get_single_element(batched)
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensors(batched_tensor).flat_map(
|
||||
dataset_ops.Dataset.from_tensor_slices).repeat()
|
||||
dataset = dataset_ops.Dataset.from_tensors(batched_tensor).repeat()
|
||||
return SingleThreadedFlatMapDataset(
|
||||
dataset, dataset_ops.Dataset.from_tensor_slices)
|
||||
|
||||
self.run_and_report_benchmark(
|
||||
dataset,
|
||||
make_dataset(),
|
||||
num_elements=100000,
|
||||
iters=5,
|
||||
name="slice_repeat_sparse_elements_per_row_%d_num_rows_%d" % (
|
||||
|
Loading…
Reference in New Issue
Block a user