[tf.data] Fix the problem for the current AsGraphDefInternal function in ModelDatasetOp is not capturing all attrs.

PiperOrigin-RevId: 331643560
Change-Id: I221b7a5510ee78ffb7ccfbed051d398ca056f4eb
This commit is contained in:
Jay Shi 2020-09-14 15:37:13 -07:00 committed by TensorFlower Gardener
parent 269b937042
commit 0e2e8460eb

View File

@ -34,31 +34,29 @@ constexpr double kRamBudgetShare = 0.5;
class ModelDatasetOp : public UnaryDatasetOpKernel {
public:
static constexpr const char* const kAlgorithm = "algorithm";
static constexpr const char* const kCpuBudget = "cpu_budget";
static constexpr const char* const kRamBudget = "ram_budget";
explicit ModelDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {
if (ctx->HasAttr("algorithm")) {
if (ctx->HasAttr(kAlgorithm)) {
int64 algorithm;
OP_REQUIRES_OK(ctx, ctx->GetAttr("algorithm", &algorithm));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kAlgorithm, &algorithm));
algorithm_ = model::AutotuneAlgorithm(algorithm);
} else {
algorithm_ = model::AutotuneAlgorithm::HILL_CLIMB;
}
OP_REQUIRES_OK(ctx, ctx->GetAttr("cpu_budget", &cpu_budget_));
if (cpu_budget_ == 0) {
cpu_budget_ = port::NumSchedulableCPUs();
}
OP_REQUIRES(ctx, cpu_budget_ > 0,
OP_REQUIRES_OK(ctx, ctx->GetAttr(kCpuBudget, &cpu_budget_));
OP_REQUIRES(ctx, cpu_budget_ >= 0,
errors::InvalidArgument("CPU budget must be positive but is ",
cpu_budget_, "."));
if (ctx->HasAttr("ram_budget")) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("ram_budget", &ram_budget_));
if (ctx->HasAttr(kRamBudget)) {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kRamBudget, &ram_budget_));
} else {
ram_budget_ = 0;
}
if (ram_budget_ == 0) {
ram_budget_ = kRamBudgetShare * port::AvailableRam();
}
OP_REQUIRES(ctx, ram_budget_ > 0,
OP_REQUIRES(ctx, ram_budget_ >= 0,
errors::InvalidArgument("RAM budget must be positive but is ",
ram_budget_, "."));
}
@ -112,6 +110,19 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
AttrValue algorithm_attr;
b->BuildAttrValue(static_cast<int64>(algorithm_), &algorithm_attr);
AttrValue cpu_budget_attr;
b->BuildAttrValue(cpu_budget_, &cpu_budget_attr);
AttrValue ram_budget_attr;
b->BuildAttrValue(ram_budget_, &ram_budget_attr);
TF_RETURN_IF_ERROR(
b->AddDataset(this, {input_graph_node},
{std::make_pair(kAlgorithm, algorithm_attr),
std::make_pair(kCpuBudget, cpu_budget_attr),
std::make_pair(kRamBudget, ram_budget_attr)},
output));
return Status::OK();
}
@ -119,7 +130,12 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {
: DatasetIterator<Dataset>(params),
cpu_budget_(dataset()->cpu_budget_ == 0 ? port::NumSchedulableCPUs()
: dataset()->cpu_budget_),
ram_budget_(dataset()->ram_budget_ == 0
? kRamBudgetShare * port::AvailableRam()
: dataset()->ram_budget_) {
model_ = std::make_shared<model::Model>();
}
@ -213,8 +229,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
tf_shared_lock l(mu_);
model_input_time = SelfInputTime();
}
model_->Optimize(dataset()->algorithm_, dataset()->cpu_budget_,
dataset()->ram_budget_, /*model_input_time=*/0);
model_->Optimize(dataset()->algorithm_, cpu_budget_, ram_budget_,
/*model_input_time=*/0);
// Exponentially increase the period of running the optimization
// until a threshold is reached.
if (optimization_period_ms != kOptimizationPeriodThresholdMs) {
@ -256,6 +272,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
int64 num_input_events_ TF_GUARDED_BY(mu_) = 0;
int64 input_time_ TF_GUARDED_BY(mu_) = 0;
int64 last_output_time_ TF_GUARDED_BY(mu_) = 0;
const int64 cpu_budget_;
const int64 ram_budget_;
};
const DatasetBase* input_;