From f6e8f7a1fb7c530a98ba93d40dba1f6c859a873b Mon Sep 17 00:00:00 2001
From: Brennan Saeta <saeta@google.com>
Date: Thu, 29 Nov 2018 09:53:22 -0800
Subject: [PATCH] [tf.data]: Deprecate dataset.shard.

In preparation for a distributed-aware tf.data system, and TF 2.0, this CL
deprecates the legacy `.shard` method on tf.data.Dataset objects. The old
behavior is preserved under an experimental `filter_for_shard` dataset
transformation, as well as in the V1 API (albeit now marked with a deprecation
warning).

PiperOrigin-RevId: 223362539
---
 .../python/data/experimental/__init__.py      |   2 +
 tensorflow/python/data/experimental/ops/BUILD |  13 ++
 .../experimental/ops/filter_for_shard_ops.py  | 106 +++++++++++++++
 tensorflow/python/data/ops/BUILD              |   1 +
 tensorflow/python/data/ops/dataset_ops.py     | 128 ++++++++----------
 tensorflow/python/distribute/BUILD            |   1 +
 tensorflow/python/distribute/input_ops.py     |  13 +-
 .../v1/tensorflow.data.experimental.pbtxt     |   4 +
 .../golden/v2/tensorflow.data.-dataset.pbtxt  |   4 -
 ...ow.data.-fixed-length-record-dataset.pbtxt |   4 -
 .../tensorflow.data.-t-f-record-dataset.pbtxt |   4 -
 .../tensorflow.data.-text-line-dataset.pbtxt  |   4 -
 ...rflow.data.experimental.-csv-dataset.pbtxt |   4 -
 ...ow.data.experimental.-random-dataset.pbtxt |   4 -
 ...rflow.data.experimental.-sql-dataset.pbtxt |   4 -
 .../v2/tensorflow.data.experimental.pbtxt     |   4 +
 16 files changed, 194 insertions(+), 106 deletions(-)
 create mode 100644 tensorflow/python/data/experimental/ops/filter_for_shard_ops.py

diff --git a/tensorflow/python/data/experimental/__init__.py b/tensorflow/python/data/experimental/__init__.py
index 12aa0a09e11..8cec75b5992 100644
--- a/tensorflow/python/data/experimental/__init__.py
+++ b/tensorflow/python/data/experimental/__init__.py
@@ -39,6 +39,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
 @@copy_to_device
 @@dense_to_sparse_batch
 @@enumerate_dataset
+@@filter_for_shard
 @@get_next_as_optional
 @@get_single_element
 @@group_by_reducer
@@ -74,6 +75,7 @@ from tensorflow.python.data.experimental.ops.batching import unbatch
 from tensorflow.python.data.experimental.ops.counter import Counter
 from tensorflow.python.data.experimental.ops.enumerate_ops import enumerate_dataset
 from tensorflow.python.data.experimental.ops.error_ops import ignore_errors
+from tensorflow.python.data.experimental.ops.filter_for_shard_ops import filter_for_shard
 from tensorflow.python.data.experimental.ops.get_single_element import get_single_element
 from tensorflow.python.data.experimental.ops.grouping import bucket_by_sequence_length
 from tensorflow.python.data.experimental.ops.grouping import group_by_reducer
diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD
index f9544857a13..f85e774887d 100644
--- a/tensorflow/python/data/experimental/ops/BUILD
+++ b/tensorflow/python/data/experimental/ops/BUILD
@@ -139,6 +139,18 @@ py_library(
     ],
 )
 
+py_library(
+    name = "filter_for_shard_ops",
+    srcs = ["filter_for_shard_ops.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:ops",
+        "//tensorflow/python:tensor_util",
+    ],
+)
+
 py_library(
     name = "error_ops",
     srcs = ["error_ops.py"],
@@ -403,6 +415,7 @@ py_library(
         ":counter",
         ":enumerate_ops",
         ":error_ops",
+        ":filter_for_shard_ops",
         ":get_single_element",
         ":grouping",
         ":indexed_dataset_ops",
diff --git a/tensorflow/python/data/experimental/ops/filter_for_shard_ops.py b/tensorflow/python/data/experimental/ops/filter_for_shard_ops.py
new file mode 100644
index 00000000000..91d3dca3e9a
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/filter_for_shard_ops.py
@@ -0,0 +1,106 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Naive shard dataset transformation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import math_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.filter_for_shard")
+def filter_for_shard(num_shards, shard_index):
+  """Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
+
+  This dataset operator is very useful when running distributed training, as
+  it allows each worker to read a unique subset.
+
+  When reading a single input file, you can skip elements as follows:
+
+  ```python
+  d = tf.data.TFRecordDataset(FLAGS.input_file)
+  d = d.apply(tf.data.experimental.naive_shard(FLAGS.num_workers,
+                                               FLAGS.worker_index))
+  d = d.repeat(FLAGS.num_epochs)
+  d = d.shuffle(FLAGS.shuffle_buffer_size)
+  d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
+  ```
+
+  Important caveats:
+
+  - Be sure to shard before you use any randomizing operator (such as
+    shuffle).
+  - Generally it is best if the shard operator is used early in the dataset
+    pipeline. For example, when reading from a set of TFRecord files, shard
+    before converting the dataset to input samples. This avoids reading every
+    file on every worker. The following is an example of an efficient
+    sharding strategy within a complete pipeline:
+
+  ```python
+  d = Dataset.list_files(FLAGS.pattern)
+  d = d.apply(tf.data.experimental.naive_shard(FLAGS.num_workers,
+                                               FLAGS.worker_index))
+  d = d.repeat(FLAGS.num_epochs)
+  d = d.shuffle(FLAGS.shuffle_buffer_size)
+  d = d.interleave(tf.data.TFRecordDataset,
+                   cycle_length=FLAGS.num_readers, block_length=1)
+  d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
+  ```
+
+  Args:
+    num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
+      shards operating in parallel.
+    shard_index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
+
+  Returns:
+    A `Dataset` transformation function, which can be passed to
+    `tf.data.Dataset.apply`.
+
+  Raises:
+    ValueError: if `num_shards` or `shard_index` are illegal values. Note: error
+      checking is done on a best-effort basis, and errors aren't guaranteed to
+      be caught upon dataset creation. (e.g. providing in a placeholder tensor
+      bypasses the early checking, and will instead result in an error during
+      a session.run call.)
+  """
+  num_shards = ops.convert_to_tensor(
+      num_shards, name="num_shards", dtype=dtypes.int64)
+  num_shards_static = tensor_util.constant_value(num_shards)
+  shard_index = ops.convert_to_tensor(shard_index, name="shard_index",
+                                      dtype=dtypes.int64)
+  shard_index_static = tensor_util.constant_value(shard_index)
+
+  if num_shards_static is not None and num_shards_static < 1:
+    raise ValueError("num_shards must be >= 1; got: %s" % num_shards_static)
+  if shard_index_static is not None and shard_index_static < 0:
+    raise ValueError("shard_index must be >= 0; got: %s" % shard_index_static)
+  if (shard_index_static is not None and num_shards_static is not None and
+      shard_index_static >= num_shards_static):
+    raise ValueError("shard_index must be < num_shards; %s is not < %s" %
+                     (shard_index_static, num_shards_static))
+
+  def filter_fn(elem_index, _):
+    mod_result = math_ops.mod(elem_index, num_shards)
+    return math_ops.equal(mod_result, shard_index)
+
+  def _apply_fn(dataset):
+    # pylint: disable=protected-access
+    return dataset._enumerate().filter(filter_fn).map(lambda _, elem: elem)
+
+  return _apply_fn
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index dcbb0f1868b..27c9175ccbf 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -26,6 +26,7 @@ py_library(
         "//tensorflow/python:tensor_shape",
         "//tensorflow/python:tensor_util",
         "//tensorflow/python:util",
+        "//tensorflow/python/data/experimental/ops:filter_for_shard_ops",
         "//tensorflow/python/data/experimental/ops:stats_options",
         "//tensorflow/python/data/experimental/ops:threading_options",
         "//tensorflow/python/data/util:nest",
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 71175fc6a20..51123aaf443 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -26,6 +26,7 @@ import numpy as np
 import six
 
 from tensorflow.python.compat import compat
+from tensorflow.python.data.experimental.ops import filter_for_shard_ops
 from tensorflow.python.data.experimental.ops import stats_options
 from tensorflow.python.data.experimental.ops import threading_options
 from tensorflow.python.data.ops import iterator_ops
@@ -821,78 +822,6 @@ class DatasetV2(object):
     """
     return SkipDataset(self, count)
 
-  def shard(self, num_shards, index):
-    """Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
-
-    This dataset operator is very useful when running distributed training, as
-    it allows each worker to read a unique subset.
-
-    When reading a single input file, you can skip elements as follows:
-
-    ```python
-    d = tf.data.TFRecordDataset(FLAGS.input_file)
-    d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
-    d = d.repeat(FLAGS.num_epochs)
-    d = d.shuffle(FLAGS.shuffle_buffer_size)
-    d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
-    ```
-
-    Important caveats:
-
-    - Be sure to shard before you use any randomizing operator (such as
-      shuffle).
-    - Generally it is best if the shard operator is used early in the dataset
-      pipeline. For example, when reading from a set of TFRecord files, shard
-      before converting the dataset to input samples. This avoids reading every
-      file on every worker. The following is an example of an efficient
-      sharding strategy within a complete pipeline:
-
-    ```python
-    d = Dataset.list_files(FLAGS.pattern)
-    d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
-    d = d.repeat(FLAGS.num_epochs)
-    d = d.shuffle(FLAGS.shuffle_buffer_size)
-    d = d.interleave(tf.data.TFRecordDataset,
-                     cycle_length=FLAGS.num_readers, block_length=1)
-    d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
-    ```
-
-    Args:
-      num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
-        shards operating in parallel.
-      index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
-
-    Returns:
-      Dataset: A `Dataset`.
-
-    Raises:
-      ValueError: if `num_shards` or `index` are illegal values. Note: error
-        checking is done on a best-effort basis, and errors aren't guaranteed
-        to be caught upon dataset creation. (e.g. providing in a placeholder
-        tensor bypasses the early checking, and will instead result in an error
-        during a session.run call.)
-    """
-    num_shards = ops.convert_to_tensor(
-        num_shards, name="num_shards", dtype=dtypes.int64)
-    num_shards_static = tensor_util.constant_value(num_shards)
-    index = ops.convert_to_tensor(index, name="index", dtype=dtypes.int64)
-    index_static = tensor_util.constant_value(index)
-
-    if num_shards_static is not None and num_shards_static < 1:
-      raise ValueError("num_shards must be >= 1; got: %s" % num_shards_static)
-    if index_static is not None and index_static < 0:
-      raise ValueError("index must be >= 0; got: %s" % index_static)
-    if (index_static is not None and num_shards_static is not None and
-        index_static >= num_shards_static):
-      raise ValueError("index must be <= num_shards; %s is not < %s" %
-                       (index_static, num_shards_static))
-
-    def filter_fn(elem_index, _):
-      mod_result = math_ops.mod(elem_index, num_shards)
-      return math_ops.equal(mod_result, index)
-
-    return self._enumerate().filter(filter_fn).map(lambda _, elem: elem)
-
   def batch(self, batch_size, drop_remainder=False):
     """Combines consecutive elements of this dataset into batches.
 
@@ -1486,9 +1415,60 @@ class DatasetV1(DatasetV2):
   def skip(self, count):
     return DatasetV1Adapter(super(DatasetV1, self).skip(count))
 
-  @functools.wraps(DatasetV2.shard)
+  @deprecation.deprecated(
+      None, "Use `dataset.apply(tf.data.experimental.filter_for_shard(...))`.")
   def shard(self, num_shards, index):
-    return DatasetV1Adapter(super(DatasetV1, self).shard(num_shards, index))
+    """Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
+
+    This dataset operator is very useful when running distributed training, as
+    it allows each worker to read a unique subset.
+
+    When reading a single input file, you can skip elements as follows:
+
+    ```python
+    d = tf.data.TFRecordDataset(FLAGS.input_file)
+    d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
+    d = d.repeat(FLAGS.num_epochs)
+    d = d.shuffle(FLAGS.shuffle_buffer_size)
+    d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
+    ```
+
+    Important caveats:
+
+    - Be sure to shard before you use any randomizing operator (such as
+      shuffle).
+    - Generally it is best if the shard operator is used early in the dataset
+      pipeline. For example, when reading from a set of TFRecord files, shard
+      before converting the dataset to input samples. This avoids reading every
+      file on every worker. The following is an example of an efficient
+      sharding strategy within a complete pipeline:
+
+    ```python
+    d = Dataset.list_files(FLAGS.pattern)
+    d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
+    d = d.repeat(FLAGS.num_epochs)
+    d = d.shuffle(FLAGS.shuffle_buffer_size)
+    d = d.interleave(tf.data.TFRecordDataset,
+                     cycle_length=FLAGS.num_readers, block_length=1)
+    d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
+    ```
+
+    Args:
+      num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
+        shards operating in parallel.
+      index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
+
+    Returns:
+      Dataset: A `Dataset`.
+
+    Raises:
+      ValueError: if `num_shards` or `index` are illegal values. Note: error
+        checking is done on a best-effort basis, and errors aren't guaranteed
+        to be caught upon dataset creation. (e.g. providing in a placeholder
+        tensor bypasses the early checking, and will instead result in an error
+        during a session.run call.)
+    """
+    return self.apply(filter_for_shard_ops.filter_for_shard(num_shards, index))
 
   @functools.wraps(DatasetV2.batch)
   def batch(self, batch_size, drop_remainder=False):
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 5afbcec3a9e..ec438d00f0f 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -245,6 +245,7 @@ py_library(
     srcs = ["input_ops.py"],
     deps = [
         "//tensorflow/python:framework_ops",
+        "//tensorflow/python/data/experimental/ops:filter_for_shard_ops",
         "//tensorflow/python/data/util:nest",
     ],
 )
diff --git a/tensorflow/python/distribute/input_ops.py b/tensorflow/python/distribute/input_ops.py
index c40b2bf27a8..d7974942a14 100644
--- a/tensorflow/python/distribute/input_ops.py
+++ b/tensorflow/python/distribute/input_ops.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.data.experimental.ops import filter_for_shard_ops
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.ops import readers
 from tensorflow.python.data.util import nest
@@ -41,7 +42,8 @@ def auto_shard_dataset(dataset, num_shards, index):
     dataset: A `tf.data.Dataset` instance, typically the result of a bunch of
       dataset transformations.
     num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
-        shards operating in parallel. Same usage as in `Dataset.shard`.
+        shards operating in parallel. Same usage as in
+        `tf.data.experimental.filter_for_shard`.
     index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
       Same usage as in `Dataset.shard`.
 
@@ -74,9 +76,11 @@ def auto_shard_dataset(dataset, num_shards, index):
         # constructor. Eventually we will change all cases to clone datasets
         # instead of updating in-place.
         return dataset._clone(
-            filenames=dataset._filenames.shard(num_shards, index))
+            filenames=dataset._filenames.apply(
+                filter_for_shard_ops.filter_for_shard(num_shards, index)))
       elif isinstance(dataset, dataset_ops.RangeDataset):
-        return dataset.shard(num_shards, index)
+        return dataset.apply(
+            filter_for_shard_ops.filter_for_shard(num_shards, index))
       elif hasattr(dataset, "_map_func"):
         # TODO(priyag): Make this check more robust by enforcing some common
         # property on all map/flatmap/interleave datasets.
@@ -142,6 +146,7 @@ def auto_shard_dataset(dataset, num_shards, index):
     # TODO(priyag): This will shard the filenames before any shuffling of the
     # filename dataset. It might be desirable to shard after shuffling
     # filenames? If so, how do we achieve that?
-    return dataset.shard(num_shards, index)
+    return dataset.apply(
+        filter_for_shard_ops.filter_for_shard(num_shards, index))
 
   return _auto_shard_impl(dataset=dataset, found_reader_op=False)
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt
index 7bc3faaedc5..a3cb799fc3d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt
@@ -68,6 +68,10 @@ tf_module {
     name: "enumerate_dataset"
     argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
   }
+  member_method {
+    name: "filter_for_shard"
+    argspec: "args=[\'num_shards\', \'shard_index\'], varargs=None, keywords=None, defaults=None"
+  }
   member_method {
     name: "get_next_as_optional"
     argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
index 9394f4b767b..39a6e1ee717 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
@@ -97,10 +97,6 @@ tf_class {
     name: "repeat"
     argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
-  member_method {
-    name: "shard"
-    argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
-  }
   member_method {
     name: "shuffle"
     argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 8c32c773b7c..ef367238d04 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -100,10 +100,6 @@ tf_class {
     name: "repeat"
     argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
-  member_method {
-    name: "shard"
-    argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
-  }
   member_method {
     name: "shuffle"
     argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
index 9f32bce1093..a8fc6fbec1d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -99,10 +99,6 @@ tf_class {
     name: "repeat"
     argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
-  member_method {
-    name: "shard"
-    argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
-  }
   member_method {
     name: "shuffle"
     argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
index 0eedfdbfe13..697f3713444 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
@@ -100,10 +100,6 @@ tf_class {
     name: "repeat"
     argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
-  member_method {
-    name: "shard"
-    argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
-  }
   member_method {
     name: "shuffle"
     argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt
index 08214ec3cf7..17ac098910d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt
@@ -100,10 +100,6 @@ tf_class {
     name: "repeat"
     argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
-  member_method {
-    name: "shard"
-    argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
-  }
   member_method {
     name: "shuffle"
     argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt
index 608253298ee..f005d36e1a0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt
@@ -100,10 +100,6 @@ tf_class {
     name: "repeat"
     argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
-  member_method {
-    name: "shard"
-    argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
-  }
   member_method {
     name: "shuffle"
     argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt
index 3335eb1dc79..b0c0b73ad68 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt
@@ -100,10 +100,6 @@ tf_class {
     name: "repeat"
     argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
-  member_method {
-    name: "shard"
-    argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
-  }
   member_method {
     name: "shuffle"
     argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt
index 7bc3faaedc5..a3cb799fc3d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt
@@ -68,6 +68,10 @@ tf_module {
     name: "enumerate_dataset"
     argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
   }
+  member_method {
+    name: "filter_for_shard"
+    argspec: "args=[\'num_shards\', \'shard_index\'], varargs=None, keywords=None, defaults=None"
+  }
   member_method {
     name: "get_next_as_optional"
     argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"