One more refactoring of kernel.cc. This introduces OpInputs to match OpOutputs, which is

slightly enhanced.

PiperOrigin-RevId: 227908106
This commit is contained in:
A. Unique TensorFlower 2019-01-04 13:49:49 -08:00 committed by TensorFlower Gardener
parent cea2e9c4fa
commit 4614b04cc0

View File

@ -51,24 +51,57 @@ namespace tflite {
namespace flex {
namespace kernel {
// Controls the lifetime of tensor handles in a vector.
// A list of inputs of a given node of the TensorFlow/Eager graph.
class OpInputs {
public:
explicit OpInputs(const TfLiteIntArray* indexes) {
for (int index : TfLiteIntArrayView(indexes)) {
inputs_.push_back(index);
}
}
~OpInputs() {}
int Size() const { return inputs_.size(); }
int TfLiteIndex(int i) const { return inputs_[i]; }
private:
std::vector<int> inputs_;
};
// A list of outputs of a given node of the TensorFlow/Eager graph, along with
// the actual outputs of the EagerOperation.
class OpOutputs {
public:
explicit OpOutputs(int num_elements) : vector_(num_elements, nullptr) {}
explicit OpOutputs(const TfLiteIntArray* indexes) {
for (int index : TfLiteIntArrayView(indexes)) {
outputs_.push_back(index);
}
vector_.resize(outputs_.size());
}
~OpOutputs() { ResetTensorHandles(); }
~OpOutputs() {
for (auto* handle : vector_) {
if (handle) handle->Unref();
int Size() const { return outputs_.size(); }
int TfLiteIndex(int i) const { return outputs_[i]; }
// Carefully unreference all the handles in the eager output vector.
void ResetTensorHandles() {
for (int i = 0; i < vector_.size(); ++i) {
if (vector_[i]) {
vector_[i]->Unref();
vector_[i] = nullptr;
}
}
}
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2>* GetVector() {
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2>*
GetTensorHandles() {
return &vector_;
}
tensorflow::TensorHandle* GetHandle(int index) { return vector_[index]; }
private:
std::vector<int> outputs_;
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> vector_;
};
@ -76,7 +109,8 @@ class OpOutputs {
// TensorFlow ops within a single TF Lite op.
class OpNode {
public:
OpNode() {}
OpNode(const TfLiteIntArray* inputs, const TfLiteIntArray* outputs)
: inputs_(inputs), outputs_(outputs) {}
~OpNode() {}
const string& name() const { return name_; }
@ -87,19 +121,14 @@ class OpNode {
const tensorflow::NodeDef& nodedef() const { return nodedef_; }
const std::vector<int>& inputs() const { return inputs_; }
void InitializeInputs(const TfLiteIntArray* inputs) {
for (int index : TfLiteIntArrayView(inputs)) {
inputs_.push_back(index);
}
}
const OpInputs& inputs() const { return inputs_; }
OpInputs* mutable_inputs() { return &inputs_; }
const std::vector<int>& outputs() const { return outputs_; }
void InitializeOutputs(const TfLiteIntArray* outputs) {
for (int index : TfLiteIntArrayView(outputs)) {
outputs_.push_back(index);
}
}
const OpOutputs& outputs() const { return outputs_; }
OpOutputs* mutable_outputs() { return &outputs_; }
int NumInputs() const { return inputs_.Size(); }
int NumOutputs() const { return outputs_.Size(); }
tensorflow::Status InitializeNodeDef(const void* custom_initial_data,
int custom_initial_data_size) {
@ -162,7 +191,8 @@ class OpNode {
tensorflow::Status BuildEagerInputs(BufferMap* buffer_map,
tensorflow::EagerOperation* op) {
for (int input_index : inputs_) {
for (int i = 0; i < inputs_.Size(); ++i) {
int input_index = inputs_.TfLiteIndex(i);
if (!buffer_map->HasTensor(input_index)) {
return tensorflow::errors::Internal(
"Cannot read from invalid tensor index ", input_index);
@ -181,17 +211,21 @@ class OpNode {
return tensorflow::Status::OK();
}
tensorflow::Status PersistEagerOutputs(BufferMap* buffer_map,
OpOutputs* retvals) {
for (int i = 0; i < outputs_.size(); ++i) {
tensorflow::Status PersistEagerOutputs(BufferMap* buffer_map) {
auto* handles = outputs_.GetTensorHandles();
for (int i = 0; i < outputs_.Size(); ++i) {
const tensorflow::Tensor* tensor = nullptr;
TF_RETURN_IF_ERROR(retvals->GetHandle(i)->Tensor(&tensor));
buffer_map->SetFromTensorFlow(outputs_[i], *tensor);
TF_RETURN_IF_ERROR(handles->at(i)->Tensor(&tensor));
buffer_map->SetFromTensorFlow(outputs_.TfLiteIndex(i), *tensor);
}
outputs_.ResetTensorHandles();
return tensorflow::Status::OK();
}
private:
OpNode(const OpNode&) = delete;
OpNode& operator=(const OpNode&) = delete;
// The name of the TensorFlow op to execute.
string name_;
// Index of this node into TF Lite's operator list.
@ -199,9 +233,9 @@ class OpNode {
// The corresponding NodeDef, containing the attributes for the op.
tensorflow::NodeDef nodedef_;
// List of inputs, as TF Lite tensor indices.
std::vector<int> inputs_;
OpInputs inputs_;
// List of outputs, as TF Lite tensor indices.
std::vector<int> outputs_;
OpOutputs outputs_;
};
// Executes the TensorFlow op given by 'op_name', with the attributes specified
@ -212,18 +246,18 @@ tensorflow::Status ExecuteFlexOp(tensorflow::EagerContext* eager_context,
TF_RETURN_IF_ERROR(node_data->BuildEagerOp(eager_context, &op));
TF_RETURN_IF_ERROR(node_data->BuildEagerInputs(buffer_map, op.get()));
int num_retvals = node_data->outputs().size();
OpOutputs retvals(num_retvals);
int num_retvals = node_data->NumOutputs();
TF_RETURN_WITH_CONTEXT_IF_ERROR(
EagerExecute(op.get(), retvals.GetVector(), &num_retvals),
EagerExecute(op.get(), node_data->mutable_outputs()->GetTensorHandles(),
&num_retvals),
" (while executing '", node_data->name(), "' via Eager)");
if (num_retvals != node_data->outputs().size()) {
if (num_retvals != node_data->NumOutputs()) {
return tensorflow::errors::Internal(
"Unexpected number of outputs from EagerExecute");
}
TF_RETURN_IF_ERROR(node_data->PersistEagerOutputs(buffer_map, &retvals));
TF_RETURN_IF_ERROR(node_data->PersistEagerOutputs(buffer_map));
return tensorflow::Status::OK();
}
@ -232,7 +266,7 @@ tensorflow::Status ExecuteFlexOp(tensorflow::EagerContext* eager_context,
struct OpData {
tensorflow::EagerContext* eager_context;
BufferMap* buffer_map;
std::vector<OpNode> nodes;
std::vector<std::unique_ptr<OpNode>> nodes;
std::vector<int> subgraph_inputs;
std::vector<int> subgraph_outputs;
};
@ -261,6 +295,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
op_data->subgraph_inputs.push_back(tensor_index);
}
op_data->nodes.reserve(params->nodes_to_replace->size);
CHECK(params->nodes_to_replace);
tensorflow::Status status;
for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
@ -268,8 +304,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TfLiteRegistration* reg;
context->GetNodeAndRegistration(context, node_index, &node, &reg);
op_data->nodes.push_back(OpNode());
OpNode& node_data = op_data->nodes.back();
op_data->nodes.emplace_back(new OpNode(node->inputs, node->outputs));
OpNode& node_data = *op_data->nodes.back();
node_data.set_index(node_index);
node_data.set_name("");
@ -277,9 +313,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
status = node_data.InitializeNodeDef(node->custom_initial_data,
node->custom_initial_data_size);
if (!status.ok()) break;
node_data.InitializeInputs(node->inputs);
node_data.InitializeOutputs(node->outputs);
}
if (ConvertStatus(context, status) != kTfLiteOk) {
@ -328,14 +361,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
for (const auto& node_data : op_data->nodes) {
if (node_data.nodedef().op().empty()) {
if (node_data->nodedef().op().empty()) {
context->ReportError(context, "Invalid NodeDef in Flex op '%s'",
node_data.name().c_str());
node_data->name().c_str());
return kTfLiteError;
}
for (int tensor_index : node_data.inputs()) {
++tensor_ref_count[tensor_index];
for (int i = 0; i < node_data->inputs().Size(); ++i) {
++tensor_ref_count[node_data->inputs().TfLiteIndex(i)];
}
}
@ -371,12 +404,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
// Execute the TensorFlow Ops sequentially.
for (OpNode& node_data : op_data->nodes) {
for (auto& node_data : op_data->nodes) {
SCOPED_TAGGED_OPERATOR_PROFILE(
reinterpret_cast<profiling::Profiler*>(context->profiler),
node_data.name().c_str(), node_data.index());
node_data->name().c_str(), node_data->index());
auto status = ExecuteFlexOp(eager_context, buffer_map, &node_data);
auto status = ExecuteFlexOp(eager_context, buffer_map, node_data.get());
TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
}