One more refactoring of kernel.cc. This introduces OpInputs to match OpOutputs, which is
slightly enhanced. PiperOrigin-RevId: 227908106
This commit is contained in:
parent
cea2e9c4fa
commit
4614b04cc0
@ -51,24 +51,57 @@ namespace tflite {
|
||||
namespace flex {
|
||||
namespace kernel {
|
||||
|
||||
// Controls the lifetime of tensor handles in a vector.
|
||||
// A list of inputs of a given node of the TensorFlow/Eager graph.
|
||||
class OpInputs {
|
||||
public:
|
||||
explicit OpInputs(const TfLiteIntArray* indexes) {
|
||||
for (int index : TfLiteIntArrayView(indexes)) {
|
||||
inputs_.push_back(index);
|
||||
}
|
||||
}
|
||||
~OpInputs() {}
|
||||
|
||||
int Size() const { return inputs_.size(); }
|
||||
|
||||
int TfLiteIndex(int i) const { return inputs_[i]; }
|
||||
|
||||
private:
|
||||
std::vector<int> inputs_;
|
||||
};
|
||||
|
||||
// A list of outputs of a given node of the TensorFlow/Eager graph, along with
|
||||
// the actual outputs of the EagerOperation.
|
||||
class OpOutputs {
|
||||
public:
|
||||
explicit OpOutputs(int num_elements) : vector_(num_elements, nullptr) {}
|
||||
explicit OpOutputs(const TfLiteIntArray* indexes) {
|
||||
for (int index : TfLiteIntArrayView(indexes)) {
|
||||
outputs_.push_back(index);
|
||||
}
|
||||
vector_.resize(outputs_.size());
|
||||
}
|
||||
~OpOutputs() { ResetTensorHandles(); }
|
||||
|
||||
~OpOutputs() {
|
||||
for (auto* handle : vector_) {
|
||||
if (handle) handle->Unref();
|
||||
int Size() const { return outputs_.size(); }
|
||||
|
||||
int TfLiteIndex(int i) const { return outputs_[i]; }
|
||||
|
||||
// Carefully unreference all the handles in the eager output vector.
|
||||
void ResetTensorHandles() {
|
||||
for (int i = 0; i < vector_.size(); ++i) {
|
||||
if (vector_[i]) {
|
||||
vector_[i]->Unref();
|
||||
vector_[i] = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2>* GetVector() {
|
||||
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2>*
|
||||
GetTensorHandles() {
|
||||
return &vector_;
|
||||
}
|
||||
|
||||
tensorflow::TensorHandle* GetHandle(int index) { return vector_[index]; }
|
||||
|
||||
private:
|
||||
std::vector<int> outputs_;
|
||||
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> vector_;
|
||||
};
|
||||
|
||||
@ -76,7 +109,8 @@ class OpOutputs {
|
||||
// TensorFlow ops within a single TF Lite op.
|
||||
class OpNode {
|
||||
public:
|
||||
OpNode() {}
|
||||
OpNode(const TfLiteIntArray* inputs, const TfLiteIntArray* outputs)
|
||||
: inputs_(inputs), outputs_(outputs) {}
|
||||
~OpNode() {}
|
||||
|
||||
const string& name() const { return name_; }
|
||||
@ -87,19 +121,14 @@ class OpNode {
|
||||
|
||||
const tensorflow::NodeDef& nodedef() const { return nodedef_; }
|
||||
|
||||
const std::vector<int>& inputs() const { return inputs_; }
|
||||
void InitializeInputs(const TfLiteIntArray* inputs) {
|
||||
for (int index : TfLiteIntArrayView(inputs)) {
|
||||
inputs_.push_back(index);
|
||||
}
|
||||
}
|
||||
const OpInputs& inputs() const { return inputs_; }
|
||||
OpInputs* mutable_inputs() { return &inputs_; }
|
||||
|
||||
const std::vector<int>& outputs() const { return outputs_; }
|
||||
void InitializeOutputs(const TfLiteIntArray* outputs) {
|
||||
for (int index : TfLiteIntArrayView(outputs)) {
|
||||
outputs_.push_back(index);
|
||||
}
|
||||
}
|
||||
const OpOutputs& outputs() const { return outputs_; }
|
||||
OpOutputs* mutable_outputs() { return &outputs_; }
|
||||
|
||||
int NumInputs() const { return inputs_.Size(); }
|
||||
int NumOutputs() const { return outputs_.Size(); }
|
||||
|
||||
tensorflow::Status InitializeNodeDef(const void* custom_initial_data,
|
||||
int custom_initial_data_size) {
|
||||
@ -162,7 +191,8 @@ class OpNode {
|
||||
|
||||
tensorflow::Status BuildEagerInputs(BufferMap* buffer_map,
|
||||
tensorflow::EagerOperation* op) {
|
||||
for (int input_index : inputs_) {
|
||||
for (int i = 0; i < inputs_.Size(); ++i) {
|
||||
int input_index = inputs_.TfLiteIndex(i);
|
||||
if (!buffer_map->HasTensor(input_index)) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Cannot read from invalid tensor index ", input_index);
|
||||
@ -181,17 +211,21 @@ class OpNode {
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
tensorflow::Status PersistEagerOutputs(BufferMap* buffer_map,
|
||||
OpOutputs* retvals) {
|
||||
for (int i = 0; i < outputs_.size(); ++i) {
|
||||
tensorflow::Status PersistEagerOutputs(BufferMap* buffer_map) {
|
||||
auto* handles = outputs_.GetTensorHandles();
|
||||
for (int i = 0; i < outputs_.Size(); ++i) {
|
||||
const tensorflow::Tensor* tensor = nullptr;
|
||||
TF_RETURN_IF_ERROR(retvals->GetHandle(i)->Tensor(&tensor));
|
||||
buffer_map->SetFromTensorFlow(outputs_[i], *tensor);
|
||||
TF_RETURN_IF_ERROR(handles->at(i)->Tensor(&tensor));
|
||||
buffer_map->SetFromTensorFlow(outputs_.TfLiteIndex(i), *tensor);
|
||||
}
|
||||
outputs_.ResetTensorHandles();
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
OpNode(const OpNode&) = delete;
|
||||
OpNode& operator=(const OpNode&) = delete;
|
||||
|
||||
// The name of the TensorFlow op to execute.
|
||||
string name_;
|
||||
// Index of this node into TF Lite's operator list.
|
||||
@ -199,9 +233,9 @@ class OpNode {
|
||||
// The corresponding NodeDef, containing the attributes for the op.
|
||||
tensorflow::NodeDef nodedef_;
|
||||
// List of inputs, as TF Lite tensor indices.
|
||||
std::vector<int> inputs_;
|
||||
OpInputs inputs_;
|
||||
// List of outputs, as TF Lite tensor indices.
|
||||
std::vector<int> outputs_;
|
||||
OpOutputs outputs_;
|
||||
};
|
||||
|
||||
// Executes the TensorFlow op given by 'op_name', with the attributes specified
|
||||
@ -212,18 +246,18 @@ tensorflow::Status ExecuteFlexOp(tensorflow::EagerContext* eager_context,
|
||||
TF_RETURN_IF_ERROR(node_data->BuildEagerOp(eager_context, &op));
|
||||
TF_RETURN_IF_ERROR(node_data->BuildEagerInputs(buffer_map, op.get()));
|
||||
|
||||
int num_retvals = node_data->outputs().size();
|
||||
OpOutputs retvals(num_retvals);
|
||||
int num_retvals = node_data->NumOutputs();
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
EagerExecute(op.get(), retvals.GetVector(), &num_retvals),
|
||||
EagerExecute(op.get(), node_data->mutable_outputs()->GetTensorHandles(),
|
||||
&num_retvals),
|
||||
" (while executing '", node_data->name(), "' via Eager)");
|
||||
|
||||
if (num_retvals != node_data->outputs().size()) {
|
||||
if (num_retvals != node_data->NumOutputs()) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Unexpected number of outputs from EagerExecute");
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(node_data->PersistEagerOutputs(buffer_map, &retvals));
|
||||
TF_RETURN_IF_ERROR(node_data->PersistEagerOutputs(buffer_map));
|
||||
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
@ -232,7 +266,7 @@ tensorflow::Status ExecuteFlexOp(tensorflow::EagerContext* eager_context,
|
||||
struct OpData {
|
||||
tensorflow::EagerContext* eager_context;
|
||||
BufferMap* buffer_map;
|
||||
std::vector<OpNode> nodes;
|
||||
std::vector<std::unique_ptr<OpNode>> nodes;
|
||||
std::vector<int> subgraph_inputs;
|
||||
std::vector<int> subgraph_outputs;
|
||||
};
|
||||
@ -261,6 +295,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
op_data->subgraph_inputs.push_back(tensor_index);
|
||||
}
|
||||
|
||||
op_data->nodes.reserve(params->nodes_to_replace->size);
|
||||
|
||||
CHECK(params->nodes_to_replace);
|
||||
tensorflow::Status status;
|
||||
for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
|
||||
@ -268,8 +304,8 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
TfLiteRegistration* reg;
|
||||
context->GetNodeAndRegistration(context, node_index, &node, ®);
|
||||
|
||||
op_data->nodes.push_back(OpNode());
|
||||
OpNode& node_data = op_data->nodes.back();
|
||||
op_data->nodes.emplace_back(new OpNode(node->inputs, node->outputs));
|
||||
OpNode& node_data = *op_data->nodes.back();
|
||||
|
||||
node_data.set_index(node_index);
|
||||
node_data.set_name("");
|
||||
@ -277,9 +313,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
status = node_data.InitializeNodeDef(node->custom_initial_data,
|
||||
node->custom_initial_data_size);
|
||||
if (!status.ok()) break;
|
||||
|
||||
node_data.InitializeInputs(node->inputs);
|
||||
node_data.InitializeOutputs(node->outputs);
|
||||
}
|
||||
|
||||
if (ConvertStatus(context, status) != kTfLiteOk) {
|
||||
@ -328,14 +361,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
for (const auto& node_data : op_data->nodes) {
|
||||
if (node_data.nodedef().op().empty()) {
|
||||
if (node_data->nodedef().op().empty()) {
|
||||
context->ReportError(context, "Invalid NodeDef in Flex op '%s'",
|
||||
node_data.name().c_str());
|
||||
node_data->name().c_str());
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
for (int tensor_index : node_data.inputs()) {
|
||||
++tensor_ref_count[tensor_index];
|
||||
for (int i = 0; i < node_data->inputs().Size(); ++i) {
|
||||
++tensor_ref_count[node_data->inputs().TfLiteIndex(i)];
|
||||
}
|
||||
}
|
||||
|
||||
@ -371,12 +404,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
// Execute the TensorFlow Ops sequentially.
|
||||
for (OpNode& node_data : op_data->nodes) {
|
||||
for (auto& node_data : op_data->nodes) {
|
||||
SCOPED_TAGGED_OPERATOR_PROFILE(
|
||||
reinterpret_cast<profiling::Profiler*>(context->profiler),
|
||||
node_data.name().c_str(), node_data.index());
|
||||
node_data->name().c_str(), node_data->index());
|
||||
|
||||
auto status = ExecuteFlexOp(eager_context, buffer_map, &node_data);
|
||||
auto status = ExecuteFlexOp(eager_context, buffer_map, node_data.get());
|
||||
TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user