[tf.data] Options-related changes.
This CL: - refactors all options classes to use a shared options utility - introduces `tf.data.experimental.ThreadingOptions` options for threading configuration and surfaces it through `experimental_threading` of `tf.data.Options` PiperOrigin-RevId: 222462977
This commit is contained in:
parent
78ef03f348
commit
fa5590b4ad
@ -0,0 +1,13 @@
|
||||
op {
|
||||
graph_op_name: "ExperimentalMaxIntraOpParallelismDataset"
|
||||
in_arg {
|
||||
name: "max_intra_op_parallelism"
|
||||
description: <<END
|
||||
Identifies the maximum intra-op parallelism to use.
|
||||
END
|
||||
}
|
||||
summary: <<END
|
||||
Creates a dataset that overrides the maximum intra-op parallelism.
|
||||
END
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,13 @@
|
||||
op {
|
||||
graph_op_name: "ExperimentalPrivateThreadPoolDataset"
|
||||
in_arg {
|
||||
name: "num_threads"
|
||||
description: <<END
|
||||
Identifies the number of threads to use for the private threadpool.
|
||||
END
|
||||
}
|
||||
summary: <<END
|
||||
Creates a dataset that uses a custom thread pool to compute `input_dataset`.
|
||||
END
|
||||
visibility: HIDDEN
|
||||
}
|
@ -13,10 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -225,6 +227,221 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
};
|
||||
};
|
||||
|
||||
class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
explicit MaxIntraOpParallelismDatasetOp(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {}
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override {
|
||||
int64 max_intra_op_parallelism;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ParseScalarArgument<int64>(ctx, "max_intra_op_parallelism",
|
||||
&max_intra_op_parallelism));
|
||||
OP_REQUIRES(
|
||||
ctx, max_intra_op_parallelism >= 0,
|
||||
errors::InvalidArgument("`max_intra_op_parallelism` must be >= 0"));
|
||||
*output = new Dataset(ctx, input, max_intra_op_parallelism);
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
int max_intra_op_parallelism)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
max_intra_op_parallelism_(max_intra_op_parallelism) {
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
~Dataset() override { input_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return std::unique_ptr<IteratorBase>(new Iterator(
|
||||
{this, strings::StrCat(prefix, "::MaxIntraOpParallelism")}));
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
return input_->output_dtypes();
|
||||
}
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
return input_->output_shapes();
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return "MaxIntraOpParallelismDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||
Node* max_intra_op_parallelism_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(max_intra_op_parallelism_,
|
||||
&max_intra_op_parallelism_node));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this, {input_graph_node, max_intra_op_parallelism_node}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
IteratorContext::Params params(ctx);
|
||||
auto max_parallelism = dataset()->max_intra_op_parallelism_;
|
||||
params.runner = std::bind(
|
||||
[max_parallelism](
|
||||
const std::function<void(std::function<void()>)>& runner,
|
||||
std::function<void()> fn) {
|
||||
std::function<void()> scoped_fn = std::bind(
|
||||
[max_parallelism](const std::function<void()>& fn) {
|
||||
ScopedPerThreadMaxParallelism scope(max_parallelism);
|
||||
fn();
|
||||
},
|
||||
std::move(fn));
|
||||
(runner)(std::move(scoped_fn));
|
||||
},
|
||||
std::move(*ctx->runner()), std::placeholders::_1);
|
||||
return input_impl_->GetNext(IteratorContext{std::move(params)},
|
||||
out_tensors, end_of_sequence);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<model::Node> CreateNode(
|
||||
IteratorContext* ctx, model::Node::Args args) const override {
|
||||
return model::MakeKnownRatioNode(std::move(args),
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
const int max_intra_op_parallelism_;
|
||||
};
|
||||
};
|
||||
|
||||
class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
explicit PrivateThreadPoolDatasetOp(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {}
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override {
|
||||
int64 num_threads;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ParseScalarArgument<int64>(ctx, "num_threads", &num_threads));
|
||||
OP_REQUIRES(ctx, num_threads >= 1,
|
||||
errors::InvalidArgument("`num_threads` must be >= 1"));
|
||||
*output = new Dataset(ctx, input, num_threads);
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const DatasetBase* input, int num_threads)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
input_(input),
|
||||
num_threads_(num_threads) {
|
||||
thread_pool_ = MakeUnique<thread::ThreadPool>(
|
||||
ctx->env(), ThreadOptions{}, "tf_data_private_threadpool",
|
||||
num_threads,
|
||||
/*low_latency_hint=*/false);
|
||||
input_->Ref();
|
||||
}
|
||||
|
||||
~Dataset() override { input_->Unref(); }
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return std::unique_ptr<IteratorBase>(
|
||||
new Iterator({this, strings::StrCat(prefix, "::PrivateThreadPool")}));
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
return input_->output_dtypes();
|
||||
}
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
return input_->output_shapes();
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
return "PrivateThreadPoolDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||
Node* num_threads_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(num_threads_, &num_threads_node));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
thread::ThreadPool* pool = dataset()->thread_pool_.get();
|
||||
IteratorContext::Params params(ctx);
|
||||
params.runner = [pool](std::function<void()> c) {
|
||||
pool->Schedule(std::move(c));
|
||||
};
|
||||
params.runner_threadpool_size = dataset()->num_threads_;
|
||||
return input_impl_->GetNext(IteratorContext{std::move(params)},
|
||||
out_tensors, end_of_sequence);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<model::Node> CreateNode(
|
||||
IteratorContext* ctx, model::Node::Args args) const override {
|
||||
return model::MakeKnownRatioNode(std::move(args),
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
const int64 num_threads_;
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalMaxIntraOpParallelismDataset").Device(DEVICE_CPU),
|
||||
MaxIntraOpParallelismDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ExperimentalPrivateThreadPoolDataset").Device(DEVICE_CPU),
|
||||
PrivateThreadPoolDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ExperimentalThreadPoolHandle").Device(DEVICE_CPU),
|
||||
ThreadPoolHandleOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
|
@ -140,6 +140,22 @@ REGISTER_OP("ExperimentalFunctionBufferingResourceReset")
|
||||
.Input("function_buffer_resource: resource")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("ExperimentalMaxIntraOpParallelismDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("max_intra_op_parallelism: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalPrivateThreadPoolDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("num_threads: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalThreadPoolDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("thread_pool: resource")
|
||||
|
@ -32,6 +32,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
||||
@@StatsAggregator
|
||||
@@StatsOptions
|
||||
@@TFRecordWriter
|
||||
@@ThreadingOptions
|
||||
|
||||
@@bucket_by_sequence_length
|
||||
@@choose_from_datasets
|
||||
@ -101,6 +102,7 @@ from tensorflow.python.data.experimental.ops.shuffle_ops import shuffle_and_repe
|
||||
from tensorflow.python.data.experimental.ops.stats_aggregator import StatsAggregator
|
||||
from tensorflow.python.data.experimental.ops.stats_ops import latency_stats
|
||||
from tensorflow.python.data.experimental.ops.stats_options import StatsOptions
|
||||
from tensorflow.python.data.experimental.ops.threading_options import ThreadingOptions
|
||||
from tensorflow.python.data.experimental.ops.unique import unique
|
||||
from tensorflow.python.data.experimental.ops.writers import TFRecordWriter
|
||||
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
|
||||
|
@ -60,7 +60,8 @@ class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
["LatencyStats", "Map", "LatencyStats", "Prefetch",
|
||||
"LatencyStats"])).map(lambda x: x * x).prefetch(1)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_stats = stats_options.StatsOptions(aggregator)
|
||||
options.experimental_stats = stats_options.StatsOptions()
|
||||
options.experimental_stats.aggregator = aggregator
|
||||
dataset = dataset.with_options(options)
|
||||
self.assertDatasetProduces(
|
||||
dataset,
|
||||
|
@ -22,6 +22,7 @@ import threading
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import threading_options
|
||||
from tensorflow.python.data.experimental.ops import threadpool
|
||||
from tensorflow.python.data.experimental.ops import unique
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
@ -35,18 +36,7 @@ from tensorflow.python.platform import test
|
||||
class OverrideThreadpoolTest(test_base.DatasetTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", 1, None),
|
||||
("2", 2, None),
|
||||
("3", 4, None),
|
||||
("4", 8, None),
|
||||
("5", 16, None),
|
||||
("6", 4, -1),
|
||||
("7", 4, 0),
|
||||
("8", 4, 1),
|
||||
("9", 4, 4),
|
||||
)
|
||||
def testNumThreads(self, num_threads, max_intra_op_parallelism):
|
||||
def _testNumThreadsHelper(self, num_threads, override_threadpool_fn):
|
||||
|
||||
def get_thread_id(_):
|
||||
# Python creates a dummy thread object to represent the current
|
||||
@ -60,14 +50,7 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
|
||||
dataset_ops.Dataset.range(1000).map(
|
||||
lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
|
||||
num_parallel_calls=32).apply(unique.unique()))
|
||||
|
||||
dataset = threadpool.override_threadpool(
|
||||
dataset,
|
||||
threadpool.PrivateThreadPool(
|
||||
num_threads,
|
||||
max_intra_op_parallelism=max_intra_op_parallelism,
|
||||
display_name="private_thread_pool_%d" % num_threads))
|
||||
|
||||
dataset = override_threadpool_fn(dataset)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
@ -79,12 +62,64 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
|
||||
thread_ids.append(sess.run(next_element))
|
||||
except errors.OutOfRangeError:
|
||||
pass
|
||||
self.assertEqual(len(thread_ids), len(set(thread_ids)))
|
||||
self.assertGreater(len(thread_ids), 0)
|
||||
# NOTE(mrry): We don't control the thread pool scheduling, and
|
||||
# so cannot guarantee that all of the threads in the pool will
|
||||
# perform work.
|
||||
self.assertLessEqual(len(thread_ids), num_threads)
|
||||
self.assertLen(thread_ids, len(set(thread_ids)))
|
||||
self.assertNotEmpty(thread_ids)
|
||||
if num_threads:
|
||||
# NOTE(mrry): We don't control the thread pool scheduling, and
|
||||
# so cannot guarantee that all of the threads in the pool will
|
||||
# perform work.
|
||||
self.assertLessEqual(len(thread_ids), num_threads)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", 1, None),
|
||||
("2", 2, None),
|
||||
("3", 4, None),
|
||||
("4", 8, None),
|
||||
("5", 16, None),
|
||||
("6", 4, -1),
|
||||
("7", 4, 0),
|
||||
("8", 4, 1),
|
||||
("9", 4, 4),
|
||||
)
|
||||
def testNumThreadsDeprecated(self, num_threads, max_intra_op_parallelism):
|
||||
|
||||
def override_threadpool_fn(dataset):
|
||||
return threadpool.override_threadpool(
|
||||
dataset,
|
||||
threadpool.PrivateThreadPool(
|
||||
num_threads,
|
||||
max_intra_op_parallelism=max_intra_op_parallelism,
|
||||
display_name="private_thread_pool_%d" % num_threads))
|
||||
|
||||
self._testNumThreadsHelper(num_threads, override_threadpool_fn)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", 1, None),
|
||||
("2", 2, None),
|
||||
("3", 4, None),
|
||||
("4", 8, None),
|
||||
("5", 16, None),
|
||||
("6", None, 0),
|
||||
("7", None, 1),
|
||||
("8", None, 4),
|
||||
("9", 4, 0),
|
||||
("10", 4, 1),
|
||||
("11", 4, 4),
|
||||
("12", None, None),
|
||||
)
|
||||
def testNumThreads(self, num_threads, max_intra_op_parallelism):
|
||||
|
||||
def override_threadpool_fn(dataset):
|
||||
t_options = threading_options.ThreadingOptions()
|
||||
if max_intra_op_parallelism is not None:
|
||||
t_options.max_intra_op_parallelism = max_intra_op_parallelism
|
||||
if num_threads is not None:
|
||||
t_options.private_threadpool_size = num_threads
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_threading = t_options
|
||||
return dataset.with_options(options)
|
||||
|
||||
self._testNumThreadsHelper(num_threads, override_threadpool_fn)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -45,22 +45,18 @@ def function_set_stats_aggregator(dataset,
|
||||
|
||||
def function_apply_options(dataset, aggregator, prefix="", counter_prefix=""):
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_stats = stats_options.StatsOptions(aggregator)
|
||||
options.experimental_stats = stats_options.StatsOptions()
|
||||
options.experimental_stats.aggregator = aggregator
|
||||
options.experimental_stats.prefix = prefix
|
||||
options.experimental_stats.counter_prefix = counter_prefix
|
||||
options.experimental_stats.latency_all_edges = False
|
||||
if prefix:
|
||||
options.experimental_stats.prefix = prefix
|
||||
if counter_prefix:
|
||||
options.experimental_stats.counter_prefix = counter_prefix
|
||||
return dataset.with_options(options)
|
||||
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
testcase_name="SetStatsAggregator",
|
||||
dataset_transformation=function_set_stats_aggregator),
|
||||
dict(
|
||||
testcase_name="StatsOptions",
|
||||
dataset_transformation=function_apply_options))
|
||||
("SetStatsAggregator", function_set_stats_aggregator),
|
||||
("StatsOptions", function_apply_options),
|
||||
)
|
||||
class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
|
||||
|
||||
def testBytesProduced(self, dataset_transformation):
|
||||
|
@ -188,6 +188,17 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "map_defun",
|
||||
srcs = ["map_defun.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "optimization",
|
||||
srcs = ["optimization.py"],
|
||||
@ -217,17 +228,6 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "map_defun",
|
||||
srcs = ["map_defun.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "resampling",
|
||||
srcs = ["resampling.py"],
|
||||
@ -303,6 +303,18 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":stats_aggregator",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/util:options",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "threading_options",
|
||||
srcs = ["threading_options.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/util:options",
|
||||
],
|
||||
)
|
||||
|
||||
@ -313,9 +325,8 @@ py_library(
|
||||
deps = [
|
||||
"//tensorflow/python:experimental_dataset_ops_gen",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//tensorflow/python/data/util:sparse",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
||||
|
@ -20,11 +20,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.experimental.ops import stats_aggregator
|
||||
from tensorflow.python.data.util import options
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("data.experimental.StatsOptions")
|
||||
class StatsOptions(object):
|
||||
class StatsOptions(options.OptionsBase):
|
||||
"""Represents options for collecting dataset stats using `StatsAggregator`.
|
||||
|
||||
To apply `StatsOptions` with a `tf.data.Dataset` object, use the following
|
||||
@ -52,52 +53,29 @@ class StatsOptions(object):
|
||||
```
|
||||
"""
|
||||
|
||||
for _name, _ty, _default, _docstring in [
|
||||
("aggregator", stats_aggregator.StatsAggregator, None,
|
||||
"Associate the given statistics options with the dataset pipeline."),
|
||||
("prefix", str, "",
|
||||
"Prefix to prepend all statistics recorded for the input `dataset` with."
|
||||
),
|
||||
("counter_prefix", str, "",
|
||||
"Prefix for the statistics recorded as counter."),
|
||||
("latency_all_edges", bool, True,
|
||||
"Whether to add latency measurements on all edges."),
|
||||
]:
|
||||
aggregator = options.create_option(
|
||||
name="aggregator",
|
||||
ty=stats_aggregator.StatsAggregator,
|
||||
docstring=
|
||||
"Associates the given statistics aggregator with the dataset pipeline.")
|
||||
|
||||
def _make_getter(name): # pylint: disable=no-self-argument
|
||||
prefix = options.create_option(
|
||||
name="prefix",
|
||||
ty=str,
|
||||
docstring=
|
||||
"Prefix to prepend all statistics recorded for the input `dataset` with.",
|
||||
default="")
|
||||
|
||||
def getter(self):
|
||||
return getattr(self, "_" + name)
|
||||
counter_prefix = options.create_option(
|
||||
name="counter_prefix",
|
||||
ty=str,
|
||||
docstring=
|
||||
"Prefix for the statistics recorded as counter.",
|
||||
default="")
|
||||
|
||||
return getter
|
||||
|
||||
def _make_setter(name, ty): # pylint: disable=no-self-argument
|
||||
|
||||
def setter(self, value):
|
||||
if not isinstance(value, ty):
|
||||
raise TypeError(
|
||||
"Attempting to set the option %s to incompatible value: %r when "
|
||||
"it expects %r" % (name, value, ty))
|
||||
setattr(self, "_" + name, value)
|
||||
|
||||
return setter
|
||||
|
||||
vars()["_" + _name] = _default
|
||||
vars()[_name] = property(
|
||||
_make_getter(_name), _make_setter(_name, _ty), _default, _docstring)
|
||||
|
||||
def __init__(self, aggregator=None):
|
||||
if aggregator:
|
||||
self.aggregator = aggregator
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, self.__class__):
|
||||
return self.__dict__ == other.__dict__
|
||||
else:
|
||||
return False
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.__dict__)
|
||||
latency_all_edges = options.create_option(
|
||||
name="latency_all_edges",
|
||||
ty=bool,
|
||||
docstring=
|
||||
"Whether to add latency measurements on all edges.",
|
||||
default=True)
|
||||
|
50
tensorflow/python/data/experimental/ops/threading_options.py
Normal file
50
tensorflow/python/data/experimental/ops/threading_options.py
Normal file
@ -0,0 +1,50 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Experimental API for controlling threading in `tf.data` pipelines."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
from tensorflow.python.data.util import options
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("data.experimental.ThreadingOptions")
|
||||
class ThreadingOptions(options.OptionsBase):
|
||||
"""Represents options for dataset threading.
|
||||
|
||||
To apply `ThreadingOptions` to a `dataset` object, use the following pattern:
|
||||
|
||||
```python
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_threading = tf.data.experimental.ThreadingOptions()
|
||||
options.experimental_threading.private_threadpool_size = 10
|
||||
dataset = dataset.with_options(options)
|
||||
```
|
||||
"""
|
||||
|
||||
max_intra_op_parallelism = options.create_option(
|
||||
name="max_intra_op_parallelism",
|
||||
ty=int,
|
||||
docstring=
|
||||
"If set, it overrides the maximum degree of intra-op parallelism.")
|
||||
|
||||
private_threadpool_size = options.create_option(
|
||||
name="private_threadpool_size",
|
||||
ty=int,
|
||||
docstring=
|
||||
"If set, the dataset will use a private threadpool of the given size.",
|
||||
default=None)
|
@ -239,7 +239,7 @@ class DatasetOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
options2 = dataset_ops.Options()
|
||||
options2.experimental_autotune = False
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Cannot merge incompatible values of option"):
|
||||
"Cannot merge incompatible values"):
|
||||
dataset_ops.Dataset.range(0).with_options(options1).with_options(options2)
|
||||
|
||||
def testOptionsMergeOptionsFromMultipleInputs(self):
|
||||
|
@ -14,6 +14,7 @@ py_library(
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:experimental_dataset_ops_gen",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:math_ops",
|
||||
@ -26,7 +27,9 @@ py_library(
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/experimental/ops:stats_options",
|
||||
"//tensorflow/python/data/experimental/ops:threading_options",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//tensorflow/python/data/util:options",
|
||||
"//tensorflow/python/data/util:random_seed",
|
||||
"//tensorflow/python/data/util:sparse",
|
||||
"//tensorflow/python/data/util:structure",
|
||||
|
@ -27,8 +27,10 @@ import six
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import stats_options
|
||||
from tensorflow.python.data.experimental.ops import threading_options
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import options as options_lib
|
||||
from tensorflow.python.data.util import random_seed
|
||||
from tensorflow.python.data.util import sparse
|
||||
from tensorflow.python.data.util import structure as structure_lib
|
||||
@ -45,6 +47,7 @@ from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
|
||||
from tensorflow.python.ops import gen_io_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
@ -107,6 +110,14 @@ class DatasetV2(object):
|
||||
|
||||
dataset = self
|
||||
options = self.options()
|
||||
if options.experimental_threading is not None:
|
||||
t_options = options.experimental_threading
|
||||
if t_options.private_threadpool_size is not None:
|
||||
dataset = _PrivateThreadPoolDataset(dataset,
|
||||
t_options.private_threadpool_size)
|
||||
if t_options.max_intra_op_parallelism is not None:
|
||||
dataset = _MaxIntraOpParallelismDataset(
|
||||
dataset, t_options.max_intra_op_parallelism)
|
||||
static_optimizations = options._static_optimizations() # pylint: disable=protected-access
|
||||
if static_optimizations:
|
||||
dataset = _OptimizeDataset(dataset, static_optimizations)
|
||||
@ -1371,10 +1382,9 @@ class DatasetV2(object):
|
||||
def with_options(self, options):
|
||||
"""Returns a new `tf.data.Dataset` with the given options set.
|
||||
|
||||
The options are "global" in the sense they apply to the entire input
|
||||
pipeline in which the `with_options` transformation is used. If options are
|
||||
set multiple times, they are merged if possible (see
|
||||
`tf.data.Options.merge()` for details).
|
||||
The options are "global" in the sense they apply to the entire dataset.
|
||||
If options are set multiple times, they are merged as long as different
|
||||
options do not use different non-default values.
|
||||
|
||||
Args:
|
||||
options: A `tf.data.Options` that identifies the options the use.
|
||||
@ -1383,7 +1393,7 @@ class DatasetV2(object):
|
||||
Dataset: A `Dataset` with the given options.
|
||||
|
||||
Raises:
|
||||
ValueError: if options are set more than once
|
||||
ValueError: when an option is set more than once to a non-default value
|
||||
"""
|
||||
return _OptionsDataset(self, options)
|
||||
|
||||
@ -1571,7 +1581,7 @@ class DatasetV1Adapter(DatasetV1):
|
||||
|
||||
|
||||
@tf_export("data.Options")
|
||||
class Options(object):
|
||||
class Options(options_lib.OptionsBase):
|
||||
"""Represents options for tf.data.Dataset.
|
||||
|
||||
An `Options` object can be for instance used to control which static
|
||||
@ -1579,69 +1589,81 @@ class Options(object):
|
||||
tune the parallelism of operations such as `tf.data.Dataset.map` or
|
||||
`tf.data.Dataset.interleave`.
|
||||
"""
|
||||
for _name, _ty, _docstring in [
|
||||
("experimental_autotune", bool,
|
||||
"Whether to dynamically adjust the values of tunable parameters (e.g. "
|
||||
"degrees of parallelism)."),
|
||||
("experimental_deterministic", bool,
|
||||
"Whether the outputs need to be produced in deterministic order."),
|
||||
("experimental_filter_fusion", bool,
|
||||
"Whether to fuse filter transformations."),
|
||||
("experimental_hoist_random_uniform", bool,
|
||||
"Whether to hoist `tf.random_uniform()` ops out of map transformations."
|
||||
),
|
||||
("experimental_stats", stats_options.StatsOptions,
|
||||
"Associate the given statistics options with the dataset pipeline."),
|
||||
("experimental_map_and_batch_fusion", bool,
|
||||
"Whether to fuse map and batch transformations."),
|
||||
("experimental_map_and_filter_fusion", bool,
|
||||
"Whether to fuse map and filter transformations."),
|
||||
("experimental_map_fusion", bool, "Whether to fuse map transformations."),
|
||||
("experimental_map_parallelization", bool,
|
||||
"Whether to parallelize stateless map transformations."),
|
||||
("experimental_map_vectorization", bool,
|
||||
"Whether to vectorize map transformations."),
|
||||
("experimental_noop_elimination", bool,
|
||||
"Whether to eliminate no-op transformations."),
|
||||
("experimental_shuffle_and_repeat_fusion", bool,
|
||||
"Whether to fuse shuffle and repeat transformations."),
|
||||
("experimental_numa_aware", bool,
|
||||
"Whether to use NUMA-aware operations."),
|
||||
]:
|
||||
|
||||
def _make_getter(name): # pylint: disable=no-self-argument
|
||||
experimental_autotune = options_lib.create_option(
|
||||
name="experimental_autotune",
|
||||
ty=bool,
|
||||
docstring=
|
||||
"Whether to dynamically adjust the values of tunable parameters (e.g. "
|
||||
"degrees of parallelism).")
|
||||
|
||||
def getter(self):
|
||||
return getattr(self, "_" + name)
|
||||
experimental_deterministic = options_lib.create_option(
|
||||
name="experimental_deterministic",
|
||||
ty=bool,
|
||||
docstring=
|
||||
"Whether to dynamically adjust the values of tunable parameters (e.g. "
|
||||
"degrees of parallelism).")
|
||||
|
||||
return getter
|
||||
experimental_filter_fusion = options_lib.create_option(
|
||||
name="experimental_filter_fusion",
|
||||
ty=bool,
|
||||
docstring="Whether to fuse filter transformations.")
|
||||
|
||||
def _make_setter(name, ty): # pylint: disable=no-self-argument
|
||||
experimental_hoist_random_uniform = options_lib.create_option(
|
||||
name="experimental_hoist_random_uniform",
|
||||
ty=bool,
|
||||
docstring=
|
||||
"Whether to hoist `tf.random_uniform()` ops out of map transformations.")
|
||||
|
||||
def setter(self, value):
|
||||
if not isinstance(value, ty):
|
||||
raise TypeError(
|
||||
"Attempting to set the option %s to incompatible value: %r when "
|
||||
"it expects %r" % (name, value, ty))
|
||||
setattr(self, "_" + name, value)
|
||||
experimental_map_and_batch_fusion = options_lib.create_option(
|
||||
name="experimental_map_and_batch_fusion",
|
||||
ty=bool,
|
||||
docstring="Whether to fuse map and batch transformations.")
|
||||
|
||||
return setter
|
||||
experimental_map_and_filter_fusion = options_lib.create_option(
|
||||
name="experimental_map_and_filter_fusion",
|
||||
ty=bool,
|
||||
docstring="Whether to fuse map and filter transformations.")
|
||||
|
||||
vars()["_" + _name] = None
|
||||
vars()[_name] = property(
|
||||
_make_getter(_name), _make_setter(_name, _ty), None, _docstring)
|
||||
experimental_map_fusion = options_lib.create_option(
|
||||
name="experimental_map_and_filter_fusion",
|
||||
ty=bool,
|
||||
docstring="Whether to fuse map transformations.")
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
experimental_map_parallelization = options_lib.create_option(
|
||||
name="experimental_map_parallelization",
|
||||
ty=bool,
|
||||
docstring="Whether to parallelize stateless map transformations.")
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, self.__class__):
|
||||
return self.__dict__ == other.__dict__
|
||||
else:
|
||||
return False
|
||||
experimental_map_vectorization = options_lib.create_option(
|
||||
name="experimental_map_vectorization",
|
||||
ty=bool,
|
||||
docstring="Whether to vectorize map transformations.")
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
experimental_noop_elimination = options_lib.create_option(
|
||||
name="experimental_noop_elimination",
|
||||
ty=bool,
|
||||
docstring="Whether to eliminate no-op transformations.")
|
||||
|
||||
experimental_numa_aware = options_lib.create_option(
|
||||
name="experimental_numa_aware",
|
||||
ty=bool,
|
||||
docstring="Whether to use NUMA-aware operations.")
|
||||
|
||||
experimental_shuffle_and_repeat_fusion = options_lib.create_option(
|
||||
name="experimental_shuffle_and_repeat_fusion",
|
||||
ty=bool,
|
||||
docstring="Whether to fuse shuffle and repeat transformations.")
|
||||
|
||||
experimental_stats = options_lib.create_option(
|
||||
name="experimental_stats",
|
||||
ty=stats_options.StatsOptions,
|
||||
docstring="Associates the given statistics options with the dataset.")
|
||||
|
||||
experimental_threading = options_lib.create_option(
|
||||
name="experimental_threading",
|
||||
ty=threading_options.ThreadingOptions,
|
||||
docstring="Associates the given threading options with the dataset.")
|
||||
|
||||
def _static_optimizations(self):
|
||||
"""Produces the list of enabled static optimizations."""
|
||||
@ -1687,32 +1709,7 @@ class Options(object):
|
||||
New `tf.data.Options()` object which is the result of merging self with
|
||||
the input `tf.data.Options`.
|
||||
"""
|
||||
result = Options()
|
||||
for other in [self, options]:
|
||||
for name in [
|
||||
"experimental_autotune",
|
||||
"experimental_deterministic",
|
||||
"experimental_filter_fusion",
|
||||
"experimental_hoist_random_uniform",
|
||||
"experimental_map_and_batch_fusion",
|
||||
"experimental_map_and_filter_fusion",
|
||||
"experimental_map_fusion",
|
||||
"experimental_map_parallelization",
|
||||
"experimental_map_vectorization",
|
||||
"experimental_noop_elimination",
|
||||
"experimental_numa_aware",
|
||||
"experimental_shuffle_and_repeat_fusion",
|
||||
"experimental_stats",
|
||||
]:
|
||||
this = getattr(result, name)
|
||||
that = getattr(other, name)
|
||||
if that is not None:
|
||||
if this is None:
|
||||
setattr(result, name, that)
|
||||
elif this != that:
|
||||
raise ValueError(
|
||||
"Cannot merge incompatible values of option: %s" % (name))
|
||||
return result
|
||||
return options_lib.merge_options(self, options)
|
||||
|
||||
|
||||
class DatasetSource(DatasetV2):
|
||||
@ -3065,7 +3062,7 @@ class _OptimizeDataset(UnaryUnchangedStructureDataset):
|
||||
|
||||
|
||||
class _SetStatsAggregatorDataset(UnaryUnchangedStructureDataset):
|
||||
"""A `Dataset` that acts as an identity, and sets stats aggregator."""
|
||||
"""A `Dataset` that acts as an identity, and sets a stats aggregator."""
|
||||
|
||||
def __init__(self, input_dataset, aggregator, prefix, counter_prefix):
|
||||
super(_SetStatsAggregatorDataset, self).__init__(input_dataset)
|
||||
@ -3081,3 +3078,37 @@ class _SetStatsAggregatorDataset(UnaryUnchangedStructureDataset):
|
||||
self._prefix,
|
||||
self._counter_prefix,
|
||||
**flat_structure(self))
|
||||
|
||||
|
||||
class _MaxIntraOpParallelismDataset(UnaryUnchangedStructureDataset):
|
||||
"""A `Dataset` that acts as an identity, overriding intra-op parallelism."""
|
||||
|
||||
def __init__(self, input_dataset, max_intra_op_parallelism):
|
||||
super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._max_intra_op_parallelism = ops.convert_to_tensor(
|
||||
max_intra_op_parallelism,
|
||||
dtype=dtypes.int64,
|
||||
name="max_intra_op_parallelism")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return ged_ops.experimental_max_intra_op_parallelism_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
self._max_intra_op_parallelism,
|
||||
**flat_structure(self))
|
||||
|
||||
|
||||
class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset):
|
||||
"""A `Dataset` that acts as an identity, setting a private threadpool."""
|
||||
|
||||
def __init__(self, input_dataset, num_threads):
|
||||
super(_PrivateThreadPoolDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._num_threads = ops.convert_to_tensor(
|
||||
num_threads, dtype=dtypes.int64, name="num_threads")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return ged_ops.experimental_private_thread_pool_dataset(
|
||||
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
|
||||
self._num_threads,
|
||||
**flat_structure(self))
|
||||
|
@ -97,6 +97,23 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "options",
|
||||
srcs = ["options.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "options_test",
|
||||
size = "small",
|
||||
srcs = ["options_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":options",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "convert",
|
||||
srcs = ["convert.py"],
|
||||
|
131
tensorflow/python/data/util/options.py
Normal file
131
tensorflow/python/data/util/options.py
Normal file
@ -0,0 +1,131 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for tf.data options."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
def _internal_attr_name(name):
|
||||
return "_" + name
|
||||
|
||||
|
||||
class OptionsBase(object):
|
||||
"""Base class for representing a set of tf.data options.
|
||||
|
||||
Attributes:
|
||||
_options: Stores the option values.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._options = {}
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
for name in set(self._options) | set(other._options): # pylint: disable=protected-access
|
||||
if getattr(self, name) != getattr(other, name):
|
||||
return False
|
||||
return True
|
||||
|
||||
def __ne__(self, other):
|
||||
if isinstance(other, self.__class__):
|
||||
return not self.__eq__(other)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
|
||||
def create_option(name, ty, docstring, default=None):
|
||||
"""Creates a type-checked property.
|
||||
|
||||
Args:
|
||||
name: the name to use
|
||||
ty: the type to use
|
||||
docstring: the docstring to use
|
||||
default: the default value to use
|
||||
|
||||
Returns:
|
||||
A type-checked property.
|
||||
"""
|
||||
|
||||
def get_fn(self):
|
||||
return self._options.get(name, default) # pylint: disable=protected-access
|
||||
|
||||
def set_fn(self, value):
|
||||
if not isinstance(value, ty):
|
||||
raise TypeError("Property \"%s\" must be of type %s, got: %r (type: %r)" %
|
||||
(name, ty, value, type(value)))
|
||||
self._options[name] = value # pylint: disable=protected-access
|
||||
|
||||
return property(get_fn, set_fn, None, docstring)
|
||||
|
||||
|
||||
def merge_options(*options_list):
|
||||
"""Merges the given options, returning the result as a new options object.
|
||||
|
||||
The input arguments are expected to have a matching type that derives from
|
||||
`OptionsBase` (and thus each represent a set of options). The method outputs
|
||||
an object of the same type created by merging the sets of options represented
|
||||
by the input arguments.
|
||||
|
||||
The sets of options can be merged as long as there does not exist an option
|
||||
with different non-default values.
|
||||
|
||||
If an option is an instance of `OptionsBase` itself, then this method is
|
||||
applied recursively to the set of options represented by this option.
|
||||
|
||||
Args:
|
||||
*options_list: options to merge
|
||||
|
||||
Raises:
|
||||
TypeError: if the input arguments are incompatible or not derived from
|
||||
`OptionsBase`
|
||||
ValueError: if the given options cannot be merged
|
||||
|
||||
Returns:
|
||||
A new options object which is the result of merging the given options.
|
||||
"""
|
||||
if len(options_list) < 1:
|
||||
raise ValueError("At least one options should be provided")
|
||||
result_type = type(options_list[0])
|
||||
|
||||
for options in options_list:
|
||||
if not isinstance(options, result_type):
|
||||
raise TypeError("Incompatible options type: %r vs %r" % (type(options),
|
||||
result_type))
|
||||
|
||||
if not isinstance(options_list[0], OptionsBase):
|
||||
raise TypeError("The inputs should inherit from `OptionsBase`")
|
||||
|
||||
default_options = result_type()
|
||||
result = result_type()
|
||||
for options in options_list:
|
||||
# Iterate over all set options and merge the into the result.
|
||||
for name in options._options: # pylint: disable=protected-access
|
||||
this = getattr(result, name)
|
||||
that = getattr(options, name)
|
||||
default = getattr(default_options, name)
|
||||
if that == default:
|
||||
continue
|
||||
elif this == default:
|
||||
setattr(result, name, that)
|
||||
elif isinstance(this, OptionsBase):
|
||||
setattr(result, name, merge_options(this, that))
|
||||
elif this != that:
|
||||
raise ValueError(
|
||||
"Cannot merge incompatible values (%r and %r) of option: %s" %
|
||||
(this, that, name))
|
||||
return result
|
96
tensorflow/python/data/util/options_test.py
Normal file
96
tensorflow/python/data/util/options_test.py
Normal file
@ -0,0 +1,96 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for dataset options utilities."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.data.util import options
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class _TestOptions(options.OptionsBase):
|
||||
x = options.create_option(
|
||||
name="x", ty=int, docstring="the answer to everything", default=42)
|
||||
y = options.create_option(
|
||||
name="y", ty=float, docstring="a tasty pie", default=3.14)
|
||||
|
||||
|
||||
class _NestedTestOptions(options.OptionsBase):
|
||||
opts = options.create_option(
|
||||
name="opts", ty=_TestOptions, docstring="nested options")
|
||||
|
||||
|
||||
class OptionsTest(test.TestCase):
|
||||
|
||||
def testDocumentation(self):
|
||||
self.assertEqual(_TestOptions.x.__doc__, "the answer to everything")
|
||||
self.assertEqual(_TestOptions.y.__doc__, "a tasty pie")
|
||||
|
||||
def testCreateOption(self):
|
||||
opts = _TestOptions()
|
||||
self.assertEqual(opts.x, 42)
|
||||
self.assertEqual(opts.y, 3.14)
|
||||
self.assertIsInstance(opts.x, int)
|
||||
self.assertIsInstance(opts.y, float)
|
||||
opts.x = 0
|
||||
self.assertEqual(opts.x, 0)
|
||||
with self.assertRaises(TypeError):
|
||||
opts.x = 3.14
|
||||
opts.y = 0.0
|
||||
self.assertEqual(opts.y, 0.0)
|
||||
with self.assertRaises(TypeError):
|
||||
opts.y = 42
|
||||
|
||||
def testMergeOptions(self):
|
||||
options1, options2 = _TestOptions(), _TestOptions()
|
||||
with self.assertRaises(ValueError):
|
||||
options.merge_options()
|
||||
merged_options = options.merge_options(options1, options2)
|
||||
self.assertEqual(merged_options.x, 42)
|
||||
self.assertEqual(merged_options.y, 3.14)
|
||||
options1.x = 0
|
||||
options2.y = 0.0
|
||||
merged_options = options.merge_options(options1, options2)
|
||||
self.assertEqual(merged_options.x, 0)
|
||||
self.assertEqual(merged_options.y, 0.0)
|
||||
|
||||
def testMergeNestedOptions(self):
|
||||
options1, options2 = _NestedTestOptions(), _NestedTestOptions()
|
||||
merged_options = options.merge_options(options1, options2)
|
||||
self.assertEqual(merged_options.opts, None)
|
||||
options1.opts = _TestOptions()
|
||||
merged_options = options.merge_options(options1, options2)
|
||||
self.assertEqual(merged_options.opts, _TestOptions())
|
||||
options2.opts = _TestOptions()
|
||||
merged_options = options.merge_options(options1, options2)
|
||||
self.assertEqual(merged_options.opts, _TestOptions())
|
||||
options1.opts.x = 0
|
||||
options2.opts.y = 0.0
|
||||
merged_options = options.merge_options(options1, options2)
|
||||
self.assertEqual(merged_options.opts.x, 0)
|
||||
self.assertEqual(merged_options.opts.y, 0.0)
|
||||
|
||||
def testMergeOptionsInvalid(self):
|
||||
with self.assertRaises(TypeError):
|
||||
options.merge_options(0)
|
||||
options1, options2 = _TestOptions(), _NestedTestOptions()
|
||||
with self.assertRaises(TypeError):
|
||||
options.merge_options(options1, options2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.data.Options"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Options\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "experimental_autotune"
|
||||
@ -54,6 +55,10 @@ tf_class {
|
||||
name: "experimental_stats"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_threading"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.data.experimental.StatsOptions"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_options.StatsOptions\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "aggregator"
|
||||
@ -20,6 +21,6 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'aggregator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,18 @@
|
||||
path: "tensorflow.data.experimental.ThreadingOptions"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.experimental.ops.threading_options.ThreadingOptions\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "max_intra_op_parallelism"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "private_threadpool_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -40,6 +40,10 @@ tf_module {
|
||||
name: "TFRecordWriter"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "ThreadingOptions"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "Counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.data.Options"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Options\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "experimental_autotune"
|
||||
@ -54,6 +55,10 @@ tf_class {
|
||||
name: "experimental_stats"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_threading"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.data.experimental.StatsOptions"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_options.StatsOptions\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "aggregator"
|
||||
@ -20,6 +21,6 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'aggregator\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,18 @@
|
||||
path: "tensorflow.data.experimental.ThreadingOptions"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.experimental.ops.threading_options.ThreadingOptions\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "max_intra_op_parallelism"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "private_threadpool_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -40,6 +40,10 @@ tf_module {
|
||||
name: "TFRecordWriter"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "ThreadingOptions"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "Counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
|
||||
|
Loading…
x
Reference in New Issue
Block a user