[tf.data] Add a new proto to store tf.data.Options data and conversion functions between tf.data.Options and their proto representation.
PiperOrigin-RevId: 355945339 Change-Id: I3eba0c05899aeda629e34483b75fc61c435cb4d8
This commit is contained in:
parent
7dbb6cf44a
commit
eb001c7165
@ -203,6 +203,7 @@ FRAMEWORK_PROTO_SRCS = [
|
||||
"//tensorflow/core/framework:model.proto",
|
||||
"//tensorflow/core/framework:node_def.proto",
|
||||
"//tensorflow/core/framework:op_def.proto",
|
||||
"//tensorflow/core/framework:dataset_options.proto",
|
||||
"//tensorflow/core/framework:reader_base.proto",
|
||||
"//tensorflow/core/framework:remote_fused_graph_execute_info.proto",
|
||||
"//tensorflow/core/framework:resource_handle.proto",
|
||||
|
@ -119,6 +119,7 @@ exports_files(
|
||||
"api_def.proto",
|
||||
"attr_value.proto",
|
||||
"cost_graph.proto",
|
||||
"dataset_options.proto",
|
||||
"device_attributes.proto",
|
||||
"function.proto",
|
||||
"graph.proto",
|
||||
@ -1660,6 +1661,13 @@ tf_proto_library(
|
||||
make_default_target_header_only = True,
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "dataset_options_proto",
|
||||
srcs = ["dataset_options.proto"],
|
||||
cc_api_version = 2,
|
||||
make_default_target_header_only = True,
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "protos_all",
|
||||
cc_api_version = 2,
|
||||
@ -1678,6 +1686,7 @@ tf_proto_library(
|
||||
":model_proto",
|
||||
":node_def_proto",
|
||||
":op_def_proto",
|
||||
":dataset_options_proto",
|
||||
":reader_base_proto",
|
||||
":remote_fused_graph_execute_info_proto",
|
||||
":resource_handle_proto",
|
||||
|
179
tensorflow/core/framework/dataset_options.proto
Normal file
179
tensorflow/core/framework/dataset_options.proto
Normal file
@ -0,0 +1,179 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow.data;
|
||||
|
||||
// Represents the type of auto-sharding we enable.
|
||||
enum AutoShardPolicy {
|
||||
AUTO = 0;
|
||||
FILE = 1;
|
||||
DATA = 2;
|
||||
OFF = -1;
|
||||
}
|
||||
|
||||
message DistributeOptions {
|
||||
// The type of sharding that auto-shard should attempt. If this is set to
|
||||
// FILE, then we will attempt to shard by files (each worker will get a set of
|
||||
// files to process). If we cannot find a set of files to shard for at least
|
||||
// one file per worker, we will error out. When this option is selected, make
|
||||
// sure that you have enough files so that each worker gets at least one file.
|
||||
// There will be a runtime error thrown if there are insufficient files. If
|
||||
// this is set to DATA, then we will shard by elements produced by the
|
||||
// dataset, and each worker will process the whole dataset and discard the
|
||||
// portion that is not for itself. If this is set to OFF, then we will not
|
||||
// autoshard, and each worker will receive a copy of the full dataset. This
|
||||
// option is set to AUTO by default, AUTO will attempt to first shard by FILE,
|
||||
// and fall back to sharding by DATA if we cannot find a set of files to
|
||||
// shard.
|
||||
AutoShardPolicy auto_shard_policy = 1;
|
||||
// The number of devices attached to this input pipeline.
|
||||
oneof optional_num_devices {
|
||||
int32 num_devices = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message MapVectorization {
|
||||
// Whether to vectorize map transformations.
|
||||
oneof optional_enabled {
|
||||
bool enabled = 1;
|
||||
}
|
||||
// Whether to use ChooseFastestBranchDataset with this transformation. If
|
||||
// True, the pipeline picks between the vectorized and original segment at
|
||||
// runtime based on their iterations speed.
|
||||
oneof optional_use_choose_fastest {
|
||||
bool use_choose_fastest = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message OptimizationOptions {
|
||||
// Whether to apply default graph optimizations. If False, only graph
|
||||
// optimizations that have been explicitly enabled will be applied.
|
||||
oneof optional_apply_default_optimizations {
|
||||
bool apply_default_optimizations = 1;
|
||||
}
|
||||
// Whether to automatically tune performance knobs.
|
||||
oneof optional_autotune {
|
||||
bool autotune = 2;
|
||||
}
|
||||
// When autotuning is enabled (through autotune), determines whether to also
|
||||
// autotune buffer sizes for datasets with parallelism.
|
||||
oneof optional_autotune_buffers {
|
||||
bool autotune_buffers = 3;
|
||||
}
|
||||
// When autotuning is enabled (through autotune), determines the CPU budget to
|
||||
// use. Values greater than the number of schedulable CPU cores are allowed
|
||||
// but may result in CPU contention.
|
||||
oneof optional_autotune_cpu_budget {
|
||||
int32 autotune_cpu_budget = 4;
|
||||
}
|
||||
// When autotuning is enabled (through autotune), determines the RAM budget to
|
||||
// use. Values greater than the available RAM in bytes may result in OOM. If
|
||||
// 0, defaults to half of the available RAM in bytes.
|
||||
oneof optional_autotune_ram_budget {
|
||||
int32 autotune_ram_budget = 5;
|
||||
}
|
||||
// Whether to fuse filter transformations.
|
||||
oneof optional_filter_fusion {
|
||||
bool filter_fusion = 6;
|
||||
}
|
||||
// Whether to fuse filter dataset that predicts random_uniform < rate into a
|
||||
// sampling dataset.
|
||||
oneof optional_filter_with_random_uniform_fusion {
|
||||
bool filter_with_random_uniform_fusion = 7;
|
||||
}
|
||||
// Whether to hoist tf.random_uniform() ops out of map transformations.
|
||||
oneof optional_hoist_random_uniform {
|
||||
bool hoist_random_uniform = 8;
|
||||
}
|
||||
// Whether to fuse map and batch transformations.
|
||||
oneof optional_map_and_batch_fusion {
|
||||
bool map_and_batch_fusion = 9;
|
||||
}
|
||||
// Whether to fuse map and filter transformations.
|
||||
oneof optional_map_and_filter_fusion {
|
||||
bool map_and_filter_fusion = 10;
|
||||
}
|
||||
// Whether to fuse map transformations.
|
||||
oneof optional_map_fusion {
|
||||
bool map_fusion = 11;
|
||||
}
|
||||
// Whether to parallelize stateless map transformations.
|
||||
oneof optional_map_parallelization {
|
||||
bool map_parallelization = 12;
|
||||
}
|
||||
// The map vectorization options associated with the dataset.
|
||||
MapVectorization map_vectorization = 13;
|
||||
// Whether to eliminate no-op transformations.
|
||||
oneof optional_noop_elimination {
|
||||
bool noop_elimination = 14;
|
||||
}
|
||||
// Whether to parallelize copying of batch elements. This optimization is
|
||||
// highly experimental and can cause performance degradation (e.g. when the
|
||||
// parallelization overhead exceeds the benefits of performing the data copies
|
||||
// in parallel). You should only enable this optimization if a) your input
|
||||
// pipeline is bottlenecked on batching and b) you have validated that this
|
||||
// optimization improves performance.
|
||||
oneof optional_parallel_batch {
|
||||
bool parallel_batch = 15;
|
||||
}
|
||||
// Whether to reorder ops that will discard data to the front of unary
|
||||
// cardinality preserving transformations, e.g. dataset.map(...).take(3) will
|
||||
// be optimized to dataset.take(3).map(...). For now this optimization will
|
||||
// move `skip`, `shard` and `take` to the front of `map` and `prefetch`. This
|
||||
// optimization is only for performance; it will not affect the output of the
|
||||
// dataset.
|
||||
oneof optional_reorder_data_discarding_ops {
|
||||
bool reorder_data_discarding_ops = 16;
|
||||
}
|
||||
// Whether to fuse shuffle and repeat transformations.
|
||||
oneof optional_shuffle_and_repeat_fusion {
|
||||
bool shuffle_and_repeat_fusion = 17;
|
||||
}
|
||||
}
|
||||
|
||||
message ThreadingOptions {
|
||||
// If set, it overrides the maximum degree of intra-op parallelism.
|
||||
oneof optional_max_intra_op_parallelism {
|
||||
int32 max_intra_op_parallelism = 1;
|
||||
}
|
||||
// If set, the dataset will use a private threadpool of the given size.
|
||||
oneof optional_private_threadpool_size {
|
||||
int32 private_threadpool_size = 2;
|
||||
}
|
||||
}
|
||||
|
||||
// Represents how to handle external state during serialization.
|
||||
enum ExternalStatePolicy {
|
||||
WARN = 0;
|
||||
IGNORE = 1;
|
||||
FAIL = 2;
|
||||
}
|
||||
|
||||
// Message stored with Dataset objects to control how datasets are processed and
|
||||
// optimized.
|
||||
message Options {
|
||||
// Whether the outputs need to be produced in deterministic order.
|
||||
oneof optional_deterministic {
|
||||
bool deterministic = 1;
|
||||
}
|
||||
// The distribution strategy options associated with the dataset.
|
||||
DistributeOptions distribute_options = 2;
|
||||
// The optimization options associated with the dataset.
|
||||
OptimizationOptions optimization_options = 3;
|
||||
// Whether to introduce 'slack' in the last `prefetch` of the input pipeline,
|
||||
// if it exists. This may reduce CPU contention with accelerator host-side
|
||||
// activity at the start of a step. The slack frequency is determined by the
|
||||
// number of devices attached to this input pipeline.
|
||||
oneof optional_slack {
|
||||
bool slack = 4;
|
||||
}
|
||||
// The threading options associated with the dataset.
|
||||
ThreadingOptions threading_options = 5;
|
||||
// This option can be used to override the default policy for how to handle
|
||||
// external state when serializing a dataset or checkpointing its iterator.
|
||||
// There are three settings available - IGNORE: External state is ignored
|
||||
// without a warning; WARN: External state is ignored and a warning is logged;
|
||||
// FAIL: External state results in an error.
|
||||
oneof optional_external_state_policy {
|
||||
ExternalStatePolicy external_state_policy = 6;
|
||||
}
|
||||
}
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import enum
|
||||
|
||||
from tensorflow.core.framework import dataset_options_pb2
|
||||
from tensorflow.python.data.util import options
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -35,6 +36,34 @@ class AutoShardPolicy(enum.IntEnum):
|
||||
FILE = 1
|
||||
DATA = 2
|
||||
|
||||
@classmethod
|
||||
def _to_proto(cls, obj):
|
||||
"""Convert enum to proto."""
|
||||
if obj == cls.OFF:
|
||||
return dataset_options_pb2.AutoShardPolicy.OFF
|
||||
if obj == cls.FILE:
|
||||
return dataset_options_pb2.AutoShardPolicy.FILE
|
||||
if obj == cls.DATA:
|
||||
return dataset_options_pb2.AutoShardPolicy.DATA
|
||||
if obj == cls.AUTO:
|
||||
return dataset_options_pb2.AutoShardPolicy.AUTO
|
||||
raise ValueError("%s._to_proto() is called with undefined enum %s." %
|
||||
(cls.__name__, obj.name))
|
||||
|
||||
@classmethod
|
||||
def _from_proto(cls, pb):
|
||||
"""Convert proto to enum."""
|
||||
if pb == dataset_options_pb2.AutoShardPolicy.OFF:
|
||||
return cls.OFF
|
||||
if pb == dataset_options_pb2.AutoShardPolicy.FILE:
|
||||
return cls.FILE
|
||||
if pb == dataset_options_pb2.AutoShardPolicy.DATA:
|
||||
return cls.DATA
|
||||
if pb == dataset_options_pb2.AutoShardPolicy.AUTO:
|
||||
return cls.AUTO
|
||||
raise ValueError("%s._from_proto() is called with undefined enum %s." %
|
||||
(cls.__name__, pb))
|
||||
|
||||
|
||||
@tf_export("data.experimental.ExternalStatePolicy")
|
||||
class ExternalStatePolicy(enum.Enum):
|
||||
@ -47,6 +76,30 @@ class ExternalStatePolicy(enum.Enum):
|
||||
IGNORE = 1
|
||||
FAIL = 2
|
||||
|
||||
@classmethod
|
||||
def _to_proto(cls, obj):
|
||||
"""Convert enum to proto."""
|
||||
if obj == cls.IGNORE:
|
||||
return dataset_options_pb2.ExternalStatePolicy.IGNORE
|
||||
if obj == cls.FAIL:
|
||||
return dataset_options_pb2.ExternalStatePolicy.FAIL
|
||||
if obj == cls.WARN:
|
||||
return dataset_options_pb2.ExternalStatePolicy.WARN
|
||||
raise ValueError("%s._to_proto() is called with undefined enum %s." %
|
||||
(cls.__name__, obj.name))
|
||||
|
||||
@classmethod
|
||||
def _from_proto(cls, pb):
|
||||
"""Convert proto to enum."""
|
||||
if pb == dataset_options_pb2.ExternalStatePolicy.IGNORE:
|
||||
return cls.IGNORE
|
||||
if pb == dataset_options_pb2.ExternalStatePolicy.FAIL:
|
||||
return cls.FAIL
|
||||
if pb == dataset_options_pb2.ExternalStatePolicy.WARN:
|
||||
return cls.WARN
|
||||
raise ValueError("%s._from_proto() is called with undefined enum %s." %
|
||||
(cls.__name__, pb))
|
||||
|
||||
|
||||
@tf_export("data.experimental.DistributeOptions")
|
||||
class DistributeOptions(options.OptionsBase):
|
||||
@ -89,3 +142,15 @@ class DistributeOptions(options.OptionsBase):
|
||||
docstring=
|
||||
"The number of devices attached to this input pipeline. This will be "
|
||||
"automatically set by MultiDeviceIterator.")
|
||||
|
||||
def _to_proto(self):
|
||||
pb = dataset_options_pb2.DistributeOptions()
|
||||
pb.auto_shard_policy = AutoShardPolicy._to_proto(self.auto_shard_policy) # pylint: disable=protected-access
|
||||
if self.num_devices is not None:
|
||||
pb.num_devices = self.num_devices
|
||||
return pb
|
||||
|
||||
def _from_proto(self, pb):
|
||||
self.auto_shard_policy = AutoShardPolicy._from_proto(pb.auto_shard_policy) # pylint: disable=protected-access
|
||||
if pb.WhichOneof("optional_num_devices") is not None:
|
||||
self.num_devices = pb.num_devices
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import enum
|
||||
|
||||
from tensorflow.core.framework import dataset_options_pb2
|
||||
from tensorflow.python.data.util import options
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -69,6 +70,20 @@ class MapVectorizationOptions(options.OptionsBase):
|
||||
else:
|
||||
return ["map_vectorization:use_choose_fastest:false"]
|
||||
|
||||
def _to_proto(self):
|
||||
pb = dataset_options_pb2.MapVectorization()
|
||||
if self.enabled is not None:
|
||||
pb.enabled = self.enabled
|
||||
if self.use_choose_fastest is not None:
|
||||
pb.use_choose_fastest = self.use_choose_fastest
|
||||
return pb
|
||||
|
||||
def _from_proto(self, pb):
|
||||
if pb.WhichOneof("optional_enabled") is not None:
|
||||
self.enabled = pb.enabled
|
||||
if pb.WhichOneof("optional_use_choose_fastest") is not None:
|
||||
self.use_choose_fastest = pb.use_choose_fastest
|
||||
|
||||
|
||||
@tf_export("data.experimental.OptimizationOptions")
|
||||
class OptimizationOptions(options.OptionsBase):
|
||||
@ -327,3 +342,77 @@ class OptimizationOptions(options.OptionsBase):
|
||||
graph_rewrite_configs.append(optimization + ":autotune:true")
|
||||
|
||||
return graph_rewrite_configs
|
||||
|
||||
def _to_proto(self):
|
||||
pb = dataset_options_pb2.OptimizationOptions()
|
||||
if self.apply_default_optimizations is not None:
|
||||
pb.apply_default_optimizations = self.apply_default_optimizations
|
||||
if self.autotune is not None:
|
||||
pb.autotune = self.autotune
|
||||
if self.autotune_buffers is not None:
|
||||
pb.autotune_buffers = self.autotune_buffers
|
||||
if self.autotune_cpu_budget is not None:
|
||||
pb.autotune_cpu_budget = self.autotune_cpu_budget
|
||||
if self.autotune_ram_budget is not None:
|
||||
pb.autotune_ram_budget = self.autotune_ram_budget
|
||||
if self.filter_fusion is not None:
|
||||
pb.filter_fusion = self.filter_fusion
|
||||
if self.filter_with_random_uniform_fusion is not None:
|
||||
pb.filter_with_random_uniform_fusion = (
|
||||
self.filter_with_random_uniform_fusion)
|
||||
if self.hoist_random_uniform is not None:
|
||||
pb.hoist_random_uniform = self.hoist_random_uniform
|
||||
if self.map_and_batch_fusion is not None:
|
||||
pb.map_and_batch_fusion = self.map_and_batch_fusion
|
||||
if self.map_and_filter_fusion is not None:
|
||||
pb.map_and_filter_fusion = self.map_and_filter_fusion
|
||||
if self.map_fusion is not None:
|
||||
pb.map_fusion = self.map_fusion
|
||||
if self.map_parallelization is not None:
|
||||
pb.map_parallelization = self.map_parallelization
|
||||
pb.map_vectorization.CopyFrom(self.map_vectorization._to_proto()) # pylint: disable=protected-access
|
||||
if self.noop_elimination is not None:
|
||||
pb.noop_elimination = self.noop_elimination
|
||||
if self.parallel_batch is not None:
|
||||
pb.parallel_batch = self.parallel_batch
|
||||
if self.reorder_data_discarding_ops is not None:
|
||||
pb.reorder_data_discarding_ops = self.reorder_data_discarding_ops
|
||||
if self.shuffle_and_repeat_fusion is not None:
|
||||
pb.shuffle_and_repeat_fusion = self.shuffle_and_repeat_fusion
|
||||
return pb
|
||||
|
||||
def _from_proto(self, pb):
|
||||
if pb.WhichOneof("optional_apply_default_optimizations") is not None:
|
||||
self.apply_default_optimizations = pb.apply_default_optimizations
|
||||
if pb.WhichOneof("optional_autotune") is not None:
|
||||
self.autotune = pb.autotune
|
||||
if pb.WhichOneof("optional_autotune_buffers") is not None:
|
||||
self.autotune_buffers = pb.autotune_buffers
|
||||
if pb.WhichOneof("optional_autotune_cpu_budget") is not None:
|
||||
self.autotune_cpu_budget = pb.autotune_cpu_budget
|
||||
if pb.WhichOneof("optional_autotune_ram_budget") is not None:
|
||||
self.autotune_ram_budget = pb.autotune_ram_budget
|
||||
if pb.WhichOneof("optional_filter_fusion") is not None:
|
||||
self.filter_fusion = pb.filter_fusion
|
||||
if pb.WhichOneof("optional_filter_with_random_uniform_fusion") is not None:
|
||||
self.filter_with_random_uniform_fusion = (
|
||||
pb.filter_with_random_uniform_fusion)
|
||||
if pb.WhichOneof("optional_hoist_random_uniform") is not None:
|
||||
self.hoist_random_uniform = pb.hoist_random_uniform
|
||||
if pb.WhichOneof("optional_map_and_batch_fusion") is not None:
|
||||
self.map_and_batch_fusion = pb.map_and_batch_fusion
|
||||
if pb.WhichOneof("optional_map_and_filter_fusion") is not None:
|
||||
self.map_and_filter_fusion = pb.map_and_filter_fusion
|
||||
if pb.WhichOneof("optional_map_fusion") is not None:
|
||||
self.map_fusion = pb.map_fusion
|
||||
if pb.WhichOneof("optional_map_parallelization") is not None:
|
||||
self.map_parallelization = pb.map_parallelization
|
||||
self.map_vectorization._from_proto(pb.map_vectorization) # pylint: disable=protected-access
|
||||
if pb.WhichOneof("optional_noop_elimination") is not None:
|
||||
self.noop_elimination = pb.noop_elimination
|
||||
if pb.WhichOneof("optional_parallel_batch") is not None:
|
||||
self.parallel_batch = pb.parallel_batch
|
||||
if pb.WhichOneof("optional_reorder_data_discarding_ops") is not None:
|
||||
self.reorder_data_discarding_ops = pb.reorder_data_discarding_ops
|
||||
if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None:
|
||||
self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion
|
||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
from tensorflow.core.framework import dataset_options_pb2
|
||||
from tensorflow.python.data.util import options
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -48,3 +49,17 @@ class ThreadingOptions(options.OptionsBase):
|
||||
ty=int,
|
||||
docstring=
|
||||
"If set, the dataset will use a private threadpool of the given size.")
|
||||
|
||||
def _to_proto(self):
|
||||
pb = dataset_options_pb2.ThreadingOptions()
|
||||
if self.max_intra_op_parallelism is not None:
|
||||
pb.max_intra_op_parallelism = self.max_intra_op_parallelism
|
||||
if self.private_threadpool_size is not None:
|
||||
pb.private_threadpool_size = self.private_threadpool_size
|
||||
return pb
|
||||
|
||||
def _from_proto(self, pb):
|
||||
if pb.WhichOneof("optional_max_intra_op_parallelism") is not None:
|
||||
self.max_intra_op_parallelism = pb.max_intra_op_parallelism
|
||||
if pb.WhichOneof("optional_private_threadpool_size") is not None:
|
||||
self.private_threadpool_size = pb.private_threadpool_size
|
||||
|
@ -23,6 +23,8 @@ import sys
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.core.framework import dataset_options_pb2
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.experimental.ops import optimization_options
|
||||
from tensorflow.python.data.experimental.ops import stats_options
|
||||
from tensorflow.python.data.experimental.ops import threading_options
|
||||
@ -127,6 +129,67 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
result = result.concatenate(ds)
|
||||
self.assertDatasetProduces(result, [0]*1000)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptionsProtoRoundTrip(self):
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_deterministic = True
|
||||
options.experimental_external_state_policy = (
|
||||
distribute_options.ExternalStatePolicy.FAIL)
|
||||
options.experimental_distribute.auto_shard_policy = (
|
||||
distribute_options.AutoShardPolicy.DATA)
|
||||
options.experimental_distribute.num_devices = 1000
|
||||
options.experimental_optimization.apply_default_optimizations = True
|
||||
options.experimental_optimization.autotune = True
|
||||
options.experimental_optimization.autotune_buffers = True
|
||||
options.experimental_optimization.autotune_cpu_budget = 10
|
||||
options.experimental_optimization.autotune_ram_budget = 20
|
||||
options.experimental_optimization.filter_fusion = True
|
||||
options.experimental_optimization.filter_with_random_uniform_fusion = True
|
||||
options.experimental_optimization.hoist_random_uniform = True
|
||||
options.experimental_optimization.map_and_batch_fusion = True
|
||||
options.experimental_optimization.map_and_filter_fusion = True
|
||||
options.experimental_optimization.map_fusion = True
|
||||
options.experimental_optimization.map_parallelization = True
|
||||
options.experimental_optimization.map_vectorization.enabled = True
|
||||
options.experimental_optimization.map_vectorization.use_choose_fastest = (
|
||||
True)
|
||||
options.experimental_optimization.noop_elimination = True
|
||||
options.experimental_optimization.parallel_batch = True
|
||||
options.experimental_optimization.reorder_data_discarding_ops = True
|
||||
options.experimental_optimization.shuffle_and_repeat_fusion = True
|
||||
options.experimental_slack = True
|
||||
options.experimental_threading.max_intra_op_parallelism = 30
|
||||
options.experimental_threading.private_threadpool_size = 40
|
||||
pb = options._to_proto()
|
||||
result = dataset_ops.Options()
|
||||
result._from_proto(pb)
|
||||
self.assertEqual(options, result)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptionsProtoDefaultValuesRoundTrip(self):
|
||||
options = dataset_ops.Options()
|
||||
pb = options._to_proto()
|
||||
result = dataset_ops.Options()
|
||||
result._from_proto(pb)
|
||||
self.assertEqual(options, result)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testProtoOptionsDefaultValuesRoundTrip(self):
|
||||
pb = dataset_options_pb2.Options()
|
||||
options = dataset_ops.Options()
|
||||
options._from_proto(pb)
|
||||
result = options._to_proto()
|
||||
expected_pb = dataset_options_pb2.Options()
|
||||
expected_pb.distribute_options.CopyFrom(
|
||||
dataset_options_pb2.DistributeOptions())
|
||||
expected_pb.optimization_options.CopyFrom(
|
||||
dataset_options_pb2.OptimizationOptions())
|
||||
expected_pb.optimization_options.map_vectorization.CopyFrom(
|
||||
dataset_options_pb2.MapVectorization())
|
||||
expected_pb.threading_options.CopyFrom(
|
||||
dataset_options_pb2.ThreadingOptions())
|
||||
self.assertProtoEquals(expected_pb, result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -28,6 +28,7 @@ import numpy as np
|
||||
import six
|
||||
from six.moves import queue as Queue # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import dataset_options_pb2
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
@ -3039,6 +3040,34 @@ class Options(options_lib.OptionsBase):
|
||||
"state is ignored and a warning is logged; FAIL: External state results "
|
||||
"in an error.")
|
||||
|
||||
def _to_proto(self):
|
||||
pb = dataset_options_pb2.Options()
|
||||
if self.experimental_deterministic is not None:
|
||||
pb.deterministic = self.experimental_deterministic
|
||||
pb.distribute_options.CopyFrom(self.experimental_distribute._to_proto()) # pylint: disable=protected-access
|
||||
if self.experimental_external_state_policy is not None:
|
||||
pb.external_state_policy = (
|
||||
distribute_options.ExternalStatePolicy._to_proto( # pylint: disable=protected-access
|
||||
self.experimental_external_state_policy))
|
||||
pb.optimization_options.CopyFrom(self.experimental_optimization._to_proto()) # pylint: disable=protected-access
|
||||
if self.experimental_slack is not None:
|
||||
pb.slack = self.experimental_slack
|
||||
pb.threading_options.CopyFrom(self.experimental_threading._to_proto()) # pylint: disable=protected-access
|
||||
return pb
|
||||
|
||||
def _from_proto(self, pb):
|
||||
if pb.WhichOneof("optional_deterministic") is not None:
|
||||
self.experimental_deterministic = pb.deterministic
|
||||
self.experimental_distribute._from_proto(pb.distribute_options) # pylint: disable=protected-access
|
||||
if pb.WhichOneof("optional_external_state_policy") is not None:
|
||||
self.experimental_external_state_policy = (
|
||||
distribute_options.ExternalStatePolicy._from_proto( # pylint: disable=protected-access
|
||||
pb.external_state_policy))
|
||||
self.experimental_optimization._from_proto(pb.optimization_options) # pylint: disable=protected-access
|
||||
if pb.WhichOneof("optional_slack") is not None:
|
||||
self.experimental_slack = pb.slack
|
||||
self.experimental_threading._from_proto(pb.threading_options) # pylint: disable=protected-access
|
||||
|
||||
def _graph_rewrites(self):
|
||||
"""Produces lists of enabled, disabled, default static graph rewrites.
|
||||
|
||||
|
@ -59,6 +59,14 @@ class OptionsBase(object):
|
||||
raise AttributeError(
|
||||
"Cannot set the property %s on %s." % (name, type(self).__name__))
|
||||
|
||||
def _to_proto(self):
|
||||
"""Convert options to protocol buffer."""
|
||||
raise NotImplementedError("%s._to_proto()" % type(self).__name__)
|
||||
|
||||
def _from_proto(self, pb):
|
||||
"""Convert protocol buffer to options."""
|
||||
raise NotImplementedError("%s._from_proto()" % type(self).__name__)
|
||||
|
||||
|
||||
# Creates a namedtuple with three keys for optimization graph rewrites settings.
|
||||
def graph_rewrites():
|
||||
|
Loading…
Reference in New Issue
Block a user