diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 9694fc2574b..7cd838b926c 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -1152,6 +1152,18 @@ class DatasetV2(object): def filter(self, predicate): """Filters this dataset according to `predicate`. + ```python + d = tf.data.Dataset.from_tensor_slices([1, 2, 3]) + + d = d.filter(lambda x: x < 3) # [1, 2] + + # `tf.math.equal(x, y)` is required for equality comparison + def filter_fn(x): + return tf.math.equal(x, 1) + + d = d.filter(filter_fn) # [1] + ``` + Args: predicate: A function mapping a nested structure of tensors (having shapes and types defined by `self.output_shapes` and `self.output_types`) to a