Internal Change

PiperOrigin-RevId: 237073663
This commit is contained in:
Sachin Joglekar 2019-03-06 10:28:48 -08:00 committed by TensorFlower Gardener
parent 7a415783c6
commit e70b131052
8 changed files with 227 additions and 249 deletions

View File

@ -34,23 +34,12 @@ cc_library(
hdrs = ["evaluation_stage.h"],
copts = tflite_copts(),
deps = [
"//tensorflow/core:lib",
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_proto_cc",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "evaluation_stage_factory",
hdrs = ["evaluation_stage_factory.h"],
copts = tflite_copts(),
deps = [
":evaluation_stage",
":identity_stage",
"//tensorflow/core:tflite_portable_logging",
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_proto_cc",
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_proto_cc",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)
@ -61,26 +50,10 @@ cc_library(
copts = tflite_copts(),
deps = [
":evaluation_stage",
"@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:scope",
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_proto_cc",
] + select(
{
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib",
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_proto_cc",
"@com_google_absl//absl/container:flat_hash_map",
],
"//conditions:default": [
"//tensorflow/core:tensorflow",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:ops",
],
},
),
alwayslink = 1,
)
tf_cc_test(
@ -88,31 +61,12 @@ tf_cc_test(
srcs = ["evaluation_stage_test.cc"],
linkopts = common_linkopts,
linkstatic = 1,
tags = [
"tflite_not_portable_android",
"tflite_not_portable_ios",
],
deps = [
":evaluation_stage",
":evaluation_stage_factory",
":identity_stage",
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/core:protos_all_cc",
"//tensorflow/lite/tools/evaluation/proto:evaluation_config_proto_cc",
"//tensorflow/lite/tools/evaluation/proto:evaluation_stages_proto_cc",
] + select(
{
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib",
"//tensorflow/core:android_tensorflow_test_lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_googletest//:gtest_main",
],
"//conditions:default": [
"//tensorflow/core:framework",
"//tensorflow/core:core_cpu",
"//tensorflow/core:ops",
"//tensorflow/core:tensorflow",
],
},
),
)

View File

@ -73,5 +73,9 @@ bool EvaluationStage::ProcessExpectedTags(
return true;
}
std::map<ProcessClass, FactoryFunc>*
EvaluationStage::process_class_to_factory_map_ =
new std::map<ProcessClass, FactoryFunc>();
} // namespace evaluation
} // namespace tflite

View File

@ -15,16 +15,27 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_H_
#define TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_H_
#include <functional>
#include <map>
#include <regex> // NOLINT
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
namespace tflite {
namespace evaluation {
class EvaluationStage;
typedef std::function<std::unique_ptr<EvaluationStage>(
const EvaluationStageConfig&)>
FactoryFunc;
// Superclass for a single stage of an EvaluationPipeline.
// Provides basic functionality for construction and accessing
// initializers/inputs/outputs.
@ -43,17 +54,42 @@ class EvaluationStage {
bool Init(absl::flat_hash_map<std::string, void*>& object_map);
// An individual run of the EvaluationStage. Returns false if there was a
// failure, true otherwise. Populates metrics into the EvaluationStageMetrics
// proto.
// failure, true otherwise.
// Init() should be called before any calls to run().
// Inputs are acquired from and outputs are written to the incoming
// object_map, using appropriate TAGs.
//
// NOTE: The EvaluationStage should maintain ownership of outputs it
// populates into object_map. Ownership of inputs will be maintained
// populates into object_map. Ownership of inputs must be maintained
// elsewhere.
virtual bool Run(absl::flat_hash_map<std::string, void*>& object_map,
EvaluationStageMetrics& metrics) = 0;
virtual bool Run(absl::flat_hash_map<std::string, void*>& object_map) = 0;
// Returns the latest metrics based on all Run() calls made so far.
virtual EvaluationStageMetrics LatestMetrics() = 0;
// The canonical way to instantiate EvaluationStages.
// Remember to call <classname>_ENABLE() first.
static std::unique_ptr<EvaluationStage> Create(
const EvaluationStageConfig& config) {
if (!config.has_specification() ||
!config.specification().has_process_class()) {
LOG(ERROR) << "Process specification not present in config: "
<< config.name();
return nullptr;
}
auto& factory_ptr =
(*GetFactoryMapPtr())[config.specification().process_class()];
if (!factory_ptr) return nullptr;
return factory_ptr(config);
}
// Used by DEFINE_REGISTRATION.
// This method takes ownership of factory.
// Should only be used via DEFINE_REGISTRATION macro.
static void RegisterStage(const ProcessClass& process_class,
FactoryFunc class_factory) {
(*GetFactoryMapPtr())[process_class] = std::move(class_factory);
}
virtual ~EvaluationStage() = default;
@ -62,8 +98,7 @@ class EvaluationStage {
// Each subclass constructor must invoke this constructor.
//
// NOTE: Do NOT use constructors to obtain new EvaluationStages. Use
// tflite::evaluation::GetEvaluationStageFromConfig from
// evaluation_stage_factory.h instead.
// EvaluationStage::Create instead.
explicit EvaluationStage(const EvaluationStageConfig& config)
: config_(config) {}
@ -90,8 +125,9 @@ class EvaluationStage {
// Populates a pointer to the object corresponding to provided TAG.
// Returns true if success, false otherwise.
// object_map must contain {name : object pointer} mappings, with one of the
// names being mapped to the expected TAG in the EvaluationStageConfig.
// object_map contain a {name : object pointer} mapping, with the
// name being mapped to the expected TAG in the EvaluationStageConfig.
// NOTE: object pointer must be non-NULL.
template <class T>
bool GetObjectFromTag(const std::string& tag,
absl::flat_hash_map<std::string, void*>& object_map,
@ -100,7 +136,7 @@ class EvaluationStage {
// Find name corresponding to TAG.
auto mapping_iter = tags_to_names_map_.find(tag);
if (mapping_iter == tags_to_names_map_.end()) {
LOG(ERROR) << "Unexpected TAG: " << tag;
LOG(ERROR) << "Unexpected TAG to GetObjectFromTag: " << tag;
return false;
}
const std::string& expected_name = mapping_iter->second;
@ -111,6 +147,10 @@ class EvaluationStage {
LOG(ERROR) << "Could not find object for name: " << expected_name;
return false;
}
if (!object_iter->second) {
LOG(ERROR) << "Found null pointer for name: " << expected_name;
return false;
}
*object_ptr = static_cast<T*>(object_iter->second);
return true;
}
@ -126,7 +166,7 @@ class EvaluationStage {
// Find name corresponding to TAG.
auto mapping_iter = tags_to_names_map_.find(tag);
if (mapping_iter == tags_to_names_map_.end()) {
LOG(ERROR) << "Unexpected TAG: " << tag;
LOG(ERROR) << "Unexpected TAG to AssignObjectToTag: " << tag;
return false;
}
const std::string& expected_name = mapping_iter->second;
@ -135,7 +175,7 @@ class EvaluationStage {
return true;
}
const EvaluationStageConfig config_;
EvaluationStageConfig config_;
private:
// Verifies that all TAGs from expected_tags are present in
@ -148,6 +188,13 @@ class EvaluationStage {
bool ProcessExpectedTags(const std::vector<std::string>& expected_tags,
std::vector<std::string>& tag_to_name_mappings);
static std::map<ProcessClass, FactoryFunc>* GetFactoryMapPtr() {
return process_class_to_factory_map_;
}
// Used by factories.
static std::map<ProcessClass, FactoryFunc>* process_class_to_factory_map_;
// Maps expected TAGs to their names as defined by the EvaluationStageConfig.
absl::flat_hash_map<std::string, std::string> tags_to_names_map_;
@ -159,6 +206,34 @@ class EvaluationStage {
const std::regex kTagPattern{"^[A-Z0-9_]+$", std::regex::optimize};
};
// Add this to headers of new EvaluationStages.
#define DECLARE_FACTORY(classname) void classname##_ENABLE();
// Add this to implementation files of new EvaluationStages.
// Call <stage_name>_ENABLE() before using EvaluationStage::Create for the
// class.
#define DEFINE_FACTORY(classname, processclass) \
void classname##_ENABLE() { \
FactoryFunc classname##Factory = [](const EvaluationStageConfig& config) { \
return absl::make_unique<classname>(config); \
}; \
EvaluationStage::RegisterStage(processclass, classname##Factory); \
}
// Use this to assign a non-nullptr pointer to tag in object_map.
#define ASSIGN_OBJECT(tag, ptr, object_map) \
if (!AssignObjectToTag(tag, ptr, object_map)) { \
return false; \
}
// Use this to obtain pointers to required object.
// Will return false if name corresponding to tag is not found, or if the
// pointer found is nullptr.
#define GET_OBJECT(tag, object_map, location) \
if (!GetObjectFromTag(tag, object_map, location)) { \
return false; \
}
} // namespace evaluation
} // namespace tflite

View File

@ -1,48 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_FACTORY_H_
#define TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_FACTORY_H_
#include "absl/memory/memory.h"
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
#include "tensorflow/lite/tools/evaluation/identity_stage.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
namespace tflite {
namespace evaluation {
// The canonical way to generate EvaluationStages.
// TODO(b/122482115): Implement a Factory class for registration of classes.
std::unique_ptr<EvaluationStage> CreateEvaluationStageFromConfig(
const EvaluationStageConfig& config) {
if (!config.has_specification() ||
!config.specification().has_process_class()) {
LOG(ERROR) << "Process specification not present in config: "
<< config.name();
return nullptr;
}
switch (config.specification().process_class()) {
case UNKNOWN:
return nullptr;
case IDENTITY:
return absl::make_unique<IdentityStage>(config);
}
}
} // namespace evaluation
} // namespace tflite
#endif // TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_FACTORY_H_

View File

@ -18,9 +18,6 @@ limitations under the License.
#include <gtest/gtest.h>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/lite/tools/evaluation/evaluation_stage_factory.h"
#include "tensorflow/lite/tools/evaluation/identity_stage.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
@ -29,18 +26,16 @@ namespace tflite {
namespace evaluation {
namespace {
using ::tensorflow::DataType;
using ::tensorflow::Tensor;
constexpr char kIdentityStageName[] = "identity_stage";
constexpr char kInputTypeName[] = "type";
constexpr char kInputTensorsName[] = "in";
constexpr char kOutputTensorsName[] = "out";
constexpr char kInitializerMapping[] = "INPUT_TYPE:type";
constexpr char kInputMapping[] = "INPUT_TENSORS:in";
constexpr char kOutputMapping[] = "OUTPUT_TENSORS:out";
constexpr char kDefaultValueName[] = "default";
constexpr char kInputValueName[] = "in";
constexpr char kOutputValueName[] = "out";
constexpr char kInitializerMapping[] = "DEFAULT_VALUE:default";
constexpr char kInputMapping[] = "INPUT_VALUE:in";
constexpr char kOutputMapping[] = "OUTPUT_VALUE:out";
EvaluationStageConfig GetIdentityStageConfig() {
IdentityStage_ENABLE();
EvaluationStageConfig config;
config.set_name(kIdentityStageName);
config.mutable_specification()->set_process_class(IDENTITY);
@ -50,16 +45,39 @@ EvaluationStageConfig GetIdentityStageConfig() {
return config;
}
TEST(EvaluationStage, CreateFailsForMissingSpecification) {
// Construct
EvaluationStageConfig config;
config.set_name(kIdentityStageName);
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
EXPECT_EQ(stage_ptr, nullptr);
}
TEST(EvaluationStage, StageEnableRequired) {
// Construct
EvaluationStageConfig config;
config.set_name(kIdentityStageName);
config.mutable_specification()->set_process_class(IDENTITY);
config.add_initializers(kInitializerMapping);
config.add_inputs(kInputMapping);
config.add_outputs(kOutputMapping);
config.clear_inputs();
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
EXPECT_EQ(stage_ptr, nullptr);
IdentityStage_ENABLE();
stage_ptr = EvaluationStage::Create(config);
EXPECT_NE(stage_ptr, nullptr);
}
TEST(EvaluationStage, IncompleteConfig) {
// Construct
EvaluationStageConfig config = GetIdentityStageConfig();
config.clear_inputs();
std::unique_ptr<EvaluationStage> stage_ptr =
CreateEvaluationStageFromConfig(config);
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
// Initialize
absl::flat_hash_map<std::string, void*> object_map;
DataType input_type = tensorflow::DT_FLOAT;
object_map[kInputTypeName] = &input_type;
float default_value = 1.0;
object_map[kDefaultValueName] = &default_value;
EXPECT_FALSE(stage_ptr->Init(object_map));
}
@ -67,13 +85,12 @@ TEST(EvaluationStage, IncorrectlyFormattedConfig) {
// Construct
EvaluationStageConfig config = GetIdentityStageConfig();
config.clear_initializers();
config.add_initializers("INPUT_TYPE-type");
std::unique_ptr<EvaluationStage> stage_ptr =
CreateEvaluationStageFromConfig(config);
config.add_initializers("DEFAULT_VALUE-default");
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
// Initialize
absl::flat_hash_map<std::string, void*> object_map;
DataType input_type = tensorflow::DT_FLOAT;
object_map[kInputTypeName] = &input_type;
float default_value = 1.0;
object_map[kDefaultValueName] = &default_value;
EXPECT_FALSE(stage_ptr->Init(object_map));
}
@ -81,63 +98,76 @@ TEST(EvaluationStage, ConstructFromConfig_UnknownProcess) {
// Construct
EvaluationStageConfig config = GetIdentityStageConfig();
config.mutable_specification()->clear_process_class();
std::unique_ptr<EvaluationStage> stage_ptr =
CreateEvaluationStageFromConfig(config);
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
EXPECT_EQ(stage_ptr.get(), nullptr);
}
TEST(EvaluationStage, NoInitializer) {
// Construct
EvaluationStageConfig config = GetIdentityStageConfig();
std::unique_ptr<EvaluationStage> stage_ptr =
CreateEvaluationStageFromConfig(config);
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
// Initialize
absl::flat_hash_map<std::string, void*> object_map;
EXPECT_FALSE(stage_ptr->Init(object_map));
}
TEST(EvaluationStage, InvalidInitializer) {
// Construct
EvaluationStageConfig config = GetIdentityStageConfig();
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
// Initialize
absl::flat_hash_map<std::string, void*> object_map;
object_map[kDefaultValueName] = nullptr;
EXPECT_FALSE(stage_ptr->Init(object_map));
}
TEST(EvaluationStage, NoInputs) {
// Construct
EvaluationStageConfig config = GetIdentityStageConfig();
std::unique_ptr<EvaluationStage> stage_ptr =
CreateEvaluationStageFromConfig(config);
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
// Initialize
absl::flat_hash_map<std::string, void*> object_map;
DataType input_type = tensorflow::DT_FLOAT;
object_map[kInputTypeName] = &input_type;
float default_value = 1.0;
object_map[kDefaultValueName] = &default_value;
EXPECT_TRUE(stage_ptr->Init(object_map));
// Run
EvaluationStageMetrics metrics;
EXPECT_FALSE(stage_ptr->Run(object_map, metrics));
EXPECT_FALSE(stage_ptr->Run(object_map));
}
TEST(EvaluationStage, ExpectedIdentityOutput) {
// Construct
EvaluationStageConfig config = GetIdentityStageConfig();
std::unique_ptr<EvaluationStage> stage_ptr =
CreateEvaluationStageFromConfig(config);
std::unique_ptr<EvaluationStage> stage_ptr = EvaluationStage::Create(config);
// Initialize
absl::flat_hash_map<std::string, void*> object_map;
DataType input_type = tensorflow::DT_FLOAT;
object_map[kInputTypeName] = &input_type;
float default_value = 1.0;
object_map[kDefaultValueName] = &default_value;
EXPECT_TRUE(stage_ptr->Init(object_map));
// Input Data
float float_value = 5.6f;
Tensor input_tensor(float_value);
std::vector<Tensor> input_tensors = {input_tensor};
float input_value = 2.0f;
// Run
object_map[kInputTensorsName] = &input_tensors;
EvaluationStageMetrics metrics;
EXPECT_TRUE(stage_ptr->Run(object_map, metrics));
object_map[kInputValueName] = &input_value;
EXPECT_TRUE(stage_ptr->Run(object_map));
EvaluationStageMetrics metrics = stage_ptr->LatestMetrics();
// Check output
std::vector<Tensor>* output_tensors_ptr =
static_cast<std::vector<Tensor>*>(object_map[kOutputTensorsName]);
EXPECT_TRUE(output_tensors_ptr != nullptr);
EXPECT_FLOAT_EQ(output_tensors_ptr->at(0).scalar<float>()(), float_value);
EXPECT_GE(metrics.total_latency_ms(), 0);
float* output_value_ptr = static_cast<float*>(object_map[kOutputValueName]);
EXPECT_NE(output_value_ptr, nullptr);
EXPECT_FLOAT_EQ(*output_value_ptr, input_value);
EXPECT_EQ(metrics.num_runs(), 1);
// Input Data
input_value = 0.0f;
// Run
object_map[kInputValueName] = &input_value;
EXPECT_TRUE(stage_ptr->Run(object_map));
metrics = stage_ptr->LatestMetrics();
// Check output
output_value_ptr = static_cast<float*>(object_map[kOutputValueName]);
EXPECT_NE(output_value_ptr, nullptr);
EXPECT_FLOAT_EQ(*output_value_ptr, default_value);
EXPECT_EQ(metrics.num_runs(), 2);
}
} // namespace

View File

@ -14,78 +14,42 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/evaluation/identity_stage.h"
#include <ctime>
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
namespace tflite {
namespace evaluation {
using ::tensorflow::Scope;
using ::tensorflow::SessionOptions;
using ::tensorflow::Tensor;
using ::tensorflow::ops::Identity;
using ::tensorflow::ops::Placeholder;
IdentityStage::IdentityStage(const EvaluationStageConfig& config)
: EvaluationStage(config) {
stage_input_name_ = config_.name() + "_identity_input";
stage_output_name_ = config_.name() + "_identity_output";
}
bool IdentityStage::DoInit(
absl::flat_hash_map<std::string, void*>& object_map) {
// Initialize TF Graph.
const Scope scope = Scope::NewRootScope();
if (!GetObjectFromTag(kInputTypeTag, object_map, &input_type_)) {
float* default_value;
if (!GetObjectFromTag(kDefaultValueTag, object_map, &default_value)) {
return false;
}
auto input_placeholder =
Placeholder(scope.WithOpName(stage_input_name_), *input_type_);
stage_output_ =
Identity(scope.WithOpName(stage_output_name_), input_placeholder);
if (!scope.status().ok() || !scope.ToGraphDef(&graph_def_).ok()) {
return false;
}
// Initialize TF Session.
session_.reset(NewSession(SessionOptions()));
if (!session_->Create(graph_def_).ok()) {
return false;
}
default_value_ = *default_value;
return true;
}
bool IdentityStage::Run(absl::flat_hash_map<std::string, void*>& object_map,
EvaluationStageMetrics& metrics) {
std::vector<Tensor>* input_tensors;
if (!GetObjectFromTag(kInputTensorsTag, object_map, &input_tensors)) {
return false;
}
tensor_outputs_.clear();
// TODO(b/122482115): Encapsulate timing into its own helper.
std::clock_t start = std::clock();
if (!session_
->Run({{stage_input_name_, input_tensors->at(0)}},
{stage_output_name_}, {}, &tensor_outputs_)
.ok()) {
return false;
}
metrics.set_total_latency_ms(
static_cast<float>((std::clock() - start) / (CLOCKS_PER_SEC / 1000)));
if (!AssignObjectToTag(kOutputTensorsTag, &tensor_outputs_, object_map)) {
return false;
}
bool IdentityStage::Run(absl::flat_hash_map<std::string, void*>& object_map) {
float* current_value;
GET_OBJECT(kInputValueTag, object_map, &current_value);
current_value_ = *current_value ? *current_value : default_value_;
ASSIGN_OBJECT(kOutputValueTag, &current_value_, object_map);
++num_runs_;
return true;
}
const char IdentityStage::kInputTypeTag[] = "INPUT_TYPE";
const char IdentityStage::kInputTensorsTag[] = "INPUT_TENSORS";
const char IdentityStage::kOutputTensorsTag[] = "OUTPUT_TENSORS";
EvaluationStageMetrics IdentityStage::LatestMetrics() {
EvaluationStageMetrics metrics;
metrics.set_num_runs(num_runs_);
return metrics;
}
const char IdentityStage::kDefaultValueTag[] = "DEFAULT_VALUE";
const char IdentityStage::kInputValueTag[] = "INPUT_VALUE";
const char IdentityStage::kOutputValueTag[] = "OUTPUT_VALUE";
DEFINE_FACTORY(IdentityStage, IDENTITY);
} // namespace evaluation
} // namespace tflite

View File

@ -16,26 +16,25 @@ limitations under the License.
#define TENSORFLOW_LITE_TOOLS_EVALUATION_IDENTITY_STAGE_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
#include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
namespace tflite {
namespace evaluation {
// Simple EvaluationStage subclass that encapsulates the functionality of
// tensorflow::ops::Identity. Primarily used for tests.
// Initializer TAGs (Object Class): INPUT_TYPE (DataType)
// Input TAGs (Object Class): INPUT_TENSORS (std::vector<Tensor>)
// Output TAGs (Object Class): OUTPUT_TENSORS (std::vector<Tensor>)
// TODO(b/122482115): Migrate common TF-related code into an abstract class.
// Simple EvaluationStage that passes INPUT_VALUE to OUTPUT_VALUE if the former
// is non-zero, DEFAULT_VALUE otherwise. Primarily used for tests.
// Initializer TAGs (Object Class): DEFAULT_VALUE (float)
// Input TAGs (Object Class): INPUT_VALUE (float)
// Output TAGs (Object Class): OUTPUT_VALUE (float)
class IdentityStage : public EvaluationStage {
public:
explicit IdentityStage(const EvaluationStageConfig& config);
explicit IdentityStage(const EvaluationStageConfig& config)
: EvaluationStage(config) {}
bool Run(absl::flat_hash_map<std::string, void*>& object_map,
EvaluationStageMetrics& metrics) override;
bool Run(absl::flat_hash_map<std::string, void*>& object_map) override;
EvaluationStageMetrics LatestMetrics() override;
~IdentityStage() {}
@ -43,29 +42,25 @@ class IdentityStage : public EvaluationStage {
bool DoInit(absl::flat_hash_map<std::string, void*>& object_map) override;
std::vector<std::string> GetInitializerTags() override {
return {kInputTypeTag};
}
std::vector<std::string> GetInputTags() override {
return {kInputTensorsTag};
return {kDefaultValueTag};
}
std::vector<std::string> GetInputTags() override { return {kInputValueTag}; }
std::vector<std::string> GetOutputTags() override {
return {kOutputTensorsTag};
return {kOutputValueTag};
}
private:
::tensorflow::DataType* input_type_;
::tensorflow::GraphDef graph_def_;
::tensorflow::Output stage_output_;
std::unique_ptr<::tensorflow::Session> session_;
std::vector<::tensorflow::Tensor> tensor_outputs_;
std::string stage_input_name_;
std::string stage_output_name_;
float default_value_ = 0;
float current_value_ = 0;
int num_runs_ = 0;
static const char kInputTypeTag[];
static const char kInputTensorsTag[];
static const char kOutputTensorsTag[];
static const char kDefaultValueTag[];
static const char kInputValueTag[];
static const char kOutputValueTag[];
};
DECLARE_FACTORY(IdentityStage);
} // namespace evaluation
} // namespace tflite

View File

@ -34,15 +34,19 @@ message EvaluationStageConfig {
// Example mapping: "BITMAP1:image_in"
// It is up to individual EvaluationStage sub-classes to specify the
// initializer/input TAGs they require, and outputs TAGs they provide.
// For more information: go/mlperflite-framework#heading=h.fxpk50cps4zs
repeated string initializers = 3;
repeated string inputs = 4;
repeated string outputs = 5;
}
// Metrics returned from EvaluationStage.LatestMetrics() need not have all
// fields set.
message EvaluationStageMetrics {
// Total latency in ms.
optional double total_latency_ms = 1;
// Total number of times the EvaluationStage is run.
// Aka number of calls to EvaluationStage::Run().
optional int32 num_runs = 1;
// Process-specific numbers such as accuracy, step-latencies, etc.
// Process-specific numbers such as latencies, accuracy, etc.
optional ProcessMetrics process_metrics = 2;
}