From eb001c716506815ebc3bcd5b1ac6eb2cadb6d244 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 5 Feb 2021 15:49:32 -0800 Subject: [PATCH] [tf.data] Add a new proto to store tf.data.Options data and conversion functions between tf.data.Options and their proto representation. PiperOrigin-RevId: 355945339 Change-Id: I3eba0c05899aeda629e34483b75fc61c435cb4d8 --- tensorflow/core/BUILD | 1 + tensorflow/core/framework/BUILD | 9 + .../core/framework/dataset_options.proto | 179 ++++++++++++++++++ .../experimental/ops/distribute_options.py | 65 +++++++ .../experimental/ops/optimization_options.py | 89 +++++++++ .../experimental/ops/threading_options.py | 15 ++ .../python/data/kernel_tests/options_test.py | 63 ++++++ tensorflow/python/data/ops/dataset_ops.py | 29 +++ tensorflow/python/data/util/options.py | 8 + 9 files changed, 458 insertions(+) create mode 100644 tensorflow/core/framework/dataset_options.proto diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 744c0b75ae5..eafee44d2ef 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -203,6 +203,7 @@ FRAMEWORK_PROTO_SRCS = [ "//tensorflow/core/framework:model.proto", "//tensorflow/core/framework:node_def.proto", "//tensorflow/core/framework:op_def.proto", + "//tensorflow/core/framework:dataset_options.proto", "//tensorflow/core/framework:reader_base.proto", "//tensorflow/core/framework:remote_fused_graph_execute_info.proto", "//tensorflow/core/framework:resource_handle.proto", diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 12d637ad30a..30bf857163b 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -119,6 +119,7 @@ exports_files( "api_def.proto", "attr_value.proto", "cost_graph.proto", + "dataset_options.proto", "device_attributes.proto", "function.proto", "graph.proto", @@ -1660,6 +1661,13 @@ tf_proto_library( make_default_target_header_only = True, ) +tf_proto_library( + name = "dataset_options_proto", + srcs = ["dataset_options.proto"], + cc_api_version = 2, + make_default_target_header_only = True, +) + tf_proto_library( name = "protos_all", cc_api_version = 2, @@ -1678,6 +1686,7 @@ tf_proto_library( ":model_proto", ":node_def_proto", ":op_def_proto", + ":dataset_options_proto", ":reader_base_proto", ":remote_fused_graph_execute_info_proto", ":resource_handle_proto", diff --git a/tensorflow/core/framework/dataset_options.proto b/tensorflow/core/framework/dataset_options.proto new file mode 100644 index 00000000000..05e15e15625 --- /dev/null +++ b/tensorflow/core/framework/dataset_options.proto @@ -0,0 +1,179 @@ +syntax = "proto3"; + +package tensorflow.data; + +// Represents the type of auto-sharding we enable. +enum AutoShardPolicy { + AUTO = 0; + FILE = 1; + DATA = 2; + OFF = -1; +} + +message DistributeOptions { + // The type of sharding that auto-shard should attempt. If this is set to + // FILE, then we will attempt to shard by files (each worker will get a set of + // files to process). If we cannot find a set of files to shard for at least + // one file per worker, we will error out. When this option is selected, make + // sure that you have enough files so that each worker gets at least one file. + // There will be a runtime error thrown if there are insufficient files. If + // this is set to DATA, then we will shard by elements produced by the + // dataset, and each worker will process the whole dataset and discard the + // portion that is not for itself. If this is set to OFF, then we will not + // autoshard, and each worker will receive a copy of the full dataset. This + // option is set to AUTO by default, AUTO will attempt to first shard by FILE, + // and fall back to sharding by DATA if we cannot find a set of files to + // shard. + AutoShardPolicy auto_shard_policy = 1; + // The number of devices attached to this input pipeline. + oneof optional_num_devices { + int32 num_devices = 2; + } +} + +message MapVectorization { + // Whether to vectorize map transformations. + oneof optional_enabled { + bool enabled = 1; + } + // Whether to use ChooseFastestBranchDataset with this transformation. If + // True, the pipeline picks between the vectorized and original segment at + // runtime based on their iterations speed. + oneof optional_use_choose_fastest { + bool use_choose_fastest = 2; + } +} + +message OptimizationOptions { + // Whether to apply default graph optimizations. If False, only graph + // optimizations that have been explicitly enabled will be applied. + oneof optional_apply_default_optimizations { + bool apply_default_optimizations = 1; + } + // Whether to automatically tune performance knobs. + oneof optional_autotune { + bool autotune = 2; + } + // When autotuning is enabled (through autotune), determines whether to also + // autotune buffer sizes for datasets with parallelism. + oneof optional_autotune_buffers { + bool autotune_buffers = 3; + } + // When autotuning is enabled (through autotune), determines the CPU budget to + // use. Values greater than the number of schedulable CPU cores are allowed + // but may result in CPU contention. + oneof optional_autotune_cpu_budget { + int32 autotune_cpu_budget = 4; + } + // When autotuning is enabled (through autotune), determines the RAM budget to + // use. Values greater than the available RAM in bytes may result in OOM. If + // 0, defaults to half of the available RAM in bytes. + oneof optional_autotune_ram_budget { + int32 autotune_ram_budget = 5; + } + // Whether to fuse filter transformations. + oneof optional_filter_fusion { + bool filter_fusion = 6; + } + // Whether to fuse filter dataset that predicts random_uniform < rate into a + // sampling dataset. + oneof optional_filter_with_random_uniform_fusion { + bool filter_with_random_uniform_fusion = 7; + } + // Whether to hoist tf.random_uniform() ops out of map transformations. + oneof optional_hoist_random_uniform { + bool hoist_random_uniform = 8; + } + // Whether to fuse map and batch transformations. + oneof optional_map_and_batch_fusion { + bool map_and_batch_fusion = 9; + } + // Whether to fuse map and filter transformations. + oneof optional_map_and_filter_fusion { + bool map_and_filter_fusion = 10; + } + // Whether to fuse map transformations. + oneof optional_map_fusion { + bool map_fusion = 11; + } + // Whether to parallelize stateless map transformations. + oneof optional_map_parallelization { + bool map_parallelization = 12; + } + // The map vectorization options associated with the dataset. + MapVectorization map_vectorization = 13; + // Whether to eliminate no-op transformations. + oneof optional_noop_elimination { + bool noop_elimination = 14; + } + // Whether to parallelize copying of batch elements. This optimization is + // highly experimental and can cause performance degradation (e.g. when the + // parallelization overhead exceeds the benefits of performing the data copies + // in parallel). You should only enable this optimization if a) your input + // pipeline is bottlenecked on batching and b) you have validated that this + // optimization improves performance. + oneof optional_parallel_batch { + bool parallel_batch = 15; + } + // Whether to reorder ops that will discard data to the front of unary + // cardinality preserving transformations, e.g. dataset.map(...).take(3) will + // be optimized to dataset.take(3).map(...). For now this optimization will + // move `skip`, `shard` and `take` to the front of `map` and `prefetch`. This + // optimization is only for performance; it will not affect the output of the + // dataset. + oneof optional_reorder_data_discarding_ops { + bool reorder_data_discarding_ops = 16; + } + // Whether to fuse shuffle and repeat transformations. + oneof optional_shuffle_and_repeat_fusion { + bool shuffle_and_repeat_fusion = 17; + } +} + +message ThreadingOptions { + // If set, it overrides the maximum degree of intra-op parallelism. + oneof optional_max_intra_op_parallelism { + int32 max_intra_op_parallelism = 1; + } + // If set, the dataset will use a private threadpool of the given size. + oneof optional_private_threadpool_size { + int32 private_threadpool_size = 2; + } +} + +// Represents how to handle external state during serialization. +enum ExternalStatePolicy { + WARN = 0; + IGNORE = 1; + FAIL = 2; +} + +// Message stored with Dataset objects to control how datasets are processed and +// optimized. +message Options { + // Whether the outputs need to be produced in deterministic order. + oneof optional_deterministic { + bool deterministic = 1; + } + // The distribution strategy options associated with the dataset. + DistributeOptions distribute_options = 2; + // The optimization options associated with the dataset. + OptimizationOptions optimization_options = 3; + // Whether to introduce 'slack' in the last `prefetch` of the input pipeline, + // if it exists. This may reduce CPU contention with accelerator host-side + // activity at the start of a step. The slack frequency is determined by the + // number of devices attached to this input pipeline. + oneof optional_slack { + bool slack = 4; + } + // The threading options associated with the dataset. + ThreadingOptions threading_options = 5; + // This option can be used to override the default policy for how to handle + // external state when serializing a dataset or checkpointing its iterator. + // There are three settings available - IGNORE: External state is ignored + // without a warning; WARN: External state is ignored and a warning is logged; + // FAIL: External state results in an error. + oneof optional_external_state_policy { + ExternalStatePolicy external_state_policy = 6; + } +} diff --git a/tensorflow/python/data/experimental/ops/distribute_options.py b/tensorflow/python/data/experimental/ops/distribute_options.py index 82c498ff993..9a18528513d 100644 --- a/tensorflow/python/data/experimental/ops/distribute_options.py +++ b/tensorflow/python/data/experimental/ops/distribute_options.py @@ -19,6 +19,7 @@ from __future__ import print_function import enum +from tensorflow.core.framework import dataset_options_pb2 from tensorflow.python.data.util import options from tensorflow.python.util.tf_export import tf_export @@ -35,6 +36,34 @@ class AutoShardPolicy(enum.IntEnum): FILE = 1 DATA = 2 + @classmethod + def _to_proto(cls, obj): + """Convert enum to proto.""" + if obj == cls.OFF: + return dataset_options_pb2.AutoShardPolicy.OFF + if obj == cls.FILE: + return dataset_options_pb2.AutoShardPolicy.FILE + if obj == cls.DATA: + return dataset_options_pb2.AutoShardPolicy.DATA + if obj == cls.AUTO: + return dataset_options_pb2.AutoShardPolicy.AUTO + raise ValueError("%s._to_proto() is called with undefined enum %s." % + (cls.__name__, obj.name)) + + @classmethod + def _from_proto(cls, pb): + """Convert proto to enum.""" + if pb == dataset_options_pb2.AutoShardPolicy.OFF: + return cls.OFF + if pb == dataset_options_pb2.AutoShardPolicy.FILE: + return cls.FILE + if pb == dataset_options_pb2.AutoShardPolicy.DATA: + return cls.DATA + if pb == dataset_options_pb2.AutoShardPolicy.AUTO: + return cls.AUTO + raise ValueError("%s._from_proto() is called with undefined enum %s." % + (cls.__name__, pb)) + @tf_export("data.experimental.ExternalStatePolicy") class ExternalStatePolicy(enum.Enum): @@ -47,6 +76,30 @@ class ExternalStatePolicy(enum.Enum): IGNORE = 1 FAIL = 2 + @classmethod + def _to_proto(cls, obj): + """Convert enum to proto.""" + if obj == cls.IGNORE: + return dataset_options_pb2.ExternalStatePolicy.IGNORE + if obj == cls.FAIL: + return dataset_options_pb2.ExternalStatePolicy.FAIL + if obj == cls.WARN: + return dataset_options_pb2.ExternalStatePolicy.WARN + raise ValueError("%s._to_proto() is called with undefined enum %s." % + (cls.__name__, obj.name)) + + @classmethod + def _from_proto(cls, pb): + """Convert proto to enum.""" + if pb == dataset_options_pb2.ExternalStatePolicy.IGNORE: + return cls.IGNORE + if pb == dataset_options_pb2.ExternalStatePolicy.FAIL: + return cls.FAIL + if pb == dataset_options_pb2.ExternalStatePolicy.WARN: + return cls.WARN + raise ValueError("%s._from_proto() is called with undefined enum %s." % + (cls.__name__, pb)) + @tf_export("data.experimental.DistributeOptions") class DistributeOptions(options.OptionsBase): @@ -89,3 +142,15 @@ class DistributeOptions(options.OptionsBase): docstring= "The number of devices attached to this input pipeline. This will be " "automatically set by MultiDeviceIterator.") + + def _to_proto(self): + pb = dataset_options_pb2.DistributeOptions() + pb.auto_shard_policy = AutoShardPolicy._to_proto(self.auto_shard_policy) # pylint: disable=protected-access + if self.num_devices is not None: + pb.num_devices = self.num_devices + return pb + + def _from_proto(self, pb): + self.auto_shard_policy = AutoShardPolicy._from_proto(pb.auto_shard_policy) # pylint: disable=protected-access + if pb.WhichOneof("optional_num_devices") is not None: + self.num_devices = pb.num_devices diff --git a/tensorflow/python/data/experimental/ops/optimization_options.py b/tensorflow/python/data/experimental/ops/optimization_options.py index 5c69855e15f..992ea647955 100644 --- a/tensorflow/python/data/experimental/ops/optimization_options.py +++ b/tensorflow/python/data/experimental/ops/optimization_options.py @@ -19,6 +19,7 @@ from __future__ import print_function import enum +from tensorflow.core.framework import dataset_options_pb2 from tensorflow.python.data.util import options from tensorflow.python.util.tf_export import tf_export @@ -69,6 +70,20 @@ class MapVectorizationOptions(options.OptionsBase): else: return ["map_vectorization:use_choose_fastest:false"] + def _to_proto(self): + pb = dataset_options_pb2.MapVectorization() + if self.enabled is not None: + pb.enabled = self.enabled + if self.use_choose_fastest is not None: + pb.use_choose_fastest = self.use_choose_fastest + return pb + + def _from_proto(self, pb): + if pb.WhichOneof("optional_enabled") is not None: + self.enabled = pb.enabled + if pb.WhichOneof("optional_use_choose_fastest") is not None: + self.use_choose_fastest = pb.use_choose_fastest + @tf_export("data.experimental.OptimizationOptions") class OptimizationOptions(options.OptionsBase): @@ -327,3 +342,77 @@ class OptimizationOptions(options.OptionsBase): graph_rewrite_configs.append(optimization + ":autotune:true") return graph_rewrite_configs + + def _to_proto(self): + pb = dataset_options_pb2.OptimizationOptions() + if self.apply_default_optimizations is not None: + pb.apply_default_optimizations = self.apply_default_optimizations + if self.autotune is not None: + pb.autotune = self.autotune + if self.autotune_buffers is not None: + pb.autotune_buffers = self.autotune_buffers + if self.autotune_cpu_budget is not None: + pb.autotune_cpu_budget = self.autotune_cpu_budget + if self.autotune_ram_budget is not None: + pb.autotune_ram_budget = self.autotune_ram_budget + if self.filter_fusion is not None: + pb.filter_fusion = self.filter_fusion + if self.filter_with_random_uniform_fusion is not None: + pb.filter_with_random_uniform_fusion = ( + self.filter_with_random_uniform_fusion) + if self.hoist_random_uniform is not None: + pb.hoist_random_uniform = self.hoist_random_uniform + if self.map_and_batch_fusion is not None: + pb.map_and_batch_fusion = self.map_and_batch_fusion + if self.map_and_filter_fusion is not None: + pb.map_and_filter_fusion = self.map_and_filter_fusion + if self.map_fusion is not None: + pb.map_fusion = self.map_fusion + if self.map_parallelization is not None: + pb.map_parallelization = self.map_parallelization + pb.map_vectorization.CopyFrom(self.map_vectorization._to_proto()) # pylint: disable=protected-access + if self.noop_elimination is not None: + pb.noop_elimination = self.noop_elimination + if self.parallel_batch is not None: + pb.parallel_batch = self.parallel_batch + if self.reorder_data_discarding_ops is not None: + pb.reorder_data_discarding_ops = self.reorder_data_discarding_ops + if self.shuffle_and_repeat_fusion is not None: + pb.shuffle_and_repeat_fusion = self.shuffle_and_repeat_fusion + return pb + + def _from_proto(self, pb): + if pb.WhichOneof("optional_apply_default_optimizations") is not None: + self.apply_default_optimizations = pb.apply_default_optimizations + if pb.WhichOneof("optional_autotune") is not None: + self.autotune = pb.autotune + if pb.WhichOneof("optional_autotune_buffers") is not None: + self.autotune_buffers = pb.autotune_buffers + if pb.WhichOneof("optional_autotune_cpu_budget") is not None: + self.autotune_cpu_budget = pb.autotune_cpu_budget + if pb.WhichOneof("optional_autotune_ram_budget") is not None: + self.autotune_ram_budget = pb.autotune_ram_budget + if pb.WhichOneof("optional_filter_fusion") is not None: + self.filter_fusion = pb.filter_fusion + if pb.WhichOneof("optional_filter_with_random_uniform_fusion") is not None: + self.filter_with_random_uniform_fusion = ( + pb.filter_with_random_uniform_fusion) + if pb.WhichOneof("optional_hoist_random_uniform") is not None: + self.hoist_random_uniform = pb.hoist_random_uniform + if pb.WhichOneof("optional_map_and_batch_fusion") is not None: + self.map_and_batch_fusion = pb.map_and_batch_fusion + if pb.WhichOneof("optional_map_and_filter_fusion") is not None: + self.map_and_filter_fusion = pb.map_and_filter_fusion + if pb.WhichOneof("optional_map_fusion") is not None: + self.map_fusion = pb.map_fusion + if pb.WhichOneof("optional_map_parallelization") is not None: + self.map_parallelization = pb.map_parallelization + self.map_vectorization._from_proto(pb.map_vectorization) # pylint: disable=protected-access + if pb.WhichOneof("optional_noop_elimination") is not None: + self.noop_elimination = pb.noop_elimination + if pb.WhichOneof("optional_parallel_batch") is not None: + self.parallel_batch = pb.parallel_batch + if pb.WhichOneof("optional_reorder_data_discarding_ops") is not None: + self.reorder_data_discarding_ops = pb.reorder_data_discarding_ops + if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None: + self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion diff --git a/tensorflow/python/data/experimental/ops/threading_options.py b/tensorflow/python/data/experimental/ops/threading_options.py index d713b9ae075..39da39353d6 100644 --- a/tensorflow/python/data/experimental/ops/threading_options.py +++ b/tensorflow/python/data/experimental/ops/threading_options.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +from tensorflow.core.framework import dataset_options_pb2 from tensorflow.python.data.util import options from tensorflow.python.util.tf_export import tf_export @@ -48,3 +49,17 @@ class ThreadingOptions(options.OptionsBase): ty=int, docstring= "If set, the dataset will use a private threadpool of the given size.") + + def _to_proto(self): + pb = dataset_options_pb2.ThreadingOptions() + if self.max_intra_op_parallelism is not None: + pb.max_intra_op_parallelism = self.max_intra_op_parallelism + if self.private_threadpool_size is not None: + pb.private_threadpool_size = self.private_threadpool_size + return pb + + def _from_proto(self, pb): + if pb.WhichOneof("optional_max_intra_op_parallelism") is not None: + self.max_intra_op_parallelism = pb.max_intra_op_parallelism + if pb.WhichOneof("optional_private_threadpool_size") is not None: + self.private_threadpool_size = pb.private_threadpool_size diff --git a/tensorflow/python/data/kernel_tests/options_test.py b/tensorflow/python/data/kernel_tests/options_test.py index 31220c69d9e..efd3f598a1f 100644 --- a/tensorflow/python/data/kernel_tests/options_test.py +++ b/tensorflow/python/data/kernel_tests/options_test.py @@ -23,6 +23,8 @@ import sys from absl.testing import parameterized +from tensorflow.core.framework import dataset_options_pb2 +from tensorflow.python.data.experimental.ops import distribute_options from tensorflow.python.data.experimental.ops import optimization_options from tensorflow.python.data.experimental.ops import stats_options from tensorflow.python.data.experimental.ops import threading_options @@ -127,6 +129,67 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase): result = result.concatenate(ds) self.assertDatasetProduces(result, [0]*1000) + @combinations.generate(test_base.default_test_combinations()) + def testOptionsProtoRoundTrip(self): + options = dataset_ops.Options() + options.experimental_deterministic = True + options.experimental_external_state_policy = ( + distribute_options.ExternalStatePolicy.FAIL) + options.experimental_distribute.auto_shard_policy = ( + distribute_options.AutoShardPolicy.DATA) + options.experimental_distribute.num_devices = 1000 + options.experimental_optimization.apply_default_optimizations = True + options.experimental_optimization.autotune = True + options.experimental_optimization.autotune_buffers = True + options.experimental_optimization.autotune_cpu_budget = 10 + options.experimental_optimization.autotune_ram_budget = 20 + options.experimental_optimization.filter_fusion = True + options.experimental_optimization.filter_with_random_uniform_fusion = True + options.experimental_optimization.hoist_random_uniform = True + options.experimental_optimization.map_and_batch_fusion = True + options.experimental_optimization.map_and_filter_fusion = True + options.experimental_optimization.map_fusion = True + options.experimental_optimization.map_parallelization = True + options.experimental_optimization.map_vectorization.enabled = True + options.experimental_optimization.map_vectorization.use_choose_fastest = ( + True) + options.experimental_optimization.noop_elimination = True + options.experimental_optimization.parallel_batch = True + options.experimental_optimization.reorder_data_discarding_ops = True + options.experimental_optimization.shuffle_and_repeat_fusion = True + options.experimental_slack = True + options.experimental_threading.max_intra_op_parallelism = 30 + options.experimental_threading.private_threadpool_size = 40 + pb = options._to_proto() + result = dataset_ops.Options() + result._from_proto(pb) + self.assertEqual(options, result) + + @combinations.generate(test_base.default_test_combinations()) + def testOptionsProtoDefaultValuesRoundTrip(self): + options = dataset_ops.Options() + pb = options._to_proto() + result = dataset_ops.Options() + result._from_proto(pb) + self.assertEqual(options, result) + + @combinations.generate(test_base.default_test_combinations()) + def testProtoOptionsDefaultValuesRoundTrip(self): + pb = dataset_options_pb2.Options() + options = dataset_ops.Options() + options._from_proto(pb) + result = options._to_proto() + expected_pb = dataset_options_pb2.Options() + expected_pb.distribute_options.CopyFrom( + dataset_options_pb2.DistributeOptions()) + expected_pb.optimization_options.CopyFrom( + dataset_options_pb2.OptimizationOptions()) + expected_pb.optimization_options.map_vectorization.CopyFrom( + dataset_options_pb2.MapVectorization()) + expected_pb.threading_options.CopyFrom( + dataset_options_pb2.ThreadingOptions()) + self.assertProtoEquals(expected_pb, result) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 9b5aa9f6dda..6497cb2143b 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -28,6 +28,7 @@ import numpy as np import six from six.moves import queue as Queue # pylint: disable=redefined-builtin +from tensorflow.core.framework import dataset_options_pb2 from tensorflow.core.framework import graph_pb2 from tensorflow.python import tf2 from tensorflow.python.data.experimental.ops import distribute_options @@ -3039,6 +3040,34 @@ class Options(options_lib.OptionsBase): "state is ignored and a warning is logged; FAIL: External state results " "in an error.") + def _to_proto(self): + pb = dataset_options_pb2.Options() + if self.experimental_deterministic is not None: + pb.deterministic = self.experimental_deterministic + pb.distribute_options.CopyFrom(self.experimental_distribute._to_proto()) # pylint: disable=protected-access + if self.experimental_external_state_policy is not None: + pb.external_state_policy = ( + distribute_options.ExternalStatePolicy._to_proto( # pylint: disable=protected-access + self.experimental_external_state_policy)) + pb.optimization_options.CopyFrom(self.experimental_optimization._to_proto()) # pylint: disable=protected-access + if self.experimental_slack is not None: + pb.slack = self.experimental_slack + pb.threading_options.CopyFrom(self.experimental_threading._to_proto()) # pylint: disable=protected-access + return pb + + def _from_proto(self, pb): + if pb.WhichOneof("optional_deterministic") is not None: + self.experimental_deterministic = pb.deterministic + self.experimental_distribute._from_proto(pb.distribute_options) # pylint: disable=protected-access + if pb.WhichOneof("optional_external_state_policy") is not None: + self.experimental_external_state_policy = ( + distribute_options.ExternalStatePolicy._from_proto( # pylint: disable=protected-access + pb.external_state_policy)) + self.experimental_optimization._from_proto(pb.optimization_options) # pylint: disable=protected-access + if pb.WhichOneof("optional_slack") is not None: + self.experimental_slack = pb.slack + self.experimental_threading._from_proto(pb.threading_options) # pylint: disable=protected-access + def _graph_rewrites(self): """Produces lists of enabled, disabled, default static graph rewrites. diff --git a/tensorflow/python/data/util/options.py b/tensorflow/python/data/util/options.py index 8af773ed68b..3df6f000bb6 100644 --- a/tensorflow/python/data/util/options.py +++ b/tensorflow/python/data/util/options.py @@ -59,6 +59,14 @@ class OptionsBase(object): raise AttributeError( "Cannot set the property %s on %s." % (name, type(self).__name__)) + def _to_proto(self): + """Convert options to protocol buffer.""" + raise NotImplementedError("%s._to_proto()" % type(self).__name__) + + def _from_proto(self, pb): + """Convert protocol buffer to options.""" + raise NotImplementedError("%s._from_proto()" % type(self).__name__) + # Creates a namedtuple with three keys for optimization graph rewrites settings. def graph_rewrites():