1. Rename BenchmarkParam and BenchmarkParams to ToolParam and ToolParams respectively, and branch them to the upper-level directory.
2. Also branch the benchmark/logging.h to the upper level directory. Note we still keep those in benchmark directory for backward compatibility but by reusing the new definitions here. These utility classes are helpful to and going to be used in other tflite-related tools. PiperOrigin-RevId: 305834516 Change-Id: Id809398b060698b641cad0c8ac5fe9ca22b9ab40
This commit is contained in:
parent
105e028d48
commit
020d0f0a49
@ -205,6 +205,31 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "logging",
|
||||
hdrs = ["logging.h"],
|
||||
copts = common_copts,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tool_params",
|
||||
srcs = ["tool_params.cc"],
|
||||
hdrs = ["tool_params.h"],
|
||||
copts = tflite_copts(),
|
||||
deps = [":logging"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tool_params_test",
|
||||
srcs = ["tool_params_test.cc"],
|
||||
copts = tflite_copts(),
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":tool_params",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "command_line_flags",
|
||||
srcs = ["command_line_flags.cc"],
|
||||
|
@ -17,6 +17,9 @@ cc_library(
|
||||
name = "logging",
|
||||
hdrs = ["logging.h"],
|
||||
copts = common_copts,
|
||||
deps = [
|
||||
"//tensorflow/lite/tools:logging",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
@ -107,6 +110,7 @@ cc_test(
|
||||
":benchmark_performance_options",
|
||||
":benchmark_tflite_model_lib",
|
||||
":delegate_provider_hdr",
|
||||
":logging",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite/testing:util",
|
||||
@ -128,6 +132,7 @@ cc_library(
|
||||
"//tensorflow/lite/profiling:profile_summarizer",
|
||||
"//tensorflow/lite/profiling:profile_summary_formatter",
|
||||
"//tensorflow/lite/profiling:profiler",
|
||||
"//tensorflow/lite/tools:logging",
|
||||
],
|
||||
)
|
||||
|
||||
@ -197,12 +202,9 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "benchmark_params",
|
||||
srcs = [
|
||||
"benchmark_params.cc",
|
||||
],
|
||||
hdrs = ["benchmark_params.h"],
|
||||
copts = common_copts,
|
||||
deps = [":logging"],
|
||||
deps = ["//tensorflow/lite/tools:tool_params"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -1,76 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/tools/benchmark/benchmark_params.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/tools/benchmark/logging.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace benchmark {
|
||||
|
||||
void BenchmarkParam::AssertHasSameType(BenchmarkParam::ParamType a,
|
||||
BenchmarkParam::ParamType b) {
|
||||
TFLITE_BENCHMARK_CHECK(a == b) << "Type mismatch while accessing parameter.";
|
||||
}
|
||||
|
||||
template <>
|
||||
BenchmarkParam::ParamType BenchmarkParam::GetValueType<int32_t>() {
|
||||
return BenchmarkParam::ParamType::TYPE_INT32;
|
||||
}
|
||||
|
||||
template <>
|
||||
BenchmarkParam::ParamType BenchmarkParam::GetValueType<bool>() {
|
||||
return BenchmarkParam::ParamType::TYPE_BOOL;
|
||||
}
|
||||
|
||||
template <>
|
||||
BenchmarkParam::ParamType BenchmarkParam::GetValueType<float>() {
|
||||
return BenchmarkParam::ParamType::TYPE_FLOAT;
|
||||
}
|
||||
|
||||
template <>
|
||||
BenchmarkParam::ParamType BenchmarkParam::GetValueType<std::string>() {
|
||||
return BenchmarkParam::ParamType::TYPE_STRING;
|
||||
}
|
||||
|
||||
void BenchmarkParams::AssertParamExists(const std::string& name) const {
|
||||
TFLITE_BENCHMARK_CHECK(HasParam(name)) << name << " was not found.";
|
||||
}
|
||||
|
||||
void BenchmarkParams::Set(const BenchmarkParams& other) {
|
||||
for (const auto& param : params_) {
|
||||
const BenchmarkParam* other_param = other.GetParam(param.first);
|
||||
if (other_param == nullptr) continue;
|
||||
param.second->Set(*other_param);
|
||||
}
|
||||
}
|
||||
|
||||
void BenchmarkParams::Merge(const BenchmarkParams& other, bool overwrite) {
|
||||
for (const auto& one : other.params_) {
|
||||
auto it = params_.find(one.first);
|
||||
if (it == params_.end()) {
|
||||
AddParam(one.first, one.second->Clone());
|
||||
} else if (overwrite) {
|
||||
it->second->Set(*one.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace benchmark
|
||||
} // namespace tflite
|
@ -15,123 +15,12 @@ limitations under the License.
|
||||
|
||||
#ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_
|
||||
#define TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/tools/benchmark/logging.h"
|
||||
#include "tensorflow/lite/tools/tool_params.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace benchmark {
|
||||
|
||||
template <typename T>
|
||||
class TypedBenchmarkParam;
|
||||
|
||||
class BenchmarkParam {
|
||||
protected:
|
||||
enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING };
|
||||
template <typename T>
|
||||
static ParamType GetValueType();
|
||||
|
||||
public:
|
||||
template <typename T>
|
||||
static std::unique_ptr<BenchmarkParam> Create(const T& default_value) {
|
||||
return std::unique_ptr<BenchmarkParam>(
|
||||
new TypedBenchmarkParam<T>(default_value));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TypedBenchmarkParam<T>* AsTyped() {
|
||||
AssertHasSameType(GetValueType<T>(), type_);
|
||||
return static_cast<TypedBenchmarkParam<T>*>(this);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const TypedBenchmarkParam<T>* AsConstTyped() const {
|
||||
AssertHasSameType(GetValueType<T>(), type_);
|
||||
return static_cast<const TypedBenchmarkParam<T>*>(this);
|
||||
}
|
||||
|
||||
virtual ~BenchmarkParam() {}
|
||||
explicit BenchmarkParam(ParamType type) : type_(type) {}
|
||||
|
||||
virtual void Set(const BenchmarkParam&) {}
|
||||
|
||||
virtual std::unique_ptr<BenchmarkParam> Clone() const = 0;
|
||||
|
||||
private:
|
||||
static void AssertHasSameType(ParamType a, ParamType b);
|
||||
|
||||
const ParamType type_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class TypedBenchmarkParam : public BenchmarkParam {
|
||||
public:
|
||||
explicit TypedBenchmarkParam(const T& value)
|
||||
: BenchmarkParam(GetValueType<T>()), value_(value) {}
|
||||
|
||||
void Set(const T& value) { value_ = value; }
|
||||
|
||||
T Get() const { return value_; }
|
||||
|
||||
void Set(const BenchmarkParam& other) override {
|
||||
Set(other.AsConstTyped<T>()->Get());
|
||||
}
|
||||
|
||||
std::unique_ptr<BenchmarkParam> Clone() const override {
|
||||
return std::unique_ptr<BenchmarkParam>(new TypedBenchmarkParam<T>(value_));
|
||||
}
|
||||
|
||||
private:
|
||||
T value_;
|
||||
};
|
||||
|
||||
class BenchmarkParams {
|
||||
public:
|
||||
void AddParam(const std::string& name,
|
||||
std::unique_ptr<BenchmarkParam> value) {
|
||||
params_[name] = std::move(value);
|
||||
}
|
||||
|
||||
bool HasParam(const std::string& name) const {
|
||||
return params_.find(name) != params_.end();
|
||||
}
|
||||
|
||||
bool Empty() const { return params_.empty(); }
|
||||
|
||||
const BenchmarkParam* GetParam(const std::string& name) const {
|
||||
const auto& entry = params_.find(name);
|
||||
if (entry == params_.end()) return nullptr;
|
||||
return entry->second.get();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Set(const std::string& name, const T& value) {
|
||||
AssertParamExists(name);
|
||||
params_.at(name)->AsTyped<T>()->Set(value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T Get(const std::string& name) const {
|
||||
AssertParamExists(name);
|
||||
return params_.at(name)->AsTyped<T>()->Get();
|
||||
}
|
||||
|
||||
// Set the value of all same parameters from 'other'.
|
||||
void Set(const BenchmarkParams& other);
|
||||
|
||||
// Merge the value of all parameters from 'other'. 'overwrite' indicates
|
||||
// whether the value of the same paratmeter is overwrite or not.
|
||||
void Merge(const BenchmarkParams& other, bool overwrite = false);
|
||||
|
||||
private:
|
||||
void AssertParamExists(const std::string& name) const;
|
||||
std::unordered_map<std::string, std::unique_ptr<BenchmarkParam>> params_;
|
||||
};
|
||||
|
||||
using BenchmarkParam = tflite::tools::ToolParam;
|
||||
using BenchmarkParams = tflite::tools::ToolParams;
|
||||
} // namespace benchmark
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/tools/benchmark/benchmark_performance_options.h"
|
||||
#include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h"
|
||||
#include "tensorflow/lite/tools/benchmark/delegate_provider.h"
|
||||
#include "tensorflow/lite/tools/benchmark/logging.h"
|
||||
#include "tensorflow/lite/tools/command_line_flags.h"
|
||||
|
||||
namespace {
|
||||
|
@ -16,74 +16,10 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_LOGGING_H_
|
||||
#define TENSORFLOW_LITE_TOOLS_BENCHMARK_LOGGING_H_
|
||||
|
||||
// LOG and CHECK macros for benchmarks.
|
||||
// TODO(b/149482807): completely remove this file from the code base.
|
||||
#include "tensorflow/lite/tools/logging.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#ifdef _WIN32
|
||||
#undef ERROR
|
||||
#endif
|
||||
|
||||
namespace tflite {
|
||||
namespace logging {
|
||||
// A wrapper that logs to stderr.
|
||||
//
|
||||
// Used for TFLITE_LOG and TFLITE_BENCHMARK_CHECK macros.
|
||||
class LoggingWrapper {
|
||||
public:
|
||||
enum class LogSeverity : int {
|
||||
INFO = 0,
|
||||
WARN = 1,
|
||||
ERROR = 2,
|
||||
FATAL = 3,
|
||||
};
|
||||
LoggingWrapper(LogSeverity severity)
|
||||
: severity_(severity), should_log_(true) {}
|
||||
LoggingWrapper(LogSeverity severity, bool log)
|
||||
: severity_(severity), should_log_(log) {}
|
||||
std::stringstream& Stream() { return stream_; }
|
||||
~LoggingWrapper() {
|
||||
if (should_log_) {
|
||||
switch (severity_) {
|
||||
case LogSeverity::INFO:
|
||||
case LogSeverity::WARN:
|
||||
std::cout << stream_.str() << std::endl;
|
||||
break;
|
||||
case LogSeverity::ERROR:
|
||||
std::cerr << stream_.str() << std::endl;
|
||||
break;
|
||||
case LogSeverity::FATAL:
|
||||
std::cerr << stream_.str() << std::endl;
|
||||
std::flush(std::cerr);
|
||||
std::abort();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::stringstream stream_;
|
||||
LogSeverity severity_;
|
||||
bool should_log_;
|
||||
};
|
||||
|
||||
} // namespace logging
|
||||
|
||||
} // namespace tflite
|
||||
|
||||
#define TFLITE_LOG(severity) \
|
||||
tflite::logging::LoggingWrapper( \
|
||||
tflite::logging::LoggingWrapper::LogSeverity::severity) \
|
||||
.Stream()
|
||||
|
||||
#define TFLITE_BENCHMARK_CHECK(condition) \
|
||||
tflite::logging::LoggingWrapper( \
|
||||
tflite::logging::LoggingWrapper::LogSeverity::FATAL, \
|
||||
(condition) ? false : true) \
|
||||
.Stream()
|
||||
|
||||
#define TFLITE_BENCHMARK_CHECK_EQ(a, b) TFLITE_BENCHMARK_CHECK(a == b)
|
||||
#define TFLITE_BENCHMARK_CHECK(condition) TFLITE_TOOLS_CHECK(condition)
|
||||
#define TFLITE_BENCHMARK_CHECK_EQ(a, b) TFLITE_TOOLS_CHECK(a == b)
|
||||
|
||||
#endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_LOGGING_H_
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "tensorflow/lite/tools/logging.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace benchmark {
|
||||
|
||||
@ -29,7 +31,7 @@ ProfilingListener::ProfilingListener(
|
||||
csv_file_path_(csv_file_path),
|
||||
interpreter_(interpreter),
|
||||
profiler_(max_num_entries) {
|
||||
TFLITE_BENCHMARK_CHECK(interpreter);
|
||||
TFLITE_TOOLS_CHECK(interpreter);
|
||||
interpreter_->SetProfiler(&profiler_);
|
||||
|
||||
// We start profiling here in order to catch events that are recorded during
|
||||
|
87
tensorflow/lite/tools/logging.h
Normal file
87
tensorflow/lite/tools/logging.h
Normal file
@ -0,0 +1,87 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_TOOLS_LOGGING_H_
|
||||
#define TENSORFLOW_LITE_TOOLS_LOGGING_H_
|
||||
|
||||
// LOG and CHECK macros for tflite tooling.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#ifdef _WIN32
|
||||
#undef ERROR
|
||||
#endif
|
||||
|
||||
namespace tflite {
|
||||
namespace logging {
|
||||
// A wrapper that logs to stderr.
|
||||
//
|
||||
// Used for TFLITE_LOG and TFLITE_BENCHMARK_CHECK macros.
|
||||
class LoggingWrapper {
|
||||
public:
|
||||
enum class LogSeverity : int {
|
||||
INFO = 0,
|
||||
WARN = 1,
|
||||
ERROR = 2,
|
||||
FATAL = 3,
|
||||
};
|
||||
LoggingWrapper(LogSeverity severity)
|
||||
: severity_(severity), should_log_(true) {}
|
||||
LoggingWrapper(LogSeverity severity, bool log)
|
||||
: severity_(severity), should_log_(log) {}
|
||||
std::stringstream& Stream() { return stream_; }
|
||||
~LoggingWrapper() {
|
||||
if (should_log_) {
|
||||
switch (severity_) {
|
||||
case LogSeverity::INFO:
|
||||
case LogSeverity::WARN:
|
||||
std::cout << stream_.str() << std::endl;
|
||||
break;
|
||||
case LogSeverity::ERROR:
|
||||
std::cerr << stream_.str() << std::endl;
|
||||
break;
|
||||
case LogSeverity::FATAL:
|
||||
std::cerr << stream_.str() << std::endl;
|
||||
std::flush(std::cerr);
|
||||
std::abort();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::stringstream stream_;
|
||||
LogSeverity severity_;
|
||||
bool should_log_;
|
||||
};
|
||||
} // namespace logging
|
||||
} // namespace tflite
|
||||
|
||||
#define TFLITE_LOG(severity) \
|
||||
tflite::logging::LoggingWrapper( \
|
||||
tflite::logging::LoggingWrapper::LogSeverity::severity) \
|
||||
.Stream()
|
||||
|
||||
#define TFLITE_TOOLS_CHECK(condition) \
|
||||
tflite::logging::LoggingWrapper( \
|
||||
tflite::logging::LoggingWrapper::LogSeverity::FATAL, \
|
||||
(condition) ? false : true) \
|
||||
.Stream()
|
||||
|
||||
#define TFLITE_TOOLS_CHECK_EQ(a, b) TFLITE_TOOLS_CHECK((a) == (b))
|
||||
|
||||
#endif // TENSORFLOW_LITE_TOOLS_LOGGING_H_
|
@ -111,7 +111,8 @@ PROFILE_SUMMARIZER_SRCS := \
|
||||
tensorflow/core/util/stats_calculator.cc
|
||||
|
||||
CMD_LINE_TOOLS_SRCS := \
|
||||
tensorflow/lite/tools/command_line_flags.cc
|
||||
tensorflow/lite/tools/command_line_flags.cc \
|
||||
tensorflow/lite/tools/tool_params.cc
|
||||
|
||||
CORE_CC_ALL_SRCS := \
|
||||
$(wildcard tensorflow/lite/*.cc) \
|
||||
|
76
tensorflow/lite/tools/tool_params.cc
Normal file
76
tensorflow/lite/tools/tool_params.cc
Normal file
@ -0,0 +1,76 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/tools/tool_params.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/tools/logging.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace tools {
|
||||
|
||||
void ToolParam::AssertHasSameType(ToolParam::ParamType a,
|
||||
ToolParam::ParamType b) {
|
||||
TFLITE_TOOLS_CHECK(a == b) << "Type mismatch while accessing parameter.";
|
||||
}
|
||||
|
||||
template <>
|
||||
ToolParam::ParamType ToolParam::GetValueType<int32_t>() {
|
||||
return ToolParam::ParamType::TYPE_INT32;
|
||||
}
|
||||
|
||||
template <>
|
||||
ToolParam::ParamType ToolParam::GetValueType<bool>() {
|
||||
return ToolParam::ParamType::TYPE_BOOL;
|
||||
}
|
||||
|
||||
template <>
|
||||
ToolParam::ParamType ToolParam::GetValueType<float>() {
|
||||
return ToolParam::ParamType::TYPE_FLOAT;
|
||||
}
|
||||
|
||||
template <>
|
||||
ToolParam::ParamType ToolParam::GetValueType<std::string>() {
|
||||
return ToolParam::ParamType::TYPE_STRING;
|
||||
}
|
||||
|
||||
void ToolParams::AssertParamExists(const std::string& name) const {
|
||||
TFLITE_TOOLS_CHECK(HasParam(name)) << name << " was not found.";
|
||||
}
|
||||
|
||||
void ToolParams::Set(const ToolParams& other) {
|
||||
for (const auto& param : params_) {
|
||||
const ToolParam* other_param = other.GetParam(param.first);
|
||||
if (other_param == nullptr) continue;
|
||||
param.second->Set(*other_param);
|
||||
}
|
||||
}
|
||||
|
||||
void ToolParams::Merge(const ToolParams& other, bool overwrite) {
|
||||
for (const auto& one : other.params_) {
|
||||
auto it = params_.find(one.first);
|
||||
if (it == params_.end()) {
|
||||
AddParam(one.first, one.second->Clone());
|
||||
} else if (overwrite) {
|
||||
it->second->Set(*one.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tools
|
||||
} // namespace tflite
|
134
tensorflow/lite/tools/tool_params.h
Normal file
134
tensorflow/lite/tools/tool_params.h
Normal file
@ -0,0 +1,134 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_
|
||||
#define TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace tflite {
|
||||
namespace tools {
|
||||
|
||||
template <typename T>
|
||||
class TypedToolParam;
|
||||
|
||||
class ToolParam {
|
||||
protected:
|
||||
enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING };
|
||||
template <typename T>
|
||||
static ParamType GetValueType();
|
||||
|
||||
public:
|
||||
template <typename T>
|
||||
static std::unique_ptr<ToolParam> Create(const T& default_value) {
|
||||
return std::unique_ptr<ToolParam>(new TypedToolParam<T>(default_value));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TypedToolParam<T>* AsTyped() {
|
||||
AssertHasSameType(GetValueType<T>(), type_);
|
||||
return static_cast<TypedToolParam<T>*>(this);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const TypedToolParam<T>* AsConstTyped() const {
|
||||
AssertHasSameType(GetValueType<T>(), type_);
|
||||
return static_cast<const TypedToolParam<T>*>(this);
|
||||
}
|
||||
|
||||
virtual ~ToolParam() {}
|
||||
explicit ToolParam(ParamType type) : type_(type) {}
|
||||
|
||||
virtual void Set(const ToolParam&) {}
|
||||
|
||||
virtual std::unique_ptr<ToolParam> Clone() const = 0;
|
||||
|
||||
private:
|
||||
static void AssertHasSameType(ParamType a, ParamType b);
|
||||
|
||||
const ParamType type_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class TypedToolParam : public ToolParam {
|
||||
public:
|
||||
explicit TypedToolParam(const T& value)
|
||||
: ToolParam(GetValueType<T>()), value_(value) {}
|
||||
|
||||
void Set(const T& value) { value_ = value; }
|
||||
|
||||
T Get() const { return value_; }
|
||||
|
||||
void Set(const ToolParam& other) override {
|
||||
Set(other.AsConstTyped<T>()->Get());
|
||||
}
|
||||
|
||||
std::unique_ptr<ToolParam> Clone() const override {
|
||||
return std::unique_ptr<ToolParam>(new TypedToolParam<T>(value_));
|
||||
}
|
||||
|
||||
private:
|
||||
T value_;
|
||||
};
|
||||
|
||||
// A map-like container for holding values of different types.
|
||||
class ToolParams {
|
||||
public:
|
||||
void AddParam(const std::string& name, std::unique_ptr<ToolParam> value) {
|
||||
params_[name] = std::move(value);
|
||||
}
|
||||
|
||||
bool HasParam(const std::string& name) const {
|
||||
return params_.find(name) != params_.end();
|
||||
}
|
||||
|
||||
bool Empty() const { return params_.empty(); }
|
||||
|
||||
const ToolParam* GetParam(const std::string& name) const {
|
||||
const auto& entry = params_.find(name);
|
||||
if (entry == params_.end()) return nullptr;
|
||||
return entry->second.get();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Set(const std::string& name, const T& value) {
|
||||
AssertParamExists(name);
|
||||
params_.at(name)->AsTyped<T>()->Set(value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T Get(const std::string& name) const {
|
||||
AssertParamExists(name);
|
||||
return params_.at(name)->AsTyped<T>()->Get();
|
||||
}
|
||||
|
||||
// Set the value of all same parameters from 'other'.
|
||||
void Set(const ToolParams& other);
|
||||
|
||||
// Merge the value of all parameters from 'other'. 'overwrite' indicates
|
||||
// whether the value of the same paratmeter is overwritten or not.
|
||||
void Merge(const ToolParams& other, bool overwrite = false);
|
||||
|
||||
private:
|
||||
void AssertParamExists(const std::string& name) const;
|
||||
std::unordered_map<std::string, std::unique_ptr<ToolParam>> params_;
|
||||
};
|
||||
|
||||
} // namespace tools
|
||||
} // namespace tflite
|
||||
#endif // TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_
|
71
tensorflow/lite/tools/tool_params_test.cc
Normal file
71
tensorflow/lite/tools/tool_params_test.cc
Normal file
@ -0,0 +1,71 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/tools/tool_params.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace tflite {
|
||||
namespace tools {
|
||||
namespace {
|
||||
|
||||
TEST(ToolParams, SetTest) {
|
||||
ToolParams params;
|
||||
params.AddParam("some-int1", ToolParam::Create<int>(13));
|
||||
params.AddParam("some-int2", ToolParam::Create<int>(17));
|
||||
|
||||
ToolParams others;
|
||||
others.AddParam("some-int1", ToolParam::Create<int>(19));
|
||||
others.AddParam("some-bool", ToolParam::Create<bool>(true));
|
||||
|
||||
params.Set(others);
|
||||
EXPECT_EQ(19, params.Get<int>("some-int1"));
|
||||
EXPECT_EQ(17, params.Get<int>("some-int2"));
|
||||
EXPECT_FALSE(params.HasParam("some-bool"));
|
||||
}
|
||||
|
||||
TEST(ToolParams, MergeTestOverwriteTrue) {
|
||||
ToolParams params;
|
||||
params.AddParam("some-int1", ToolParam::Create<int>(13));
|
||||
params.AddParam("some-int2", ToolParam::Create<int>(17));
|
||||
|
||||
ToolParams others;
|
||||
others.AddParam("some-int1", ToolParam::Create<int>(19));
|
||||
others.AddParam("some-bool", ToolParam::Create<bool>(true));
|
||||
|
||||
params.Merge(others, true /* overwrite */);
|
||||
EXPECT_EQ(19, params.Get<int>("some-int1"));
|
||||
EXPECT_EQ(17, params.Get<int>("some-int2"));
|
||||
EXPECT_TRUE(params.Get<bool>("some-bool"));
|
||||
}
|
||||
|
||||
TEST(ToolParams, MergeTestOverwriteFalse) {
|
||||
ToolParams params;
|
||||
params.AddParam("some-int1", ToolParam::Create<int>(13));
|
||||
params.AddParam("some-int2", ToolParam::Create<int>(17));
|
||||
|
||||
ToolParams others;
|
||||
others.AddParam("some-int1", ToolParam::Create<int>(19));
|
||||
others.AddParam("some-bool", ToolParam::Create<bool>(true));
|
||||
|
||||
params.Merge(others); // default overwrite is false
|
||||
EXPECT_EQ(13, params.Get<int>("some-int1"));
|
||||
EXPECT_EQ(17, params.Get<int>("some-int2"));
|
||||
EXPECT_TRUE(params.Get<bool>("some-bool"));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tools
|
||||
} // namespace tflite
|
Loading…
Reference in New Issue
Block a user