Add the ability for an opencl delegate to update a variable input tensor inplace
PiperOrigin-RevId: 330551753 Change-Id: Ibc784c9a1890b75e13a950176c57c7424775dfd5
This commit is contained in:
parent
aa9c03e948
commit
1d7f71f9d6
@ -373,6 +373,7 @@ cc_library(
|
||||
":cl_command_queue",
|
||||
":cl_device",
|
||||
":environment",
|
||||
":gpu_object",
|
||||
":model_hints",
|
||||
":opencl_wrapper",
|
||||
":precision",
|
||||
|
@ -206,6 +206,11 @@ void InferenceContext::CopyInAndOutIds(const GraphFloat32& graph) {
|
||||
input_ids_.push_back(input->id);
|
||||
}
|
||||
|
||||
const auto variable_inputs = graph.variable_inputs();
|
||||
for (const auto& variable_input : variable_inputs) {
|
||||
variable_ids_and_refs_[variable_input->id] = variable_input->tensor.ref;
|
||||
}
|
||||
|
||||
const auto outputs = graph.outputs();
|
||||
for (const auto& output : outputs) {
|
||||
output_ids_.push_back(output->id);
|
||||
@ -387,41 +392,71 @@ absl::Status InferenceContext::Merge() {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void InferenceContext::GetUsages(
|
||||
const std::function<bool(const TensorDescriptor&)>& functor,
|
||||
std::map<ValueId, int2>* usages) {
|
||||
void InferenceContext::GetUsages(const std::function<bool(ValueId)>& functor,
|
||||
std::map<ValueId, int2>* usages) {
|
||||
for (ValueId in_id : input_ids_) {
|
||||
const auto& desc = tensor_reserver_.Get(in_id).descriptor;
|
||||
if (functor(desc)) {
|
||||
if (functor(in_id)) {
|
||||
AddUsage(in_id, 0, usages);
|
||||
}
|
||||
}
|
||||
for (int op_index = 0; op_index < nodes_.size(); ++op_index) {
|
||||
auto tensors = GetCLNodeTensors(nodes_[op_index]);
|
||||
for (auto& tensor : tensors) {
|
||||
if (functor(tensor.second)) {
|
||||
if (functor(tensor.first)) {
|
||||
AddUsage(tensor.first, op_index, usages);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (ValueId out_id : output_ids_) {
|
||||
const auto& desc = tensor_reserver_.Get(out_id).descriptor;
|
||||
if (functor(desc)) {
|
||||
if (functor(out_id)) {
|
||||
AddUsage(out_id, nodes_.size(), usages);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
InferenceContext::TensorMemoryType InferenceContext::GetTensorMemoryType(
|
||||
ValueId id) {
|
||||
if (variable_ids_and_refs_.find(id) != variable_ids_and_refs_.end()) {
|
||||
return TensorMemoryType::VARIABLE;
|
||||
} else if (IsBufferBased(tensor_reserver_.Get(id).descriptor.storage_type)) {
|
||||
return TensorMemoryType::BUFFER;
|
||||
} else {
|
||||
return TensorMemoryType::STRONG_SHAPE;
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status InferenceContext::AllocateMemory(CLContext* context) {
|
||||
RETURN_IF_ERROR(AllocateMemoryForVariableTensors(context));
|
||||
RETURN_IF_ERROR(AllocateMemoryForBuffers(context));
|
||||
RETURN_IF_ERROR(AllocateMemoryForStrongShapes(context));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceContext::AllocateMemoryForVariableTensors(
|
||||
CLContext* context) {
|
||||
std::map<ValueId, int> ref_value_to_tensor_index;
|
||||
|
||||
for (auto value_and_ref_value : variable_ids_and_refs_) {
|
||||
if (ref_value_to_tensor_index.find(value_and_ref_value.second) ==
|
||||
ref_value_to_tensor_index.end()) {
|
||||
const auto& t = tensor_reserver_.Get(value_and_ref_value.first);
|
||||
const auto& shape = t.shape;
|
||||
const auto& descriptor = t.descriptor;
|
||||
|
||||
RETURN_IF_ERROR(
|
||||
CreateTensor(*context, shape, descriptor,
|
||||
&variable_tensors_[value_and_ref_value.second]));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceContext::AllocateMemoryForBuffers(CLContext* context) {
|
||||
std::map<ValueId, int2> buffer_usages;
|
||||
GetUsages(
|
||||
[](const TensorDescriptor& t) { return IsBufferBased(t.storage_type); },
|
||||
[this](ValueId id) {
|
||||
return GetTensorMemoryType(id) == TensorMemoryType::BUFFER;
|
||||
},
|
||||
&buffer_usages);
|
||||
|
||||
std::vector<TensorUsageRecord<size_t>> buffer_usage_records;
|
||||
@ -455,7 +490,7 @@ absl::Status InferenceContext::AllocateMemoryForBuffers(CLContext* context) {
|
||||
for (auto& node : nodes_) {
|
||||
auto tensors = GetCLNodeTensors(node);
|
||||
for (auto& t : tensors) {
|
||||
if (!IsBufferBased(t.second.storage_type)) continue;
|
||||
if (GetTensorMemoryType(t.first) != TensorMemoryType::BUFFER) continue;
|
||||
const int tensor_index = graph_ids_to_shared_buffer_tensors_[t.first];
|
||||
if (created_tensors[tensor_index]) continue;
|
||||
const auto& shape = tensor_reserver_.Get(t.first).shape;
|
||||
@ -473,7 +508,9 @@ absl::Status InferenceContext::AllocateMemoryForStrongShapes(
|
||||
CLContext* context) {
|
||||
std::map<ValueId, int2> usages;
|
||||
GetUsages(
|
||||
[](const TensorDescriptor& t) { return !IsBufferBased(t.storage_type); },
|
||||
[this](ValueId id) {
|
||||
return GetTensorMemoryType(id) == TensorMemoryType::STRONG_SHAPE;
|
||||
},
|
||||
&usages);
|
||||
|
||||
std::vector<TensorUsageRecord<DummyTensor>> usage_records;
|
||||
@ -492,7 +529,9 @@ absl::Status InferenceContext::AllocateMemoryForStrongShapes(
|
||||
for (auto& node : nodes_) {
|
||||
auto tensors = GetCLNodeTensors(node);
|
||||
for (auto& t : tensors) {
|
||||
if (IsBufferBased(t.second.storage_type)) continue;
|
||||
if (GetTensorMemoryType(t.first) != TensorMemoryType::STRONG_SHAPE) {
|
||||
continue;
|
||||
}
|
||||
const auto& shape = tensor_reserver_.Get(t.first).shape;
|
||||
const auto id = assignment.object_ids[remap_from_graph_ids[t.first]];
|
||||
graph_ids_to_strong_shape_tensors_[t.first] = id;
|
||||
@ -586,8 +625,10 @@ uint64_t InferenceContext::GetSizeOfMemoryAllocatedForIntermediateTensors()
|
||||
}
|
||||
|
||||
Tensor* InferenceContext::GetTensor(ValueId id) {
|
||||
if (graph_ids_to_shared_buffer_tensors_.find(id) !=
|
||||
graph_ids_to_shared_buffer_tensors_.end()) {
|
||||
if (variable_ids_and_refs_.find(id) != variable_ids_and_refs_.end()) {
|
||||
return &variable_tensors_[variable_ids_and_refs_[id]];
|
||||
} else if (graph_ids_to_shared_buffer_tensors_.find(id) !=
|
||||
graph_ids_to_shared_buffer_tensors_.end()) {
|
||||
return &shared_buffer_tensors_[graph_ids_to_shared_buffer_tensors_[id]];
|
||||
} else {
|
||||
return &strong_shape_tensors_[graph_ids_to_strong_shape_tensors_[id]];
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/environment.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/model_hints.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
|
||||
@ -62,6 +63,7 @@ class InferenceContext {
|
||||
TensorStorageType storage_type;
|
||||
ModelHints hints;
|
||||
};
|
||||
|
||||
absl::Status InitFromGraph(const CreateInferenceInfo& create_info,
|
||||
const GraphFloat32& graph, Environment* env);
|
||||
|
||||
@ -88,6 +90,8 @@ class InferenceContext {
|
||||
TensorFloat32* result);
|
||||
|
||||
private:
|
||||
enum TensorMemoryType { STRONG_SHAPE = 0, BUFFER = 1, VARIABLE = 2 };
|
||||
|
||||
void CopyInAndOutIds(const GraphFloat32& graph);
|
||||
absl::Status ConvertOperations(const DeviceInfo& device_info,
|
||||
const GraphFloat32& graph, ModelHints hints);
|
||||
@ -98,14 +102,18 @@ class InferenceContext {
|
||||
absl::Status Merge();
|
||||
absl::Status AllocateMemory(CLContext* context);
|
||||
|
||||
absl::Status AllocateMemoryForVariableTensors(CLContext* context);
|
||||
|
||||
absl::Status AllocateMemoryForBuffers(CLContext* context);
|
||||
|
||||
absl::Status AllocateMemoryForStrongShapes(CLContext* context);
|
||||
|
||||
// utility function
|
||||
void GetUsages(const std::function<bool(const TensorDescriptor&)>& functor,
|
||||
void GetUsages(const std::function<bool(ValueId)>& functor,
|
||||
std::map<ValueId, int2>* usages);
|
||||
|
||||
TensorMemoryType GetTensorMemoryType(ValueId id);
|
||||
|
||||
void BindMemoryToOperations();
|
||||
absl::Status Compile(const CreationContext& creation_context);
|
||||
absl::Status Tune(const TuningParameters& tuning_parameters);
|
||||
@ -160,6 +168,7 @@ class InferenceContext {
|
||||
};
|
||||
TensorReserver tensor_reserver_;
|
||||
|
||||
std::map<ValueId, Tensor> variable_tensors_;
|
||||
std::vector<Buffer> shared_buffers_;
|
||||
std::vector<Tensor>
|
||||
shared_buffer_tensors_; // use references to memory from shared_buffers_
|
||||
@ -169,6 +178,7 @@ class InferenceContext {
|
||||
std::map<ValueId, ValueId> graph_ids_to_strong_shape_tensors_;
|
||||
|
||||
std::vector<ValueId> input_ids_;
|
||||
std::map<ValueId, ValueId> variable_ids_and_refs_;
|
||||
std::vector<ValueId> output_ids_;
|
||||
};
|
||||
|
||||
|
@ -46,6 +46,11 @@ std::vector<Value*> GraphFloat32::inputs() const {
|
||||
return FilterValues([](const ValueDef& v) { return v.producer == nullptr; });
|
||||
}
|
||||
|
||||
std::vector<Value*> GraphFloat32::variable_inputs() const {
|
||||
return FilterValues(
|
||||
[](const ValueDef& v) { return v.value->tensor.is_variable_input; });
|
||||
}
|
||||
|
||||
std::vector<Value*> GraphFloat32::outputs() const {
|
||||
return FilterValues([](const ValueDef& v) { return v.consumers.empty(); });
|
||||
}
|
||||
|
@ -90,6 +90,9 @@ class GraphFloat32 {
|
||||
// @return graph outputs, that are values without consumers.
|
||||
std::vector<Value*> outputs() const;
|
||||
|
||||
// @return values updated in place with a previously defined tensor reference.
|
||||
std::vector<Value*> variable_inputs() const;
|
||||
|
||||
// @return inputs into the given node. Returns empty vector for deleted node.
|
||||
std::vector<Value*> FindInputs(NodeId id) const;
|
||||
|
||||
|
@ -82,6 +82,13 @@ class TFLiteOperationParser {
|
||||
virtual absl::Status IsSupported(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) = 0;
|
||||
|
||||
// Return the value ids in the graph that correspond to the updated values of
|
||||
// the variable input tensor.
|
||||
virtual absl::flat_hash_map<int, ValueId>
|
||||
GetNewValueIdsForVariableInputNodes() {
|
||||
return absl::flat_hash_map<int, ValueId>();
|
||||
}
|
||||
};
|
||||
|
||||
HW ToHW(int32_t h, int32_t w) { return HW(h > 0 ? h : 1, w > 0 ? w : 1); }
|
||||
@ -2803,6 +2810,44 @@ absl::Status PrecreateIOTensors(
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status CopyVariableTensorOutputs(
|
||||
TfLiteNode* tflite_node, TfLiteRegistration* registration,
|
||||
GraphFloat32* graph, ObjectReader& reader,
|
||||
const absl::flat_hash_map<int, ValueId>& new_variable_tensor_values) {
|
||||
absl::flat_hash_map<int, ValueId> new_variable_tensor_values_copy(
|
||||
new_variable_tensor_values);
|
||||
// Retrieve the final value id for the variable input tensors.
|
||||
for (int i = 0; i < tflite_node->inputs->size; i++) {
|
||||
int tensor_idx = tflite_node->inputs->data[i];
|
||||
Value* value;
|
||||
if (!reader.ReadValueByTensorIdx(tensor_idx, &value).ok()) continue;
|
||||
if (value->tensor.is_variable_input) {
|
||||
if (new_variable_tensor_values_copy.find(i) ==
|
||||
new_variable_tensor_values_copy.end()) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat(GetOpNameByRegistration(*registration),
|
||||
" did not provide a new value for the variable input "
|
||||
"tensor with index ",
|
||||
tensor_idx));
|
||||
} else {
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::COPY);
|
||||
RETURN_IF_ERROR(graph->AddConsumer(
|
||||
node->id, new_variable_tensor_values_copy.at(i)));
|
||||
RETURN_IF_ERROR(reader.AddUpdate(node, i));
|
||||
new_variable_tensor_values_copy.erase(
|
||||
new_variable_tensor_values_copy.find(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!new_variable_tensor_values_copy.empty()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"More input variable tensors asked to be copied than present on the "
|
||||
"node");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status BuildModel(TfLiteContext* context,
|
||||
const TfLiteDelegateParams* delegate_params,
|
||||
GraphFloat32* graph,
|
||||
@ -2833,6 +2878,7 @@ absl::Status BuildModel(TfLiteContext* context,
|
||||
tflite_nodes.push_back(i);
|
||||
}
|
||||
absl::flat_hash_map<int, Value*> tensor_to_value;
|
||||
std::vector<ValueId> variable_inputs_to_value_id;
|
||||
RETURN_IF_ERROR(PrecreateIOTensors(context, graph,
|
||||
delegate_params->input_tensors,
|
||||
quant_conversion_map, &tensor_to_value));
|
||||
@ -2853,6 +2899,23 @@ absl::Status BuildModel(TfLiteContext* context,
|
||||
return absl::InternalError(absl::StrCat(
|
||||
GetOpNameByRegistration(*registration), ": ", status.message()));
|
||||
}
|
||||
|
||||
absl::flat_hash_map<int, ValueId> new_value_for_variable_input_tensors =
|
||||
operations[i]->GetNewValueIdsForVariableInputNodes();
|
||||
|
||||
RETURN_IF_ERROR(
|
||||
CopyVariableTensorOutputs(tflite_node, registration, graph, reader,
|
||||
new_value_for_variable_input_tensors));
|
||||
}
|
||||
|
||||
// Variable input tensors expect to be unchanged throughout model execution.
|
||||
// They need to be an output of the graph in order to have them unchanged.
|
||||
for (auto value_id : variable_inputs_to_value_id) {
|
||||
if (!graph->IsGraphOutput(value_id)) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Variable input tensors must be a graph output. Value ",
|
||||
value_id, " is not a graph output"));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -72,6 +72,7 @@ absl::Status ObjectReader::ReadNonConstantTensor(
|
||||
RETURN_IF_ERROR(
|
||||
ConvertTfLiteTensorToTensorRef(*fp_tflite_tensor, &value->tensor));
|
||||
value->tensor.ref = fp_tensor_index;
|
||||
value->tensor.is_variable_input = tflite_tensor->is_variable;
|
||||
value->quant_params.emplace();
|
||||
// tflite_tensor from the outer scope is invalidated due to calling
|
||||
// CreateNewTensorWithDifferentType
|
||||
@ -89,6 +90,7 @@ absl::Status ObjectReader::ReadNonConstantTensor(
|
||||
RETURN_IF_ERROR(
|
||||
ConvertTfLiteTensorToTensorRef(*tflite_tensor, &value->tensor));
|
||||
value->tensor.ref = tensor_idx;
|
||||
value->tensor.is_variable_input = tflite_tensor->is_variable;
|
||||
(*tensor_to_value)[tensor_idx] = value;
|
||||
}
|
||||
}
|
||||
@ -159,6 +161,53 @@ absl::Status ObjectReader::AddInput(const Node* node, uint32_t idx) {
|
||||
return graph_->AddConsumer(node->id, input->id);
|
||||
}
|
||||
|
||||
absl::Status ObjectReader::AddUpdate(const Node* node, uint32_t idx) {
|
||||
if (node_->inputs->size <= idx) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Data id ", idx, " must be less than tflite node inputs size ",
|
||||
node_->inputs->size));
|
||||
}
|
||||
|
||||
int update_tensor_idx = node_->inputs->data[idx];
|
||||
TfLiteTensor* update_tensor = context_->tensors + update_tensor_idx;
|
||||
if (!update_tensor->is_variable) {
|
||||
return absl::InvalidArgumentError(
|
||||
"The tensor must be a variable tensor to update it in place");
|
||||
}
|
||||
|
||||
Value* value;
|
||||
RETURN_IF_ERROR(ReadValueByTensorIdx(update_tensor_idx, &value));
|
||||
if (!value->tensor.is_variable_input) {
|
||||
return absl::InternalError(
|
||||
"Variable input tensor is not marked as variable");
|
||||
}
|
||||
|
||||
// We cannot create a cycle in the graph. The way around this when a node
|
||||
// updates a tensor in place would be to add a new value to the graph that
|
||||
// points to the same tensor.
|
||||
Value* updated_value = graph_->NewValue();
|
||||
updated_value->tensor = value->tensor;
|
||||
updated_value->quant_params = value->quant_params;
|
||||
RETURN_IF_ERROR(graph_->SetProducer(node->id, updated_value->id));
|
||||
|
||||
// We also need to update the tensor_to_value arrays so that the nodes added
|
||||
// after the current node will access the tensor with the updated value rather
|
||||
// than the initial value.
|
||||
if (quant_conversion_map_ != nullptr &&
|
||||
quant_conversion_map_->find(update_tensor_idx) !=
|
||||
quant_conversion_map_->end()) {
|
||||
// If quantization conversion map exists, then the index provided is not the
|
||||
// actual tensor idx. We need to find the float version of the tensor from
|
||||
// the map.
|
||||
tensor_to_value_->at(quant_conversion_map_->at(update_tensor_idx)) =
|
||||
updated_value;
|
||||
} else {
|
||||
tensor_to_value_->at(update_tensor_idx) = updated_value;
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
TfLiteTensor* ObjectReader::GetInputTensor(int index) const {
|
||||
return index >= 0 && index < node_->inputs->size
|
||||
? context_->tensors + node_->inputs->data[index]
|
||||
|
@ -86,6 +86,8 @@ class ObjectReader {
|
||||
|
||||
absl::Status AddInput(const Node* node, uint32_t idx);
|
||||
|
||||
absl::Status AddUpdate(const Node* node, uint32_t idx);
|
||||
|
||||
TfLiteTensor* GetInputTensor(int index) const;
|
||||
|
||||
TfLiteTensor* GetOutputTensor(int index) const;
|
||||
|
@ -72,6 +72,10 @@ struct TensorRef {
|
||||
// Opaque reference to a tensor. Upstream component is responsible for
|
||||
// resolving this reference into an actual tensor.
|
||||
int64_t ref = -1;
|
||||
|
||||
// Specifies if the tensor should be a variable input tensor that must be an
|
||||
// output as well as an input to the graph.
|
||||
bool is_variable_input = false;
|
||||
};
|
||||
|
||||
template <typename ShapeT, DataType Type>
|
||||
|
Loading…
x
Reference in New Issue
Block a user