diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 744c0b75ae5..eafee44d2ef 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -203,6 +203,7 @@ FRAMEWORK_PROTO_SRCS = [
     "//tensorflow/core/framework:model.proto",
     "//tensorflow/core/framework:node_def.proto",
     "//tensorflow/core/framework:op_def.proto",
+    "//tensorflow/core/framework:dataset_options.proto",
     "//tensorflow/core/framework:reader_base.proto",
     "//tensorflow/core/framework:remote_fused_graph_execute_info.proto",
     "//tensorflow/core/framework:resource_handle.proto",
diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD
index 12d637ad30a..30bf857163b 100644
--- a/tensorflow/core/framework/BUILD
+++ b/tensorflow/core/framework/BUILD
@@ -119,6 +119,7 @@ exports_files(
         "api_def.proto",
         "attr_value.proto",
         "cost_graph.proto",
+        "dataset_options.proto",
         "device_attributes.proto",
         "function.proto",
         "graph.proto",
@@ -1660,6 +1661,13 @@ tf_proto_library(
     make_default_target_header_only = True,
 )
 
+tf_proto_library(
+    name = "dataset_options_proto",
+    srcs = ["dataset_options.proto"],
+    cc_api_version = 2,
+    make_default_target_header_only = True,
+)
+
 tf_proto_library(
     name = "protos_all",
     cc_api_version = 2,
@@ -1678,6 +1686,7 @@ tf_proto_library(
         ":model_proto",
         ":node_def_proto",
         ":op_def_proto",
+        ":dataset_options_proto",
         ":reader_base_proto",
         ":remote_fused_graph_execute_info_proto",
         ":resource_handle_proto",
diff --git a/tensorflow/core/framework/dataset_options.proto b/tensorflow/core/framework/dataset_options.proto
new file mode 100644
index 00000000000..05e15e15625
--- /dev/null
+++ b/tensorflow/core/framework/dataset_options.proto
@@ -0,0 +1,179 @@
+syntax = "proto3";
+
+package tensorflow.data;
+
+// Represents the type of auto-sharding we enable.
+enum AutoShardPolicy {
+  AUTO = 0;
+  FILE = 1;
+  DATA = 2;
+  OFF = -1;
+}
+
+message DistributeOptions {
+  // The type of sharding that auto-shard should attempt. If this is set to
+  // FILE, then we will attempt to shard by files (each worker will get a set of
+  // files to process). If we cannot find a set of files to shard for at least
+  // one file per worker, we will error out. When this option is selected, make
+  // sure that you have enough files so that each worker gets at least one file.
+  // There will be a runtime error thrown if there are insufficient files. If
+  // this is set to DATA, then we will shard by elements produced by the
+  // dataset, and each worker will process the whole dataset and discard the
+  // portion that is not for itself. If this is set to OFF, then we will not
+  // autoshard, and each worker will receive a copy of the full dataset. This
+  // option is set to AUTO by default, AUTO will attempt to first shard by FILE,
+  // and fall back to sharding by DATA if we cannot find a set of files to
+  // shard.
+  AutoShardPolicy auto_shard_policy = 1;
+  // The number of devices attached to this input pipeline.
+  oneof optional_num_devices {
+    int32 num_devices = 2;
+  }
+}
+
+message MapVectorization {
+  // Whether to vectorize map transformations.
+  oneof optional_enabled {
+    bool enabled = 1;
+  }
+  // Whether to use ChooseFastestBranchDataset with this transformation. If
+  // True, the pipeline picks between the vectorized and original segment at
+  // runtime based on their iterations speed.
+  oneof optional_use_choose_fastest {
+    bool use_choose_fastest = 2;
+  }
+}
+
+message OptimizationOptions {
+  // Whether to apply default graph optimizations. If False, only graph
+  // optimizations that have been explicitly enabled will be applied.
+  oneof optional_apply_default_optimizations {
+    bool apply_default_optimizations = 1;
+  }
+  // Whether to automatically tune performance knobs.
+  oneof optional_autotune {
+    bool autotune = 2;
+  }
+  // When autotuning is enabled (through autotune), determines whether to also
+  // autotune buffer sizes for datasets with parallelism.
+  oneof optional_autotune_buffers {
+    bool autotune_buffers = 3;
+  }
+  // When autotuning is enabled (through autotune), determines the CPU budget to
+  // use. Values greater than the number of schedulable CPU cores are allowed
+  // but may result in CPU contention.
+  oneof optional_autotune_cpu_budget {
+    int32 autotune_cpu_budget = 4;
+  }
+  // When autotuning is enabled (through autotune), determines the RAM budget to
+  // use. Values greater than the available RAM in bytes may result in OOM. If
+  // 0, defaults to half of the available RAM in bytes.
+  oneof optional_autotune_ram_budget {
+    int32 autotune_ram_budget = 5;
+  }
+  // Whether to fuse filter transformations.
+  oneof optional_filter_fusion {
+    bool filter_fusion = 6;
+  }
+  // Whether to fuse filter dataset that predicts random_uniform < rate into a
+  // sampling dataset.
+  oneof optional_filter_with_random_uniform_fusion {
+    bool filter_with_random_uniform_fusion = 7;
+  }
+  // Whether to hoist tf.random_uniform() ops out of map transformations.
+  oneof optional_hoist_random_uniform {
+    bool hoist_random_uniform = 8;
+  }
+  // Whether to fuse map and batch transformations.
+  oneof optional_map_and_batch_fusion {
+    bool map_and_batch_fusion = 9;
+  }
+  // Whether to fuse map and filter transformations.
+  oneof optional_map_and_filter_fusion {
+    bool map_and_filter_fusion = 10;
+  }
+  // Whether to fuse map transformations.
+  oneof optional_map_fusion {
+    bool map_fusion = 11;
+  }
+  // Whether to parallelize stateless map transformations.
+  oneof optional_map_parallelization {
+    bool map_parallelization = 12;
+  }
+  // The map vectorization options associated with the dataset.
+  MapVectorization map_vectorization = 13;
+  // Whether to eliminate no-op transformations.
+  oneof optional_noop_elimination {
+    bool noop_elimination = 14;
+  }
+  // Whether to parallelize copying of batch elements. This optimization is
+  // highly experimental and can cause performance degradation (e.g. when the
+  // parallelization overhead exceeds the benefits of performing the data copies
+  // in parallel). You should only enable this optimization if a) your input
+  // pipeline is bottlenecked on batching and b) you have validated that this
+  // optimization improves performance.
+  oneof optional_parallel_batch {
+    bool parallel_batch = 15;
+  }
+  // Whether to reorder ops that will discard data to the front of unary
+  // cardinality preserving transformations, e.g. dataset.map(...).take(3) will
+  // be optimized to dataset.take(3).map(...). For now this optimization will
+  // move `skip`, `shard` and `take` to the front of `map` and `prefetch`. This
+  // optimization is only for performance; it will not affect the output of the
+  // dataset.
+  oneof optional_reorder_data_discarding_ops {
+    bool reorder_data_discarding_ops = 16;
+  }
+  // Whether to fuse shuffle and repeat transformations.
+  oneof optional_shuffle_and_repeat_fusion {
+    bool shuffle_and_repeat_fusion = 17;
+  }
+}
+
+message ThreadingOptions {
+  // If set, it overrides the maximum degree of intra-op parallelism.
+  oneof optional_max_intra_op_parallelism {
+    int32 max_intra_op_parallelism = 1;
+  }
+  // If set, the dataset will use a private threadpool of the given size.
+  oneof optional_private_threadpool_size {
+    int32 private_threadpool_size = 2;
+  }
+}
+
+// Represents how to handle external state during serialization.
+enum ExternalStatePolicy {
+  WARN = 0;
+  IGNORE = 1;
+  FAIL = 2;
+}
+
+// Message stored with Dataset objects to control how datasets are processed and
+// optimized.
+message Options {
+  // Whether the outputs need to be produced in deterministic order.
+  oneof optional_deterministic {
+    bool deterministic = 1;
+  }
+  // The distribution strategy options associated with the dataset.
+  DistributeOptions distribute_options = 2;
+  // The optimization options associated with the dataset.
+  OptimizationOptions optimization_options = 3;
+  // Whether to introduce 'slack' in the last `prefetch` of the input pipeline,
+  // if it exists. This may reduce CPU contention with accelerator host-side
+  // activity at the start of a step. The slack frequency is determined by the
+  // number of devices attached to this input pipeline.
+  oneof optional_slack {
+    bool slack = 4;
+  }
+  // The threading options associated with the dataset.
+  ThreadingOptions threading_options = 5;
+  // This option can be used to override the default policy for how to handle
+  // external state when serializing a dataset or checkpointing its iterator.
+  // There are three settings available - IGNORE: External state is ignored
+  // without a warning; WARN: External state is ignored and a warning is logged;
+  // FAIL: External state results in an error.
+  oneof optional_external_state_policy {
+    ExternalStatePolicy external_state_policy = 6;
+  }
+}
diff --git a/tensorflow/python/data/experimental/ops/distribute_options.py b/tensorflow/python/data/experimental/ops/distribute_options.py
index 82c498ff993..9a18528513d 100644
--- a/tensorflow/python/data/experimental/ops/distribute_options.py
+++ b/tensorflow/python/data/experimental/ops/distribute_options.py
@@ -19,6 +19,7 @@ from __future__ import print_function
 
 import enum
 
+from tensorflow.core.framework import dataset_options_pb2
 from tensorflow.python.data.util import options
 from tensorflow.python.util.tf_export import tf_export
 
@@ -35,6 +36,34 @@ class AutoShardPolicy(enum.IntEnum):
   FILE = 1
   DATA = 2
 
+  @classmethod
+  def _to_proto(cls, obj):
+    """Convert enum to proto."""
+    if obj == cls.OFF:
+      return dataset_options_pb2.AutoShardPolicy.OFF
+    if obj == cls.FILE:
+      return dataset_options_pb2.AutoShardPolicy.FILE
+    if obj == cls.DATA:
+      return dataset_options_pb2.AutoShardPolicy.DATA
+    if obj == cls.AUTO:
+      return dataset_options_pb2.AutoShardPolicy.AUTO
+    raise ValueError("%s._to_proto() is called with undefined enum %s." %
+                     (cls.__name__, obj.name))
+
+  @classmethod
+  def _from_proto(cls, pb):
+    """Convert proto to enum."""
+    if pb == dataset_options_pb2.AutoShardPolicy.OFF:
+      return cls.OFF
+    if pb == dataset_options_pb2.AutoShardPolicy.FILE:
+      return cls.FILE
+    if pb == dataset_options_pb2.AutoShardPolicy.DATA:
+      return cls.DATA
+    if pb == dataset_options_pb2.AutoShardPolicy.AUTO:
+      return cls.AUTO
+    raise ValueError("%s._from_proto() is called with undefined enum %s." %
+                     (cls.__name__, pb))
+
 
 @tf_export("data.experimental.ExternalStatePolicy")
 class ExternalStatePolicy(enum.Enum):
@@ -47,6 +76,30 @@ class ExternalStatePolicy(enum.Enum):
   IGNORE = 1
   FAIL = 2
 
+  @classmethod
+  def _to_proto(cls, obj):
+    """Convert enum to proto."""
+    if obj == cls.IGNORE:
+      return dataset_options_pb2.ExternalStatePolicy.IGNORE
+    if obj == cls.FAIL:
+      return dataset_options_pb2.ExternalStatePolicy.FAIL
+    if obj == cls.WARN:
+      return dataset_options_pb2.ExternalStatePolicy.WARN
+    raise ValueError("%s._to_proto() is called with undefined enum %s." %
+                     (cls.__name__, obj.name))
+
+  @classmethod
+  def _from_proto(cls, pb):
+    """Convert proto to enum."""
+    if pb == dataset_options_pb2.ExternalStatePolicy.IGNORE:
+      return cls.IGNORE
+    if pb == dataset_options_pb2.ExternalStatePolicy.FAIL:
+      return cls.FAIL
+    if pb == dataset_options_pb2.ExternalStatePolicy.WARN:
+      return cls.WARN
+    raise ValueError("%s._from_proto() is called with undefined enum %s." %
+                     (cls.__name__, pb))
+
 
 @tf_export("data.experimental.DistributeOptions")
 class DistributeOptions(options.OptionsBase):
@@ -89,3 +142,15 @@ class DistributeOptions(options.OptionsBase):
       docstring=
       "The number of devices attached to this input pipeline. This will be "
       "automatically set by MultiDeviceIterator.")
+
+  def _to_proto(self):
+    pb = dataset_options_pb2.DistributeOptions()
+    pb.auto_shard_policy = AutoShardPolicy._to_proto(self.auto_shard_policy)  # pylint: disable=protected-access
+    if self.num_devices is not None:
+      pb.num_devices = self.num_devices
+    return pb
+
+  def _from_proto(self, pb):
+    self.auto_shard_policy = AutoShardPolicy._from_proto(pb.auto_shard_policy)  # pylint: disable=protected-access
+    if pb.WhichOneof("optional_num_devices") is not None:
+      self.num_devices = pb.num_devices
diff --git a/tensorflow/python/data/experimental/ops/optimization_options.py b/tensorflow/python/data/experimental/ops/optimization_options.py
index 5c69855e15f..992ea647955 100644
--- a/tensorflow/python/data/experimental/ops/optimization_options.py
+++ b/tensorflow/python/data/experimental/ops/optimization_options.py
@@ -19,6 +19,7 @@ from __future__ import print_function
 
 import enum
 
+from tensorflow.core.framework import dataset_options_pb2
 from tensorflow.python.data.util import options
 from tensorflow.python.util.tf_export import tf_export
 
@@ -69,6 +70,20 @@ class MapVectorizationOptions(options.OptionsBase):
     else:
       return ["map_vectorization:use_choose_fastest:false"]
 
+  def _to_proto(self):
+    pb = dataset_options_pb2.MapVectorization()
+    if self.enabled is not None:
+      pb.enabled = self.enabled
+    if self.use_choose_fastest is not None:
+      pb.use_choose_fastest = self.use_choose_fastest
+    return pb
+
+  def _from_proto(self, pb):
+    if pb.WhichOneof("optional_enabled") is not None:
+      self.enabled = pb.enabled
+    if pb.WhichOneof("optional_use_choose_fastest") is not None:
+      self.use_choose_fastest = pb.use_choose_fastest
+
 
 @tf_export("data.experimental.OptimizationOptions")
 class OptimizationOptions(options.OptionsBase):
@@ -327,3 +342,77 @@ class OptimizationOptions(options.OptionsBase):
         graph_rewrite_configs.append(optimization + ":autotune:true")
 
     return graph_rewrite_configs
+
+  def _to_proto(self):
+    pb = dataset_options_pb2.OptimizationOptions()
+    if self.apply_default_optimizations is not None:
+      pb.apply_default_optimizations = self.apply_default_optimizations
+    if self.autotune is not None:
+      pb.autotune = self.autotune
+    if self.autotune_buffers is not None:
+      pb.autotune_buffers = self.autotune_buffers
+    if self.autotune_cpu_budget is not None:
+      pb.autotune_cpu_budget = self.autotune_cpu_budget
+    if self.autotune_ram_budget is not None:
+      pb.autotune_ram_budget = self.autotune_ram_budget
+    if self.filter_fusion is not None:
+      pb.filter_fusion = self.filter_fusion
+    if self.filter_with_random_uniform_fusion is not None:
+      pb.filter_with_random_uniform_fusion = (
+          self.filter_with_random_uniform_fusion)
+    if self.hoist_random_uniform is not None:
+      pb.hoist_random_uniform = self.hoist_random_uniform
+    if self.map_and_batch_fusion is not None:
+      pb.map_and_batch_fusion = self.map_and_batch_fusion
+    if self.map_and_filter_fusion is not None:
+      pb.map_and_filter_fusion = self.map_and_filter_fusion
+    if self.map_fusion is not None:
+      pb.map_fusion = self.map_fusion
+    if self.map_parallelization is not None:
+      pb.map_parallelization = self.map_parallelization
+    pb.map_vectorization.CopyFrom(self.map_vectorization._to_proto())  # pylint: disable=protected-access
+    if self.noop_elimination is not None:
+      pb.noop_elimination = self.noop_elimination
+    if self.parallel_batch is not None:
+      pb.parallel_batch = self.parallel_batch
+    if self.reorder_data_discarding_ops is not None:
+      pb.reorder_data_discarding_ops = self.reorder_data_discarding_ops
+    if self.shuffle_and_repeat_fusion is not None:
+      pb.shuffle_and_repeat_fusion = self.shuffle_and_repeat_fusion
+    return pb
+
+  def _from_proto(self, pb):
+    if pb.WhichOneof("optional_apply_default_optimizations") is not None:
+      self.apply_default_optimizations = pb.apply_default_optimizations
+    if pb.WhichOneof("optional_autotune") is not None:
+      self.autotune = pb.autotune
+    if pb.WhichOneof("optional_autotune_buffers") is not None:
+      self.autotune_buffers = pb.autotune_buffers
+    if pb.WhichOneof("optional_autotune_cpu_budget") is not None:
+      self.autotune_cpu_budget = pb.autotune_cpu_budget
+    if pb.WhichOneof("optional_autotune_ram_budget") is not None:
+      self.autotune_ram_budget = pb.autotune_ram_budget
+    if pb.WhichOneof("optional_filter_fusion") is not None:
+      self.filter_fusion = pb.filter_fusion
+    if pb.WhichOneof("optional_filter_with_random_uniform_fusion") is not None:
+      self.filter_with_random_uniform_fusion = (
+          pb.filter_with_random_uniform_fusion)
+    if pb.WhichOneof("optional_hoist_random_uniform") is not None:
+      self.hoist_random_uniform = pb.hoist_random_uniform
+    if pb.WhichOneof("optional_map_and_batch_fusion") is not None:
+      self.map_and_batch_fusion = pb.map_and_batch_fusion
+    if pb.WhichOneof("optional_map_and_filter_fusion") is not None:
+      self.map_and_filter_fusion = pb.map_and_filter_fusion
+    if pb.WhichOneof("optional_map_fusion") is not None:
+      self.map_fusion = pb.map_fusion
+    if pb.WhichOneof("optional_map_parallelization") is not None:
+      self.map_parallelization = pb.map_parallelization
+    self.map_vectorization._from_proto(pb.map_vectorization)  # pylint: disable=protected-access
+    if pb.WhichOneof("optional_noop_elimination") is not None:
+      self.noop_elimination = pb.noop_elimination
+    if pb.WhichOneof("optional_parallel_batch") is not None:
+      self.parallel_batch = pb.parallel_batch
+    if pb.WhichOneof("optional_reorder_data_discarding_ops") is not None:
+      self.reorder_data_discarding_ops = pb.reorder_data_discarding_ops
+    if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None:
+      self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion
diff --git a/tensorflow/python/data/experimental/ops/threading_options.py b/tensorflow/python/data/experimental/ops/threading_options.py
index d713b9ae075..39da39353d6 100644
--- a/tensorflow/python/data/experimental/ops/threading_options.py
+++ b/tensorflow/python/data/experimental/ops/threading_options.py
@@ -18,6 +18,7 @@ from __future__ import division
 from __future__ import print_function
 
 
+from tensorflow.core.framework import dataset_options_pb2
 from tensorflow.python.data.util import options
 from tensorflow.python.util.tf_export import tf_export
 
@@ -48,3 +49,17 @@ class ThreadingOptions(options.OptionsBase):
       ty=int,
       docstring=
       "If set, the dataset will use a private threadpool of the given size.")
+
+  def _to_proto(self):
+    pb = dataset_options_pb2.ThreadingOptions()
+    if self.max_intra_op_parallelism is not None:
+      pb.max_intra_op_parallelism = self.max_intra_op_parallelism
+    if self.private_threadpool_size is not None:
+      pb.private_threadpool_size = self.private_threadpool_size
+    return pb
+
+  def _from_proto(self, pb):
+    if pb.WhichOneof("optional_max_intra_op_parallelism") is not None:
+      self.max_intra_op_parallelism = pb.max_intra_op_parallelism
+    if pb.WhichOneof("optional_private_threadpool_size") is not None:
+      self.private_threadpool_size = pb.private_threadpool_size
diff --git a/tensorflow/python/data/kernel_tests/options_test.py b/tensorflow/python/data/kernel_tests/options_test.py
index 31220c69d9e..efd3f598a1f 100644
--- a/tensorflow/python/data/kernel_tests/options_test.py
+++ b/tensorflow/python/data/kernel_tests/options_test.py
@@ -23,6 +23,8 @@ import sys
 
 from absl.testing import parameterized
 
+from tensorflow.core.framework import dataset_options_pb2
+from tensorflow.python.data.experimental.ops import distribute_options
 from tensorflow.python.data.experimental.ops import optimization_options
 from tensorflow.python.data.experimental.ops import stats_options
 from tensorflow.python.data.experimental.ops import threading_options
@@ -127,6 +129,67 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
       result = result.concatenate(ds)
     self.assertDatasetProduces(result, [0]*1000)
 
+  @combinations.generate(test_base.default_test_combinations())
+  def testOptionsProtoRoundTrip(self):
+    options = dataset_ops.Options()
+    options.experimental_deterministic = True
+    options.experimental_external_state_policy = (
+        distribute_options.ExternalStatePolicy.FAIL)
+    options.experimental_distribute.auto_shard_policy = (
+        distribute_options.AutoShardPolicy.DATA)
+    options.experimental_distribute.num_devices = 1000
+    options.experimental_optimization.apply_default_optimizations = True
+    options.experimental_optimization.autotune = True
+    options.experimental_optimization.autotune_buffers = True
+    options.experimental_optimization.autotune_cpu_budget = 10
+    options.experimental_optimization.autotune_ram_budget = 20
+    options.experimental_optimization.filter_fusion = True
+    options.experimental_optimization.filter_with_random_uniform_fusion = True
+    options.experimental_optimization.hoist_random_uniform = True
+    options.experimental_optimization.map_and_batch_fusion = True
+    options.experimental_optimization.map_and_filter_fusion = True
+    options.experimental_optimization.map_fusion = True
+    options.experimental_optimization.map_parallelization = True
+    options.experimental_optimization.map_vectorization.enabled = True
+    options.experimental_optimization.map_vectorization.use_choose_fastest = (
+        True)
+    options.experimental_optimization.noop_elimination = True
+    options.experimental_optimization.parallel_batch = True
+    options.experimental_optimization.reorder_data_discarding_ops = True
+    options.experimental_optimization.shuffle_and_repeat_fusion = True
+    options.experimental_slack = True
+    options.experimental_threading.max_intra_op_parallelism = 30
+    options.experimental_threading.private_threadpool_size = 40
+    pb = options._to_proto()
+    result = dataset_ops.Options()
+    result._from_proto(pb)
+    self.assertEqual(options, result)
+
+  @combinations.generate(test_base.default_test_combinations())
+  def testOptionsProtoDefaultValuesRoundTrip(self):
+    options = dataset_ops.Options()
+    pb = options._to_proto()
+    result = dataset_ops.Options()
+    result._from_proto(pb)
+    self.assertEqual(options, result)
+
+  @combinations.generate(test_base.default_test_combinations())
+  def testProtoOptionsDefaultValuesRoundTrip(self):
+    pb = dataset_options_pb2.Options()
+    options = dataset_ops.Options()
+    options._from_proto(pb)
+    result = options._to_proto()
+    expected_pb = dataset_options_pb2.Options()
+    expected_pb.distribute_options.CopyFrom(
+        dataset_options_pb2.DistributeOptions())
+    expected_pb.optimization_options.CopyFrom(
+        dataset_options_pb2.OptimizationOptions())
+    expected_pb.optimization_options.map_vectorization.CopyFrom(
+        dataset_options_pb2.MapVectorization())
+    expected_pb.threading_options.CopyFrom(
+        dataset_options_pb2.ThreadingOptions())
+    self.assertProtoEquals(expected_pb, result)
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 9b5aa9f6dda..6497cb2143b 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -28,6 +28,7 @@ import numpy as np
 import six
 from six.moves import queue as Queue  # pylint: disable=redefined-builtin
 
+from tensorflow.core.framework import dataset_options_pb2
 from tensorflow.core.framework import graph_pb2
 from tensorflow.python import tf2
 from tensorflow.python.data.experimental.ops import distribute_options
@@ -3039,6 +3040,34 @@ class Options(options_lib.OptionsBase):
       "state is ignored and a warning is logged; FAIL: External state results "
       "in an error.")
 
+  def _to_proto(self):
+    pb = dataset_options_pb2.Options()
+    if self.experimental_deterministic is not None:
+      pb.deterministic = self.experimental_deterministic
+    pb.distribute_options.CopyFrom(self.experimental_distribute._to_proto())  # pylint: disable=protected-access
+    if self.experimental_external_state_policy is not None:
+      pb.external_state_policy = (
+          distribute_options.ExternalStatePolicy._to_proto(  # pylint: disable=protected-access
+              self.experimental_external_state_policy))
+    pb.optimization_options.CopyFrom(self.experimental_optimization._to_proto())  # pylint: disable=protected-access
+    if self.experimental_slack is not None:
+      pb.slack = self.experimental_slack
+    pb.threading_options.CopyFrom(self.experimental_threading._to_proto())  # pylint: disable=protected-access
+    return pb
+
+  def _from_proto(self, pb):
+    if pb.WhichOneof("optional_deterministic") is not None:
+      self.experimental_deterministic = pb.deterministic
+    self.experimental_distribute._from_proto(pb.distribute_options)  # pylint: disable=protected-access
+    if pb.WhichOneof("optional_external_state_policy") is not None:
+      self.experimental_external_state_policy = (
+          distribute_options.ExternalStatePolicy._from_proto(  # pylint: disable=protected-access
+              pb.external_state_policy))
+    self.experimental_optimization._from_proto(pb.optimization_options)  # pylint: disable=protected-access
+    if pb.WhichOneof("optional_slack") is not None:
+      self.experimental_slack = pb.slack
+    self.experimental_threading._from_proto(pb.threading_options)  # pylint: disable=protected-access
+
   def _graph_rewrites(self):
     """Produces lists of enabled, disabled, default static graph rewrites.
 
diff --git a/tensorflow/python/data/util/options.py b/tensorflow/python/data/util/options.py
index 8af773ed68b..3df6f000bb6 100644
--- a/tensorflow/python/data/util/options.py
+++ b/tensorflow/python/data/util/options.py
@@ -59,6 +59,14 @@ class OptionsBase(object):
       raise AttributeError(
           "Cannot set the property %s on %s." % (name, type(self).__name__))
 
+  def _to_proto(self):
+    """Convert options to protocol buffer."""
+    raise NotImplementedError("%s._to_proto()" % type(self).__name__)
+
+  def _from_proto(self, pb):
+    """Convert protocol buffer to options."""
+    raise NotImplementedError("%s._from_proto()" % type(self).__name__)
+
 
 # Creates a namedtuple with three keys for optimization graph rewrites settings.
 def graph_rewrites():