Fix some old TODO comments: delete methods of Subgraph that expose
mutable access to the private `tensors_` vector and to the private `nodes_and_registrations_` vector, replacing them with functions that provide mutable access to the individual tensors/nodes/registrations, and apply minor refactorings to the code that was previously using these methods to use other methods instead. PiperOrigin-RevId: 345649125 Change-Id: I1b38c22421d62ef85c7a1eadf9662234995c90fe
This commit is contained in:
parent
0fd465c057
commit
c37b2f1152
10
RELEASE.md
10
RELEASE.md
@ -42,6 +42,16 @@
|
||||
renamed `tf.function(jit_compile=True)`.
|
||||
|
||||
* `tf.lite`:
|
||||
* class `tflite::Subgraph`:
|
||||
* Removed the `tensors()` method and the non-const overload of the
|
||||
`nodes_and_registration()` method, both of which were previously
|
||||
documented as temporary and to be removed.
|
||||
* Uses of `tensors()` can be replaced by calling the existing
|
||||
methods `tensors_size()` and `tensor(int)`.
|
||||
* Uses of the non-const overload of `nodes_and_registration`
|
||||
can be replaced by calling the existing methods `nodes_size()`
|
||||
and `context()`, and then calling the `GetNodeAndRegistration`
|
||||
method in the `TfLiteContext` returned by `context()`.
|
||||
* NNAPI
|
||||
* Removed deprecated `Interpreter::UseNNAPI(bool)` C++ API.
|
||||
* Use `NnApiDelegate()` and related delegate configuration methods
|
||||
|
@ -162,9 +162,9 @@ class InterpreterInfo : public GraphInfo {
|
||||
public:
|
||||
explicit InterpreterInfo(Subgraph* subgraph) : subgraph_(subgraph) {}
|
||||
|
||||
size_t num_tensors() const override { return subgraph_->tensors().size(); }
|
||||
size_t num_tensors() const override { return subgraph_->tensors_size(); }
|
||||
TfLiteTensor* tensor(size_t index) override {
|
||||
return &subgraph_->tensors()[index];
|
||||
return subgraph_->tensor(index);
|
||||
}
|
||||
size_t num_execution_nodes() const override {
|
||||
return subgraph_->execution_plan().size();
|
||||
@ -218,7 +218,7 @@ Subgraph::Subgraph(ErrorReporter* error_reporter,
|
||||
|
||||
// Reserve some space for the tensors to avoid excessive resizing.
|
||||
tensors_.reserve(kTensorsReservedCapacity);
|
||||
nodes_and_registration().reserve(kTensorsReservedCapacity);
|
||||
nodes_and_registration_.reserve(kTensorsReservedCapacity);
|
||||
// Invalid to call these except from TfLiteDelegate
|
||||
SwitchToKernelContext();
|
||||
}
|
||||
@ -961,12 +961,13 @@ TfLiteStatus Subgraph::PrepareOpsAndTensors() {
|
||||
// overhead should be minimal since the number of custom-allocated tensors
|
||||
// will typically be low.
|
||||
for (int i = 0; i < custom_allocations_.size(); ++i) {
|
||||
auto idx_and_alloc = custom_allocations_[i];
|
||||
auto& tensor = tensors()[idx_and_alloc.first];
|
||||
const auto& alloc = idx_and_alloc.second;
|
||||
TF_LITE_ENSURE(context(), tensor.allocation_type == kTfLiteCustom);
|
||||
auto index_and_alloc = custom_allocations_[i];
|
||||
TfLiteTensor* tensor_at_index = tensor(index_and_alloc.first);
|
||||
const auto& alloc = index_and_alloc.second;
|
||||
TF_LITE_ENSURE(context(),
|
||||
tensor_at_index->allocation_type == kTfLiteCustom);
|
||||
TF_LITE_ENSURE_STATUS(
|
||||
ValidateCustomAllocationForTensor(context(), &tensor, alloc));
|
||||
ValidateCustomAllocationForTensor(context(), tensor_at_index, alloc));
|
||||
}
|
||||
|
||||
next_execution_plan_index_to_plan_allocation_ =
|
||||
|
@ -180,17 +180,6 @@ class Subgraph {
|
||||
// Return read-only vector of node indices in the order of execution.
|
||||
const std::vector<int>& execution_plan() const { return execution_plan_; }
|
||||
|
||||
// Mutable form of tensors (TEMPORARY for refactor).
|
||||
// TODO(b/119495520): remove when refactoring complete.
|
||||
std::vector<TfLiteTensor>& tensors() { return tensors_; }
|
||||
|
||||
// Mutable form of nodes_and_registration (TEMPORARY for refactor).
|
||||
// TODO(b/119495520): remove when refactoring complete.
|
||||
std::vector<std::pair<TfLiteNode, TfLiteRegistration>>&
|
||||
nodes_and_registration() {
|
||||
return nodes_and_registration_;
|
||||
}
|
||||
|
||||
const std::vector<std::pair<TfLiteNode, TfLiteRegistration>>&
|
||||
nodes_and_registration() const {
|
||||
return nodes_and_registration_;
|
||||
|
@ -141,12 +141,28 @@ class DelegatedInterpreter {
|
||||
}
|
||||
|
||||
// Get the TfLiteContext to be mocked for swapping out functions that have to
|
||||
// be called inside delegate (i.e. in delegat kernel mode).
|
||||
// be called inside delegate (i.e. in delegate kernel mode).
|
||||
TfLiteContext* context() { return interpreter_.primary_subgraph().context(); }
|
||||
|
||||
std::vector<std::pair<TfLiteNode, TfLiteRegistration>>&
|
||||
nodes_and_registration() {
|
||||
return interpreter_.primary_subgraph().nodes_and_registration();
|
||||
// node(int) and registration(int) are used to implement
|
||||
// GetNodeAndRegistration. We can't implement those using
|
||||
// TfLiteContext *context = interpreter_.primary_subgraph().context();
|
||||
// context->GetNodeAndRegistration(context, &node, ®istration);
|
||||
// here, because calling GetNodeAndRegistration from within it's own
|
||||
// implementation would lead to an infinite loop.
|
||||
// Instead, we just call node_and_registration and use a const_cast.
|
||||
// These const_casts are a bit ugly, but I think less ugly than exposing
|
||||
// the private GetNodeAndRegistration method in Subgraph as public,
|
||||
// or making this class a friend of Subgraph.
|
||||
TfLiteNode* node(int index) {
|
||||
const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration =
|
||||
interpreter_.primary_subgraph().node_and_registration(index);
|
||||
return const_cast<TfLiteNode*>(&node_and_registration->first);
|
||||
}
|
||||
TfLiteRegistration* registration(int index) {
|
||||
const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration =
|
||||
interpreter_.primary_subgraph().node_and_registration(index);
|
||||
return const_cast<TfLiteRegistration*>(&node_and_registration->second);
|
||||
}
|
||||
|
||||
TfLiteIntArray* exec_plan() const { return exec_plan_; }
|
||||
@ -280,10 +296,8 @@ TEST(ModelBuilderTest, GetOpsToReplaceAcceptsFp16DequantizeNodes) {
|
||||
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
||||
TfLiteNode** node,
|
||||
TfLiteRegistration** registration) {
|
||||
auto& node_and_reg =
|
||||
interpreter_fp16_add_op->nodes_and_registration()[node_index];
|
||||
*node = &node_and_reg.first;
|
||||
*registration = &node_and_reg.second;
|
||||
*node = interpreter_fp16_add_op->node(node_index);
|
||||
*registration = interpreter_fp16_add_op->registration(node_index);
|
||||
return kTfLiteOk;
|
||||
};
|
||||
context->PreviewDelegatePartitioning =
|
||||
@ -346,10 +360,8 @@ TEST(ModelBuilderTest, GetOpsToReplaceRejectsFp16DequantizeNodes) {
|
||||
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
||||
TfLiteNode** node,
|
||||
TfLiteRegistration** registration) {
|
||||
auto& node_and_reg =
|
||||
interpreter_fp16_gt_op->nodes_and_registration()[node_index];
|
||||
*node = &node_and_reg.first;
|
||||
*registration = &node_and_reg.second;
|
||||
*node = interpreter_fp16_gt_op->node(node_index);
|
||||
*registration = interpreter_fp16_gt_op->registration(node_index);
|
||||
return kTfLiteOk;
|
||||
};
|
||||
context->PreviewDelegatePartitioning =
|
||||
@ -462,9 +474,8 @@ TEST(ModelBuilderTest, GetOpsToReplaceDoesNotPruneUint8) {
|
||||
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
||||
TfLiteNode** node,
|
||||
TfLiteRegistration** registration) {
|
||||
auto& node_and_reg = interpreter_fp32->nodes_and_registration()[node_index];
|
||||
*node = &node_and_reg.first;
|
||||
*registration = &node_and_reg.second;
|
||||
*node = interpreter_fp32->node(node_index);
|
||||
*registration = interpreter_fp32->registration(node_index);
|
||||
return kTfLiteOk;
|
||||
};
|
||||
context->PreviewDelegatePartitioning =
|
||||
@ -630,10 +641,8 @@ TEST(ModelBuilderTest, GetOpsToReplaceMultiplePartitions) {
|
||||
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
||||
TfLiteNode** node,
|
||||
TfLiteRegistration** registration) {
|
||||
auto& node_and_reg =
|
||||
interpreter2_fp32->nodes_and_registration()[node_index];
|
||||
*node = &node_and_reg.first;
|
||||
*registration = &node_and_reg.second;
|
||||
*node = interpreter2_fp32->node(node_index);
|
||||
*registration = interpreter2_fp32->registration(node_index);
|
||||
return kTfLiteOk;
|
||||
};
|
||||
context->PreviewDelegatePartitioning =
|
||||
@ -845,9 +854,8 @@ TEST(ModelBuilderTest, GetOpsToReplaceSelectsCorrectFp16Nodes_SinglePartition) {
|
||||
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
||||
TfLiteNode** node,
|
||||
TfLiteRegistration** registration) {
|
||||
auto& node_and_reg = interpreter_mn->nodes_and_registration()[node_index];
|
||||
*node = &node_and_reg.first;
|
||||
*registration = &node_and_reg.second;
|
||||
*node = interpreter_mn->node(node_index);
|
||||
*registration = interpreter_mn->registration(node_index);
|
||||
return kTfLiteOk;
|
||||
};
|
||||
context->PreviewDelegatePartitioning =
|
||||
@ -917,9 +925,8 @@ TEST(ModelBuilderTest,
|
||||
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
||||
TfLiteNode** node,
|
||||
TfLiteRegistration** registration) {
|
||||
auto& node_and_reg = interpreter_mn2->nodes_and_registration()[node_index];
|
||||
*node = &node_and_reg.first;
|
||||
*registration = &node_and_reg.second;
|
||||
*node = interpreter_mn2->node(node_index);
|
||||
*registration = interpreter_mn2->registration(node_index);
|
||||
return kTfLiteOk;
|
||||
};
|
||||
|
||||
@ -1121,10 +1128,8 @@ TEST(ModelBuilderTest, GetOpsToReplace_AllowQuantOps) {
|
||||
context->GetNodeAndRegistration = [](struct TfLiteContext*, int node_index,
|
||||
TfLiteNode** node,
|
||||
TfLiteRegistration** registration) {
|
||||
auto& node_and_reg =
|
||||
interpreter_quant->nodes_and_registration()[node_index];
|
||||
*node = &node_and_reg.first;
|
||||
*registration = &node_and_reg.second;
|
||||
*node = interpreter_quant->node(node_index);
|
||||
*registration = interpreter_quant->registration(node_index);
|
||||
return kTfLiteOk;
|
||||
};
|
||||
context->PreviewDelegatePartitioning =
|
||||
|
@ -411,8 +411,7 @@ TfLiteStatus Interpreter::SetBufferHandle(int tensor_index,
|
||||
TfLiteBufferHandle buffer_handle,
|
||||
TfLiteDelegate* delegate) {
|
||||
TF_LITE_ENSURE(context_, tensor_index < tensors_size());
|
||||
std::vector<TfLiteTensor>& tensors = primary_subgraph().tensors();
|
||||
TfLiteTensor* tensor = &tensors[tensor_index];
|
||||
TfLiteTensor* tensor = primary_subgraph().tensor(tensor_index);
|
||||
|
||||
TF_LITE_ENSURE(context_,
|
||||
tensor->delegate == nullptr || tensor->delegate == delegate);
|
||||
@ -431,8 +430,7 @@ TfLiteStatus Interpreter::GetBufferHandle(int tensor_index,
|
||||
TfLiteBufferHandle* buffer_handle,
|
||||
TfLiteDelegate** delegate) {
|
||||
TF_LITE_ENSURE(context_, tensor_index < tensors_size());
|
||||
std::vector<TfLiteTensor>& tensors = primary_subgraph().tensors();
|
||||
TfLiteTensor* tensor = &tensors[tensor_index];
|
||||
TfLiteTensor* tensor = primary_subgraph().tensor(tensor_index);
|
||||
|
||||
*delegate = tensor->delegate;
|
||||
*buffer_handle = tensor->buffer_handle;
|
||||
|
Loading…
Reference in New Issue
Block a user