One more refactoring of kernel.cc. This introduces OpInputs to match OpOutputs, which is

slightly enhanced.

PiperOrigin-RevId: 227908106
This commit is contained in:
A. Unique TensorFlower 2019-01-04 13:49:49 -08:00 committed by TensorFlower Gardener
parent cea2e9c4fa
commit 4614b04cc0

View File

@ -51,24 +51,57 @@ namespace tflite {
namespace flex { namespace flex {
namespace kernel { namespace kernel {
// Controls the lifetime of tensor handles in a vector. // A list of inputs of a given node of the TensorFlow/Eager graph.
class OpInputs {
public:
explicit OpInputs(const TfLiteIntArray* indexes) {
for (int index : TfLiteIntArrayView(indexes)) {
inputs_.push_back(index);
}
}
~OpInputs() {}
int Size() const { return inputs_.size(); }
int TfLiteIndex(int i) const { return inputs_[i]; }
private:
std::vector<int> inputs_;
};
// A list of outputs of a given node of the TensorFlow/Eager graph, along with
// the actual outputs of the EagerOperation.
class OpOutputs { class OpOutputs {
public: public:
explicit OpOutputs(int num_elements) : vector_(num_elements, nullptr) {} explicit OpOutputs(const TfLiteIntArray* indexes) {
for (int index : TfLiteIntArrayView(indexes)) {
outputs_.push_back(index);
}
vector_.resize(outputs_.size());
}
~OpOutputs() { ResetTensorHandles(); }
~OpOutputs() { int Size() const { return outputs_.size(); }
for (auto* handle : vector_) {
if (handle) handle->Unref(); int TfLiteIndex(int i) const { return outputs_[i]; }
// Carefully unreference all the handles in the eager output vector.
void ResetTensorHandles() {
for (int i = 0; i < vector_.size(); ++i) {
if (vector_[i]) {
vector_[i]->Unref();
vector_[i] = nullptr;
}
} }
} }
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2>* GetVector() { tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2>*
GetTensorHandles() {
return &vector_; return &vector_;
} }
tensorflow::TensorHandle* GetHandle(int index) { return vector_[index]; }
private: private:
std::vector<int> outputs_;
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> vector_; tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> vector_;
}; };
@ -76,7 +109,8 @@ class OpOutputs {
// TensorFlow ops within a single TF Lite op. // TensorFlow ops within a single TF Lite op.
class OpNode { class OpNode {
public: public:
OpNode() {} OpNode(const TfLiteIntArray* inputs, const TfLiteIntArray* outputs)
: inputs_(inputs), outputs_(outputs) {}
~OpNode() {} ~OpNode() {}
const string& name() const { return name_; } const string& name() const { return name_; }
@ -87,19 +121,14 @@ class OpNode {
const tensorflow::NodeDef& nodedef() const { return nodedef_; } const tensorflow::NodeDef& nodedef() const { return nodedef_; }
const std::vector<int>& inputs() const { return inputs_; } const OpInputs& inputs() const { return inputs_; }
void InitializeInputs(const TfLiteIntArray* inputs) { OpInputs* mutable_inputs() { return &inputs_; }
for (int index : TfLiteIntArrayView(inputs)) {
inputs_.push_back(index);
}
}
const std::vector<int>& outputs() const { return outputs_; } const OpOutputs& outputs() const { return outputs_; }
void InitializeOutputs(const TfLiteIntArray* outputs) { OpOutputs* mutable_outputs() { return &outputs_; }
for (int index : TfLiteIntArrayView(outputs)) {
outputs_.push_back(index); int NumInputs() const { return inputs_.Size(); }
} int NumOutputs() const { return outputs_.Size(); }
}
tensorflow::Status InitializeNodeDef(const void* custom_initial_data, tensorflow::Status InitializeNodeDef(const void* custom_initial_data,
int custom_initial_data_size) { int custom_initial_data_size) {
@ -162,7 +191,8 @@ class OpNode {
tensorflow::Status BuildEagerInputs(BufferMap* buffer_map, tensorflow::Status BuildEagerInputs(BufferMap* buffer_map,
tensorflow::EagerOperation* op) { tensorflow::EagerOperation* op) {
for (int input_index : inputs_) { for (int i = 0; i < inputs_.Size(); ++i) {
int input_index = inputs_.TfLiteIndex(i);
if (!buffer_map->HasTensor(input_index)) { if (!buffer_map->HasTensor(input_index)) {
return tensorflow::errors::Internal( return tensorflow::errors::Internal(
"Cannot read from invalid tensor index ", input_index); "Cannot read from invalid tensor index ", input_index);
@ -181,17 +211,21 @@ class OpNode {
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
tensorflow::Status PersistEagerOutputs(BufferMap* buffer_map, tensorflow::Status PersistEagerOutputs(BufferMap* buffer_map) {
OpOutputs* retvals) { auto* handles = outputs_.GetTensorHandles();
for (int i = 0; i < outputs_.size(); ++i) { for (int i = 0; i < outputs_.Size(); ++i) {
const tensorflow::Tensor* tensor = nullptr; const tensorflow::Tensor* tensor = nullptr;
TF_RETURN_IF_ERROR(retvals->GetHandle(i)->Tensor(&tensor)); TF_RETURN_IF_ERROR(handles->at(i)->Tensor(&tensor));
buffer_map->SetFromTensorFlow(outputs_[i], *tensor); buffer_map->SetFromTensorFlow(outputs_.TfLiteIndex(i), *tensor);
} }
outputs_.ResetTensorHandles();
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
private: private:
OpNode(const OpNode&) = delete;
OpNode& operator=(const OpNode&) = delete;
// The name of the TensorFlow op to execute. // The name of the TensorFlow op to execute.
string name_; string name_;
// Index of this node into TF Lite's operator list. // Index of this node into TF Lite's operator list.
@ -199,9 +233,9 @@ class OpNode {
// The corresponding NodeDef, containing the attributes for the op. // The corresponding NodeDef, containing the attributes for the op.
tensorflow::NodeDef nodedef_; tensorflow::NodeDef nodedef_;
// List of inputs, as TF Lite tensor indices. // List of inputs, as TF Lite tensor indices.
std::vector<int> inputs_; OpInputs inputs_;
// List of outputs, as TF Lite tensor indices. // List of outputs, as TF Lite tensor indices.
std::vector<int> outputs_; OpOutputs outputs_;
}; };
// Executes the TensorFlow op given by 'op_name', with the attributes specified // Executes the TensorFlow op given by 'op_name', with the attributes specified
@ -212,18 +246,18 @@ tensorflow::Status ExecuteFlexOp(tensorflow::EagerContext* eager_context,
TF_RETURN_IF_ERROR(node_data->BuildEagerOp(eager_context, &op)); TF_RETURN_IF_ERROR(node_data->BuildEagerOp(eager_context, &op));
TF_RETURN_IF_ERROR(node_data->BuildEagerInputs(buffer_map, op.get())); TF_RETURN_IF_ERROR(node_data->BuildEagerInputs(buffer_map, op.get()));
int num_retvals = node_data->outputs().size(); int num_retvals = node_data->NumOutputs();
OpOutputs retvals(num_retvals);
TF_RETURN_WITH_CONTEXT_IF_ERROR( TF_RETURN_WITH_CONTEXT_IF_ERROR(
EagerExecute(op.get(), retvals.GetVector(), &num_retvals), EagerExecute(op.get(), node_data->mutable_outputs()->GetTensorHandles(),
&num_retvals),
" (while executing '", node_data->name(), "' via Eager)"); " (while executing '", node_data->name(), "' via Eager)");
if (num_retvals != node_data->outputs().size()) { if (num_retvals != node_data->NumOutputs()) {
return tensorflow::errors::Internal( return tensorflow::errors::Internal(
"Unexpected number of outputs from EagerExecute"); "Unexpected number of outputs from EagerExecute");
} }
TF_RETURN_IF_ERROR(node_data->PersistEagerOutputs(buffer_map, &retvals)); TF_RETURN_IF_ERROR(node_data->PersistEagerOutputs(buffer_map));
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
@ -232,7 +266,7 @@ tensorflow::Status ExecuteFlexOp(tensorflow::EagerContext* eager_context,
struct OpData { struct OpData {
tensorflow::EagerContext* eager_context; tensorflow::EagerContext* eager_context;
BufferMap* buffer_map; BufferMap* buffer_map;
std::vector<OpNode> nodes; std::vector<std::unique_ptr<OpNode>> nodes;
std::vector<int> subgraph_inputs; std::vector<int> subgraph_inputs;
std::vector<int> subgraph_outputs; std::vector<int> subgraph_outputs;
}; };
@ -261,6 +295,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
op_data->subgraph_inputs.push_back(tensor_index); op_data->subgraph_inputs.push_back(tensor_index);
} }
op_data->nodes.reserve(params->nodes_to_replace->size);
CHECK(params->nodes_to_replace); CHECK(params->nodes_to_replace);
tensorflow::Status status; tensorflow::Status status;
for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) { for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
@ -268,8 +304,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TfLiteRegistration* reg; TfLiteRegistration* reg;
context->GetNodeAndRegistration(context, node_index, &node, &reg); context->GetNodeAndRegistration(context, node_index, &node, &reg);
op_data->nodes.push_back(OpNode()); op_data->nodes.emplace_back(new OpNode(node->inputs, node->outputs));
OpNode& node_data = op_data->nodes.back(); OpNode& node_data = *op_data->nodes.back();
node_data.set_index(node_index); node_data.set_index(node_index);
node_data.set_name(""); node_data.set_name("");
@ -277,9 +313,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
status = node_data.InitializeNodeDef(node->custom_initial_data, status = node_data.InitializeNodeDef(node->custom_initial_data,
node->custom_initial_data_size); node->custom_initial_data_size);
if (!status.ok()) break; if (!status.ok()) break;
node_data.InitializeInputs(node->inputs);
node_data.InitializeOutputs(node->outputs);
} }
if (ConvertStatus(context, status) != kTfLiteOk) { if (ConvertStatus(context, status) != kTfLiteOk) {
@ -328,14 +361,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} }
for (const auto& node_data : op_data->nodes) { for (const auto& node_data : op_data->nodes) {
if (node_data.nodedef().op().empty()) { if (node_data->nodedef().op().empty()) {
context->ReportError(context, "Invalid NodeDef in Flex op '%s'", context->ReportError(context, "Invalid NodeDef in Flex op '%s'",
node_data.name().c_str()); node_data->name().c_str());
return kTfLiteError; return kTfLiteError;
} }
for (int tensor_index : node_data.inputs()) { for (int i = 0; i < node_data->inputs().Size(); ++i) {
++tensor_ref_count[tensor_index]; ++tensor_ref_count[node_data->inputs().TfLiteIndex(i)];
} }
} }
@ -371,12 +404,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} }
// Execute the TensorFlow Ops sequentially. // Execute the TensorFlow Ops sequentially.
for (OpNode& node_data : op_data->nodes) { for (auto& node_data : op_data->nodes) {
SCOPED_TAGGED_OPERATOR_PROFILE( SCOPED_TAGGED_OPERATOR_PROFILE(
reinterpret_cast<profiling::Profiler*>(context->profiler), reinterpret_cast<profiling::Profiler*>(context->profiler),
node_data.name().c_str(), node_data.index()); node_data->name().c_str(), node_data->index());
auto status = ExecuteFlexOp(eager_context, buffer_map, &node_data); auto status = ExecuteFlexOp(eager_context, buffer_map, node_data.get());
TF_LITE_ENSURE_OK(context, ConvertStatus(context, status)); TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
} }