Internal change
PiperOrigin-RevId: 323507899 Change-Id: I5ede4ed67d92f132a4e8b7fbe175084daa4181ec
This commit is contained in:
parent
8d35859243
commit
a1d78970aa
@ -0,0 +1,32 @@
|
||||
op {
|
||||
graph_op_name: "OptimizeDatasetV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "input_dataset"
|
||||
description: <<END
|
||||
A variant tensor representing the input dataset.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "optimizations_enabled"
|
||||
description: <<END
|
||||
A `tf.string` vector `tf.Tensor` identifying user enabled optimizations.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "optimizations_disabled"
|
||||
description: <<END
|
||||
A `tf.string` vector `tf.Tensor` identifying user disabled optimizations.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "optimizations_default"
|
||||
description: <<END
|
||||
A `tf.string` vector `tf.Tensor` identifying optimizations by default.
|
||||
END
|
||||
}
|
||||
summary: "Creates a dataset by applying related optimizations to `input_dataset`."
|
||||
description: <<END
|
||||
Creates a dataset by applying related optimizations to `input_dataset`.
|
||||
END
|
||||
}
|
@ -1243,6 +1243,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:platform_port",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -905,5 +905,135 @@ bool MatchesAnyVersionRE(StringPiece op_prefix, StringPiece op_to_match) {
|
||||
return RE2::FullMatch(op_to_match, expected_re);
|
||||
}
|
||||
|
||||
std::vector<tstring> SelectOptimizations(
|
||||
const string& job_name, const string& opt_ins_raw,
|
||||
const string& opt_outs_raw,
|
||||
const absl::flat_hash_map<string, uint64>& live_experiments,
|
||||
const std::vector<tstring>& optimizations_enabled,
|
||||
const std::vector<tstring>& optimizations_disabled,
|
||||
const std::vector<tstring>& optimizations_default,
|
||||
std::function<uint64(const string&)> hash_func) {
|
||||
// Creates a set of optimizations.
|
||||
absl::flat_hash_set<tstring> optimizations_set;
|
||||
|
||||
// Creates the opt in and opt out settings.
|
||||
std::vector<string> opt_ins, opt_outs;
|
||||
if (opt_ins_raw == "all") {
|
||||
for (auto& pair : live_experiments) {
|
||||
opt_ins.push_back(pair.first);
|
||||
}
|
||||
} else {
|
||||
opt_ins = str_util::Split(opt_ins_raw, ',', str_util::SkipEmpty());
|
||||
}
|
||||
if (opt_outs_raw == "all") {
|
||||
for (auto& pair : live_experiments) {
|
||||
opt_outs.push_back(pair.first);
|
||||
}
|
||||
} else {
|
||||
opt_outs = str_util::Split(opt_outs_raw, ',', str_util::SkipEmpty());
|
||||
}
|
||||
|
||||
// Checks if the opt in and opt out experiments are live experiments.
|
||||
for (auto& optimization : opt_ins) {
|
||||
if (live_experiments.find(optimization) == live_experiments.end()) {
|
||||
LOG(WARNING) << "The experiment \"" << optimization
|
||||
<< "\" is opted in but it is not a live experiment.";
|
||||
}
|
||||
}
|
||||
for (auto& optimization : opt_outs) {
|
||||
if (live_experiments.find(optimization) == live_experiments.end()) {
|
||||
LOG(WARNING) << "The experiment \"" << optimization
|
||||
<< "\" is opted out but it is not a live experiment.";
|
||||
}
|
||||
}
|
||||
|
||||
// Checks if the opt in settings conflict with opt out settings.
|
||||
for (auto& optimization : opt_ins) {
|
||||
if (std::find(opt_outs.begin(), opt_outs.end(), optimization) !=
|
||||
opt_outs.end()) {
|
||||
LOG(WARNING) << "The experiment \"" << optimization
|
||||
<< "\" is set in both \"TF_DATA_EXPERIMENT_OPT_IN\" and "
|
||||
"\"TF_DATA_EXPERIMENT_OPT_OUT\". Unless the experiment "
|
||||
"corresponds to an explicitly enabled optimization, it "
|
||||
"is not applied.";
|
||||
}
|
||||
}
|
||||
|
||||
// Checks if the enable/disable settings from tf.data.Options conflict with
|
||||
// user opt in/out settings. In which case we assume tf.data.Options settings
|
||||
// have higher priority to overwrite.
|
||||
for (auto& optimization : optimizations_enabled) {
|
||||
if (std::find(opt_outs.begin(), opt_outs.end(), optimization) !=
|
||||
opt_outs.end()) {
|
||||
LOG(WARNING) << "The optimization \"" << optimization
|
||||
<< "\" is opt out, but is still applied since"
|
||||
" it is enabled through tf.data.Options.";
|
||||
}
|
||||
}
|
||||
for (auto& optimization : optimizations_disabled) {
|
||||
if (std::find(opt_ins.begin(), opt_ins.end(), optimization) !=
|
||||
opt_ins.end()) {
|
||||
LOG(WARNING) << "The optimization \"" << optimization
|
||||
<< "\" is opt in, but is not applied since"
|
||||
" it is disabled through tf.data.Options.";
|
||||
}
|
||||
}
|
||||
|
||||
// Add the enabled optimizations.
|
||||
optimizations_set.insert(optimizations_enabled.begin(),
|
||||
optimizations_enabled.end());
|
||||
|
||||
// Add the default optimizations that are not explicitly opted out.
|
||||
for (auto& optimization : optimizations_default) {
|
||||
if (std::find(opt_outs.begin(), opt_outs.end(), optimization) ==
|
||||
opt_outs.end()) {
|
||||
optimizations_set.insert(optimization);
|
||||
}
|
||||
}
|
||||
|
||||
// Add the live experiments stochastically if they are neither opted in nor
|
||||
// opted out.
|
||||
for (auto& pair : live_experiments) {
|
||||
string experiment = pair.first;
|
||||
// Skip experiments that are explicitly opted out.
|
||||
if (std::find(opt_outs.begin(), opt_outs.end(), experiment) !=
|
||||
opt_outs.end()) {
|
||||
continue;
|
||||
}
|
||||
// Skip experiments whose transformations are explicitly disabled.
|
||||
if (std::find(optimizations_disabled.begin(), optimizations_disabled.end(),
|
||||
experiment) != optimizations_disabled.end()) {
|
||||
continue;
|
||||
}
|
||||
// Apply experiments that are explicitly opted in.
|
||||
if (std::find(opt_ins.begin(), opt_ins.end(), experiment) !=
|
||||
opt_ins.end()) {
|
||||
optimizations_set.insert(experiment);
|
||||
continue;
|
||||
}
|
||||
// Otherwise, apply experiment stochastically based on job name and
|
||||
// experiment roll out percentage.
|
||||
if (hash_func(strings::StrCat(job_name, experiment)) % 100 < pair.second) {
|
||||
optimizations_set.insert(experiment);
|
||||
}
|
||||
}
|
||||
|
||||
// Log the experiments that will be applied.
|
||||
if (VLOG_IS_ON(1)) {
|
||||
for (auto& pair : live_experiments) {
|
||||
string experiment = pair.first;
|
||||
if (std::find(optimizations_set.begin(), optimizations_set.end(),
|
||||
experiment) != optimizations_set.end()) {
|
||||
VLOG(1) << "The experiment \"" << experiment << "\" is applied.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<tstring> optimizations;
|
||||
optimizations.insert(optimizations.end(), optimizations_set.begin(),
|
||||
optimizations_set.end());
|
||||
return optimizations;
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -304,6 +304,18 @@ class DummyResourceOp : public OpKernel {
|
||||
// MatchesAnyVersionRE("PaddedBatchDataset", "BatchDataset") == false
|
||||
bool MatchesAnyVersionRE(StringPiece op_prefix, StringPiece op_to_match);
|
||||
|
||||
// Based on `optimizations_enabled`, `optimizations_disabled`, and
|
||||
// `optimizations_disabled`, returns the list of optimizations that will be
|
||||
// applied.
|
||||
std::vector<tstring> SelectOptimizations(
|
||||
const string& job_name, const string& opt_ins_raw,
|
||||
const string& opt_outs_raw,
|
||||
const absl::flat_hash_map<string, uint64>& live_experiments,
|
||||
const std::vector<tstring>& optimizations_enabled,
|
||||
const std::vector<tstring>& optimizations_disabled,
|
||||
const std::vector<tstring>& optimizations_default,
|
||||
std::function<uint64(const string&)> hash_func);
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -30,6 +30,8 @@ namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
using ::testing::UnorderedElementsAre;
|
||||
|
||||
class DatasetHashUtilsTest : public ::testing::Test {
|
||||
protected:
|
||||
uint64 GetHash(const FunctionDefLibrary& library, const FunctionDef& fn) {
|
||||
@ -1131,6 +1133,141 @@ TEST_F(DatasetHashUtilsTest, HashStringTensor) {
|
||||
EXPECT_NE(GetHash(v1), GetHash(v3));
|
||||
}
|
||||
|
||||
class SelectOptimizationsHashTest : public ::testing::TestWithParam<uint64> {};
|
||||
|
||||
TEST_P(SelectOptimizationsHashTest, DatasetUtils) {
|
||||
const uint64 hash_result = GetParam();
|
||||
string job_name = "job";
|
||||
const string opt_ins_raw = "";
|
||||
const string opt_outs_raw = "";
|
||||
auto hash_func = [hash_result](const string& str) { return hash_result; };
|
||||
absl::flat_hash_map<string, uint64> live_experiments = {
|
||||
{"exp1", 0}, {"exp2", 20}, {"exp3", 33}, {"exp4", 45},
|
||||
{"exp5", 67}, {"exp6", 88}, {"exp7", 100}};
|
||||
std::vector<tstring> optimizations_enabled, optimizations_disabled,
|
||||
optimizations_default;
|
||||
std::vector<tstring> optimizations =
|
||||
SelectOptimizations(job_name, opt_ins_raw, opt_outs_raw, live_experiments,
|
||||
optimizations_enabled, optimizations_disabled,
|
||||
optimizations_default, hash_func);
|
||||
|
||||
int tested_times = 0;
|
||||
switch (hash_result) {
|
||||
case 0:
|
||||
case 100:
|
||||
case 200:
|
||||
tested_times++;
|
||||
EXPECT_THAT(optimizations, UnorderedElementsAre("exp2", "exp3", "exp4",
|
||||
"exp5", "exp6", "exp7"));
|
||||
break;
|
||||
case 33:
|
||||
case 133:
|
||||
tested_times++;
|
||||
EXPECT_THAT(optimizations,
|
||||
UnorderedElementsAre("exp4", "exp5", "exp6", "exp7"));
|
||||
break;
|
||||
case 67:
|
||||
case 167:
|
||||
tested_times++;
|
||||
EXPECT_THAT(optimizations, UnorderedElementsAre("exp6", "exp7"));
|
||||
break;
|
||||
}
|
||||
EXPECT_EQ(tested_times, 1);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Test, SelectOptimizationsHashTest,
|
||||
::testing::Values(0, 33, 67, 100, 133, 167, 200));
|
||||
|
||||
class SelectOptimizationsOptTest
|
||||
: public ::testing::TestWithParam<std::tuple<string, string>> {};
|
||||
|
||||
TEST_P(SelectOptimizationsOptTest, DatasetUtils) {
|
||||
string job_name = "job";
|
||||
const string opt_ins_raw = std::get<0>(GetParam());
|
||||
const string opt_outs_raw = std::get<1>(GetParam());
|
||||
auto hash_func = [](const string& str) { return 50; };
|
||||
absl::flat_hash_map<string, uint64> live_experiments = {
|
||||
{"exp1", 0}, {"exp2", 25}, {"exp3", 50}, {"exp4", 75}, {"exp5", 100}};
|
||||
std::vector<tstring> optimizations_enabled, optimizations_disabled,
|
||||
optimizations_default;
|
||||
std::vector<tstring> optimizations =
|
||||
SelectOptimizations(job_name, opt_ins_raw, opt_outs_raw, live_experiments,
|
||||
optimizations_enabled, optimizations_disabled,
|
||||
optimizations_default, hash_func);
|
||||
|
||||
int tested_times = 0;
|
||||
if (opt_outs_raw == "all") {
|
||||
EXPECT_THAT(optimizations, UnorderedElementsAre());
|
||||
tested_times++;
|
||||
} else if (opt_outs_raw.empty()) {
|
||||
if (opt_ins_raw == "all") {
|
||||
EXPECT_THAT(optimizations,
|
||||
UnorderedElementsAre("exp1", "exp2", "exp3", "exp4", "exp5"));
|
||||
tested_times++;
|
||||
} else if (opt_ins_raw.empty()) {
|
||||
EXPECT_THAT(optimizations, UnorderedElementsAre("exp4", "exp5"));
|
||||
tested_times++;
|
||||
} else if (opt_ins_raw == "exp2,exp4") {
|
||||
EXPECT_THAT(optimizations, UnorderedElementsAre("exp2", "exp4", "exp5"));
|
||||
tested_times++;
|
||||
}
|
||||
} else if (opt_outs_raw == "exp1,exp5") {
|
||||
if (opt_ins_raw == "all") {
|
||||
EXPECT_THAT(optimizations, UnorderedElementsAre("exp2", "exp3", "exp4"));
|
||||
tested_times++;
|
||||
} else if (opt_ins_raw.empty()) {
|
||||
EXPECT_THAT(optimizations, UnorderedElementsAre("exp4"));
|
||||
tested_times++;
|
||||
} else if (opt_ins_raw == "exp2,exp4") {
|
||||
EXPECT_THAT(optimizations, UnorderedElementsAre("exp2", "exp4"));
|
||||
tested_times++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(tested_times, 1);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
Test, SelectOptimizationsOptTest,
|
||||
::testing::Combine(::testing::Values("all", "", "exp2,exp4"),
|
||||
::testing::Values("all", "", "exp1,exp5")));
|
||||
|
||||
class SelectOptimizationsConflictTest
|
||||
: public ::testing::TestWithParam<std::tuple<string, string, uint64>> {};
|
||||
|
||||
TEST_P(SelectOptimizationsConflictTest, DatasetUtils) {
|
||||
string job_name = "job";
|
||||
const string opt_ins_raw = std::get<0>(GetParam());
|
||||
const string opt_outs_raw = std::get<1>(GetParam());
|
||||
const uint64 hash_result = std::get<2>(GetParam());
|
||||
auto hash_func = [hash_result](const string& str) { return hash_result; };
|
||||
absl::flat_hash_map<string, uint64> live_experiments = {
|
||||
{"exp1", 20}, {"exp2", 30}, {"exp3", 40},
|
||||
{"exp4", 60}, {"exp5", 70}, {"exp6", 80}};
|
||||
std::vector<tstring> optimizations_enabled = {"exp1", "exp4"},
|
||||
optimizations_disabled = {"exp2", "exp5"},
|
||||
optimizations_default = {"exp3", "exp6"};
|
||||
std::vector<tstring> optimizations =
|
||||
SelectOptimizations(job_name, opt_ins_raw, opt_outs_raw, live_experiments,
|
||||
optimizations_enabled, optimizations_disabled,
|
||||
optimizations_default, hash_func);
|
||||
|
||||
int tested_times = 0;
|
||||
if (opt_outs_raw.empty()) {
|
||||
EXPECT_THAT(optimizations,
|
||||
UnorderedElementsAre("exp1", "exp3", "exp4", "exp6"));
|
||||
tested_times++;
|
||||
} else if (opt_outs_raw == "exp1,exp3") {
|
||||
EXPECT_THAT(optimizations, UnorderedElementsAre("exp1", "exp4", "exp6"));
|
||||
tested_times++;
|
||||
}
|
||||
EXPECT_EQ(tested_times, 1);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Test, SelectOptimizationsConflictTest,
|
||||
::testing::Combine(::testing::Values("", "exp2"),
|
||||
::testing::Values("", "exp1,exp3"),
|
||||
::testing::Values(10, 50, 90)));
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -18,8 +18,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/kernels/data/rewrite_utils.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/platform/host_info.h"
|
||||
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -31,10 +33,18 @@ namespace data {
|
||||
/* static */ constexpr const char* const OptimizeDatasetOp::kDatasetType;
|
||||
/* static */ constexpr const char* const OptimizeDatasetOp::kInputDataset;
|
||||
/* static */ constexpr const char* const OptimizeDatasetOp::kOptimizations;
|
||||
/* static */ constexpr const char* const
|
||||
OptimizeDatasetOp::kOptimizationsEnabled;
|
||||
/* static */ constexpr const char* const
|
||||
OptimizeDatasetOp::kOptimizationsDisabled;
|
||||
/* static */ constexpr const char* const
|
||||
OptimizeDatasetOp::kOptimizationsDefault;
|
||||
/* static */ constexpr const char* const OptimizeDatasetOp::kOutputTypes;
|
||||
/* static */ constexpr const char* const OptimizeDatasetOp::kOutputShapes;
|
||||
/* static */ constexpr const char* const
|
||||
OptimizeDatasetOp::kOptimizationConfigs;
|
||||
/* static */ constexpr const char* const OptimizeDatasetOp::kOptimizeDatasetV1;
|
||||
/* static */ constexpr const char* const OptimizeDatasetOp::kOptimizeDatasetV2;
|
||||
|
||||
constexpr char kOptimizerName[] = "tf_data_meta_optimizer";
|
||||
constexpr char kOptimizers[] = "optimizers";
|
||||
@ -42,6 +52,12 @@ constexpr char kOptimizerConfigs[] = "optimizer_configs";
|
||||
|
||||
OptimizeDatasetOp::OptimizeDatasetOp(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {
|
||||
auto& op_name = ctx->def().op();
|
||||
if (op_name == kOptimizeDatasetV1) {
|
||||
op_version_ = 1;
|
||||
} else if (op_name == kOptimizeDatasetV2) {
|
||||
op_version_ = 2;
|
||||
}
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetAttr(kOptimizationConfigs, &optimization_configs_));
|
||||
}
|
||||
@ -49,8 +65,44 @@ OptimizeDatasetOp::OptimizeDatasetOp(OpKernelConstruction* ctx)
|
||||
void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) {
|
||||
std::vector<tstring> optimizations;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ParseVectorArgument<tstring>(ctx, kOptimizations, &optimizations));
|
||||
if (op_version_ == 1) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ParseVectorArgument<tstring>(ctx, kOptimizations, &optimizations));
|
||||
} else if (op_version_ == 2) {
|
||||
std::vector<tstring> optimizations_enabled, optimizations_disabled,
|
||||
optimizations_default;
|
||||
OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, kOptimizationsEnabled,
|
||||
&optimizations_enabled));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ParseVectorArgument<tstring>(ctx, kOptimizationsDisabled,
|
||||
&optimizations_disabled));
|
||||
OP_REQUIRES_OK(ctx, ParseVectorArgument<tstring>(ctx, kOptimizationsDefault,
|
||||
&optimizations_default));
|
||||
|
||||
string job_name = port::JobName();
|
||||
if (job_name.empty()) {
|
||||
// If `job_name` is empty, apply the enabled and default optimizations
|
||||
// directly.
|
||||
optimizations.insert(optimizations.end(), optimizations_enabled.begin(),
|
||||
optimizations_enabled.end());
|
||||
optimizations.insert(optimizations.end(), optimizations_default.begin(),
|
||||
optimizations_default.end());
|
||||
} else {
|
||||
// The map that stores the experiment names and for how much percentage
|
||||
// of the jobs, the experiments will be randomly turned on.
|
||||
//
|
||||
// This is currently empty; we have no live experiments yet.
|
||||
absl::flat_hash_map<string, uint64> live_experiments;
|
||||
|
||||
const string opt_ins_raw = std::getenv("TF_DATA_EXPERIMENT_OPT_IN");
|
||||
const string opt_outs_raw = std::getenv("TF_DATA_EXPERIMENT_OPT_OUT");
|
||||
auto hash_func = [](const string& str) { return Hash64(str); };
|
||||
optimizations = SelectOptimizations(
|
||||
job_name, opt_ins_raw, opt_outs_raw, live_experiments,
|
||||
optimizations_enabled, optimizations_disabled, optimizations_default,
|
||||
hash_func);
|
||||
}
|
||||
}
|
||||
|
||||
auto config_factory = [this, &optimizations]() {
|
||||
return CreateConfig(optimizations, optimization_configs_);
|
||||
@ -95,6 +147,8 @@ RewriterConfig OptimizeDatasetOp::CreateConfig(
|
||||
namespace {
|
||||
REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU),
|
||||
OptimizeDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("OptimizeDatasetV2").Device(DEVICE_CPU),
|
||||
OptimizeDatasetOp);
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -25,10 +25,18 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
||||
static constexpr const char* const kDatasetType = "Optimize";
|
||||
static constexpr const char* const kInputDataset = "input_dataset";
|
||||
static constexpr const char* const kOptimizations = "optimizations";
|
||||
static constexpr const char* const kOptimizationsEnabled =
|
||||
"optimizations_enabled";
|
||||
static constexpr const char* const kOptimizationsDisabled =
|
||||
"optimizations_disabled";
|
||||
static constexpr const char* const kOptimizationsDefault =
|
||||
"optimizations_default";
|
||||
static constexpr const char* const kOutputTypes = "output_types";
|
||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||
static constexpr const char* const kOptimizationConfigs =
|
||||
"optimization_configs";
|
||||
static constexpr const char* const kOptimizeDatasetV1 = "OptimizeDataset";
|
||||
static constexpr const char* const kOptimizeDatasetV2 = "OptimizeDatasetV2";
|
||||
|
||||
explicit OptimizeDatasetOp(OpKernelConstruction* ctx);
|
||||
|
||||
@ -41,6 +49,7 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::vector<string> optimizations_configs);
|
||||
|
||||
std::vector<string> optimization_configs_;
|
||||
int op_version_ = 0;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
|
@ -837,6 +837,17 @@ REGISTER_OP("OptimizeDataset")
|
||||
.Attr("optimization_configs: list(string) = []")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("OptimizeDatasetV2")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("optimizations_enabled: string")
|
||||
.Input("optimizations_disabled: string")
|
||||
.Input("optimizations_default: string")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("optimization_configs: list(string) = []")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("OptionalFromValue")
|
||||
.Input("components: Toutput_types")
|
||||
.Output("optional: variant")
|
||||
|
@ -61,6 +61,8 @@ string Hostname() {
|
||||
return string(hostname);
|
||||
}
|
||||
|
||||
string JobName() { return ""; }
|
||||
|
||||
int NumSchedulableCPUs() {
|
||||
#if defined(__linux__) && !defined(__ANDROID__)
|
||||
cpu_set_t cpuset;
|
||||
|
@ -21,9 +21,13 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace port {
|
||||
|
||||
// Return the hostname of the machine on which this process is running
|
||||
// Return the hostname of the machine on which this process is running.
|
||||
string Hostname();
|
||||
|
||||
// Return the job name as a string if it exists, otherwise return an empty
|
||||
// string.
|
||||
string JobName();
|
||||
|
||||
} // namespace port
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -225,11 +225,14 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
optimized_it = dataset_ops.make_initializable_iterator(optimized_dataset)
|
||||
|
||||
self.assertGreaterEqual(len(w), 1)
|
||||
expected = ("tf.data graph rewrites are not compatible with "
|
||||
"tf.Variable. The following rewrites will be disabled: %s."
|
||||
" To enable rewrites, use resource variables instead by "
|
||||
"calling `tf.enable_resource_variables()` at the start of the "
|
||||
"program." % (", ".join(options._graph_rewrites())))
|
||||
graph_rewrites = options._graph_rewrites()
|
||||
expected = (
|
||||
"tf.data graph rewrites are not compatible with "
|
||||
"tf.Variable. The following rewrites will be disabled: %s."
|
||||
" To enable rewrites, use resource variables instead by "
|
||||
"calling `tf.enable_resource_variables()` at the start of the "
|
||||
"program." %
|
||||
(", ".join(graph_rewrites.enabled + graph_rewrites.default)))
|
||||
self.assertTrue(any(expected in str(warning) for warning in w))
|
||||
|
||||
# Check that outputs are the same in the optimized and unoptimized cases,
|
||||
@ -251,34 +254,136 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
break
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationEnabledByDefault(self):
|
||||
"""Tests that some optimizations are applied to datasets by default."""
|
||||
def testOptimizationDefault(self):
|
||||
"""Tests the optimization settings by default."""
|
||||
options = dataset_ops.Options()
|
||||
expected_optimizations = [
|
||||
expected_optimizations_enabled = []
|
||||
expected_optimizations_disabled = []
|
||||
expected_optimizations_default = [
|
||||
"map_and_batch_fusion",
|
||||
"noop_elimination",
|
||||
"shuffle_and_repeat_fusion",
|
||||
]
|
||||
self.assertEqual(
|
||||
set(options._graph_rewrites()), set(expected_optimizations))
|
||||
graph_rewrites = options._graph_rewrites()
|
||||
self.assertEqual(set(graph_rewrites.enabled),
|
||||
set(expected_optimizations_enabled))
|
||||
self.assertEqual(set(graph_rewrites.disabled),
|
||||
set(expected_optimizations_disabled))
|
||||
self.assertEqual(set(graph_rewrites.default),
|
||||
set(expected_optimizations_default))
|
||||
|
||||
options.experimental_optimization.apply_default_optimizations = True
|
||||
graph_rewrites = options._graph_rewrites()
|
||||
self.assertEqual(set(graph_rewrites.enabled),
|
||||
set(expected_optimizations_enabled))
|
||||
self.assertEqual(set(graph_rewrites.disabled),
|
||||
set(expected_optimizations_disabled))
|
||||
self.assertEqual(set(graph_rewrites.default),
|
||||
set(expected_optimizations_default))
|
||||
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
expected_optimizations_default = []
|
||||
graph_rewrites = options._graph_rewrites()
|
||||
self.assertEqual(set(graph_rewrites.enabled),
|
||||
set(expected_optimizations_enabled))
|
||||
self.assertEqual(set(graph_rewrites.disabled),
|
||||
set(expected_optimizations_disabled))
|
||||
self.assertEqual(set(graph_rewrites.default),
|
||||
set(expected_optimizations_default))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationDisableDefault(self):
|
||||
"""Tests that we can disable all graph optimizations enabled by default.
|
||||
|
||||
If the `apply_default_optimizations` optimization options flag is False,
|
||||
only explicitly enabled optimizations will be applied.
|
||||
"""
|
||||
def testOptimizationEnabled(self):
|
||||
"""Tests the optimization settings by enabling all."""
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.apply_default_optimizations = False
|
||||
options.experimental_optimization.filter_fusion = True
|
||||
options.experimental_optimization.filter_with_random_uniform_fusion = True
|
||||
options.experimental_optimization.hoist_random_uniform = True
|
||||
options.experimental_optimization.map_and_batch_fusion = True
|
||||
options.experimental_optimization.map_and_filter_fusion = True
|
||||
options.experimental_optimization.map_parallelization = True
|
||||
options.experimental_optimization.map_fusion = True
|
||||
options.experimental_optimization.noop_elimination = True
|
||||
expected_optimizations = [
|
||||
options.experimental_optimization.parallel_batch = True
|
||||
options.experimental_optimization.shuffle_and_repeat_fusion = True
|
||||
options.experimental_optimization.map_vectorization.enabled = True
|
||||
options.experimental_optimization.autotune_buffers = True
|
||||
options.experimental_deterministic = False
|
||||
options.experimental_stats.latency_all_edges = True
|
||||
options.experimental_slack = True
|
||||
|
||||
expected_optimizations_enabled = [
|
||||
"filter_fusion",
|
||||
"filter_with_random_uniform_fusion",
|
||||
"hoist_random_uniform",
|
||||
"map_and_batch_fusion",
|
||||
"map_and_filter_fusion",
|
||||
"map_parallelization",
|
||||
"map_fusion",
|
||||
"noop_elimination",
|
||||
"parallel_batch",
|
||||
"shuffle_and_repeat_fusion",
|
||||
"map_vectorization",
|
||||
"inject_prefetch",
|
||||
"make_sloppy",
|
||||
"latency_all_edges",
|
||||
"slack",
|
||||
]
|
||||
self.assertEqual(
|
||||
set(options._graph_rewrites()), set(expected_optimizations))
|
||||
expected_optimizations_disabled = []
|
||||
expected_optimizations_default = []
|
||||
graph_rewrites = options._graph_rewrites()
|
||||
self.assertEqual(set(graph_rewrites.enabled),
|
||||
set(expected_optimizations_enabled))
|
||||
self.assertEqual(set(graph_rewrites.disabled),
|
||||
set(expected_optimizations_disabled))
|
||||
self.assertEqual(set(graph_rewrites.default),
|
||||
set(expected_optimizations_default))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptimizationDisabled(self):
|
||||
"""Tests the optimization settings by disabling all."""
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.filter_fusion = False
|
||||
options.experimental_optimization.filter_with_random_uniform_fusion = False
|
||||
options.experimental_optimization.hoist_random_uniform = False
|
||||
options.experimental_optimization.map_and_batch_fusion = False
|
||||
options.experimental_optimization.map_and_filter_fusion = False
|
||||
options.experimental_optimization.map_parallelization = False
|
||||
options.experimental_optimization.map_fusion = False
|
||||
options.experimental_optimization.noop_elimination = False
|
||||
options.experimental_optimization.parallel_batch = False
|
||||
options.experimental_optimization.shuffle_and_repeat_fusion = False
|
||||
options.experimental_optimization.map_vectorization.enabled = False
|
||||
options.experimental_optimization.autotune = False
|
||||
options.experimental_deterministic = True
|
||||
options.experimental_stats.latency_all_edges = False
|
||||
options.experimental_slack = False
|
||||
|
||||
expected_optimizations_enabled = []
|
||||
expected_optimizations_disabled = [
|
||||
"filter_fusion",
|
||||
"filter_with_random_uniform_fusion",
|
||||
"hoist_random_uniform",
|
||||
"map_and_batch_fusion",
|
||||
"map_and_filter_fusion",
|
||||
"map_parallelization",
|
||||
"map_fusion",
|
||||
"noop_elimination",
|
||||
"parallel_batch",
|
||||
"shuffle_and_repeat_fusion",
|
||||
"map_vectorization",
|
||||
"inject_prefetch",
|
||||
"make_sloppy",
|
||||
"latency_all_edges",
|
||||
"slack",
|
||||
]
|
||||
expected_optimizations_default = []
|
||||
graph_rewrites = options._graph_rewrites()
|
||||
self.assertEqual(set(graph_rewrites.enabled),
|
||||
set(expected_optimizations_enabled))
|
||||
self.assertEqual(set(graph_rewrites.disabled),
|
||||
set(expected_optimizations_disabled))
|
||||
self.assertEqual(set(graph_rewrites.default),
|
||||
set(expected_optimizations_default))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testAutotuningDefaults(self):
|
||||
@ -295,7 +400,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def testAutotuningBufferSizes(self):
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_optimization.autotune_buffers = True
|
||||
self.assertIn("inject_prefetch", options._graph_rewrites())
|
||||
self.assertIn("inject_prefetch", options._graph_rewrites().enabled)
|
||||
autotune, algorithm, cpu_budget = options._autotune_settings()
|
||||
self.assertTrue(autotune)
|
||||
self.assertEqual(algorithm,
|
||||
|
@ -45,7 +45,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||
dataset, ["/cpu:1", "/cpu:2"])
|
||||
dataset = multi_device_iterator._dataset # pylint: disable=protected-access
|
||||
self.assertIn("slack", dataset.options()._graph_rewrites())
|
||||
self.assertIn("slack", dataset.options()._graph_rewrites().enabled)
|
||||
self.assertIn("slack:slack_period:2",
|
||||
dataset.options()._graph_rewrite_configs())
|
||||
|
||||
@ -69,7 +69,7 @@ class PrefetchWithSlackTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_slack = True
|
||||
dataset = dataset.with_options(options)
|
||||
self.assertIn("slack", dataset.options()._graph_rewrites())
|
||||
self.assertIn("slack", dataset.options()._graph_rewrites().enabled)
|
||||
self.assertIn("slack:slack_period:1",
|
||||
dataset.options()._graph_rewrite_configs())
|
||||
self.assertDatasetProduces(dataset, range(10))
|
||||
|
@ -36,7 +36,8 @@ class OptimizeDatasetSerializationTest(
|
||||
|
||||
def build_dataset(num_elements, batch_size):
|
||||
return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch(
|
||||
batch_size).apply(optimization.optimize(["map_and_batch_fusion"]))
|
||||
batch_size).apply(
|
||||
optimization.optimize(["map_and_batch_fusion"], None, None))
|
||||
|
||||
self.run_core_tests(lambda: build_dataset(200, 10), 20)
|
||||
|
||||
@ -50,7 +51,8 @@ class OptimizeDatasetSerializationTest(
|
||||
dataset = dataset.batch(5)
|
||||
# map_vectorization adds a new vectorized function to the function
|
||||
# library.
|
||||
dataset = dataset.apply(optimization.optimize(["map_vectorization"]))
|
||||
dataset = dataset.apply(
|
||||
optimization.optimize(["map_vectorization"], None, None))
|
||||
return dataset
|
||||
|
||||
self.run_core_tests(build_dataset, 20)
|
||||
|
@ -36,13 +36,19 @@ def model():
|
||||
return _apply_fn
|
||||
|
||||
|
||||
def optimize(optimizations=None):
|
||||
def optimize(optimizations_enabled=None, optimizations_disabled=None,
|
||||
optimizations_default=None):
|
||||
"""A transformation that applies optimizations.
|
||||
|
||||
Args:
|
||||
optimizations: (Optional.) A `tf.string` vector `tf.Tensor` identifying
|
||||
optimizations to use. If not specified, the default set of optimizations
|
||||
is applied.
|
||||
optimizations_enabled: (Optional.) A `tf.string` vector `tf.Tensor`
|
||||
identifying enabled optimizations. If not specified, set to be empty.
|
||||
|
||||
optimizations_disabled: (Optional.) A `tf.string` vector `tf.Tensor`
|
||||
identifying disabled optimizations. If not specified, set to be empty.
|
||||
|
||||
optimizations_default: (Optional.) A `tf.string` vector `tf.Tensor`
|
||||
identifying default optimizations. If not specified, set to be empty.
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
@ -51,7 +57,11 @@ def optimize(optimizations=None):
|
||||
|
||||
def _apply_fn(dataset):
|
||||
"""Function from `Dataset` to `Dataset` that applies the transformation."""
|
||||
return dataset_ops._OptimizeDataset(dataset, optimizations) # pylint: disable=protected-access
|
||||
return dataset_ops._OptimizeDataset( # pylint: disable=protected-access
|
||||
dataset,
|
||||
optimizations_enabled,
|
||||
optimizations_disabled,
|
||||
optimizations_default)
|
||||
|
||||
return _apply_fn
|
||||
|
||||
|
@ -53,9 +53,13 @@ class MapVectorizationOptions(options.OptionsBase):
|
||||
"defaults to False.")
|
||||
|
||||
def _graph_rewrites(self):
|
||||
if self.enabled:
|
||||
return ["map_vectorization"]
|
||||
return []
|
||||
graph_rewrites = options.graph_rewrites()
|
||||
result = graph_rewrites(enabled=[], disabled=[], default=[])
|
||||
if self.enabled is True: # pylint: disable=g-bool-id-comparison
|
||||
result.enabled.append("map_vectorization")
|
||||
elif self.enabled is False: # pylint: disable=g-bool-id-comparison
|
||||
result.disabled.append("map_vectorization")
|
||||
return result
|
||||
|
||||
def _graph_rewrite_configs(self):
|
||||
if not self.enabled:
|
||||
@ -229,8 +233,20 @@ class OptimizationOptions(options.OptionsBase):
|
||||
return autotune, algorithm, cpu_budget
|
||||
|
||||
def _graph_rewrites(self):
|
||||
"""Produces the list of enabled graph optimizations."""
|
||||
result = set()
|
||||
"""Produces lists of enabled, disabled and default graph optimizations.
|
||||
|
||||
Returns:
|
||||
result: a namedtuple with three attributes. `result.enabled` is the list
|
||||
of user enabled optimizations. `result.disabled` is the list of user
|
||||
disabled optimizations. `result.default` is the list of optimizations
|
||||
that are enabled by default (the user has not explicitly enabled or
|
||||
disabled them).
|
||||
"""
|
||||
if self.map_vectorization is not None:
|
||||
result = self.map_vectorization._graph_rewrites() # pylint: disable=protected-access
|
||||
else:
|
||||
result = MapVectorizationOptions()._graph_rewrites() # pylint: disable=protected-access
|
||||
|
||||
all_optimizations = [
|
||||
"filter_fusion",
|
||||
"filter_with_random_uniform_fusion",
|
||||
@ -244,11 +260,8 @@ class OptimizationOptions(options.OptionsBase):
|
||||
"reorder_data_discarding_ops",
|
||||
"shuffle_and_repeat_fusion",
|
||||
]
|
||||
for optimization in all_optimizations:
|
||||
if getattr(self, optimization):
|
||||
result.add(optimization)
|
||||
|
||||
if self.apply_default_optimizations is not False:
|
||||
if self.apply_default_optimizations is not False: # pylint: disable=g-bool-id-comparison
|
||||
# The following optimizations are turned on by default, unless the user
|
||||
# explicitly disables them.
|
||||
optimizations_to_disable = [
|
||||
@ -257,21 +270,29 @@ class OptimizationOptions(options.OptionsBase):
|
||||
"shuffle_and_repeat_fusion",
|
||||
]
|
||||
for optimization in optimizations_to_disable:
|
||||
if getattr(self, optimization) is not False:
|
||||
result.add(optimization)
|
||||
if getattr(self, optimization) is None:
|
||||
result.default.append(optimization)
|
||||
|
||||
if self.map_vectorization is not None:
|
||||
result.update(self.map_vectorization._graph_rewrites()) # pylint: disable=protected-access
|
||||
# Each of these attributes on the Options object is either True (explicitly
|
||||
# enabled), False (explicitly disabled), or None (default).
|
||||
for optimization in all_optimizations:
|
||||
if getattr(self, optimization) is True: # pylint: disable=g-bool-id-comparison
|
||||
result.enabled.append(optimization)
|
||||
elif getattr(self, optimization) is False: # pylint: disable=g-bool-id-comparison
|
||||
result.disabled.append(optimization)
|
||||
|
||||
autotune_buffers = self._autotune_buffers()
|
||||
if self.autotune is not False and autotune_buffers: # pylint: disable=g-bool-id-comparison
|
||||
if self.autotune is not False and autotune_buffers is True: # pylint: disable=g-bool-id-comparison
|
||||
# When autotuning buffer sizes is enabled, we inject a `prefetch`
|
||||
# transformation after asynchronous dataset ops. Only the buffer sizes of
|
||||
# prefetch transformations will be autotuned, though this is practically
|
||||
# equivalent to tuning the buffer sizes of the other asynchronous
|
||||
# transformations.
|
||||
result.add("inject_prefetch")
|
||||
return sorted(list(result))
|
||||
result.enabled.append("inject_prefetch")
|
||||
if self.autotune is False: # pylint: disable=g-bool-id-comparison
|
||||
result.disabled.append("inject_prefetch")
|
||||
|
||||
return result
|
||||
|
||||
def _graph_rewrite_configs(self):
|
||||
if self.map_vectorization is not None:
|
||||
|
@ -30,11 +30,13 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.experimental.ops import optimization_options
|
||||
from tensorflow.python.data.experimental.ops import stats_options
|
||||
from tensorflow.python.data.experimental.ops import threading_options
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.util import convert
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import options as options_lib
|
||||
from tensorflow.python.data.util import random_seed
|
||||
@ -374,16 +376,18 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
|
||||
graph_rewrites = options._graph_rewrites()
|
||||
graph_rewrite_configs = options._graph_rewrite_configs()
|
||||
# pylint: enable=protected-access
|
||||
if graph_rewrites:
|
||||
if graph_rewrites.enabled or graph_rewrites.default:
|
||||
if self._has_captured_ref():
|
||||
warnings.warn(
|
||||
"tf.data graph rewrites are not compatible with tf.Variable. "
|
||||
"The following rewrites will be disabled: %s. To enable "
|
||||
"rewrites, use resource variables instead by calling "
|
||||
"`tf.enable_resource_variables()` at the start of the program." %
|
||||
", ".join(graph_rewrites))
|
||||
", ".join(graph_rewrites.enabled + graph_rewrites.default))
|
||||
else:
|
||||
dataset = _OptimizeDataset(dataset, graph_rewrites,
|
||||
dataset = _OptimizeDataset(dataset, graph_rewrites.enabled,
|
||||
graph_rewrites.disabled,
|
||||
graph_rewrites.default,
|
||||
graph_rewrite_configs)
|
||||
|
||||
# (3) Apply autotune options
|
||||
@ -2887,22 +2891,39 @@ class Options(options_lib.OptionsBase):
|
||||
"is being captured.")
|
||||
|
||||
def _graph_rewrites(self):
|
||||
"""Produces the list of enabled static graph rewrites."""
|
||||
result = []
|
||||
"""Produces lists of enabled, disabled, default static graph rewrites.
|
||||
|
||||
Returns:
|
||||
result: a namedtuple with three attributes. `result.enabled` is the list
|
||||
of user enabled graph rewrites. `result.disabled` is the list of user
|
||||
disabled graph rewrites. `result.default` is the list of graph
|
||||
rewrites that are enabled by default (the user has not explicitly
|
||||
enabled or disabled them).
|
||||
"""
|
||||
if self.experimental_optimization is not None:
|
||||
result.extend(self.experimental_optimization._graph_rewrites()) # pylint: disable=protected-access
|
||||
result = self.experimental_optimization._graph_rewrites() # pylint: disable=protected-access
|
||||
else:
|
||||
# Apply default options
|
||||
result.extend(
|
||||
optimization_options.OptimizationOptions()._graph_rewrites()) # pylint: disable=protected-access
|
||||
result = optimization_options.OptimizationOptions()._graph_rewrites() # pylint: disable=protected-access
|
||||
|
||||
if self.experimental_deterministic is False: # pylint: disable=g-bool-id-comparison
|
||||
result.append("make_sloppy")
|
||||
if self.experimental_stats and self.experimental_stats.latency_all_edges:
|
||||
result.append("latency_all_edges")
|
||||
if self.experimental_slack:
|
||||
result.append("slack")
|
||||
return result
|
||||
result.enabled.append("make_sloppy")
|
||||
elif self.experimental_deterministic is True: # pylint: disable=g-bool-id-comparison
|
||||
result.disabled.append("make_sloppy")
|
||||
if self.experimental_stats:
|
||||
if self.experimental_stats.latency_all_edges is True: # pylint: disable=g-bool-id-comparison
|
||||
result.enabled.append("latency_all_edges")
|
||||
elif self.experimental_stats.latency_all_edges is False: # pylint: disable=g-bool-id-comparison
|
||||
result.disabled.append("latency_all_edges")
|
||||
if self.experimental_slack is True: # pylint: disable=g-bool-id-comparison
|
||||
result.enabled.append("slack")
|
||||
elif self.experimental_slack is False: # pylint: disable=g-bool-id-comparison
|
||||
result.disabled.append("slack")
|
||||
|
||||
graph_rewrites = options_lib.graph_rewrites()
|
||||
return graph_rewrites(enabled=list(set(result.enabled)),
|
||||
disabled=list(set(result.disabled)),
|
||||
default=list(set(result.default)))
|
||||
|
||||
def _graph_rewrite_configs(self):
|
||||
"""Produces the list of configurations for enabled graph optimizations."""
|
||||
@ -4387,19 +4408,55 @@ class _ModelDataset(UnaryUnchangedStructureDataset):
|
||||
class _OptimizeDataset(UnaryUnchangedStructureDataset):
|
||||
"""A `Dataset` that acts as an identity, and applies optimizations."""
|
||||
|
||||
def __init__(self, input_dataset, optimizations, optimization_configs=None):
|
||||
def __init__(self,
|
||||
input_dataset,
|
||||
optimizations_enabled,
|
||||
optimizations_disabled,
|
||||
optimizations_default,
|
||||
optimization_configs=None):
|
||||
self._input_dataset = input_dataset
|
||||
if optimizations is None:
|
||||
optimizations = []
|
||||
if optimization_configs is None:
|
||||
optimization_configs = []
|
||||
self._optimizations = ops.convert_to_tensor(
|
||||
optimizations, dtype=dtypes.string, name="optimizations")
|
||||
variant_tensor = gen_dataset_ops.optimize_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._optimizations,
|
||||
optimization_configs=optimization_configs,
|
||||
**self._flat_structure)
|
||||
|
||||
if compat.forward_compatible(2020, 8, 6):
|
||||
self._optimizations_enabled = convert.optional_param_to_tensor(
|
||||
argument_name="optimizations_enabled",
|
||||
argument_value=optimizations_enabled,
|
||||
argument_default=[],
|
||||
argument_dtype=dtypes.string)
|
||||
self._optimizations_disabled = convert.optional_param_to_tensor(
|
||||
argument_name="optimizations_disabled",
|
||||
argument_value=optimizations_disabled,
|
||||
argument_default=[],
|
||||
argument_dtype=dtypes.string)
|
||||
self._optimizations_default = convert.optional_param_to_tensor(
|
||||
argument_name="optimizations_default",
|
||||
argument_value=optimizations_default,
|
||||
argument_default=[],
|
||||
argument_dtype=dtypes.string)
|
||||
|
||||
variant_tensor = gen_dataset_ops.optimize_dataset_v2(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._optimizations_enabled,
|
||||
self._optimizations_disabled,
|
||||
self._optimizations_default,
|
||||
optimization_configs=optimization_configs,
|
||||
**self._flat_structure)
|
||||
else:
|
||||
if optimizations_enabled is None:
|
||||
optimizations_enabled = []
|
||||
if optimizations_default is None:
|
||||
optimizations_default = []
|
||||
|
||||
self._optimizations = ops.convert_to_tensor(
|
||||
optimizations_enabled + optimizations_default,
|
||||
dtype=dtypes.string,
|
||||
name="optimizations")
|
||||
variant_tensor = gen_dataset_ops.optimize_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._optimizations,
|
||||
optimization_configs=optimization_configs,
|
||||
**self._flat_structure)
|
||||
super(_OptimizeDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
|
||||
def _internal_attr_name(name):
|
||||
return "_" + name
|
||||
@ -56,6 +58,12 @@ class OptionsBase(object):
|
||||
"Cannot set the property %s on %s." % (name, type(self).__name__))
|
||||
|
||||
|
||||
# Creates a namedtuple with three keys for optimization graph rewrites settings.
|
||||
def graph_rewrites():
|
||||
return collections.namedtuple("GraphRewrites",
|
||||
["enabled", "disabled", "default"])
|
||||
|
||||
|
||||
def create_option(name, ty, docstring, default_factory=lambda: None):
|
||||
"""Creates a type-checked property.
|
||||
|
||||
|
@ -2660,6 +2660,10 @@ tf_module {
|
||||
name: "OptimizeDataset"
|
||||
argspec: "args=[\'input_dataset\', \'optimizations\', \'output_types\', \'output_shapes\', \'optimization_configs\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "OptimizeDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'optimizations_enabled\', \'optimizations_disabled\', \'optimizations_default\', \'output_types\', \'output_shapes\', \'optimization_configs\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "OptionalFromValue"
|
||||
argspec: "args=[\'components\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -2660,6 +2660,10 @@ tf_module {
|
||||
name: "OptimizeDataset"
|
||||
argspec: "args=[\'input_dataset\', \'optimizations\', \'output_types\', \'output_shapes\', \'optimization_configs\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "OptimizeDatasetV2"
|
||||
argspec: "args=[\'input_dataset\', \'optimizations_enabled\', \'optimizations_disabled\', \'optimizations_default\', \'output_types\', \'output_shapes\', \'optimization_configs\', \'name\'], varargs=None, keywords=None, defaults=[\'[]\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "OptionalFromValue"
|
||||
argspec: "args=[\'components\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user