[tf.data] Fix the problem for the current AsGraphDefInternal
function in ModelDatasetOp
is not capturing all attrs.
PiperOrigin-RevId: 331643560 Change-Id: I221b7a5510ee78ffb7ccfbed051d398ca056f4eb
This commit is contained in:
parent
269b937042
commit
0e2e8460eb
@ -34,31 +34,29 @@ constexpr double kRamBudgetShare = 0.5;
|
||||
|
||||
class ModelDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
static constexpr const char* const kAlgorithm = "algorithm";
|
||||
static constexpr const char* const kCpuBudget = "cpu_budget";
|
||||
static constexpr const char* const kRamBudget = "ram_budget";
|
||||
|
||||
explicit ModelDatasetOp(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {
|
||||
if (ctx->HasAttr("algorithm")) {
|
||||
if (ctx->HasAttr(kAlgorithm)) {
|
||||
int64 algorithm;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("algorithm", &algorithm));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kAlgorithm, &algorithm));
|
||||
algorithm_ = model::AutotuneAlgorithm(algorithm);
|
||||
} else {
|
||||
algorithm_ = model::AutotuneAlgorithm::HILL_CLIMB;
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("cpu_budget", &cpu_budget_));
|
||||
if (cpu_budget_ == 0) {
|
||||
cpu_budget_ = port::NumSchedulableCPUs();
|
||||
}
|
||||
OP_REQUIRES(ctx, cpu_budget_ > 0,
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kCpuBudget, &cpu_budget_));
|
||||
OP_REQUIRES(ctx, cpu_budget_ >= 0,
|
||||
errors::InvalidArgument("CPU budget must be positive but is ",
|
||||
cpu_budget_, "."));
|
||||
if (ctx->HasAttr("ram_budget")) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("ram_budget", &ram_budget_));
|
||||
if (ctx->HasAttr(kRamBudget)) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kRamBudget, &ram_budget_));
|
||||
} else {
|
||||
ram_budget_ = 0;
|
||||
}
|
||||
if (ram_budget_ == 0) {
|
||||
ram_budget_ = kRamBudgetShare * port::AvailableRam();
|
||||
}
|
||||
OP_REQUIRES(ctx, ram_budget_ > 0,
|
||||
OP_REQUIRES(ctx, ram_budget_ >= 0,
|
||||
errors::InvalidArgument("RAM budget must be positive but is ",
|
||||
ram_budget_, "."));
|
||||
}
|
||||
@ -112,6 +110,19 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
|
||||
Node* input_graph_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
|
||||
AttrValue algorithm_attr;
|
||||
b->BuildAttrValue(static_cast<int64>(algorithm_), &algorithm_attr);
|
||||
AttrValue cpu_budget_attr;
|
||||
b->BuildAttrValue(cpu_budget_, &cpu_budget_attr);
|
||||
AttrValue ram_budget_attr;
|
||||
b->BuildAttrValue(ram_budget_, &ram_budget_attr);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this, {input_graph_node},
|
||||
{std::make_pair(kAlgorithm, algorithm_attr),
|
||||
std::make_pair(kCpuBudget, cpu_budget_attr),
|
||||
std::make_pair(kRamBudget, ram_budget_attr)},
|
||||
output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -119,7 +130,12 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {
|
||||
: DatasetIterator<Dataset>(params),
|
||||
cpu_budget_(dataset()->cpu_budget_ == 0 ? port::NumSchedulableCPUs()
|
||||
: dataset()->cpu_budget_),
|
||||
ram_budget_(dataset()->ram_budget_ == 0
|
||||
? kRamBudgetShare * port::AvailableRam()
|
||||
: dataset()->ram_budget_) {
|
||||
model_ = std::make_shared<model::Model>();
|
||||
}
|
||||
|
||||
@ -213,8 +229,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
|
||||
tf_shared_lock l(mu_);
|
||||
model_input_time = SelfInputTime();
|
||||
}
|
||||
model_->Optimize(dataset()->algorithm_, dataset()->cpu_budget_,
|
||||
dataset()->ram_budget_, /*model_input_time=*/0);
|
||||
model_->Optimize(dataset()->algorithm_, cpu_budget_, ram_budget_,
|
||||
/*model_input_time=*/0);
|
||||
// Exponentially increase the period of running the optimization
|
||||
// until a threshold is reached.
|
||||
if (optimization_period_ms != kOptimizationPeriodThresholdMs) {
|
||||
@ -256,6 +272,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
|
||||
int64 num_input_events_ TF_GUARDED_BY(mu_) = 0;
|
||||
int64 input_time_ TF_GUARDED_BY(mu_) = 0;
|
||||
int64 last_output_time_ TF_GUARDED_BY(mu_) = 0;
|
||||
const int64 cpu_budget_;
|
||||
const int64 ram_budget_;
|
||||
};
|
||||
|
||||
const DatasetBase* input_;
|
||||
|
Loading…
Reference in New Issue
Block a user