[tf.data] Add a new proto to store tf.data.Options data and conversion functions between tf.data.Options and their proto representation.

PiperOrigin-RevId: 355945339
Change-Id: I3eba0c05899aeda629e34483b75fc61c435cb4d8
This commit is contained in:
A. Unique TensorFlower 2021-02-05 15:49:32 -08:00 committed by TensorFlower Gardener
parent 7dbb6cf44a
commit eb001c7165
9 changed files with 458 additions and 0 deletions

View File

@ -203,6 +203,7 @@ FRAMEWORK_PROTO_SRCS = [
"//tensorflow/core/framework:model.proto",
"//tensorflow/core/framework:node_def.proto",
"//tensorflow/core/framework:op_def.proto",
"//tensorflow/core/framework:dataset_options.proto",
"//tensorflow/core/framework:reader_base.proto",
"//tensorflow/core/framework:remote_fused_graph_execute_info.proto",
"//tensorflow/core/framework:resource_handle.proto",

View File

@ -119,6 +119,7 @@ exports_files(
"api_def.proto",
"attr_value.proto",
"cost_graph.proto",
"dataset_options.proto",
"device_attributes.proto",
"function.proto",
"graph.proto",
@ -1660,6 +1661,13 @@ tf_proto_library(
make_default_target_header_only = True,
)
tf_proto_library(
name = "dataset_options_proto",
srcs = ["dataset_options.proto"],
cc_api_version = 2,
make_default_target_header_only = True,
)
tf_proto_library(
name = "protos_all",
cc_api_version = 2,
@ -1678,6 +1686,7 @@ tf_proto_library(
":model_proto",
":node_def_proto",
":op_def_proto",
":dataset_options_proto",
":reader_base_proto",
":remote_fused_graph_execute_info_proto",
":resource_handle_proto",

View File

@ -0,0 +1,179 @@
syntax = "proto3";
package tensorflow.data;
// Represents the type of auto-sharding we enable.
enum AutoShardPolicy {
AUTO = 0;
FILE = 1;
DATA = 2;
OFF = -1;
}
message DistributeOptions {
// The type of sharding that auto-shard should attempt. If this is set to
// FILE, then we will attempt to shard by files (each worker will get a set of
// files to process). If we cannot find a set of files to shard for at least
// one file per worker, we will error out. When this option is selected, make
// sure that you have enough files so that each worker gets at least one file.
// There will be a runtime error thrown if there are insufficient files. If
// this is set to DATA, then we will shard by elements produced by the
// dataset, and each worker will process the whole dataset and discard the
// portion that is not for itself. If this is set to OFF, then we will not
// autoshard, and each worker will receive a copy of the full dataset. This
// option is set to AUTO by default, AUTO will attempt to first shard by FILE,
// and fall back to sharding by DATA if we cannot find a set of files to
// shard.
AutoShardPolicy auto_shard_policy = 1;
// The number of devices attached to this input pipeline.
oneof optional_num_devices {
int32 num_devices = 2;
}
}
message MapVectorization {
// Whether to vectorize map transformations.
oneof optional_enabled {
bool enabled = 1;
}
// Whether to use ChooseFastestBranchDataset with this transformation. If
// True, the pipeline picks between the vectorized and original segment at
// runtime based on their iterations speed.
oneof optional_use_choose_fastest {
bool use_choose_fastest = 2;
}
}
message OptimizationOptions {
// Whether to apply default graph optimizations. If False, only graph
// optimizations that have been explicitly enabled will be applied.
oneof optional_apply_default_optimizations {
bool apply_default_optimizations = 1;
}
// Whether to automatically tune performance knobs.
oneof optional_autotune {
bool autotune = 2;
}
// When autotuning is enabled (through autotune), determines whether to also
// autotune buffer sizes for datasets with parallelism.
oneof optional_autotune_buffers {
bool autotune_buffers = 3;
}
// When autotuning is enabled (through autotune), determines the CPU budget to
// use. Values greater than the number of schedulable CPU cores are allowed
// but may result in CPU contention.
oneof optional_autotune_cpu_budget {
int32 autotune_cpu_budget = 4;
}
// When autotuning is enabled (through autotune), determines the RAM budget to
// use. Values greater than the available RAM in bytes may result in OOM. If
// 0, defaults to half of the available RAM in bytes.
oneof optional_autotune_ram_budget {
int32 autotune_ram_budget = 5;
}
// Whether to fuse filter transformations.
oneof optional_filter_fusion {
bool filter_fusion = 6;
}
// Whether to fuse filter dataset that predicts random_uniform < rate into a
// sampling dataset.
oneof optional_filter_with_random_uniform_fusion {
bool filter_with_random_uniform_fusion = 7;
}
// Whether to hoist tf.random_uniform() ops out of map transformations.
oneof optional_hoist_random_uniform {
bool hoist_random_uniform = 8;
}
// Whether to fuse map and batch transformations.
oneof optional_map_and_batch_fusion {
bool map_and_batch_fusion = 9;
}
// Whether to fuse map and filter transformations.
oneof optional_map_and_filter_fusion {
bool map_and_filter_fusion = 10;
}
// Whether to fuse map transformations.
oneof optional_map_fusion {
bool map_fusion = 11;
}
// Whether to parallelize stateless map transformations.
oneof optional_map_parallelization {
bool map_parallelization = 12;
}
// The map vectorization options associated with the dataset.
MapVectorization map_vectorization = 13;
// Whether to eliminate no-op transformations.
oneof optional_noop_elimination {
bool noop_elimination = 14;
}
// Whether to parallelize copying of batch elements. This optimization is
// highly experimental and can cause performance degradation (e.g. when the
// parallelization overhead exceeds the benefits of performing the data copies
// in parallel). You should only enable this optimization if a) your input
// pipeline is bottlenecked on batching and b) you have validated that this
// optimization improves performance.
oneof optional_parallel_batch {
bool parallel_batch = 15;
}
// Whether to reorder ops that will discard data to the front of unary
// cardinality preserving transformations, e.g. dataset.map(...).take(3) will
// be optimized to dataset.take(3).map(...). For now this optimization will
// move `skip`, `shard` and `take` to the front of `map` and `prefetch`. This
// optimization is only for performance; it will not affect the output of the
// dataset.
oneof optional_reorder_data_discarding_ops {
bool reorder_data_discarding_ops = 16;
}
// Whether to fuse shuffle and repeat transformations.
oneof optional_shuffle_and_repeat_fusion {
bool shuffle_and_repeat_fusion = 17;
}
}
message ThreadingOptions {
// If set, it overrides the maximum degree of intra-op parallelism.
oneof optional_max_intra_op_parallelism {
int32 max_intra_op_parallelism = 1;
}
// If set, the dataset will use a private threadpool of the given size.
oneof optional_private_threadpool_size {
int32 private_threadpool_size = 2;
}
}
// Represents how to handle external state during serialization.
enum ExternalStatePolicy {
WARN = 0;
IGNORE = 1;
FAIL = 2;
}
// Message stored with Dataset objects to control how datasets are processed and
// optimized.
message Options {
// Whether the outputs need to be produced in deterministic order.
oneof optional_deterministic {
bool deterministic = 1;
}
// The distribution strategy options associated with the dataset.
DistributeOptions distribute_options = 2;
// The optimization options associated with the dataset.
OptimizationOptions optimization_options = 3;
// Whether to introduce 'slack' in the last `prefetch` of the input pipeline,
// if it exists. This may reduce CPU contention with accelerator host-side
// activity at the start of a step. The slack frequency is determined by the
// number of devices attached to this input pipeline.
oneof optional_slack {
bool slack = 4;
}
// The threading options associated with the dataset.
ThreadingOptions threading_options = 5;
// This option can be used to override the default policy for how to handle
// external state when serializing a dataset or checkpointing its iterator.
// There are three settings available - IGNORE: External state is ignored
// without a warning; WARN: External state is ignored and a warning is logged;
// FAIL: External state results in an error.
oneof optional_external_state_policy {
ExternalStatePolicy external_state_policy = 6;
}
}

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import enum
from tensorflow.core.framework import dataset_options_pb2
from tensorflow.python.data.util import options
from tensorflow.python.util.tf_export import tf_export
@ -35,6 +36,34 @@ class AutoShardPolicy(enum.IntEnum):
FILE = 1
DATA = 2
@classmethod
def _to_proto(cls, obj):
"""Convert enum to proto."""
if obj == cls.OFF:
return dataset_options_pb2.AutoShardPolicy.OFF
if obj == cls.FILE:
return dataset_options_pb2.AutoShardPolicy.FILE
if obj == cls.DATA:
return dataset_options_pb2.AutoShardPolicy.DATA
if obj == cls.AUTO:
return dataset_options_pb2.AutoShardPolicy.AUTO
raise ValueError("%s._to_proto() is called with undefined enum %s." %
(cls.__name__, obj.name))
@classmethod
def _from_proto(cls, pb):
"""Convert proto to enum."""
if pb == dataset_options_pb2.AutoShardPolicy.OFF:
return cls.OFF
if pb == dataset_options_pb2.AutoShardPolicy.FILE:
return cls.FILE
if pb == dataset_options_pb2.AutoShardPolicy.DATA:
return cls.DATA
if pb == dataset_options_pb2.AutoShardPolicy.AUTO:
return cls.AUTO
raise ValueError("%s._from_proto() is called with undefined enum %s." %
(cls.__name__, pb))
@tf_export("data.experimental.ExternalStatePolicy")
class ExternalStatePolicy(enum.Enum):
@ -47,6 +76,30 @@ class ExternalStatePolicy(enum.Enum):
IGNORE = 1
FAIL = 2
@classmethod
def _to_proto(cls, obj):
"""Convert enum to proto."""
if obj == cls.IGNORE:
return dataset_options_pb2.ExternalStatePolicy.IGNORE
if obj == cls.FAIL:
return dataset_options_pb2.ExternalStatePolicy.FAIL
if obj == cls.WARN:
return dataset_options_pb2.ExternalStatePolicy.WARN
raise ValueError("%s._to_proto() is called with undefined enum %s." %
(cls.__name__, obj.name))
@classmethod
def _from_proto(cls, pb):
"""Convert proto to enum."""
if pb == dataset_options_pb2.ExternalStatePolicy.IGNORE:
return cls.IGNORE
if pb == dataset_options_pb2.ExternalStatePolicy.FAIL:
return cls.FAIL
if pb == dataset_options_pb2.ExternalStatePolicy.WARN:
return cls.WARN
raise ValueError("%s._from_proto() is called with undefined enum %s." %
(cls.__name__, pb))
@tf_export("data.experimental.DistributeOptions")
class DistributeOptions(options.OptionsBase):
@ -89,3 +142,15 @@ class DistributeOptions(options.OptionsBase):
docstring=
"The number of devices attached to this input pipeline. This will be "
"automatically set by MultiDeviceIterator.")
def _to_proto(self):
pb = dataset_options_pb2.DistributeOptions()
pb.auto_shard_policy = AutoShardPolicy._to_proto(self.auto_shard_policy) # pylint: disable=protected-access
if self.num_devices is not None:
pb.num_devices = self.num_devices
return pb
def _from_proto(self, pb):
self.auto_shard_policy = AutoShardPolicy._from_proto(pb.auto_shard_policy) # pylint: disable=protected-access
if pb.WhichOneof("optional_num_devices") is not None:
self.num_devices = pb.num_devices

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import enum
from tensorflow.core.framework import dataset_options_pb2
from tensorflow.python.data.util import options
from tensorflow.python.util.tf_export import tf_export
@ -69,6 +70,20 @@ class MapVectorizationOptions(options.OptionsBase):
else:
return ["map_vectorization:use_choose_fastest:false"]
def _to_proto(self):
pb = dataset_options_pb2.MapVectorization()
if self.enabled is not None:
pb.enabled = self.enabled
if self.use_choose_fastest is not None:
pb.use_choose_fastest = self.use_choose_fastest
return pb
def _from_proto(self, pb):
if pb.WhichOneof("optional_enabled") is not None:
self.enabled = pb.enabled
if pb.WhichOneof("optional_use_choose_fastest") is not None:
self.use_choose_fastest = pb.use_choose_fastest
@tf_export("data.experimental.OptimizationOptions")
class OptimizationOptions(options.OptionsBase):
@ -327,3 +342,77 @@ class OptimizationOptions(options.OptionsBase):
graph_rewrite_configs.append(optimization + ":autotune:true")
return graph_rewrite_configs
def _to_proto(self):
pb = dataset_options_pb2.OptimizationOptions()
if self.apply_default_optimizations is not None:
pb.apply_default_optimizations = self.apply_default_optimizations
if self.autotune is not None:
pb.autotune = self.autotune
if self.autotune_buffers is not None:
pb.autotune_buffers = self.autotune_buffers
if self.autotune_cpu_budget is not None:
pb.autotune_cpu_budget = self.autotune_cpu_budget
if self.autotune_ram_budget is not None:
pb.autotune_ram_budget = self.autotune_ram_budget
if self.filter_fusion is not None:
pb.filter_fusion = self.filter_fusion
if self.filter_with_random_uniform_fusion is not None:
pb.filter_with_random_uniform_fusion = (
self.filter_with_random_uniform_fusion)
if self.hoist_random_uniform is not None:
pb.hoist_random_uniform = self.hoist_random_uniform
if self.map_and_batch_fusion is not None:
pb.map_and_batch_fusion = self.map_and_batch_fusion
if self.map_and_filter_fusion is not None:
pb.map_and_filter_fusion = self.map_and_filter_fusion
if self.map_fusion is not None:
pb.map_fusion = self.map_fusion
if self.map_parallelization is not None:
pb.map_parallelization = self.map_parallelization
pb.map_vectorization.CopyFrom(self.map_vectorization._to_proto()) # pylint: disable=protected-access
if self.noop_elimination is not None:
pb.noop_elimination = self.noop_elimination
if self.parallel_batch is not None:
pb.parallel_batch = self.parallel_batch
if self.reorder_data_discarding_ops is not None:
pb.reorder_data_discarding_ops = self.reorder_data_discarding_ops
if self.shuffle_and_repeat_fusion is not None:
pb.shuffle_and_repeat_fusion = self.shuffle_and_repeat_fusion
return pb
def _from_proto(self, pb):
if pb.WhichOneof("optional_apply_default_optimizations") is not None:
self.apply_default_optimizations = pb.apply_default_optimizations
if pb.WhichOneof("optional_autotune") is not None:
self.autotune = pb.autotune
if pb.WhichOneof("optional_autotune_buffers") is not None:
self.autotune_buffers = pb.autotune_buffers
if pb.WhichOneof("optional_autotune_cpu_budget") is not None:
self.autotune_cpu_budget = pb.autotune_cpu_budget
if pb.WhichOneof("optional_autotune_ram_budget") is not None:
self.autotune_ram_budget = pb.autotune_ram_budget
if pb.WhichOneof("optional_filter_fusion") is not None:
self.filter_fusion = pb.filter_fusion
if pb.WhichOneof("optional_filter_with_random_uniform_fusion") is not None:
self.filter_with_random_uniform_fusion = (
pb.filter_with_random_uniform_fusion)
if pb.WhichOneof("optional_hoist_random_uniform") is not None:
self.hoist_random_uniform = pb.hoist_random_uniform
if pb.WhichOneof("optional_map_and_batch_fusion") is not None:
self.map_and_batch_fusion = pb.map_and_batch_fusion
if pb.WhichOneof("optional_map_and_filter_fusion") is not None:
self.map_and_filter_fusion = pb.map_and_filter_fusion
if pb.WhichOneof("optional_map_fusion") is not None:
self.map_fusion = pb.map_fusion
if pb.WhichOneof("optional_map_parallelization") is not None:
self.map_parallelization = pb.map_parallelization
self.map_vectorization._from_proto(pb.map_vectorization) # pylint: disable=protected-access
if pb.WhichOneof("optional_noop_elimination") is not None:
self.noop_elimination = pb.noop_elimination
if pb.WhichOneof("optional_parallel_batch") is not None:
self.parallel_batch = pb.parallel_batch
if pb.WhichOneof("optional_reorder_data_discarding_ops") is not None:
self.reorder_data_discarding_ops = pb.reorder_data_discarding_ops
if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None:
self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import dataset_options_pb2
from tensorflow.python.data.util import options
from tensorflow.python.util.tf_export import tf_export
@ -48,3 +49,17 @@ class ThreadingOptions(options.OptionsBase):
ty=int,
docstring=
"If set, the dataset will use a private threadpool of the given size.")
def _to_proto(self):
pb = dataset_options_pb2.ThreadingOptions()
if self.max_intra_op_parallelism is not None:
pb.max_intra_op_parallelism = self.max_intra_op_parallelism
if self.private_threadpool_size is not None:
pb.private_threadpool_size = self.private_threadpool_size
return pb
def _from_proto(self, pb):
if pb.WhichOneof("optional_max_intra_op_parallelism") is not None:
self.max_intra_op_parallelism = pb.max_intra_op_parallelism
if pb.WhichOneof("optional_private_threadpool_size") is not None:
self.private_threadpool_size = pb.private_threadpool_size

View File

@ -23,6 +23,8 @@ import sys
from absl.testing import parameterized
from tensorflow.core.framework import dataset_options_pb2
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.experimental.ops import optimization_options
from tensorflow.python.data.experimental.ops import stats_options
from tensorflow.python.data.experimental.ops import threading_options
@ -127,6 +129,67 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
result = result.concatenate(ds)
self.assertDatasetProduces(result, [0]*1000)
@combinations.generate(test_base.default_test_combinations())
def testOptionsProtoRoundTrip(self):
options = dataset_ops.Options()
options.experimental_deterministic = True
options.experimental_external_state_policy = (
distribute_options.ExternalStatePolicy.FAIL)
options.experimental_distribute.auto_shard_policy = (
distribute_options.AutoShardPolicy.DATA)
options.experimental_distribute.num_devices = 1000
options.experimental_optimization.apply_default_optimizations = True
options.experimental_optimization.autotune = True
options.experimental_optimization.autotune_buffers = True
options.experimental_optimization.autotune_cpu_budget = 10
options.experimental_optimization.autotune_ram_budget = 20
options.experimental_optimization.filter_fusion = True
options.experimental_optimization.filter_with_random_uniform_fusion = True
options.experimental_optimization.hoist_random_uniform = True
options.experimental_optimization.map_and_batch_fusion = True
options.experimental_optimization.map_and_filter_fusion = True
options.experimental_optimization.map_fusion = True
options.experimental_optimization.map_parallelization = True
options.experimental_optimization.map_vectorization.enabled = True
options.experimental_optimization.map_vectorization.use_choose_fastest = (
True)
options.experimental_optimization.noop_elimination = True
options.experimental_optimization.parallel_batch = True
options.experimental_optimization.reorder_data_discarding_ops = True
options.experimental_optimization.shuffle_and_repeat_fusion = True
options.experimental_slack = True
options.experimental_threading.max_intra_op_parallelism = 30
options.experimental_threading.private_threadpool_size = 40
pb = options._to_proto()
result = dataset_ops.Options()
result._from_proto(pb)
self.assertEqual(options, result)
@combinations.generate(test_base.default_test_combinations())
def testOptionsProtoDefaultValuesRoundTrip(self):
options = dataset_ops.Options()
pb = options._to_proto()
result = dataset_ops.Options()
result._from_proto(pb)
self.assertEqual(options, result)
@combinations.generate(test_base.default_test_combinations())
def testProtoOptionsDefaultValuesRoundTrip(self):
pb = dataset_options_pb2.Options()
options = dataset_ops.Options()
options._from_proto(pb)
result = options._to_proto()
expected_pb = dataset_options_pb2.Options()
expected_pb.distribute_options.CopyFrom(
dataset_options_pb2.DistributeOptions())
expected_pb.optimization_options.CopyFrom(
dataset_options_pb2.OptimizationOptions())
expected_pb.optimization_options.map_vectorization.CopyFrom(
dataset_options_pb2.MapVectorization())
expected_pb.threading_options.CopyFrom(
dataset_options_pb2.ThreadingOptions())
self.assertProtoEquals(expected_pb, result)
if __name__ == "__main__":
test.main()

View File

@ -28,6 +28,7 @@ import numpy as np
import six
from six.moves import queue as Queue # pylint: disable=redefined-builtin
from tensorflow.core.framework import dataset_options_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import distribute_options
@ -3039,6 +3040,34 @@ class Options(options_lib.OptionsBase):
"state is ignored and a warning is logged; FAIL: External state results "
"in an error.")
def _to_proto(self):
pb = dataset_options_pb2.Options()
if self.experimental_deterministic is not None:
pb.deterministic = self.experimental_deterministic
pb.distribute_options.CopyFrom(self.experimental_distribute._to_proto()) # pylint: disable=protected-access
if self.experimental_external_state_policy is not None:
pb.external_state_policy = (
distribute_options.ExternalStatePolicy._to_proto( # pylint: disable=protected-access
self.experimental_external_state_policy))
pb.optimization_options.CopyFrom(self.experimental_optimization._to_proto()) # pylint: disable=protected-access
if self.experimental_slack is not None:
pb.slack = self.experimental_slack
pb.threading_options.CopyFrom(self.experimental_threading._to_proto()) # pylint: disable=protected-access
return pb
def _from_proto(self, pb):
if pb.WhichOneof("optional_deterministic") is not None:
self.experimental_deterministic = pb.deterministic
self.experimental_distribute._from_proto(pb.distribute_options) # pylint: disable=protected-access
if pb.WhichOneof("optional_external_state_policy") is not None:
self.experimental_external_state_policy = (
distribute_options.ExternalStatePolicy._from_proto( # pylint: disable=protected-access
pb.external_state_policy))
self.experimental_optimization._from_proto(pb.optimization_options) # pylint: disable=protected-access
if pb.WhichOneof("optional_slack") is not None:
self.experimental_slack = pb.slack
self.experimental_threading._from_proto(pb.threading_options) # pylint: disable=protected-access
def _graph_rewrites(self):
"""Produces lists of enabled, disabled, default static graph rewrites.

View File

@ -59,6 +59,14 @@ class OptionsBase(object):
raise AttributeError(
"Cannot set the property %s on %s." % (name, type(self).__name__))
def _to_proto(self):
"""Convert options to protocol buffer."""
raise NotImplementedError("%s._to_proto()" % type(self).__name__)
def _from_proto(self, pb):
"""Convert protocol buffer to options."""
raise NotImplementedError("%s._from_proto()" % type(self).__name__)
# Creates a namedtuple with three keys for optimization graph rewrites settings.
def graph_rewrites():