diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index a384f4a0320..9f351edf11a 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -173,6 +173,7 @@ tf_kernel_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:graph_topology_view", "//tensorflow/core/grappler/utils:traversal", @@ -540,6 +541,7 @@ tf_cc_test( tf_kernel_library( name = "model_dataset_op", srcs = ["model_dataset_op.cc"], + hdrs = ["model_dataset_op.h"], deps = [ "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", @@ -1444,15 +1446,11 @@ filegroup( "*.h", ], exclude = [ - "dataset_ops*", # includes grappler dependency, which isn't supported on mobile. - "optimize_dataset_op.*", # includes grappler dependency, which isn't supported on mobile. - "model_dataset_op.*", # not supported on mobile. - "rewrite_utils*", # includes grappler dependency, which isn't supported on mobile. + "dataset_test_base.*", "*test.cc", - "*test.h", - "*_test_*", ], ), + visibility = ["//tensorflow:__subpackages__"], ) tf_kernel_library( diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc index 7c701b7886a..597e2587e66 100644 --- a/tensorflow/core/kernels/data/dataset_ops.cc +++ b/tensorflow/core/kernels/data/dataset_ops.cc @@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "tensorflow/core/kernels/data/dataset_ops.h" +// On mobile we do not provide this functionality because not all of its +// dependencies are available there. +#if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" @@ -168,3 +170,4 @@ REGISTER_KERNEL_BUILDER(Name("DatasetFromGraph").Device(DEVICE_CPU), } // namespace data } // namespace tensorflow +#endif // !IS_MOBILE_PLATFORM diff --git a/tensorflow/core/kernels/data/dataset_ops.h b/tensorflow/core/kernels/data/dataset_ops.h index 9895585f3de..576e018d320 100644 --- a/tensorflow/core/kernels/data/dataset_ops.h +++ b/tensorflow/core/kernels/data/dataset_ops.h @@ -12,10 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_OPS_H_ #define TENSORFLOW_CORE_KERNELS_DATA_DATASET_OPS_H_ +#include "tensorflow/core/platform/platform.h" + +// On mobile we do not provide this functionality because not all of its +// dependencies are available there. +#if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op_kernel.h" @@ -61,5 +65,6 @@ class DatasetFromGraphOp : public OpKernel { } // namespace data } // namespace tensorflow +#endif // !IS_MOBILE_PLATFORM #endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_OPS_H_ diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc index d325a3dcf66..f790b4bf07f 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.cc +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -12,7 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/kernels/data/model_dataset_op.h" +// On mobile we do not provide model dataset op because not all of its +// dependencies are available there. The op is replaced with a no-op. +#if !defined(IS_MOBILE_PLATFORM) #include "absl/memory/memory.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/metrics.h" @@ -32,269 +36,283 @@ constexpr int64 kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMillis; // Default share of available RAM that can be used by model's internal buffers. constexpr double kRamBudgetShare = 0.5; -class ModelDatasetOp : public UnaryDatasetOpKernel { - public: - static constexpr const char* const kAlgorithm = "algorithm"; - static constexpr const char* const kCpuBudget = "cpu_budget"; - static constexpr const char* const kRamBudget = "ram_budget"; +} // namespace - explicit ModelDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx) { - if (ctx->HasAttr(kAlgorithm)) { - int64 algorithm; - OP_REQUIRES_OK(ctx, ctx->GetAttr(kAlgorithm, &algorithm)); - algorithm_ = model::AutotuneAlgorithm(algorithm); - } else { - algorithm_ = model::AutotuneAlgorithm::HILL_CLIMB; - } - OP_REQUIRES_OK(ctx, ctx->GetAttr(kCpuBudget, &cpu_budget_)); - OP_REQUIRES(ctx, cpu_budget_ >= 0, - errors::InvalidArgument("CPU budget must be positive but is ", - cpu_budget_, ".")); - if (ctx->HasAttr(kRamBudget)) { - OP_REQUIRES_OK(ctx, ctx->GetAttr(kRamBudget, &ram_budget_)); - } else { - ram_budget_ = 0; - } - OP_REQUIRES(ctx, ram_budget_ >= 0, - errors::InvalidArgument("RAM budget must be positive but is ", - ram_budget_, ".")); +/* static */ constexpr const char* const ModelDatasetOp::kAlgorithm; +/* static */ constexpr const char* const ModelDatasetOp::kCpuBudget; +/* static */ constexpr const char* const ModelDatasetOp::kRamBudget; + +class ModelDatasetOp::Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + model::AutotuneAlgorithm algorithm, int64 cpu_budget, + int64 ram_budget) + : DatasetBase(DatasetContext(ctx)), + input_(input), + algorithm_(algorithm), + cpu_budget_(cpu_budget), + ram_budget_(ram_budget) { + input_->Ref(); } - void MakeDataset(OpKernelContext* ctx, DatasetBase* input, - DatasetBase** output) override { - *output = new Dataset(ctx, input, algorithm_, cpu_budget_, ram_budget_); + ~Dataset() override { input_->Unref(); } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return absl::make_unique( + Iterator::Params{this, strings::StrCat(prefix, "::Model")}); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + const std::vector& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() const override { return "ModelDatasetOp::Dataset"; } + + int64 Cardinality() const override { return input_->Cardinality(); } + + Status InputDatasets(std::vector* inputs) const override { + inputs->push_back(input_); + return Status::OK(); + } + + Status CheckExternalState() const override { + return input_->CheckExternalState(); + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); + AttrValue algorithm_attr; + b->BuildAttrValue(static_cast(algorithm_), &algorithm_attr); + AttrValue cpu_budget_attr; + b->BuildAttrValue(cpu_budget_, &cpu_budget_attr); + AttrValue ram_budget_attr; + b->BuildAttrValue(ram_budget_, &ram_budget_attr); + + TF_RETURN_IF_ERROR( + b->AddDataset(this, {input_graph_node}, + {std::make_pair(kAlgorithm, algorithm_attr), + std::make_pair(kCpuBudget, cpu_budget_attr), + std::make_pair(kRamBudget, ram_budget_attr)}, + output)); + return Status::OK(); } private: - class Dataset : public DatasetBase { + class Iterator : public DatasetIterator { public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, - model::AutotuneAlgorithm algorithm, int64 cpu_budget, - int64 ram_budget) - : DatasetBase(DatasetContext(ctx)), - input_(input), - algorithm_(algorithm), - cpu_budget_(cpu_budget), - ram_budget_(ram_budget) { - input_->Ref(); + explicit Iterator(const Params& params) + : DatasetIterator(params), + cpu_budget_(dataset()->cpu_budget_ == 0 ? port::NumSchedulableCPUs() + : dataset()->cpu_budget_), + ram_budget_(dataset()->ram_budget_ == 0 + ? kRamBudgetShare * port::AvailableRam() + : dataset()->ram_budget_) { + model_ = std::make_shared(); } - ~Dataset() override { input_->Unref(); } - - std::unique_ptr MakeIteratorInternal( - const string& prefix) const override { - return absl::make_unique( - Iterator::Params{this, strings::StrCat(prefix, "::Model")}); + ~Iterator() override { + // Signal the optimize thread to terminate it. We will then join that + // thread when we delete `this->optimize_thread_`. + mutex_lock l(mu_); + cancelled_ = true; + cond_var_.notify_all(); } - const DataTypeVector& output_dtypes() const override { - return input_->output_dtypes(); - } - const std::vector& output_shapes() const override { - return input_->output_shapes(); + Status Initialize(IteratorContext* ctx) override { + IteratorContext::Params params(ctx); + params.model = model_; + return dataset()->input_->MakeIterator(IteratorContext(std::move(params)), + this, prefix(), &input_impl_); } - string DebugString() const override { return "ModelDatasetOp::Dataset"; } - - int64 Cardinality() const override { return input_->Cardinality(); } - - Status InputDatasets( - std::vector* inputs) const override { - inputs->push_back(input_); - return Status::OK(); - } - - Status CheckExternalState() const override { - return input_->CheckExternalState(); + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + IteratorContext::Params params(ctx); + { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(EnsureOptimizeThreadStarted(ctx)); + params.model = model_; + int64 now_nanos = EnvTime::NowNanos(); + RecordInput(now_nanos); + } + Status s = input_impl_->GetNext(IteratorContext(std::move(params)), + out_tensors, end_of_sequence); + int64 now_nanos = EnvTime::NowNanos(); + mutex_lock l(mu_); + RecordOutput(now_nanos); + return s; } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); - TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); - AttrValue algorithm_attr; - b->BuildAttrValue(static_cast(algorithm_), &algorithm_attr); - AttrValue cpu_budget_attr; - b->BuildAttrValue(cpu_budget_, &cpu_budget_attr); - AttrValue ram_budget_attr; - b->BuildAttrValue(ram_budget_, &ram_budget_attr); + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeKnownRatioNode(std::move(args), + /*ratio=*/1); + } - TF_RETURN_IF_ERROR( - b->AddDataset(this, {input_graph_node}, - {std::make_pair(kAlgorithm, algorithm_attr), - std::make_pair(kCpuBudget, cpu_budget_attr), - std::make_pair(kRamBudget, ram_budget_attr)}, - output)); + Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); return Status::OK(); } private: - class Iterator : public DatasetIterator { - public: - explicit Iterator(const Params& params) - : DatasetIterator(params), - cpu_budget_(dataset()->cpu_budget_ == 0 ? port::NumSchedulableCPUs() - : dataset()->cpu_budget_), - ram_budget_(dataset()->ram_budget_ == 0 - ? kRamBudgetShare * port::AvailableRam() - : dataset()->ram_budget_) { - model_ = std::make_shared(); + Status EnsureOptimizeThreadStarted(IteratorContext* ctx) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!model_thread_) { + std::shared_ptr new_ctx = + std::make_shared(*ctx); + model_thread_ = ctx->StartThread( + "tf_data_model", [this, new_ctx]() { ModelThread(new_ctx); }); } + return Status::OK(); + } - ~Iterator() override { - // Signal the optimize thread to terminate it. We will then join that - // thread when we delete `this->optimize_thread_`. - mutex_lock l(mu_); - cancelled_ = true; - cond_var_.notify_all(); - } - - Status Initialize(IteratorContext* ctx) override { - IteratorContext::Params params(ctx); - params.model = model_; - return dataset()->input_->MakeIterator( - IteratorContext(std::move(params)), this, prefix(), &input_impl_); - } - - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - IteratorContext::Params params(ctx); + void ModelThread(const std::shared_ptr& ctx) { + int64 last_optimization_ms = 0; + int64 optimization_period_ms = 10; + int64 current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros; + while (true) { { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(EnsureOptimizeThreadStarted(ctx)); - params.model = model_; - int64 now_nanos = EnvTime::NowNanos(); - RecordInput(now_nanos); - } - Status s = input_impl_->GetNext(IteratorContext(std::move(params)), - out_tensors, end_of_sequence); - int64 now_nanos = EnvTime::NowNanos(); - mutex_lock l(mu_); - RecordOutput(now_nanos); - return s; - } - - protected: - std::shared_ptr CreateNode( - IteratorContext* ctx, model::Node::Args args) const override { - return model::MakeKnownRatioNode(std::move(args), - /*ratio=*/1); - } - - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return Status::OK(); - } - - private: - Status EnsureOptimizeThreadStarted(IteratorContext* ctx) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (!model_thread_) { - std::shared_ptr new_ctx = - std::make_shared(*ctx); - model_thread_ = ctx->StartThread( - "tf_data_model", [this, new_ctx]() { ModelThread(new_ctx); }); - } - return Status::OK(); - } - - void ModelThread(const std::shared_ptr& ctx) { - int64 last_optimization_ms = 0; - int64 optimization_period_ms = 10; - int64 current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros; - while (true) { - { - mutex_lock l(mu_); - while (!cancelled_ && - last_optimization_ms + optimization_period_ms > - current_time_ms) { - auto wait_ms = last_optimization_ms + optimization_period_ms - - current_time_ms; - VLOG(2) << "Waiting for " << wait_ms << " ms."; - cond_var_.wait_for(l, std::chrono::milliseconds(wait_ms)); - current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros; - } - if (cancelled_) return; + while (!cancelled_ && last_optimization_ms + optimization_period_ms > + current_time_ms) { + auto wait_ms = + last_optimization_ms + optimization_period_ms - current_time_ms; + VLOG(2) << "Waiting for " << wait_ms << " ms."; + cond_var_.wait_for(l, std::chrono::milliseconds(wait_ms)); + current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros; } - double model_input_time; - { - tf_shared_lock l(mu_); - model_input_time = SelfInputTime(); - } - model_->Optimize(dataset()->algorithm_, cpu_budget_, ram_budget_, - /*model_input_time=*/0); - // Exponentially increase the period of running the optimization - // until a threshold is reached. - if (optimization_period_ms != kOptimizationPeriodThresholdMs) { - optimization_period_ms = std::min(optimization_period_ms << 1, - kOptimizationPeriodThresholdMs); - } - current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros; - last_optimization_ms = current_time_ms; - model_->FlushMetrics(); + if (cancelled_) return; } - } - - void RecordInput(int64 time_nanos) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (last_output_time_ != 0) { - DCHECK_LE(last_output_time_, time_nanos); - input_time_ += time_nanos - last_output_time_; - num_input_events_++; + double model_input_time; + { + tf_shared_lock l(mu_); + model_input_time = SelfInputTime(); } - } - - void RecordOutput(int64 time_nanos) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - last_output_time_ = time_nanos; - } - - double SelfInputTime() const TF_SHARED_LOCKS_REQUIRED(mu_) { - if (num_input_events_ == 0) { - return 0; + model_->Optimize(dataset()->algorithm_, cpu_budget_, ram_budget_, + /*model_input_time=*/0); + // Exponentially increase the period of running the optimization + // until a threshold is reached. + if (optimization_period_ms != kOptimizationPeriodThresholdMs) { + optimization_period_ms = std::min(optimization_period_ms << 1, + kOptimizationPeriodThresholdMs); } - return static_cast(input_time_) / - static_cast(num_input_events_); + current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros; + last_optimization_ms = current_time_ms; + model_->FlushMetrics(); } + } - mutex mu_; - condition_variable cond_var_; - std::shared_ptr model_; - std::unique_ptr model_thread_ TF_GUARDED_BY(mu_); - bool cancelled_ TF_GUARDED_BY(mu_) = false; - std::unique_ptr input_impl_; - int64 num_input_events_ TF_GUARDED_BY(mu_) = 0; - int64 input_time_ TF_GUARDED_BY(mu_) = 0; - int64 last_output_time_ TF_GUARDED_BY(mu_) = 0; - const int64 cpu_budget_; - const int64 ram_budget_; - }; + void RecordInput(int64 time_nanos) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (last_output_time_ != 0) { + DCHECK_LE(last_output_time_, time_nanos); + input_time_ += time_nanos - last_output_time_; + num_input_events_++; + } + } - const DatasetBase* input_; - const model::AutotuneAlgorithm algorithm_; + void RecordOutput(int64 time_nanos) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + last_output_time_ = time_nanos; + } + + double SelfInputTime() const TF_SHARED_LOCKS_REQUIRED(mu_) { + if (num_input_events_ == 0) { + return 0; + } + return static_cast(input_time_) / + static_cast(num_input_events_); + } + + mutex mu_; + condition_variable cond_var_; + std::shared_ptr model_; + std::unique_ptr model_thread_ TF_GUARDED_BY(mu_); + bool cancelled_ TF_GUARDED_BY(mu_) = false; + std::unique_ptr input_impl_; + int64 num_input_events_ TF_GUARDED_BY(mu_) = 0; + int64 input_time_ TF_GUARDED_BY(mu_) = 0; + int64 last_output_time_ TF_GUARDED_BY(mu_) = 0; const int64 cpu_budget_; const int64 ram_budget_; }; - model::AutotuneAlgorithm algorithm_; - int64 cpu_budget_; - int64 ram_budget_; + const DatasetBase* input_; + const model::AutotuneAlgorithm algorithm_; + const int64 cpu_budget_; + const int64 ram_budget_; }; +ModelDatasetOp::ModelDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + if (ctx->HasAttr(kAlgorithm)) { + int64 algorithm; + OP_REQUIRES_OK(ctx, ctx->GetAttr(kAlgorithm, &algorithm)); + algorithm_ = model::AutotuneAlgorithm(algorithm); + } else { + algorithm_ = model::AutotuneAlgorithm::HILL_CLIMB; + } + OP_REQUIRES_OK(ctx, ctx->GetAttr(kCpuBudget, &cpu_budget_)); + OP_REQUIRES(ctx, cpu_budget_ >= 0, + errors::InvalidArgument("CPU budget must be positive but is ", + cpu_budget_, ".")); + if (ctx->HasAttr(kRamBudget)) { + OP_REQUIRES_OK(ctx, ctx->GetAttr(kRamBudget, &ram_budget_)); + } else { + ram_budget_ = 0; + } + OP_REQUIRES(ctx, ram_budget_ >= 0, + errors::InvalidArgument("RAM budget must be positive but is ", + ram_budget_, ".")); +} + +void ModelDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) { + *output = new ModelDatasetOp::Dataset(ctx, input, algorithm_, cpu_budget_, + ram_budget_); +} + +namespace { REGISTER_KERNEL_BUILDER(Name("ModelDataset").Device(DEVICE_CPU), ModelDatasetOp); } // namespace } // namespace data } // namespace tensorflow +#else // !IS_MOBILE_PLATFORM +namespace tensorflow { +namespace data { + +ModelDatasetOp::ModelDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + +void ModelDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) { + input->Ref(); + *output = input; +} + +namespace { +REGISTER_KERNEL_BUILDER(Name("ModelDataset").Device(DEVICE_CPU), + ModelDatasetOp); +} // namespace +} // namespace data +} // namespace tensorflow +#endif // !IS_MOBILE_PLATFORM diff --git a/tensorflow/core/kernels/data/model_dataset_op.h b/tensorflow/core/kernels/data/model_dataset_op.h new file mode 100644 index 00000000000..09935e36586 --- /dev/null +++ b/tensorflow/core/kernels/data/model_dataset_op.h @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_MODEL_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_MODEL_DATASET_OP_H_ + +#include "tensorflow/core/platform/platform.h" + +// On mobile we do not provide model dataset op because not all of its +// dependencies are available there. The op is replaced with a no-op. +#if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/model.h" + +namespace tensorflow { +namespace data { + +class ModelDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kAlgorithm = "algorithm"; + static constexpr const char* const kCpuBudget = "cpu_budget"; + static constexpr const char* const kRamBudget = "ram_budget"; + + explicit ModelDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + + model::AutotuneAlgorithm algorithm_; + int64 cpu_budget_; + int64 ram_budget_; +}; + +} // namespace data +} // namespace tensorflow +#else // !IS_MOBILE_PLATFORM +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class ModelDatasetOp : public UnaryDatasetOpKernel { + public: + explicit ModelDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; +}; + +} // namespace data +} // namespace tensorflow +#endif // !IS_MOBILE_PLATFORM + +#endif // TENSORFLOW_CORE_KERNELS_DATA_MODEL_DATASET_OP_H_ diff --git a/tensorflow/core/kernels/data/model_dataset_op_mobile.cc b/tensorflow/core/kernels/data/model_dataset_op_mobile.cc deleted file mode 100644 index 4de0e84130c..00000000000 --- a/tensorflow/core/kernels/data/model_dataset_op_mobile.cc +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/framework/dataset.h" - -namespace tensorflow { -namespace data { - -class ModelDatasetOp : public UnaryDatasetOpKernel { - public: - explicit ModelDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx) {} - - void MakeDataset(OpKernelContext* ctx, DatasetBase* input, - DatasetBase** output) { - input->Ref(); - *output = input; - } -}; - -namespace { -REGISTER_KERNEL_BUILDER(Name("ModelDataset").Device(DEVICE_CPU), - ModelDatasetOp); -} // namespace -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index 63cd1b9fd81..61383789f60 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/optimize_dataset_op.h" +// On mobile we do not provide optimize dataset op because not all of its +// dependencies are available there. The op is replaced with a no-op. +#if !defined(IS_MOBILE_PLATFORM) #include #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -27,9 +30,6 @@ limitations under the License. namespace tensorflow { namespace data { -// See documentation in ../../ops/dataset_ops.cc for a high-level -// description of the following op. - /* static */ constexpr const char* const OptimizeDatasetOp::kDatasetType; /* static */ constexpr const char* const OptimizeDatasetOp::kInputDataset; /* static */ constexpr const char* const OptimizeDatasetOp::kOptimizations; @@ -178,3 +178,25 @@ REGISTER_KERNEL_BUILDER(Name("OptimizeDatasetV2").Device(DEVICE_CPU), } // namespace } // namespace data } // namespace tensorflow +#else // !IS_MOBILE_PLATFORM +namespace tensorflow { +namespace data { + +OptimizeDatasetOp::OptimizeDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + +void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) { + input->Ref(); + *output = input; +} + +namespace { +REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU), + OptimizeDatasetOp); +REGISTER_KERNEL_BUILDER(Name("OptimizeDatasetV2").Device(DEVICE_CPU), + OptimizeDatasetOp); +} // namespace +} // namespace data +} // namespace tensorflow +#endif // !IS_MOBILE_PLATFORM diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.h b/tensorflow/core/kernels/data/optimize_dataset_op.h index d9e366f1ad5..a65cf588fef 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.h +++ b/tensorflow/core/kernels/data/optimize_dataset_op.h @@ -15,6 +15,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_DATA_OPTIMIZE_DATASET_OP_H_ #define TENSORFLOW_CORE_KERNELS_DATA_OPTIMIZE_DATASET_OP_H_ +#include "tensorflow/core/platform/platform.h" + +// On mobile we do not provide optimize dataset op because not all of its +// dependencies are available there. The op is replaced with a no-op. +#if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/framework/dataset.h" namespace tensorflow { @@ -54,5 +59,23 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { } // namespace data } // namespace tensorflow +#else // !IS_MOBILE_PLATFORM +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class OptimizeDatasetOp : public UnaryDatasetOpKernel { + public: + explicit OptimizeDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; +}; + +} // namespace data +} // namespace tensorflow +#endif // !IS_MOBILE_PLATFORM #endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIMIZE_DATASET_OP_H_ diff --git a/tensorflow/core/kernels/data/optimize_dataset_op_mobile.cc b/tensorflow/core/kernels/data/optimize_dataset_op_mobile.cc deleted file mode 100644 index 3679ae716a4..00000000000 --- a/tensorflow/core/kernels/data/optimize_dataset_op_mobile.cc +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/framework/dataset.h" - -namespace tensorflow { -namespace data { - -class OptimizeDatasetOp : public UnaryDatasetOpKernel { - public: - explicit OptimizeDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx) {} - - void MakeDataset(OpKernelContext* ctx, DatasetBase* input, - DatasetBase** output) { - input->Ref(); - *output = input; - } -}; - -namespace { -REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU), - OptimizeDatasetOp); -REGISTER_KERNEL_BUILDER(Name("OptimizeDatasetV2").Device(DEVICE_CPU), - OptimizeDatasetOp); -} // namespace -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/rewrite_utils.cc b/tensorflow/core/kernels/data/rewrite_utils.cc index c04c3445ec8..0ed3e1f75ec 100644 --- a/tensorflow/core/kernels/data/rewrite_utils.cc +++ b/tensorflow/core/kernels/data/rewrite_utils.cc @@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "tensorflow/core/kernels/data/rewrite_utils.h" +// On mobile we do not provide this functionality because not all of its +// dependencies are available there. +#if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/common_runtime/metrics.h" @@ -222,3 +224,4 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, } // namespace data } // namespace tensorflow +#endif // !IS_MOBILE_PLATFORM diff --git a/tensorflow/core/kernels/data/rewrite_utils.h b/tensorflow/core/kernels/data/rewrite_utils.h index aed878e79cf..0de413c77ed 100644 --- a/tensorflow/core/kernels/data/rewrite_utils.h +++ b/tensorflow/core/kernels/data/rewrite_utils.h @@ -15,6 +15,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_DATA_REWRITE_UTILS_H_ #define TENSORFLOW_CORE_KERNELS_DATA_REWRITE_UTILS_H_ +#include "tensorflow/core/platform/platform.h" + +// On mobile we do not provide this functionality because not all of its +// dependencies are available there. +#if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/function.h" @@ -31,5 +36,6 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, } // namespace data } // namespace tensorflow +#endif // !IS_MOBILE_PLATFORM #endif // TENSORFLOW_CORE_KERNELS_DATA_REWRITE_UTILS_H_