From 0e2e8460ebfb9451c24a44617e836c9101ab2a62 Mon Sep 17 00:00:00 2001 From: Jay Shi Date: Mon, 14 Sep 2020 15:37:13 -0700 Subject: [PATCH] [tf.data] Fix the problem for the current `AsGraphDefInternal` function in `ModelDatasetOp` is not capturing all attrs. PiperOrigin-RevId: 331643560 Change-Id: I221b7a5510ee78ffb7ccfbed051d398ca056f4eb --- .../core/kernels/data/model_dataset_op.cc | 50 +++++++++++++------ 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc index 5e15dc81f25..37b7594564c 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.cc +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -34,31 +34,29 @@ constexpr double kRamBudgetShare = 0.5; class ModelDatasetOp : public UnaryDatasetOpKernel { public: + static constexpr const char* const kAlgorithm = "algorithm"; + static constexpr const char* const kCpuBudget = "cpu_budget"; + static constexpr const char* const kRamBudget = "ram_budget"; + explicit ModelDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { - if (ctx->HasAttr("algorithm")) { + if (ctx->HasAttr(kAlgorithm)) { int64 algorithm; - OP_REQUIRES_OK(ctx, ctx->GetAttr("algorithm", &algorithm)); + OP_REQUIRES_OK(ctx, ctx->GetAttr(kAlgorithm, &algorithm)); algorithm_ = model::AutotuneAlgorithm(algorithm); } else { algorithm_ = model::AutotuneAlgorithm::HILL_CLIMB; } - OP_REQUIRES_OK(ctx, ctx->GetAttr("cpu_budget", &cpu_budget_)); - if (cpu_budget_ == 0) { - cpu_budget_ = port::NumSchedulableCPUs(); - } - OP_REQUIRES(ctx, cpu_budget_ > 0, + OP_REQUIRES_OK(ctx, ctx->GetAttr(kCpuBudget, &cpu_budget_)); + OP_REQUIRES(ctx, cpu_budget_ >= 0, errors::InvalidArgument("CPU budget must be positive but is ", cpu_budget_, ".")); - if (ctx->HasAttr("ram_budget")) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("ram_budget", &ram_budget_)); + if (ctx->HasAttr(kRamBudget)) { + OP_REQUIRES_OK(ctx, ctx->GetAttr(kRamBudget, &ram_budget_)); } else { ram_budget_ = 0; } - if (ram_budget_ == 0) { - ram_budget_ = kRamBudgetShare * port::AvailableRam(); - } - OP_REQUIRES(ctx, ram_budget_ > 0, + OP_REQUIRES(ctx, ram_budget_ >= 0, errors::InvalidArgument("RAM budget must be positive but is ", ram_budget_, ".")); } @@ -112,6 +110,19 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); + AttrValue algorithm_attr; + b->BuildAttrValue(static_cast(algorithm_), &algorithm_attr); + AttrValue cpu_budget_attr; + b->BuildAttrValue(cpu_budget_, &cpu_budget_attr); + AttrValue ram_budget_attr; + b->BuildAttrValue(ram_budget_, &ram_budget_attr); + + TF_RETURN_IF_ERROR( + b->AddDataset(this, {input_graph_node}, + {std::make_pair(kAlgorithm, algorithm_attr), + std::make_pair(kCpuBudget, cpu_budget_attr), + std::make_pair(kRamBudget, ram_budget_attr)}, + output)); return Status::OK(); } @@ -119,7 +130,12 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params) { + : DatasetIterator(params), + cpu_budget_(dataset()->cpu_budget_ == 0 ? port::NumSchedulableCPUs() + : dataset()->cpu_budget_), + ram_budget_(dataset()->ram_budget_ == 0 + ? kRamBudgetShare * port::AvailableRam() + : dataset()->ram_budget_) { model_ = std::make_shared(); } @@ -213,8 +229,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { tf_shared_lock l(mu_); model_input_time = SelfInputTime(); } - model_->Optimize(dataset()->algorithm_, dataset()->cpu_budget_, - dataset()->ram_budget_, /*model_input_time=*/0); + model_->Optimize(dataset()->algorithm_, cpu_budget_, ram_budget_, + /*model_input_time=*/0); // Exponentially increase the period of running the optimization // until a threshold is reached. if (optimization_period_ms != kOptimizationPeriodThresholdMs) { @@ -256,6 +272,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { int64 num_input_events_ TF_GUARDED_BY(mu_) = 0; int64 input_time_ TF_GUARDED_BY(mu_) = 0; int64 last_output_time_ TF_GUARDED_BY(mu_) = 0; + const int64 cpu_budget_; + const int64 ram_budget_; }; const DatasetBase* input_;