Internal Change
PiperOrigin-RevId: 237073663
This commit is contained in:
parent
7a415783c6
commit
e70b131052
@ -34,23 +34,12 @@ cc_library(
|
||||
hdrs = ["evaluation_stage.h"],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_proto_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "evaluation_stage_factory",
|
||||
hdrs = ["evaluation_stage_factory.h"],
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
":evaluation_stage",
|
||||
":identity_stage",
|
||||
"//tensorflow/core:tflite_portable_logging",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_proto_cc",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_proto_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -61,26 +50,10 @@ cc_library(
|
||||
copts = tflite_copts(),
|
||||
deps = [
|
||||
":evaluation_stage",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_proto_cc",
|
||||
] + select(
|
||||
{
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:ops",
|
||||
],
|
||||
},
|
||||
),
|
||||
alwayslink = 1,
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_proto_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
@ -88,31 +61,12 @@ tf_cc_test(
|
||||
srcs = ["evaluation_stage_test.cc"],
|
||||
linkopts = common_linkopts,
|
||||
linkstatic = 1,
|
||||
tags = [
|
||||
"tflite_not_portable_android",
|
||||
"tflite_not_portable_ios",
|
||||
],
|
||||
deps = [
|
||||
":evaluation_stage",
|
||||
":evaluation_stage_factory",
|
||||
":identity_stage",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_proto_cc",
|
||||
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_proto_cc",
|
||||
] + select(
|
||||
{
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:android_tensorflow_lib",
|
||||
"//tensorflow/core:android_tensorflow_test_lib",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
},
|
||||
),
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
@ -73,5 +73,9 @@ bool EvaluationStage::ProcessExpectedTags(
|
||||
return true;
|
||||
}
|
||||
|
||||
std::map<ProcessClass, FactoryFunc>*
|
||||
EvaluationStage::process_class_to_factory_map_ =
|
||||
new std::map<ProcessClass, FactoryFunc>();
|
||||
|
||||
} // namespace evaluation
|
||||
} // namespace tflite
|
||||
|
@ -15,16 +15,27 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_H_
|
||||
#define TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <regex> // NOLINT
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace evaluation {
|
||||
|
||||
class EvaluationStage;
|
||||
|
||||
typedef std::function<std::unique_ptr<EvaluationStage>(
|
||||
const EvaluationStageConfig&)>
|
||||
FactoryFunc;
|
||||
|
||||
// Superclass for a single stage of an EvaluationPipeline.
|
||||
// Provides basic functionality for construction and accessing
|
||||
// initializers/inputs/outputs.
|
||||
@ -43,17 +54,42 @@ class EvaluationStage {
|
||||
bool Init(absl::flat_hash_map<std::string, void*>& object_map);
|
||||
|
||||
// An individual run of the EvaluationStage. Returns false if there was a
|
||||
// failure, true otherwise. Populates metrics into the EvaluationStageMetrics
|
||||
// proto.
|
||||
// failure, true otherwise.
|
||||
// Init() should be called before any calls to run().
|
||||
// Inputs are acquired from and outputs are written to the incoming
|
||||
// object_map, using appropriate TAGs.
|
||||
//
|
||||
// NOTE: The EvaluationStage should maintain ownership of outputs it
|
||||
// populates into object_map. Ownership of inputs will be maintained
|
||||
// populates into object_map. Ownership of inputs must be maintained
|
||||
// elsewhere.
|
||||
virtual bool Run(absl::flat_hash_map<std::string, void*>& object_map,
|
||||
EvaluationStageMetrics& metrics) = 0;
|
||||
virtual bool Run(absl::flat_hash_map<std::string, void*>& object_map) = 0;
|
||||
|
||||
// Returns the latest metrics based on all Run() calls made so far.
|
||||
virtual EvaluationStageMetrics LatestMetrics() = 0;
|
||||
|
||||
// The canonical way to instantiate EvaluationStages.
|
||||
// Remember to call <classname>_ENABLE() first.
|
||||
static std::unique_ptr<EvaluationStage> Create(
|
||||
const EvaluationStageConfig& config) {
|
||||
if (!config.has_specification() ||
|
||||
!config.specification().has_process_class()) {
|
||||
LOG(ERROR) << "Process specification not present in config: "
|
||||
<< config.name();
|
||||
return nullptr;
|
||||
}
|
||||
auto& factory_ptr =
|
||||
(*GetFactoryMapPtr())[config.specification().process_class()];
|
||||
if (!factory_ptr) return nullptr;
|
||||
return factory_ptr(config);
|
||||
}
|
||||
|
||||
// Used by DEFINE_REGISTRATION.
|
||||
// This method takes ownership of factory.
|
||||
// Should only be used via DEFINE_REGISTRATION macro.
|
||||
static void RegisterStage(const ProcessClass& process_class,
|
||||
FactoryFunc class_factory) {
|
||||
(*GetFactoryMapPtr())[process_class] = std::move(class_factory);
|
||||
}
|
||||
|
||||
virtual ~EvaluationStage() = default;
|
||||
|
||||
@ -62,8 +98,7 @@ class EvaluationStage {
|
||||
// Each subclass constructor must invoke this constructor.
|
||||
//
|
||||
// NOTE: Do NOT use constructors to obtain new EvaluationStages. Use
|
||||
// tflite::evaluation::GetEvaluationStageFromConfig from
|
||||
// evaluation_stage_factory.h instead.
|
||||
// EvaluationStage::Create instead.
|
||||
explicit EvaluationStage(const EvaluationStageConfig& config)
|
||||
: config_(config) {}
|
||||
|
||||
@ -90,8 +125,9 @@ class EvaluationStage {
|
||||
|
||||
// Populates a pointer to the object corresponding to provided TAG.
|
||||
// Returns true if success, false otherwise.
|
||||
// object_map must contain {name : object pointer} mappings, with one of the
|
||||
// names being mapped to the expected TAG in the EvaluationStageConfig.
|
||||
// object_map contain a {name : object pointer} mapping, with the
|
||||
// name being mapped to the expected TAG in the EvaluationStageConfig.
|
||||
// NOTE: object pointer must be non-NULL.
|
||||
template <class T>
|
||||
bool GetObjectFromTag(const std::string& tag,
|
||||
absl::flat_hash_map<std::string, void*>& object_map,
|
||||
@ -100,7 +136,7 @@ class EvaluationStage {
|
||||
// Find name corresponding to TAG.
|
||||
auto mapping_iter = tags_to_names_map_.find(tag);
|
||||
if (mapping_iter == tags_to_names_map_.end()) {
|
||||
LOG(ERROR) << "Unexpected TAG: " << tag;
|
||||
LOG(ERROR) << "Unexpected TAG to GetObjectFromTag: " << tag;
|
||||
return false;
|
||||
}
|
||||
const std::string& expected_name = mapping_iter->second;
|
||||
@ -111,6 +147,10 @@ class EvaluationStage {
|
||||
LOG(ERROR) << "Could not find object for name: " << expected_name;
|
||||
return false;
|
||||
}
|
||||
if (!object_iter->second) {
|
||||
LOG(ERROR) << "Found null pointer for name: " << expected_name;
|
||||
return false;
|
||||
}
|
||||
*object_ptr = static_cast<T*>(object_iter->second);
|
||||
return true;
|
||||
}
|
||||
@ -126,7 +166,7 @@ class EvaluationStage {
|
||||
// Find name corresponding to TAG.
|
||||
auto mapping_iter = tags_to_names_map_.find(tag);
|
||||
if (mapping_iter == tags_to_names_map_.end()) {
|
||||
LOG(ERROR) << "Unexpected TAG: " << tag;
|
||||
LOG(ERROR) << "Unexpected TAG to AssignObjectToTag: " << tag;
|
||||
return false;
|
||||
}
|
||||
const std::string& expected_name = mapping_iter->second;
|
||||
@ -135,7 +175,7 @@ class EvaluationStage {
|
||||
return true;
|
||||
}
|
||||
|
||||
const EvaluationStageConfig config_;
|
||||
EvaluationStageConfig config_;
|
||||
|
||||
private:
|
||||
// Verifies that all TAGs from expected_tags are present in
|
||||
@ -148,6 +188,13 @@ class EvaluationStage {
|
||||
bool ProcessExpectedTags(const std::vector<std::string>& expected_tags,
|
||||
std::vector<std::string>& tag_to_name_mappings);
|
||||
|
||||
static std::map<ProcessClass, FactoryFunc>* GetFactoryMapPtr() {
|
||||
return process_class_to_factory_map_;
|
||||
}
|
||||
|
||||
// Used by factories.
|
||||
static std::map<ProcessClass, FactoryFunc>* process_class_to_factory_map_;
|
||||
|
||||
// Maps expected TAGs to their names as defined by the EvaluationStageConfig.
|
||||
absl::flat_hash_map<std::string, std::string> tags_to_names_map_;
|
||||
|
||||
@ -159,6 +206,34 @@ class EvaluationStage {
|
||||
const std::regex kTagPattern{"^[A-Z0-9_]+$", std::regex::optimize};
|
||||
};
|
||||
|
||||
// Add this to headers of new EvaluationStages.
|
||||
#define DECLARE_FACTORY(classname) void classname##_ENABLE();
|
||||
|
||||
// Add this to implementation files of new EvaluationStages.
|
||||
// Call <stage_name>_ENABLE() before using EvaluationStage::Create for the
|
||||
// class.
|
||||
#define DEFINE_FACTORY(classname, processclass) \
|
||||
void classname##_ENABLE() { \
|
||||
FactoryFunc classname##Factory = [](const EvaluationStageConfig& config) { \
|
||||
return absl::make_unique<classname>(config); \
|
||||
}; \
|
||||
EvaluationStage::RegisterStage(processclass, classname##Factory); \
|
||||
}
|
||||
|
||||
// Use this to assign a non-nullptr pointer to tag in object_map.
|
||||
#define ASSIGN_OBJECT(tag, ptr, object_map) \
|
||||
if (!AssignObjectToTag(tag, ptr, object_map)) { \
|
||||
return false; \
|
||||
}
|
||||
|
||||
// Use this to obtain pointers to required object.
|
||||
// Will return false if name corresponding to tag is not found, or if the
|
||||
// pointer found is nullptr.
|
||||
#define GET_OBJECT(tag, object_map, location) \
|
||||
if (!GetObjectFromTag(tag, object_map, location)) { \
|
||||
return false; \
|
||||
}
|
||||
|
||||
} // namespace evaluation
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -1,48 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_FACTORY_H_
|
||||
#define TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_FACTORY_H_
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
|
||||
#include "tensorflow/lite/tools/evaluation/identity_stage.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace evaluation {
|
||||
|
||||
// The canonical way to generate EvaluationStages.
|
||||
// TODO(b/122482115): Implement a Factory class for registration of classes.
|
||||
std::unique_ptr<EvaluationStage> CreateEvaluationStageFromConfig(
|
||||
const EvaluationStageConfig& config) {
|
||||
if (!config.has_specification() ||
|
||||
!config.specification().has_process_class()) {
|
||||
LOG(ERROR) << "Process specification not present in config: "
|
||||
<< config.name();
|
||||
return nullptr;
|
||||
}
|
||||
switch (config.specification().process_class()) {
|
||||
case UNKNOWN:
|
||||
return nullptr;
|
||||
case IDENTITY:
|
||||
return absl::make_unique<IdentityStage>(config);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace evaluation
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_FACTORY_H_
|
@ -18,9 +18,6 @@ limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_stage_factory.h"
|
||||
#include "tensorflow/lite/tools/evaluation/identity_stage.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
@ -29,18 +26,16 @@ namespace tflite {
|
||||
namespace evaluation {
|
||||
namespace {
|
||||
|
||||
using ::tensorflow::DataType;
|
||||
using ::tensorflow::Tensor;
|
||||
|
||||
constexpr char kIdentityStageName[] = "identity_stage";
|
||||
constexpr char kInputTypeName[] = "type";
|
||||
constexpr char kInputTensorsName[] = "in";
|
||||
constexpr char kOutputTensorsName[] = "out";
|
||||
constexpr char kInitializerMapping[] = "INPUT_TYPE:type";
|
||||
constexpr char kInputMapping[] = "INPUT_TENSORS:in";
|
||||
constexpr char kOutputMapping[] = "OUTPUT_TENSORS:out";
|
||||
constexpr char kDefaultValueName[] = "default";
|
||||
constexpr char kInputValueName[] = "in";
|
||||
constexpr char kOutputValueName[] = "out";
|
||||
constexpr char kInitializerMapping[] = "DEFAULT_VALUE:default";
|
||||
constexpr char kInputMapping[] = "INPUT_VALUE:in";
|
||||
constexpr char kOutputMapping[] = "OUTPUT_VALUE:out";
|
||||
|
||||
EvaluationStageConfig GetIdentityStageConfig() {
|
||||
IdentityStage_ENABLE();
|
||||
EvaluationStageConfig config;
|
||||
config.set_name(kIdentityStageName);
|
||||
config.mutable_specification()->set_process_class(IDENTITY);
|
||||
@ -50,16 +45,39 @@ EvaluationStageConfig GetIdentityStageConfig() {
|
||||
return config;
|
||||
}
|
||||
|
||||
TEST(EvaluationStage, CreateFailsForMissingSpecification) {
|
||||
// Construct
|
||||
EvaluationStageConfig config;
|
||||
config.set_name(kIdentityStageName);
|
||||
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
|
||||
EXPECT_EQ(stage_ptr, nullptr);
|
||||
}
|
||||
|
||||
TEST(EvaluationStage, StageEnableRequired) {
|
||||
// Construct
|
||||
EvaluationStageConfig config;
|
||||
config.set_name(kIdentityStageName);
|
||||
config.mutable_specification()->set_process_class(IDENTITY);
|
||||
config.add_initializers(kInitializerMapping);
|
||||
config.add_inputs(kInputMapping);
|
||||
config.add_outputs(kOutputMapping);
|
||||
config.clear_inputs();
|
||||
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
|
||||
EXPECT_EQ(stage_ptr, nullptr);
|
||||
IdentityStage_ENABLE();
|
||||
stage_ptr = EvaluationStage::Create(config);
|
||||
EXPECT_NE(stage_ptr, nullptr);
|
||||
}
|
||||
|
||||
TEST(EvaluationStage, IncompleteConfig) {
|
||||
// Construct
|
||||
EvaluationStageConfig config = GetIdentityStageConfig();
|
||||
config.clear_inputs();
|
||||
std::unique_ptr<EvaluationStage> stage_ptr =
|
||||
CreateEvaluationStageFromConfig(config);
|
||||
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
|
||||
// Initialize
|
||||
absl::flat_hash_map<std::string, void*> object_map;
|
||||
DataType input_type = tensorflow::DT_FLOAT;
|
||||
object_map[kInputTypeName] = &input_type;
|
||||
float default_value = 1.0;
|
||||
object_map[kDefaultValueName] = &default_value;
|
||||
EXPECT_FALSE(stage_ptr->Init(object_map));
|
||||
}
|
||||
|
||||
@ -67,13 +85,12 @@ TEST(EvaluationStage, IncorrectlyFormattedConfig) {
|
||||
// Construct
|
||||
EvaluationStageConfig config = GetIdentityStageConfig();
|
||||
config.clear_initializers();
|
||||
config.add_initializers("INPUT_TYPE-type");
|
||||
std::unique_ptr<EvaluationStage> stage_ptr =
|
||||
CreateEvaluationStageFromConfig(config);
|
||||
config.add_initializers("DEFAULT_VALUE-default");
|
||||
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
|
||||
// Initialize
|
||||
absl::flat_hash_map<std::string, void*> object_map;
|
||||
DataType input_type = tensorflow::DT_FLOAT;
|
||||
object_map[kInputTypeName] = &input_type;
|
||||
float default_value = 1.0;
|
||||
object_map[kDefaultValueName] = &default_value;
|
||||
EXPECT_FALSE(stage_ptr->Init(object_map));
|
||||
}
|
||||
|
||||
@ -81,63 +98,76 @@ TEST(EvaluationStage, ConstructFromConfig_UnknownProcess) {
|
||||
// Construct
|
||||
EvaluationStageConfig config = GetIdentityStageConfig();
|
||||
config.mutable_specification()->clear_process_class();
|
||||
std::unique_ptr<EvaluationStage> stage_ptr =
|
||||
CreateEvaluationStageFromConfig(config);
|
||||
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
|
||||
EXPECT_EQ(stage_ptr.get(), nullptr);
|
||||
}
|
||||
|
||||
TEST(EvaluationStage, NoInitializer) {
|
||||
// Construct
|
||||
EvaluationStageConfig config = GetIdentityStageConfig();
|
||||
std::unique_ptr<EvaluationStage> stage_ptr =
|
||||
CreateEvaluationStageFromConfig(config);
|
||||
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
|
||||
// Initialize
|
||||
absl::flat_hash_map<std::string, void*> object_map;
|
||||
EXPECT_FALSE(stage_ptr->Init(object_map));
|
||||
}
|
||||
|
||||
TEST(EvaluationStage, InvalidInitializer) {
|
||||
// Construct
|
||||
EvaluationStageConfig config = GetIdentityStageConfig();
|
||||
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
|
||||
// Initialize
|
||||
absl::flat_hash_map<std::string, void*> object_map;
|
||||
object_map[kDefaultValueName] = nullptr;
|
||||
EXPECT_FALSE(stage_ptr->Init(object_map));
|
||||
}
|
||||
|
||||
TEST(EvaluationStage, NoInputs) {
|
||||
// Construct
|
||||
EvaluationStageConfig config = GetIdentityStageConfig();
|
||||
std::unique_ptr<EvaluationStage> stage_ptr =
|
||||
CreateEvaluationStageFromConfig(config);
|
||||
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
|
||||
// Initialize
|
||||
absl::flat_hash_map<std::string, void*> object_map;
|
||||
DataType input_type = tensorflow::DT_FLOAT;
|
||||
object_map[kInputTypeName] = &input_type;
|
||||
float default_value = 1.0;
|
||||
object_map[kDefaultValueName] = &default_value;
|
||||
EXPECT_TRUE(stage_ptr->Init(object_map));
|
||||
|
||||
// Run
|
||||
EvaluationStageMetrics metrics;
|
||||
EXPECT_FALSE(stage_ptr->Run(object_map, metrics));
|
||||
EXPECT_FALSE(stage_ptr->Run(object_map));
|
||||
}
|
||||
|
||||
TEST(EvaluationStage, ExpectedIdentityOutput) {
|
||||
// Construct
|
||||
EvaluationStageConfig config = GetIdentityStageConfig();
|
||||
std::unique_ptr<EvaluationStage> stage_ptr =
|
||||
CreateEvaluationStageFromConfig(config);
|
||||
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
|
||||
// Initialize
|
||||
absl::flat_hash_map<std::string, void*> object_map;
|
||||
DataType input_type = tensorflow::DT_FLOAT;
|
||||
object_map[kInputTypeName] = &input_type;
|
||||
float default_value = 1.0;
|
||||
object_map[kDefaultValueName] = &default_value;
|
||||
EXPECT_TRUE(stage_ptr->Init(object_map));
|
||||
|
||||
// Input Data
|
||||
float float_value = 5.6f;
|
||||
Tensor input_tensor(float_value);
|
||||
std::vector<Tensor> input_tensors = {input_tensor};
|
||||
float input_value = 2.0f;
|
||||
// Run
|
||||
object_map[kInputTensorsName] = &input_tensors;
|
||||
EvaluationStageMetrics metrics;
|
||||
EXPECT_TRUE(stage_ptr->Run(object_map, metrics));
|
||||
|
||||
object_map[kInputValueName] = &input_value;
|
||||
EXPECT_TRUE(stage_ptr->Run(object_map));
|
||||
EvaluationStageMetrics metrics = stage_ptr->LatestMetrics();
|
||||
// Check output
|
||||
std::vector<Tensor>* output_tensors_ptr =
|
||||
static_cast<std::vector<Tensor>*>(object_map[kOutputTensorsName]);
|
||||
EXPECT_TRUE(output_tensors_ptr != nullptr);
|
||||
EXPECT_FLOAT_EQ(output_tensors_ptr->at(0).scalar<float>()(), float_value);
|
||||
EXPECT_GE(metrics.total_latency_ms(), 0);
|
||||
float* output_value_ptr = static_cast<float*>(object_map[kOutputValueName]);
|
||||
EXPECT_NE(output_value_ptr, nullptr);
|
||||
EXPECT_FLOAT_EQ(*output_value_ptr, input_value);
|
||||
EXPECT_EQ(metrics.num_runs(), 1);
|
||||
|
||||
// Input Data
|
||||
input_value = 0.0f;
|
||||
// Run
|
||||
object_map[kInputValueName] = &input_value;
|
||||
EXPECT_TRUE(stage_ptr->Run(object_map));
|
||||
metrics = stage_ptr->LatestMetrics();
|
||||
// Check output
|
||||
output_value_ptr = static_cast<float*>(object_map[kOutputValueName]);
|
||||
EXPECT_NE(output_value_ptr, nullptr);
|
||||
EXPECT_FLOAT_EQ(*output_value_ptr, default_value);
|
||||
EXPECT_EQ(metrics.num_runs(), 2);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -14,78 +14,42 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/tools/evaluation/identity_stage.h"
|
||||
|
||||
#include <ctime>
|
||||
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace evaluation {
|
||||
|
||||
using ::tensorflow::Scope;
|
||||
using ::tensorflow::SessionOptions;
|
||||
using ::tensorflow::Tensor;
|
||||
using ::tensorflow::ops::Identity;
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
IdentityStage::IdentityStage(const EvaluationStageConfig& config)
|
||||
: EvaluationStage(config) {
|
||||
stage_input_name_ = config_.name() + "_identity_input";
|
||||
stage_output_name_ = config_.name() + "_identity_output";
|
||||
}
|
||||
|
||||
bool IdentityStage::DoInit(
|
||||
absl::flat_hash_map<std::string, void*>& object_map) {
|
||||
// Initialize TF Graph.
|
||||
const Scope scope = Scope::NewRootScope();
|
||||
if (!GetObjectFromTag(kInputTypeTag, object_map, &input_type_)) {
|
||||
float* default_value;
|
||||
if (!GetObjectFromTag(kDefaultValueTag, object_map, &default_value)) {
|
||||
return false;
|
||||
}
|
||||
auto input_placeholder =
|
||||
Placeholder(scope.WithOpName(stage_input_name_), *input_type_);
|
||||
stage_output_ =
|
||||
Identity(scope.WithOpName(stage_output_name_), input_placeholder);
|
||||
if (!scope.status().ok() || !scope.ToGraphDef(&graph_def_).ok()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Initialize TF Session.
|
||||
session_.reset(NewSession(SessionOptions()));
|
||||
if (!session_->Create(graph_def_).ok()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
default_value_ = *default_value;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IdentityStage::Run(absl::flat_hash_map<std::string, void*>& object_map,
|
||||
EvaluationStageMetrics& metrics) {
|
||||
std::vector<Tensor>* input_tensors;
|
||||
if (!GetObjectFromTag(kInputTensorsTag, object_map, &input_tensors)) {
|
||||
return false;
|
||||
}
|
||||
tensor_outputs_.clear();
|
||||
// TODO(b/122482115): Encapsulate timing into its own helper.
|
||||
std::clock_t start = std::clock();
|
||||
if (!session_
|
||||
->Run({{stage_input_name_, input_tensors->at(0)}},
|
||||
{stage_output_name_}, {}, &tensor_outputs_)
|
||||
.ok()) {
|
||||
return false;
|
||||
}
|
||||
metrics.set_total_latency_ms(
|
||||
static_cast<float>((std::clock() - start) / (CLOCKS_PER_SEC / 1000)));
|
||||
|
||||
if (!AssignObjectToTag(kOutputTensorsTag, &tensor_outputs_, object_map)) {
|
||||
return false;
|
||||
}
|
||||
bool IdentityStage::Run(absl::flat_hash_map<std::string, void*>& object_map) {
|
||||
float* current_value;
|
||||
GET_OBJECT(kInputValueTag, object_map, ¤t_value);
|
||||
current_value_ = *current_value ? *current_value : default_value_;
|
||||
ASSIGN_OBJECT(kOutputValueTag, ¤t_value_, object_map);
|
||||
++num_runs_;
|
||||
return true;
|
||||
}
|
||||
|
||||
const char IdentityStage::kInputTypeTag[] = "INPUT_TYPE";
|
||||
const char IdentityStage::kInputTensorsTag[] = "INPUT_TENSORS";
|
||||
const char IdentityStage::kOutputTensorsTag[] = "OUTPUT_TENSORS";
|
||||
EvaluationStageMetrics IdentityStage::LatestMetrics() {
|
||||
EvaluationStageMetrics metrics;
|
||||
metrics.set_num_runs(num_runs_);
|
||||
return metrics;
|
||||
}
|
||||
|
||||
const char IdentityStage::kDefaultValueTag[] = "DEFAULT_VALUE";
|
||||
const char IdentityStage::kInputValueTag[] = "INPUT_VALUE";
|
||||
const char IdentityStage::kOutputValueTag[] = "OUTPUT_VALUE";
|
||||
|
||||
DEFINE_FACTORY(IdentityStage, IDENTITY);
|
||||
|
||||
} // namespace evaluation
|
||||
} // namespace tflite
|
||||
|
@ -16,26 +16,25 @@ limitations under the License.
|
||||
#define TENSORFLOW_LITE_TOOLS_EVALUATION_IDENTITY_STAGE_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
|
||||
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace evaluation {
|
||||
|
||||
// Simple EvaluationStage subclass that encapsulates the functionality of
|
||||
// tensorflow::ops::Identity. Primarily used for tests.
|
||||
// Initializer TAGs (Object Class): INPUT_TYPE (DataType)
|
||||
// Input TAGs (Object Class): INPUT_TENSORS (std::vector<Tensor>)
|
||||
// Output TAGs (Object Class): OUTPUT_TENSORS (std::vector<Tensor>)
|
||||
// TODO(b/122482115): Migrate common TF-related code into an abstract class.
|
||||
// Simple EvaluationStage that passes INPUT_VALUE to OUTPUT_VALUE if the former
|
||||
// is non-zero, DEFAULT_VALUE otherwise. Primarily used for tests.
|
||||
// Initializer TAGs (Object Class): DEFAULT_VALUE (float)
|
||||
// Input TAGs (Object Class): INPUT_VALUE (float)
|
||||
// Output TAGs (Object Class): OUTPUT_VALUE (float)
|
||||
class IdentityStage : public EvaluationStage {
|
||||
public:
|
||||
explicit IdentityStage(const EvaluationStageConfig& config);
|
||||
explicit IdentityStage(const EvaluationStageConfig& config)
|
||||
: EvaluationStage(config) {}
|
||||
|
||||
bool Run(absl::flat_hash_map<std::string, void*>& object_map,
|
||||
EvaluationStageMetrics& metrics) override;
|
||||
bool Run(absl::flat_hash_map<std::string, void*>& object_map) override;
|
||||
|
||||
EvaluationStageMetrics LatestMetrics() override;
|
||||
|
||||
~IdentityStage() {}
|
||||
|
||||
@ -43,29 +42,25 @@ class IdentityStage : public EvaluationStage {
|
||||
bool DoInit(absl::flat_hash_map<std::string, void*>& object_map) override;
|
||||
|
||||
std::vector<std::string> GetInitializerTags() override {
|
||||
return {kInputTypeTag};
|
||||
}
|
||||
std::vector<std::string> GetInputTags() override {
|
||||
return {kInputTensorsTag};
|
||||
return {kDefaultValueTag};
|
||||
}
|
||||
std::vector<std::string> GetInputTags() override { return {kInputValueTag}; }
|
||||
std::vector<std::string> GetOutputTags() override {
|
||||
return {kOutputTensorsTag};
|
||||
return {kOutputValueTag};
|
||||
}
|
||||
|
||||
private:
|
||||
::tensorflow::DataType* input_type_;
|
||||
::tensorflow::GraphDef graph_def_;
|
||||
::tensorflow::Output stage_output_;
|
||||
std::unique_ptr<::tensorflow::Session> session_;
|
||||
std::vector<::tensorflow::Tensor> tensor_outputs_;
|
||||
std::string stage_input_name_;
|
||||
std::string stage_output_name_;
|
||||
float default_value_ = 0;
|
||||
float current_value_ = 0;
|
||||
int num_runs_ = 0;
|
||||
|
||||
static const char kInputTypeTag[];
|
||||
static const char kInputTensorsTag[];
|
||||
static const char kOutputTensorsTag[];
|
||||
static const char kDefaultValueTag[];
|
||||
static const char kInputValueTag[];
|
||||
static const char kOutputValueTag[];
|
||||
};
|
||||
|
||||
DECLARE_FACTORY(IdentityStage);
|
||||
|
||||
} // namespace evaluation
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -34,15 +34,19 @@ message EvaluationStageConfig {
|
||||
// Example mapping: "BITMAP1:image_in"
|
||||
// It is up to individual EvaluationStage sub-classes to specify the
|
||||
// initializer/input TAGs they require, and outputs TAGs they provide.
|
||||
// For more information: go/mlperflite-framework#heading=h.fxpk50cps4zs
|
||||
repeated string initializers = 3;
|
||||
repeated string inputs = 4;
|
||||
repeated string outputs = 5;
|
||||
}
|
||||
|
||||
// Metrics returned from EvaluationStage.LatestMetrics() need not have all
|
||||
// fields set.
|
||||
message EvaluationStageMetrics {
|
||||
// Total latency in ms.
|
||||
optional double total_latency_ms = 1;
|
||||
// Total number of times the EvaluationStage is run.
|
||||
// Aka number of calls to EvaluationStage::Run().
|
||||
optional int32 num_runs = 1;
|
||||
|
||||
// Process-specific numbers such as accuracy, step-latencies, etc.
|
||||
// Process-specific numbers such as latencies, accuracy, etc.
|
||||
optional ProcessMetrics process_metrics = 2;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user