Add the ability for an opencl delegate to update a variable input tensor inplace

PiperOrigin-RevId: 330551753
Change-Id: Ibc784c9a1890b75e13a950176c57c7424775dfd5
This commit is contained in:
A. Unique TensorFlower 2020-09-08 11:36:01 -07:00 committed by TensorFlower Gardener
parent aa9c03e948
commit 1d7f71f9d6
9 changed files with 193 additions and 15 deletions

View File

@ -373,6 +373,7 @@ cc_library(
":cl_command_queue",
":cl_device",
":environment",
":gpu_object",
":model_hints",
":opencl_wrapper",
":precision",

View File

@ -206,6 +206,11 @@ void InferenceContext::CopyInAndOutIds(const GraphFloat32& graph) {
input_ids_.push_back(input->id);
}
const auto variable_inputs = graph.variable_inputs();
for (const auto& variable_input : variable_inputs) {
variable_ids_and_refs_[variable_input->id] = variable_input->tensor.ref;
}
const auto outputs = graph.outputs();
for (const auto& output : outputs) {
output_ids_.push_back(output->id);
@ -387,41 +392,71 @@ absl::Status InferenceContext::Merge() {
return absl::OkStatus();
}
void InferenceContext::GetUsages(
const std::function<bool(const TensorDescriptor&)>& functor,
std::map<ValueId, int2>* usages) {
void InferenceContext::GetUsages(const std::function<bool(ValueId)>& functor,
std::map<ValueId, int2>* usages) {
for (ValueId in_id : input_ids_) {
const auto& desc = tensor_reserver_.Get(in_id).descriptor;
if (functor(desc)) {
if (functor(in_id)) {
AddUsage(in_id, 0, usages);
}
}
for (int op_index = 0; op_index < nodes_.size(); ++op_index) {
auto tensors = GetCLNodeTensors(nodes_[op_index]);
for (auto& tensor : tensors) {
if (functor(tensor.second)) {
if (functor(tensor.first)) {
AddUsage(tensor.first, op_index, usages);
}
}
}
for (ValueId out_id : output_ids_) {
const auto& desc = tensor_reserver_.Get(out_id).descriptor;
if (functor(desc)) {
if (functor(out_id)) {
AddUsage(out_id, nodes_.size(), usages);
}
}
}
InferenceContext::TensorMemoryType InferenceContext::GetTensorMemoryType(
ValueId id) {
if (variable_ids_and_refs_.find(id) != variable_ids_and_refs_.end()) {
return TensorMemoryType::VARIABLE;
} else if (IsBufferBased(tensor_reserver_.Get(id).descriptor.storage_type)) {
return TensorMemoryType::BUFFER;
} else {
return TensorMemoryType::STRONG_SHAPE;
}
}
absl::Status InferenceContext::AllocateMemory(CLContext* context) {
RETURN_IF_ERROR(AllocateMemoryForVariableTensors(context));
RETURN_IF_ERROR(AllocateMemoryForBuffers(context));
RETURN_IF_ERROR(AllocateMemoryForStrongShapes(context));
return absl::OkStatus();
}
absl::Status InferenceContext::AllocateMemoryForVariableTensors(
CLContext* context) {
std::map<ValueId, int> ref_value_to_tensor_index;
for (auto value_and_ref_value : variable_ids_and_refs_) {
if (ref_value_to_tensor_index.find(value_and_ref_value.second) ==
ref_value_to_tensor_index.end()) {
const auto& t = tensor_reserver_.Get(value_and_ref_value.first);
const auto& shape = t.shape;
const auto& descriptor = t.descriptor;
RETURN_IF_ERROR(
CreateTensor(*context, shape, descriptor,
&variable_tensors_[value_and_ref_value.second]));
}
}
return absl::OkStatus();
}
absl::Status InferenceContext::AllocateMemoryForBuffers(CLContext* context) {
std::map<ValueId, int2> buffer_usages;
GetUsages(
[](const TensorDescriptor& t) { return IsBufferBased(t.storage_type); },
[this](ValueId id) {
return GetTensorMemoryType(id) == TensorMemoryType::BUFFER;
},
&buffer_usages);
std::vector<TensorUsageRecord<size_t>> buffer_usage_records;
@ -455,7 +490,7 @@ absl::Status InferenceContext::AllocateMemoryForBuffers(CLContext* context) {
for (auto& node : nodes_) {
auto tensors = GetCLNodeTensors(node);
for (auto& t : tensors) {
if (!IsBufferBased(t.second.storage_type)) continue;
if (GetTensorMemoryType(t.first) != TensorMemoryType::BUFFER) continue;
const int tensor_index = graph_ids_to_shared_buffer_tensors_[t.first];
if (created_tensors[tensor_index]) continue;
const auto& shape = tensor_reserver_.Get(t.first).shape;
@ -473,7 +508,9 @@ absl::Status InferenceContext::AllocateMemoryForStrongShapes(
CLContext* context) {
std::map<ValueId, int2> usages;
GetUsages(
[](const TensorDescriptor& t) { return !IsBufferBased(t.storage_type); },
[this](ValueId id) {
return GetTensorMemoryType(id) == TensorMemoryType::STRONG_SHAPE;
},
&usages);
std::vector<TensorUsageRecord<DummyTensor>> usage_records;
@ -492,7 +529,9 @@ absl::Status InferenceContext::AllocateMemoryForStrongShapes(
for (auto& node : nodes_) {
auto tensors = GetCLNodeTensors(node);
for (auto& t : tensors) {
if (IsBufferBased(t.second.storage_type)) continue;
if (GetTensorMemoryType(t.first) != TensorMemoryType::STRONG_SHAPE) {
continue;
}
const auto& shape = tensor_reserver_.Get(t.first).shape;
const auto id = assignment.object_ids[remap_from_graph_ids[t.first]];
graph_ids_to_strong_shape_tensors_[t.first] = id;
@ -586,8 +625,10 @@ uint64_t InferenceContext::GetSizeOfMemoryAllocatedForIntermediateTensors()
}
Tensor* InferenceContext::GetTensor(ValueId id) {
if (graph_ids_to_shared_buffer_tensors_.find(id) !=
graph_ids_to_shared_buffer_tensors_.end()) {
if (variable_ids_and_refs_.find(id) != variable_ids_and_refs_.end()) {
return &variable_tensors_[variable_ids_and_refs_[id]];
} else if (graph_ids_to_shared_buffer_tensors_.find(id) !=
graph_ids_to_shared_buffer_tensors_.end()) {
return &shared_buffer_tensors_[graph_ids_to_shared_buffer_tensors_[id]];
} else {
return &strong_shape_tensors_[graph_ids_to_strong_shape_tensors_[id]];

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
#include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h"
#include "tensorflow/lite/delegates/gpu/cl/environment.h"
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
#include "tensorflow/lite/delegates/gpu/cl/kernels/gpu_operation.h"
#include "tensorflow/lite/delegates/gpu/cl/model_hints.h"
#include "tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h"
@ -62,6 +63,7 @@ class InferenceContext {
TensorStorageType storage_type;
ModelHints hints;
};
absl::Status InitFromGraph(const CreateInferenceInfo& create_info,
const GraphFloat32& graph, Environment* env);
@ -88,6 +90,8 @@ class InferenceContext {
TensorFloat32* result);
private:
enum TensorMemoryType { STRONG_SHAPE = 0, BUFFER = 1, VARIABLE = 2 };
void CopyInAndOutIds(const GraphFloat32& graph);
absl::Status ConvertOperations(const DeviceInfo& device_info,
const GraphFloat32& graph, ModelHints hints);
@ -98,14 +102,18 @@ class InferenceContext {
absl::Status Merge();
absl::Status AllocateMemory(CLContext* context);
absl::Status AllocateMemoryForVariableTensors(CLContext* context);
absl::Status AllocateMemoryForBuffers(CLContext* context);
absl::Status AllocateMemoryForStrongShapes(CLContext* context);
// utility function
void GetUsages(const std::function<bool(const TensorDescriptor&)>& functor,
void GetUsages(const std::function<bool(ValueId)>& functor,
std::map<ValueId, int2>* usages);
TensorMemoryType GetTensorMemoryType(ValueId id);
void BindMemoryToOperations();
absl::Status Compile(const CreationContext& creation_context);
absl::Status Tune(const TuningParameters& tuning_parameters);
@ -160,6 +168,7 @@ class InferenceContext {
};
TensorReserver tensor_reserver_;
std::map<ValueId, Tensor> variable_tensors_;
std::vector<Buffer> shared_buffers_;
std::vector<Tensor>
shared_buffer_tensors_; // use references to memory from shared_buffers_
@ -169,6 +178,7 @@ class InferenceContext {
std::map<ValueId, ValueId> graph_ids_to_strong_shape_tensors_;
std::vector<ValueId> input_ids_;
std::map<ValueId, ValueId> variable_ids_and_refs_;
std::vector<ValueId> output_ids_;
};

View File

@ -46,6 +46,11 @@ std::vector<Value*> GraphFloat32::inputs() const {
return FilterValues([](const ValueDef& v) { return v.producer == nullptr; });
}
std::vector<Value*> GraphFloat32::variable_inputs() const {
return FilterValues(
[](const ValueDef& v) { return v.value->tensor.is_variable_input; });
}
std::vector<Value*> GraphFloat32::outputs() const {
return FilterValues([](const ValueDef& v) { return v.consumers.empty(); });
}

View File

@ -90,6 +90,9 @@ class GraphFloat32 {
// @return graph outputs, that are values without consumers.
std::vector<Value*> outputs() const;
// @return values updated in place with a previously defined tensor reference.
std::vector<Value*> variable_inputs() const;
// @return inputs into the given node. Returns empty vector for deleted node.
std::vector<Value*> FindInputs(NodeId id) const;

View File

@ -82,6 +82,13 @@ class TFLiteOperationParser {
virtual absl::Status IsSupported(const TfLiteContext* context,
const TfLiteNode* tflite_node,
const TfLiteRegistration* registration) = 0;
// Return the value ids in the graph that correspond to the updated values of
// the variable input tensor.
virtual absl::flat_hash_map<int, ValueId>
GetNewValueIdsForVariableInputNodes() {
return absl::flat_hash_map<int, ValueId>();
}
};
HW ToHW(int32_t h, int32_t w) { return HW(h > 0 ? h : 1, w > 0 ? w : 1); }
@ -2803,6 +2810,44 @@ absl::Status PrecreateIOTensors(
return absl::OkStatus();
}
absl::Status CopyVariableTensorOutputs(
TfLiteNode* tflite_node, TfLiteRegistration* registration,
GraphFloat32* graph, ObjectReader& reader,
const absl::flat_hash_map<int, ValueId>& new_variable_tensor_values) {
absl::flat_hash_map<int, ValueId> new_variable_tensor_values_copy(
new_variable_tensor_values);
// Retrieve the final value id for the variable input tensors.
for (int i = 0; i < tflite_node->inputs->size; i++) {
int tensor_idx = tflite_node->inputs->data[i];
Value* value;
if (!reader.ReadValueByTensorIdx(tensor_idx, &value).ok()) continue;
if (value->tensor.is_variable_input) {
if (new_variable_tensor_values_copy.find(i) ==
new_variable_tensor_values_copy.end()) {
return absl::InvalidArgumentError(
absl::StrCat(GetOpNameByRegistration(*registration),
" did not provide a new value for the variable input "
"tensor with index ",
tensor_idx));
} else {
Node* node = graph->NewNode();
node->operation.type = ToString(OperationType::COPY);
RETURN_IF_ERROR(graph->AddConsumer(
node->id, new_variable_tensor_values_copy.at(i)));
RETURN_IF_ERROR(reader.AddUpdate(node, i));
new_variable_tensor_values_copy.erase(
new_variable_tensor_values_copy.find(i));
}
}
}
if (!new_variable_tensor_values_copy.empty()) {
return absl::InvalidArgumentError(
"More input variable tensors asked to be copied than present on the "
"node");
}
return absl::OkStatus();
}
absl::Status BuildModel(TfLiteContext* context,
const TfLiteDelegateParams* delegate_params,
GraphFloat32* graph,
@ -2833,6 +2878,7 @@ absl::Status BuildModel(TfLiteContext* context,
tflite_nodes.push_back(i);
}
absl::flat_hash_map<int, Value*> tensor_to_value;
std::vector<ValueId> variable_inputs_to_value_id;
RETURN_IF_ERROR(PrecreateIOTensors(context, graph,
delegate_params->input_tensors,
quant_conversion_map, &tensor_to_value));
@ -2853,6 +2899,23 @@ absl::Status BuildModel(TfLiteContext* context,
return absl::InternalError(absl::StrCat(
GetOpNameByRegistration(*registration), ": ", status.message()));
}
absl::flat_hash_map<int, ValueId> new_value_for_variable_input_tensors =
operations[i]->GetNewValueIdsForVariableInputNodes();
RETURN_IF_ERROR(
CopyVariableTensorOutputs(tflite_node, registration, graph, reader,
new_value_for_variable_input_tensors));
}
// Variable input tensors expect to be unchanged throughout model execution.
// They need to be an output of the graph in order to have them unchanged.
for (auto value_id : variable_inputs_to_value_id) {
if (!graph->IsGraphOutput(value_id)) {
return absl::InvalidArgumentError(
absl::StrCat("Variable input tensors must be a graph output. Value ",
value_id, " is not a graph output"));
}
}
return absl::OkStatus();
}

View File

@ -72,6 +72,7 @@ absl::Status ObjectReader::ReadNonConstantTensor(
RETURN_IF_ERROR(
ConvertTfLiteTensorToTensorRef(*fp_tflite_tensor, &value->tensor));
value->tensor.ref = fp_tensor_index;
value->tensor.is_variable_input = tflite_tensor->is_variable;
value->quant_params.emplace();
// tflite_tensor from the outer scope is invalidated due to calling
// CreateNewTensorWithDifferentType
@ -89,6 +90,7 @@ absl::Status ObjectReader::ReadNonConstantTensor(
RETURN_IF_ERROR(
ConvertTfLiteTensorToTensorRef(*tflite_tensor, &value->tensor));
value->tensor.ref = tensor_idx;
value->tensor.is_variable_input = tflite_tensor->is_variable;
(*tensor_to_value)[tensor_idx] = value;
}
}
@ -159,6 +161,53 @@ absl::Status ObjectReader::AddInput(const Node* node, uint32_t idx) {
return graph_->AddConsumer(node->id, input->id);
}
absl::Status ObjectReader::AddUpdate(const Node* node, uint32_t idx) {
if (node_->inputs->size <= idx) {
return absl::InvalidArgumentError(absl::StrCat(
"Data id ", idx, " must be less than tflite node inputs size ",
node_->inputs->size));
}
int update_tensor_idx = node_->inputs->data[idx];
TfLiteTensor* update_tensor = context_->tensors + update_tensor_idx;
if (!update_tensor->is_variable) {
return absl::InvalidArgumentError(
"The tensor must be a variable tensor to update it in place");
}
Value* value;
RETURN_IF_ERROR(ReadValueByTensorIdx(update_tensor_idx, &value));
if (!value->tensor.is_variable_input) {
return absl::InternalError(
"Variable input tensor is not marked as variable");
}
// We cannot create a cycle in the graph. The way around this when a node
// updates a tensor in place would be to add a new value to the graph that
// points to the same tensor.
Value* updated_value = graph_->NewValue();
updated_value->tensor = value->tensor;
updated_value->quant_params = value->quant_params;
RETURN_IF_ERROR(graph_->SetProducer(node->id, updated_value->id));
// We also need to update the tensor_to_value arrays so that the nodes added
// after the current node will access the tensor with the updated value rather
// than the initial value.
if (quant_conversion_map_ != nullptr &&
quant_conversion_map_->find(update_tensor_idx) !=
quant_conversion_map_->end()) {
// If quantization conversion map exists, then the index provided is not the
// actual tensor idx. We need to find the float version of the tensor from
// the map.
tensor_to_value_->at(quant_conversion_map_->at(update_tensor_idx)) =
updated_value;
} else {
tensor_to_value_->at(update_tensor_idx) = updated_value;
}
return absl::OkStatus();
}
TfLiteTensor* ObjectReader::GetInputTensor(int index) const {
return index >= 0 && index < node_->inputs->size
? context_->tensors + node_->inputs->data[index]

View File

@ -86,6 +86,8 @@ class ObjectReader {
absl::Status AddInput(const Node* node, uint32_t idx);
absl::Status AddUpdate(const Node* node, uint32_t idx);
TfLiteTensor* GetInputTensor(int index) const;
TfLiteTensor* GetOutputTensor(int index) const;

View File

@ -72,6 +72,10 @@ struct TensorRef {
// Opaque reference to a tensor. Upstream component is responsible for
// resolving this reference into an actual tensor.
int64_t ref = -1;
// Specifies if the tensor should be a variable input tensor that must be an
// output as well as an input to the graph.
bool is_variable_input = false;
};
template <typename ShapeT, DataType Type>