Refactor model_builder: Move out ObjectReader.
PiperOrigin-RevId: 306750661 Change-Id: I654d2c313280016f2ea02b8106fe70da82c0d4b4
This commit is contained in:
parent
f290fd7f48
commit
713e7a5722
@ -101,20 +101,6 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "model_builder_helper",
|
||||
hdrs = ["model_builder_helper.h"],
|
||||
deps = [
|
||||
":status",
|
||||
"//tensorflow/lite:context",
|
||||
"//tensorflow/lite:kernel_api",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/delegates:utils",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "model_builder",
|
||||
srcs = ["model_builder.cc"],
|
||||
@ -123,11 +109,11 @@ cc_library(
|
||||
":data_type",
|
||||
":model",
|
||||
":model_builder_helper",
|
||||
":object_reader",
|
||||
":operations",
|
||||
":shape",
|
||||
":status",
|
||||
":tensor",
|
||||
"@FP16",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
@ -155,6 +141,29 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "model_builder_helper",
|
||||
srcs = ["model_builder_helper.cc"],
|
||||
hdrs = ["model_builder_helper.h"],
|
||||
deps = [
|
||||
":data_type",
|
||||
":model",
|
||||
":shape",
|
||||
":status",
|
||||
":tensor",
|
||||
"//tensorflow/lite:context",
|
||||
"//tensorflow/lite:kernel_api",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/delegates:utils",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/kernels/internal:reference_base",
|
||||
"//tensorflow/lite/kernels/internal:tensor",
|
||||
"//tensorflow/lite/kernels/internal:types",
|
||||
"@FP16",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "model_transformer",
|
||||
srcs = ["model_transformer.cc"],
|
||||
@ -167,6 +176,20 @@ cc_library(
|
||||
|
||||
# TODO(impjdi): Add unit test for model_transformer.
|
||||
|
||||
cc_library(
|
||||
name = "object_reader",
|
||||
srcs = ["object_reader.cc"],
|
||||
hdrs = ["object_reader.h"],
|
||||
deps = [
|
||||
":model",
|
||||
":model_builder_helper",
|
||||
":status",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/delegates:utils",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "operations",
|
||||
srcs = ["operations.cc"],
|
||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <fp16.h>
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
@ -42,6 +41,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/object_reader.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
@ -79,24 +79,6 @@ absl::Status NewPassthroughNode(GraphFloat32* graph, Node* node,
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::Status CreateVectorCopyData(const TfLiteTensor& tensor, T* tensor_data) {
|
||||
if (tensor.bytes % sizeof(T) != 0) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Input data size ", tensor.bytes,
|
||||
" is not aligned to expected type: ", sizeof(T)));
|
||||
}
|
||||
std::memcpy(tensor_data, tensor.data.uint8, tensor.bytes);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void ConvertFloat16ToFloat32(size_t num_elements, const uint16_t* src,
|
||||
float* dst) {
|
||||
for (size_t i = 0; i < num_elements; i++) {
|
||||
*dst++ = fp16_ieee_to_fp32_value(*src++);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void DequantizeConstantTensor(const TfLiteTensor& tensor,
|
||||
const T* source_data,
|
||||
@ -121,166 +103,6 @@ inline void DequantizeConstantTensor(const TfLiteTensor& tensor,
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
absl::Status CreateVectorCopyData<float>(const TfLiteTensor& tensor,
|
||||
float* tensor_data) {
|
||||
switch (tensor.type) {
|
||||
case kTfLiteFloat32:
|
||||
std::memcpy(tensor_data, tensor.data.f, tensor.bytes);
|
||||
break;
|
||||
case kTfLiteFloat16:
|
||||
ConvertFloat16ToFloat32(
|
||||
NumElements(&tensor),
|
||||
reinterpret_cast<uint16_t const*>(tensor.data.f16), tensor_data);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
DequantizeConstantTensor(tensor, tensor.data.int8, tensor_data);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
DequantizeConstantTensor(tensor, tensor.data.uint8, tensor_data);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
DequantizeConstantTensor(tensor, tensor.data.i32, tensor_data);
|
||||
break;
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
"Unsupported data type for float32 tensor");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename ShapeT>
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, ShapeT* shape);
|
||||
|
||||
template <>
|
||||
absl::Status SetAllDimensions<Scalar>(const TfLiteIntArray* dimensions,
|
||||
Scalar* shape) {
|
||||
if (dimensions->size < 0) {
|
||||
return absl::InvalidArgumentError("Invalid Scalar dimensions");
|
||||
}
|
||||
for (int i = 0; i < dimensions->size; ++i) {
|
||||
if (dimensions->data[i] != 1) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Dimension can not be reduced to scalar.");
|
||||
}
|
||||
}
|
||||
shape->v = 1;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <>
|
||||
absl::Status SetAllDimensions<Linear>(const TfLiteIntArray* dimensions,
|
||||
Linear* shape) {
|
||||
if (dimensions->size <= 0) {
|
||||
return absl::InvalidArgumentError("Dimension is empty.");
|
||||
}
|
||||
for (int i = 0; i < dimensions->size - 1; ++i) {
|
||||
if (dimensions->data[i] != 1) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Dimension can not be reduced to linear.");
|
||||
}
|
||||
}
|
||||
shape->v = dimensions->data[dimensions->size - 1];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <>
|
||||
absl::Status SetAllDimensions<HWC>(const TfLiteIntArray* dimensions,
|
||||
HWC* shape) {
|
||||
if (dimensions->size != 4) {
|
||||
return absl::InvalidArgumentError("Dimensions are not HWC");
|
||||
}
|
||||
if (dimensions->data[0] != 1) {
|
||||
return absl::UnimplementedError("Batch size is not equal to 1.");
|
||||
}
|
||||
shape->h = dimensions->data[1];
|
||||
shape->w = dimensions->data[2];
|
||||
shape->c = dimensions->data[3];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <>
|
||||
absl::Status SetAllDimensions<HW>(const TfLiteIntArray* dimensions, HW* shape) {
|
||||
if (dimensions->size != 2) {
|
||||
return absl::InvalidArgumentError("Dimensions are not HW");
|
||||
}
|
||||
shape->h = dimensions->data[0];
|
||||
shape->w = dimensions->data[1];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <>
|
||||
absl::Status SetAllDimensions<OHWI>(const TfLiteIntArray* dimensions,
|
||||
OHWI* shape) {
|
||||
if (dimensions->size != 4) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Dimensions are not OHWI: ", dimensions->size));
|
||||
}
|
||||
shape->o = dimensions->data[0];
|
||||
shape->h = dimensions->data[1];
|
||||
shape->w = dimensions->data[2];
|
||||
shape->i = dimensions->data[3];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <>
|
||||
absl::Status SetAllDimensions<BHWC>(const TfLiteIntArray* dimensions,
|
||||
BHWC* shape) {
|
||||
if (dimensions->size != 4) {
|
||||
return absl::InvalidArgumentError("Dimensions are not BHWC");
|
||||
}
|
||||
shape->b = dimensions->data[0];
|
||||
shape->h = dimensions->data[1];
|
||||
shape->w = dimensions->data[2];
|
||||
shape->c = dimensions->data[3];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
DataType ToDataType(TfLiteType type) {
|
||||
switch (type) {
|
||||
case kTfLiteFloat32:
|
||||
return DataType::FLOAT32;
|
||||
case kTfLiteInt32:
|
||||
return DataType::INT32;
|
||||
case kTfLiteInt64:
|
||||
return DataType::INT64;
|
||||
case kTfLiteInt8:
|
||||
return DataType::INT8;
|
||||
case kTfLiteUInt8:
|
||||
return DataType::UINT8;
|
||||
default:
|
||||
return DataType::UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
int GetNumberOfRuntimeInputsForNode(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node) {
|
||||
int number_of_runtime_inputs = 0;
|
||||
for (int i = 0; i < tflite_node->inputs->size; i++) {
|
||||
if (!IsConstantTensor(&context->tensors[tflite_node->inputs->data[i]])) {
|
||||
number_of_runtime_inputs++;
|
||||
}
|
||||
}
|
||||
return number_of_runtime_inputs;
|
||||
}
|
||||
|
||||
int GetNumberOfConstInputsForNode(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node) {
|
||||
return tflite_node->inputs->size -
|
||||
GetNumberOfRuntimeInputsForNode(context, tflite_node);
|
||||
}
|
||||
|
||||
int GetNumberOfRuntimeOutputsForNode(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node) {
|
||||
int number_of_runtime_outputs = 0;
|
||||
for (int i = 0; i < tflite_node->outputs->size; i++) {
|
||||
if (!IsConstantTensor(&context->tensors[tflite_node->outputs->data[i]])) {
|
||||
number_of_runtime_outputs++;
|
||||
}
|
||||
}
|
||||
return number_of_runtime_outputs;
|
||||
}
|
||||
|
||||
absl::Status CheckTensorIsAvailable(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node, int idx) {
|
||||
// If tensor id is in range, it's guaranteed that it'll be available.
|
||||
@ -292,261 +114,6 @@ absl::Status CheckTensorIsAvailable(const TfLiteContext* context,
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status CheckInputsOutputs(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
int runtime_inputs, int outputs) {
|
||||
int runtime_inputs_from_model =
|
||||
GetNumberOfRuntimeInputsForNode(context, tflite_node);
|
||||
if (runtime_inputs_from_model != runtime_inputs) {
|
||||
return absl::InternalError(absl::StrFormat(
|
||||
"Expected %d runtime input tensor(s), but node has %d runtime "
|
||||
"input(s).",
|
||||
runtime_inputs, runtime_inputs_from_model));
|
||||
}
|
||||
int runtime_outputs = GetNumberOfRuntimeOutputsForNode(context, tflite_node);
|
||||
if (runtime_outputs != outputs) {
|
||||
return absl::InternalError(
|
||||
absl::StrFormat("Expected %d output tensor(s), but node has %d "
|
||||
"output(s).",
|
||||
outputs, runtime_outputs));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status CheckInputsConstsOutputs(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
int runtime_inputs, int const_inputs,
|
||||
int outputs) {
|
||||
int const_inputs_from_model =
|
||||
GetNumberOfConstInputsForNode(context, tflite_node);
|
||||
if (const_inputs_from_model != const_inputs) {
|
||||
return absl::InternalError(absl::StrFormat(
|
||||
"Expected %d const input tensor(s), but node has %d const "
|
||||
"input(s).",
|
||||
const_inputs, const_inputs_from_model));
|
||||
}
|
||||
return CheckInputsOutputs(context, tflite_node, runtime_inputs, outputs);
|
||||
}
|
||||
|
||||
// Populates quantization parameters for non-constant UInt8/Int8 tensors.
|
||||
// This helps the delegate emulate quantized inference with
|
||||
// QuantizeAndDequantize.
|
||||
absl::Status PopulateQuantParams(const TfLiteTensor& tensor,
|
||||
QuantizationParams* quant_params) {
|
||||
const TfLiteQuantization& quant = tensor.quantization;
|
||||
if (quant.type != TfLiteQuantizationType::kTfLiteAffineQuantization) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Tensor not quantized: ", std::string(tensor.name)));
|
||||
}
|
||||
const TfLiteAffineQuantization* params =
|
||||
static_cast<const TfLiteAffineQuantization*>(quant.params);
|
||||
if (params->scale->size > 1) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Non-constant per-channel quantized tensor: ",
|
||||
std::string(tensor.name)));
|
||||
}
|
||||
const float scale = params->scale->data[0];
|
||||
const float zero_point = static_cast<float>(params->zero_point->data[0]);
|
||||
|
||||
float qmin_value = 0;
|
||||
float qmax_value = 0;
|
||||
if (tensor.type == kTfLiteUInt8) {
|
||||
qmin_value = static_cast<float>(std::numeric_limits<uint8_t>::min());
|
||||
qmax_value = static_cast<float>(std::numeric_limits<uint8_t>::max());
|
||||
} else if (tensor.type == kTfLiteInt8) {
|
||||
qmin_value = static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
qmax_value = static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
} else {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Type invalid for quantized tensor: ", std::string(tensor.name)));
|
||||
}
|
||||
quant_params->min = scale * (static_cast<float>(qmin_value) - zero_point);
|
||||
quant_params->max = scale * (static_cast<float>(qmax_value) - zero_point);
|
||||
quant_params->scale = scale;
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// If quantized tensors exist in the graph & quant_conversion_map is non-null,
|
||||
// the mapping between the original tensors (fixed-point) & GPU values (fp) is
|
||||
// stored in quant_conversion_map.
|
||||
class ObjectReader {
|
||||
public:
|
||||
ObjectReader(
|
||||
GraphFloat32* graph, TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value,
|
||||
std::unordered_map<int, int>* quant_conversion_map = nullptr)
|
||||
: graph_(graph),
|
||||
context_(context),
|
||||
tflite_node_(tflite_node),
|
||||
tensor_to_value_(tensor_to_value),
|
||||
quant_conversion_map_(quant_conversion_map) {}
|
||||
|
||||
static absl::Status ReadNonConstantTensor(
|
||||
TfLiteContext* context,
|
||||
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value,
|
||||
std::unordered_map<int, int>* quant_conversion_map, GraphFloat32* graph,
|
||||
uint32_t tensor_idx, Value<TensorRef<BHWC>>** value = nullptr) {
|
||||
if (tensor_idx >= context->tensors_size) {
|
||||
return absl::OutOfRangeError(
|
||||
absl::StrCat("ReadNonConstTensor: input tensor index: ", tensor_idx));
|
||||
}
|
||||
|
||||
if (tensor_to_value->find(tensor_idx) == tensor_to_value->end()) {
|
||||
const TfLiteTensor& tflite_tensor = context->tensors[tensor_idx];
|
||||
if (tflite::IsConstantTensor(&tflite_tensor)) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"ReadNonConstantTensor: value is a constant tensor: ", tensor_idx));
|
||||
}
|
||||
|
||||
if ((tflite_tensor.type == kTfLiteInt8 ||
|
||||
tflite_tensor.type == kTfLiteUInt8) &&
|
||||
quant_conversion_map) {
|
||||
// Quantized case
|
||||
if (quant_conversion_map->find(tensor_idx) ==
|
||||
quant_conversion_map->end()) {
|
||||
// Since the original tensor is fixed-point, add a new float tensor to
|
||||
// the TFLite graph to represent the dequantized data.
|
||||
int fp_tensor_index = 0;
|
||||
TfLiteTensor* fp_tflite_tensor;
|
||||
if (delegates::CreateNewTensorWithDifferentType(
|
||||
context, tensor_idx, kTfLiteFloat32, &fp_tflite_tensor,
|
||||
&fp_tensor_index) != kTfLiteOk) {
|
||||
return absl::InternalError("Could not add new tensor to graph");
|
||||
}
|
||||
// Remember this tensor for later.
|
||||
(*quant_conversion_map)[fp_tensor_index] = tensor_idx;
|
||||
(*quant_conversion_map)[tensor_idx] = fp_tensor_index;
|
||||
// Add a new GPU Value for the new dequantized floating-point tensor.
|
||||
Value<TensorRef<BHWC>>* value = graph->NewValue();
|
||||
RETURN_IF_ERROR(ConvertTfLiteTensorToTensorRef(*fp_tflite_tensor,
|
||||
&value->tensor));
|
||||
value->tensor.ref = fp_tensor_index;
|
||||
value->quant_params.emplace();
|
||||
RETURN_IF_ERROR(
|
||||
PopulateQuantParams(tflite_tensor, &value->quant_params.value()));
|
||||
(*tensor_to_value)[fp_tensor_index] = value;
|
||||
}
|
||||
// We do not use the original tensor index as reference for the GPU
|
||||
// Value, instead pointing at the corresponding float version.
|
||||
tensor_idx = quant_conversion_map->at(tensor_idx);
|
||||
} else {
|
||||
// Floating-point case.
|
||||
Value<TensorRef<BHWC>>* value = graph->NewValue();
|
||||
RETURN_IF_ERROR(
|
||||
ConvertTfLiteTensorToTensorRef(tflite_tensor, &value->tensor));
|
||||
value->tensor.ref = tensor_idx;
|
||||
(*tensor_to_value)[tensor_idx] = value;
|
||||
}
|
||||
}
|
||||
|
||||
if (value) {
|
||||
*value = (*tensor_to_value)[tensor_idx];
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ReadValue(uint32_t idx, Value<TensorRef<BHWC>>** value) {
|
||||
if (idx >= tflite_node_->inputs->size) {
|
||||
return absl::OutOfRangeError(
|
||||
absl::StrCat("ReadValue: input tensor index: ", idx));
|
||||
}
|
||||
return ReadValueByTensorIdx(tflite_node_->inputs->data[idx], value);
|
||||
}
|
||||
|
||||
absl::Status ReadValueByTensorIdx(uint32_t tensor_idx,
|
||||
Value<TensorRef<BHWC>>** value) {
|
||||
// Constant tensors should be handled by ReadTensor.
|
||||
return ReadNonConstantTensor(context_, tensor_to_value_,
|
||||
quant_conversion_map_, graph_, tensor_idx,
|
||||
value);
|
||||
}
|
||||
|
||||
int GetNumberOfRuntimeInputs() const {
|
||||
return GetNumberOfRuntimeInputsForNode(context_, tflite_node_);
|
||||
}
|
||||
|
||||
absl::Status GetTensorDims(uint32_t idx, TfLiteIntArray* dimensions) const {
|
||||
if (idx >= tflite_node_->inputs->size) {
|
||||
return absl::OutOfRangeError(absl::StrCat("Input tensor index: ", idx));
|
||||
}
|
||||
const int tensor_idx = tflite_node_->inputs->data[idx];
|
||||
if (tensor_idx < 0 || tensor_idx > context_->tensors_size) {
|
||||
return absl::OutOfRangeError(absl::StrCat("Tensor index: ", tensor_idx));
|
||||
}
|
||||
const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx];
|
||||
*dimensions = *tflite_tensor.dims;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename TensorT>
|
||||
absl::Status ReadTensor(uint32_t idx, TensorT* t) const {
|
||||
RETURN_IF_ERROR(CheckTensorIsAvailable(context_, tflite_node_, idx));
|
||||
const int32_t tensor_idx = tflite_node_->inputs->data[idx];
|
||||
const TfLiteTensor* tflite_tensor = context_->tensors + tensor_idx;
|
||||
t->data.resize(NumElements(tflite_tensor));
|
||||
RETURN_IF_ERROR(CreateVectorCopyData(*tflite_tensor, &t->data[0]));
|
||||
|
||||
// Axis and data layout depend on operation this tensor is used in. So,
|
||||
// postpone resolutions until operations are parsed.
|
||||
t->id = tensor_idx;
|
||||
return SetAllDimensions(tflite_tensor->dims, &t->shape);
|
||||
}
|
||||
|
||||
absl::Status AddOutput(const Node* node, int id) {
|
||||
if (tflite_node_->outputs->size <= id) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Data id ", id, " must be less than tflite node outputs size ",
|
||||
tflite_node_->outputs->size));
|
||||
}
|
||||
int output_tensor_idx = tflite_node_->outputs->data[id];
|
||||
Value<TensorRef<BHWC>>* value;
|
||||
RETURN_IF_ERROR(ReadValueByTensorIdx(output_tensor_idx, &value));
|
||||
RETURN_IF_ERROR(graph_->SetProducer(node->id, value->id));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AddOutputs(const Node* node) {
|
||||
for (int i = 0; i < tflite_node_->outputs->size; ++i) {
|
||||
RETURN_IF_ERROR(AddOutput(node, i));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AddInput(const Node* node, uint32_t idx) {
|
||||
Value<TensorRef<BHWC>>* input;
|
||||
RETURN_IF_ERROR(ReadValue(idx, &input));
|
||||
return graph_->AddConsumer(node->id, input->id);
|
||||
}
|
||||
|
||||
TfLiteTensor* GetInputTensor(int index) const {
|
||||
return index >= 0 && index < tflite_node_->inputs->size
|
||||
? context_->tensors + tflite_node_->inputs->data[index]
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
TfLiteTensor* GetOutputTensor(int index) const {
|
||||
return index >= 0 && index < tflite_node_->outputs->size
|
||||
? context_->tensors + tflite_node_->outputs->data[index]
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
absl::Status VerifyInputsConstsOutputs(const TfLiteNode* tflite_node,
|
||||
int runtime_inputs, int const_inputs,
|
||||
int outputs) {
|
||||
return CheckInputsConstsOutputs(context_, tflite_node, runtime_inputs,
|
||||
const_inputs, outputs);
|
||||
}
|
||||
|
||||
private:
|
||||
GraphFloat32* graph_ = nullptr;
|
||||
TfLiteContext* context_ = nullptr;
|
||||
const TfLiteNode* tflite_node_ = nullptr;
|
||||
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value_;
|
||||
std::unordered_map<int, int>* quant_conversion_map_;
|
||||
};
|
||||
|
||||
// A parser responsible for parsing TFLite operation and adding it to a graph.
|
||||
class TFLiteOperationParser {
|
||||
public:
|
||||
@ -781,28 +348,6 @@ absl::Status ParsePoolingAttributes(const TfLitePoolParams* tf_options,
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc) {
|
||||
const TfLiteIntArray* dims = tflite_tensor.dims;
|
||||
switch (dims->size) {
|
||||
case 1:
|
||||
*bhwc = BHWC(dims->data[0], 1, 1, 1);
|
||||
return absl::OkStatus();
|
||||
case 2:
|
||||
*bhwc = BHWC(dims->data[0], 1, 1, dims->data[1]);
|
||||
return absl::OkStatus();
|
||||
case 3:
|
||||
*bhwc = BHWC(dims->data[0], 1, dims->data[1], dims->data[2]);
|
||||
return absl::OkStatus();
|
||||
case 4:
|
||||
*bhwc = BHWC(dims->data[0], dims->data[1], dims->data[2], dims->data[3]);
|
||||
return absl::OkStatus();
|
||||
default:
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Tensor \"", tflite_tensor.name ? tflite_tensor.name : "nullptr",
|
||||
"\" has bad input dims size: ", dims->size, "."));
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader,
|
||||
TensorOrScalar* tensor_or_scalar) {
|
||||
const std::string& opname = node->operation.type;
|
||||
@ -3018,24 +2563,19 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
||||
if (custom_name == "RoIToTransformMatrix") {
|
||||
return absl::make_unique<RoIToTransformMatrixOperationParser>();
|
||||
}
|
||||
|
||||
if (custom_name == "TransformTensor") {
|
||||
return absl::make_unique<TransformTensorOperationParser>();
|
||||
}
|
||||
|
||||
if (custom_name == "TransformLandmarks") {
|
||||
return absl::make_unique<TransformLandmarksOperationParser>();
|
||||
}
|
||||
|
||||
if (custom_name == "Landmarks2TransformMatrix") {
|
||||
return absl::make_unique<Landmarks2TransformMatrixOperationParser>();
|
||||
}
|
||||
|
||||
if (custom_name == "AlignmentPointsToTransformMatrix") {
|
||||
return absl::make_unique<
|
||||
AlignmentPointsToTransformMatrixOperationParser>();
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
return absl::make_unique<UnsupportedOperationParser>();
|
||||
@ -3067,16 +2607,10 @@ bool IsAllAllowedTensors(TfLiteContext* context, const TfLiteIntArray* array,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
|
||||
TensorRef<BHWC>* tensor_ref) {
|
||||
tensor_ref->type = ToDataType(tflite_tensor.type);
|
||||
return ExtractTensorShape(tflite_tensor, &tensor_ref->shape);
|
||||
}
|
||||
|
||||
// TODO(impjdi): Check number of input/output tensors and their dimensions.
|
||||
// TODO(impjdi): Check ops' parameters.
|
||||
TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops) {
|
||||
IsNodeSupportedFn node_supported_fn =
|
||||
delegates::IsNodeSupportedFn node_supported_fn =
|
||||
[=](TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteRegistration* registration,
|
||||
std::string* unsupported_details) -> bool {
|
||||
|
450
tensorflow/lite/delegates/gpu/common/model_builder_helper.cc
Normal file
450
tensorflow/lite/delegates/gpu/common/model_builder_helper.cc
Normal file
@ -0,0 +1,450 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <fp16.h>
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/lite/builtin_ops.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/context.h"
|
||||
#include "tensorflow/lite/context_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/utils.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
TfLiteStatus GraphWithDequantPartitionHelper::Partition(
|
||||
std::set<std::string>* unsupported_nodes_info) {
|
||||
const auto status = GraphPartitionHelper::Partition(unsupported_nodes_info);
|
||||
// Clean up those partitions that have a single dequant op. NoteThose
|
||||
// removed dequant ops have to be reserved in the graph and should not be
|
||||
// delegated.
|
||||
RemoveSingleDequantNodePartitions();
|
||||
return status;
|
||||
}
|
||||
|
||||
std::vector<int>
|
||||
GraphWithDequantPartitionHelper::GetNodesOfFirstNLargestPartitions(int n) {
|
||||
// We first get partitions to reduce the number of nodes to be checked in
|
||||
// deciding which dequant ops could actually be replaced. And then we
|
||||
// remap input-tensor to dequant nodes' inputs and remove those
|
||||
// to-be-reserved dequant nodes.
|
||||
auto first_nps = GetFirstNLargestPartitions(n);
|
||||
std::vector<int> ops_to_replace;
|
||||
for (const auto p : first_nps) {
|
||||
auto nodes = p->nodes_to_replace;
|
||||
ops_to_replace.insert(ops_to_replace.end(), nodes->data,
|
||||
nodes->data + nodes->size);
|
||||
}
|
||||
RemapInputTensors(ops_to_replace);
|
||||
RemoveReservedDequantsFromNodes(&ops_to_replace);
|
||||
return ops_to_replace;
|
||||
}
|
||||
|
||||
bool GraphWithDequantPartitionHelper::IsNodeSupported(
|
||||
TfLiteContext* context, TfLiteNode* node, TfLiteRegistration* registration,
|
||||
int node_id, std::string* unsupported_details) {
|
||||
// If we need to handle dequant nodes, we have to remap input tensors of
|
||||
// this node if some of them come from a dequant node before testing if
|
||||
// the node is supported.
|
||||
std::vector<int> orig_inputs;
|
||||
if (RecordAndRemapInputTensors(registration->builtin_code, node_id, node,
|
||||
&orig_inputs)) {
|
||||
// We have a dequant op here. Note that we retrun an Ok status because a
|
||||
// dequant node is first added as supported. Later, this dequant node
|
||||
// will be removed if it has to be preserved in the graph which happens
|
||||
// when its immediate downstream nodes cannot be supported.
|
||||
return true;
|
||||
}
|
||||
const auto status = GraphPartitionHelper::IsNodeSupported(
|
||||
context, node, registration, node_id, unsupported_details);
|
||||
RestoreToOrigInputTensors(node, orig_inputs);
|
||||
return status;
|
||||
}
|
||||
|
||||
bool GraphWithDequantPartitionHelper::RecordAndRemapInputTensors(
|
||||
int32_t op_code, int node_id, TfLiteNode* node,
|
||||
std::vector<int>* orig_inputs) {
|
||||
orig_inputs->clear();
|
||||
// Record the dequant node.
|
||||
if (op_code == kTfLiteBuiltinDequantize &&
|
||||
context_->tensors[node->inputs->data[0]].type ==
|
||||
TfLiteType::kTfLiteFloat16) {
|
||||
dequant_nodes_[node->outputs->data[0]] = node->inputs->data[0];
|
||||
return true;
|
||||
}
|
||||
// For a dequantize op, there's no need to remap its input tensors.
|
||||
if (dequant_nodes_.empty()) return false;
|
||||
RemapInputTensors(node, orig_inputs);
|
||||
return false;
|
||||
}
|
||||
|
||||
void GraphWithDequantPartitionHelper::RestoreToOrigInputTensors(
|
||||
TfLiteNode* node, const std::vector<int>& orig_inputs) {
|
||||
if (node->inputs->size != orig_inputs.size()) return;
|
||||
for (int j = 0; j < node->inputs->size; ++j) {
|
||||
node->inputs->data[j] = orig_inputs[j];
|
||||
}
|
||||
}
|
||||
|
||||
void GraphWithDequantPartitionHelper::RemapInputTensors(
|
||||
const std::vector<int>& nodes) const {
|
||||
for (int node_id : nodes) {
|
||||
TfLiteNode* node;
|
||||
TfLiteRegistration* registration;
|
||||
GetNodeAndRegistration(context_, node_id, &node, ®istration)
|
||||
.IgnoreError();
|
||||
RemapInputTensors(node, nullptr /* orig_inputs*/);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphWithDequantPartitionHelper::RemoveSingleDequantNodePartitions() {
|
||||
auto it = partitions_.begin();
|
||||
while (it != partitions_.end()) {
|
||||
auto p = *it;
|
||||
if (p->nodes_to_replace->size != 1) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
int node_id = p->nodes_to_replace->data[0];
|
||||
TfLiteNode* node = nullptr;
|
||||
TfLiteRegistration* registration = nullptr;
|
||||
GetNodeAndRegistration(context_, node_id, &node, ®istration)
|
||||
.IgnoreError();
|
||||
if (registration->builtin_code != kTfLiteBuiltinDequantize ||
|
||||
context_->tensors[node->inputs->data[0]].type !=
|
||||
TfLiteType::kTfLiteFloat16) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
// Note such dequant nodes have to be preserved in the graph as dequant
|
||||
// ops are not actually supported in the GPU delegate.
|
||||
dequant_nodes_to_save_.insert(node_id);
|
||||
it = partitions_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphWithDequantPartitionHelper::RemoveReservedDequantsFromNodes(
|
||||
std::vector<int>* nodes) {
|
||||
if (dequant_nodes_to_save_.empty()) return;
|
||||
auto it = nodes->begin();
|
||||
while (it != nodes->end()) {
|
||||
if (dequant_nodes_to_save_.find(*it) == dequant_nodes_to_save_.end()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
it = nodes->erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphWithDequantPartitionHelper::RemapInputTensors(
|
||||
TfLiteNode* node, std::vector<int>* orig_inputs) const {
|
||||
TfLiteIntArray* inputs = node->inputs;
|
||||
auto inputs_view = TfLiteIntArrayView(inputs);
|
||||
// Prepopulate 'orig_inputs' first and clear it if there's no input from a
|
||||
// dequant op.
|
||||
if (orig_inputs) {
|
||||
orig_inputs->clear();
|
||||
orig_inputs->reserve(inputs->size);
|
||||
for (auto tid : inputs_view) {
|
||||
orig_inputs->push_back(tid);
|
||||
}
|
||||
}
|
||||
// Fix this node's inputs (i.e. prune out the preceding dequantize node) in
|
||||
// order to test if it is supported.
|
||||
bool is_remapped = false;
|
||||
for (int j = 0; j < inputs->size; ++j) {
|
||||
const int input_tid = inputs->data[j];
|
||||
const auto it = dequant_nodes_.find(input_tid);
|
||||
if (it != dequant_nodes_.end()) {
|
||||
inputs->data[j] = it->second;
|
||||
is_remapped = true;
|
||||
}
|
||||
}
|
||||
if (!is_remapped && orig_inputs) orig_inputs->clear();
|
||||
}
|
||||
|
||||
absl::Status GetNodeAndRegistration(TfLiteContext* context, int node_id,
|
||||
TfLiteNode** tflite_node,
|
||||
TfLiteRegistration** registration) {
|
||||
if (context->GetNodeAndRegistration(context, node_id, tflite_node,
|
||||
registration) != kTfLiteOk) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Couldn't get node and registration info for op: ", node_id));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
DataType ToDataType(TfLiteType type) {
|
||||
switch (type) {
|
||||
case kTfLiteFloat32:
|
||||
return DataType::FLOAT32;
|
||||
case kTfLiteInt32:
|
||||
return DataType::INT32;
|
||||
case kTfLiteInt64:
|
||||
return DataType::INT64;
|
||||
case kTfLiteInt8:
|
||||
return DataType::INT8;
|
||||
case kTfLiteUInt8:
|
||||
return DataType::UINT8;
|
||||
default:
|
||||
return DataType::UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc) {
|
||||
const TfLiteIntArray* dims = tflite_tensor.dims;
|
||||
switch (dims->size) {
|
||||
case 1:
|
||||
*bhwc = BHWC(dims->data[0], 1, 1, 1);
|
||||
return absl::OkStatus();
|
||||
case 2:
|
||||
*bhwc = BHWC(dims->data[0], 1, 1, dims->data[1]);
|
||||
return absl::OkStatus();
|
||||
case 3:
|
||||
*bhwc = BHWC(dims->data[0], 1, dims->data[1], dims->data[2]);
|
||||
return absl::OkStatus();
|
||||
case 4:
|
||||
*bhwc = BHWC(dims->data[0], dims->data[1], dims->data[2], dims->data[3]);
|
||||
return absl::OkStatus();
|
||||
default:
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Tensor \"", tflite_tensor.name ? tflite_tensor.name : "nullptr",
|
||||
"\" has bad input dims size: ", dims->size, "."));
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
|
||||
TensorRef<BHWC>* tensor_ref) {
|
||||
tensor_ref->type = ToDataType(tflite_tensor.type);
|
||||
return ExtractTensorShape(tflite_tensor, &tensor_ref->shape);
|
||||
}
|
||||
|
||||
absl::Status PopulateQuantParams(const TfLiteTensor& tensor,
|
||||
QuantizationParams* quant_params) {
|
||||
const TfLiteQuantization& quant = tensor.quantization;
|
||||
if (quant.type != TfLiteQuantizationType::kTfLiteAffineQuantization) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Tensor not quantized: ", std::string(tensor.name)));
|
||||
}
|
||||
const TfLiteAffineQuantization* params =
|
||||
static_cast<const TfLiteAffineQuantization*>(quant.params);
|
||||
if (params->scale->size > 1) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Non-constant per-channel quantized tensor: ",
|
||||
std::string(tensor.name)));
|
||||
}
|
||||
const float scale = params->scale->data[0];
|
||||
const float zero_point = static_cast<float>(params->zero_point->data[0]);
|
||||
|
||||
float qmin_value = 0;
|
||||
float qmax_value = 0;
|
||||
if (tensor.type == kTfLiteUInt8) {
|
||||
qmin_value = static_cast<float>(std::numeric_limits<uint8_t>::min());
|
||||
qmax_value = static_cast<float>(std::numeric_limits<uint8_t>::max());
|
||||
} else if (tensor.type == kTfLiteInt8) {
|
||||
qmin_value = static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
qmax_value = static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
} else {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Type invalid for quantized tensor: ", std::string(tensor.name)));
|
||||
}
|
||||
quant_params->min = scale * (static_cast<float>(qmin_value) - zero_point);
|
||||
quant_params->max = scale * (static_cast<float>(qmax_value) - zero_point);
|
||||
quant_params->scale = scale;
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
int GetNumberOfRuntimeInputsForNode(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node) {
|
||||
int number_of_runtime_inputs = 0;
|
||||
for (int i = 0; i < tflite_node->inputs->size; i++) {
|
||||
if (!IsConstantTensor(&context->tensors[tflite_node->inputs->data[i]])) {
|
||||
number_of_runtime_inputs++;
|
||||
}
|
||||
}
|
||||
return number_of_runtime_inputs;
|
||||
}
|
||||
|
||||
int GetNumberOfConstInputsForNode(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node) {
|
||||
return tflite_node->inputs->size -
|
||||
GetNumberOfRuntimeInputsForNode(context, tflite_node);
|
||||
}
|
||||
|
||||
int GetNumberOfRuntimeOutputsForNode(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node) {
|
||||
int number_of_runtime_outputs = 0;
|
||||
for (int i = 0; i < tflite_node->outputs->size; i++) {
|
||||
if (!IsConstantTensor(&context->tensors[tflite_node->outputs->data[i]])) {
|
||||
number_of_runtime_outputs++;
|
||||
}
|
||||
}
|
||||
return number_of_runtime_outputs;
|
||||
}
|
||||
|
||||
absl::Status CheckInputsOutputs(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
int runtime_inputs, int outputs) {
|
||||
const int runtime_inputs_from_model =
|
||||
GetNumberOfRuntimeInputsForNode(context, tflite_node);
|
||||
if (runtime_inputs_from_model != runtime_inputs) {
|
||||
return absl::InternalError(absl::StrCat(
|
||||
"Expected ", runtime_inputs, " runtime input tensor(s), but node has ",
|
||||
runtime_inputs_from_model, " runtime input(s)."));
|
||||
}
|
||||
const int runtime_outputs =
|
||||
GetNumberOfRuntimeOutputsForNode(context, tflite_node);
|
||||
if (runtime_outputs != outputs) {
|
||||
return absl::InternalError(absl::StrCat("Expected ", outputs,
|
||||
" output tensor(s), but node has ",
|
||||
runtime_outputs, " output(s)."));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status CheckInputsConstsOutputs(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
int runtime_inputs, int const_inputs,
|
||||
int outputs) {
|
||||
const int const_inputs_from_model =
|
||||
GetNumberOfConstInputsForNode(context, tflite_node);
|
||||
if (const_inputs_from_model != const_inputs) {
|
||||
return absl::InternalError(absl::StrCat(
|
||||
"Expected ", const_inputs, " const input tensor(s), but node has ",
|
||||
const_inputs_from_model, " const input(s)."));
|
||||
}
|
||||
return CheckInputsOutputs(context, tflite_node, runtime_inputs, outputs);
|
||||
}
|
||||
|
||||
void ConvertFloat16ToFloat32(size_t num_elements, const uint16_t* src,
|
||||
float* dst) {
|
||||
for (size_t i = 0; i < num_elements; i++) {
|
||||
*dst++ = fp16_ieee_to_fp32_value(*src++);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
absl::Status CreateVectorCopyData<float>(const TfLiteTensor& tensor,
|
||||
float* tensor_data) {
|
||||
switch (tensor.type) {
|
||||
case kTfLiteFloat32:
|
||||
std::memcpy(tensor_data, tensor.data.f, tensor.bytes);
|
||||
break;
|
||||
case kTfLiteFloat16:
|
||||
ConvertFloat16ToFloat32(
|
||||
NumElements(&tensor),
|
||||
reinterpret_cast<uint16_t const*>(tensor.data.f16), tensor_data);
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
DequantizeConstantTensor(tensor, tensor.data.int8, tensor_data);
|
||||
break;
|
||||
case kTfLiteUInt8:
|
||||
DequantizeConstantTensor(tensor, tensor.data.uint8, tensor_data);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
DequantizeConstantTensor(tensor, tensor.data.i32, tensor_data);
|
||||
break;
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
"Unsupported data type for float32 tensor");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Scalar* shape) {
|
||||
if (dimensions->size < 0) {
|
||||
return absl::InvalidArgumentError("Invalid Scalar dimensions");
|
||||
}
|
||||
for (int i = 0; i < dimensions->size; ++i) {
|
||||
if (dimensions->data[i] != 1) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Dimension can not be reduced to scalar.");
|
||||
}
|
||||
}
|
||||
shape->v = 1;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape) {
|
||||
if (dimensions->size <= 0) {
|
||||
return absl::InvalidArgumentError("Dimension is empty.");
|
||||
}
|
||||
for (int i = 0; i < dimensions->size - 1; ++i) {
|
||||
if (dimensions->data[i] != 1) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Dimension can not be reduced to linear.");
|
||||
}
|
||||
}
|
||||
shape->v = dimensions->data[dimensions->size - 1];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape) {
|
||||
if (dimensions->size != 4) {
|
||||
return absl::InvalidArgumentError("Dimensions are not HWC");
|
||||
}
|
||||
if (dimensions->data[0] != 1) {
|
||||
return absl::UnimplementedError("Batch size is not equal to 1.");
|
||||
}
|
||||
shape->h = dimensions->data[1];
|
||||
shape->w = dimensions->data[2];
|
||||
shape->c = dimensions->data[3];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape) {
|
||||
if (dimensions->size != 2) {
|
||||
return absl::InvalidArgumentError("Dimensions are not HW");
|
||||
}
|
||||
shape->h = dimensions->data[0];
|
||||
shape->w = dimensions->data[1];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, OHWI* shape) {
|
||||
if (dimensions->size != 4) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Dimensions are not OHWI: ", dimensions->size));
|
||||
}
|
||||
shape->o = dimensions->data[0];
|
||||
shape->h = dimensions->data[1];
|
||||
shape->w = dimensions->data[2];
|
||||
shape->i = dimensions->data[3];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, BHWC* shape) {
|
||||
if (dimensions->size != 4) {
|
||||
return absl::InvalidArgumentError("Dimensions are not BHWC");
|
||||
}
|
||||
shape->b = dimensions->data[0];
|
||||
shape->h = dimensions->data[1];
|
||||
shape->w = dimensions->data[2];
|
||||
shape->c = dimensions->data[3];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_MODEL_BUILDER_HELPER_H_
|
||||
|
||||
@ -19,88 +20,39 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/lite/builtin_ops.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/context.h"
|
||||
#include "tensorflow/lite/context_util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/tensor.h"
|
||||
#include "tensorflow/lite/delegates/utils.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
inline absl::Status GetNodeAndRegistration(TfLiteContext* context, int node_id,
|
||||
TfLiteNode** tflite_node,
|
||||
TfLiteRegistration** registration) {
|
||||
if (context->GetNodeAndRegistration(context, node_id, tflite_node,
|
||||
registration) != kTfLiteOk) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Couldn't get node and registration info for op: ", node_id));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
using IsNodeSupportedFn = tflite::delegates::IsNodeSupportedFn;
|
||||
|
||||
class GraphWithDequantPartitionHelper
|
||||
: public tflite::delegates::GraphPartitionHelper {
|
||||
class GraphWithDequantPartitionHelper : public delegates::GraphPartitionHelper {
|
||||
public:
|
||||
GraphWithDequantPartitionHelper(TfLiteContext* context,
|
||||
IsNodeSupportedFn is_node_supported_fn)
|
||||
GraphWithDequantPartitionHelper(
|
||||
TfLiteContext* context, delegates::IsNodeSupportedFn is_node_supported_fn)
|
||||
: GraphPartitionHelper(context, std::move(is_node_supported_fn)) {}
|
||||
|
||||
TfLiteStatus Partition(
|
||||
std::set<std::string>* unsupported_nodes_info) override {
|
||||
const auto status = GraphPartitionHelper::Partition(unsupported_nodes_info);
|
||||
// Clean up those partitions that have a single dequant op. NoteThose
|
||||
// removed dequant ops have to be reserved in the graph and should not be
|
||||
// delegated.
|
||||
RemoveSingleDequantNodePartitions();
|
||||
return status;
|
||||
}
|
||||
std::set<std::string>* unsupported_nodes_info) override;
|
||||
|
||||
// Returns a list of node indices of all nodes from the first n largest
|
||||
// partitions. If there are fewer paritions than n, all nodes will be
|
||||
// returned. The partition is ranked according to the number of nodes.
|
||||
std::vector<int> GetNodesOfFirstNLargestPartitions(int n) {
|
||||
// We first get partitions to reduce the number of nodes to be checked in
|
||||
// deciding which dequant ops could actually be replaced. And then we
|
||||
// remap input-tensor to dequant nodes' inputs and remove those
|
||||
// to-be-reserved dequant nodes.
|
||||
auto first_nps = GetFirstNLargestPartitions(n);
|
||||
std::vector<int> ops_to_replace;
|
||||
for (const auto p : first_nps) {
|
||||
auto nodes = p->nodes_to_replace;
|
||||
ops_to_replace.insert(ops_to_replace.end(), nodes->data,
|
||||
nodes->data + nodes->size);
|
||||
}
|
||||
RemapInputTensors(ops_to_replace);
|
||||
RemoveReservedDequantsFromNodes(&ops_to_replace);
|
||||
return ops_to_replace;
|
||||
}
|
||||
std::vector<int> GetNodesOfFirstNLargestPartitions(int n);
|
||||
|
||||
protected:
|
||||
bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteRegistration* registration, int node_id,
|
||||
std::string* unsupported_details) override {
|
||||
// If we need to handle dequant nodes, we have to remap input tensors of
|
||||
// this node if some of them come from a dequant node before testing if
|
||||
// the node is supported.
|
||||
std::vector<int> orig_inputs;
|
||||
if (RecordAndRemapInputTensors(registration->builtin_code, node_id, node,
|
||||
&orig_inputs)) {
|
||||
// We have a dequant op here. Note that we retrun an Ok status because a
|
||||
// dequant node is first added as supported. Later, this dequant node
|
||||
// will be removed if it has to be preserved in the graph which happens
|
||||
// when its immediate downstream nodes cannot be supported.
|
||||
return true;
|
||||
}
|
||||
const auto status = GraphPartitionHelper::IsNodeSupported(
|
||||
context, node, registration, node_id, unsupported_details);
|
||||
RestoreToOrigInputTensors(node, orig_inputs);
|
||||
return status;
|
||||
}
|
||||
std::string* unsupported_details) override;
|
||||
|
||||
private:
|
||||
// Record 'node' if it is a dequant op (i.e. a fp16 one here) and return true.
|
||||
@ -109,109 +61,24 @@ class GraphWithDequantPartitionHelper
|
||||
// input tensor ids of this node if any input is remapped.
|
||||
bool RecordAndRemapInputTensors(int32_t op_code, int node_id,
|
||||
TfLiteNode* node,
|
||||
std::vector<int>* orig_inputs) {
|
||||
orig_inputs->clear();
|
||||
// Record the dequant node.
|
||||
if (op_code == kTfLiteBuiltinDequantize &&
|
||||
context_->tensors[node->inputs->data[0]].type ==
|
||||
TfLiteType::kTfLiteFloat16) {
|
||||
dequant_nodes_[node->outputs->data[0]] = node->inputs->data[0];
|
||||
return true;
|
||||
}
|
||||
// For a dequantize op, there's no need to remap its input tensors.
|
||||
if (dequant_nodes_.empty()) return false;
|
||||
RemapInputTensors(node, orig_inputs);
|
||||
return false;
|
||||
}
|
||||
std::vector<int>* orig_inputs);
|
||||
|
||||
// Restore inputs of 'node' to 'orig_inputs' only if two sizes match.
|
||||
void RestoreToOrigInputTensors(TfLiteNode* node,
|
||||
const std::vector<int>& orig_inputs) {
|
||||
if (node->inputs->size != orig_inputs.size()) return;
|
||||
for (int j = 0; j < node->inputs->size; ++j) {
|
||||
node->inputs->data[j] = orig_inputs[j];
|
||||
}
|
||||
}
|
||||
const std::vector<int>& orig_inputs);
|
||||
|
||||
// Remap input tensors of every node in 'nodes' (i.e. node indices) if some of
|
||||
// them are from dequant ops.
|
||||
void RemapInputTensors(const std::vector<int>& nodes) const {
|
||||
for (int node_id : nodes) {
|
||||
TfLiteNode* node;
|
||||
TfLiteRegistration* registration;
|
||||
GetNodeAndRegistration(context_, node_id, &node, ®istration)
|
||||
.IgnoreError();
|
||||
RemapInputTensors(node, nullptr /* orig_inputs*/);
|
||||
}
|
||||
}
|
||||
void RemapInputTensors(const std::vector<int>& nodes) const;
|
||||
|
||||
void RemoveSingleDequantNodePartitions() {
|
||||
auto it = partitions_.begin();
|
||||
while (it != partitions_.end()) {
|
||||
auto p = *it;
|
||||
if (p->nodes_to_replace->size != 1) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
int node_id = p->nodes_to_replace->data[0];
|
||||
TfLiteNode* node = nullptr;
|
||||
TfLiteRegistration* registration = nullptr;
|
||||
GetNodeAndRegistration(context_, node_id, &node, ®istration)
|
||||
.IgnoreError();
|
||||
if (registration->builtin_code != kTfLiteBuiltinDequantize ||
|
||||
context_->tensors[node->inputs->data[0]].type !=
|
||||
TfLiteType::kTfLiteFloat16) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
// Note such dequant nodes have to be preserved in the graph as dequant
|
||||
// ops are not actually supported in the GPU delegate.
|
||||
dequant_nodes_to_save_.insert(node_id);
|
||||
it = partitions_.erase(it);
|
||||
}
|
||||
}
|
||||
void RemoveSingleDequantNodePartitions();
|
||||
|
||||
void RemoveReservedDequantsFromNodes(std::vector<int>* nodes) {
|
||||
if (dequant_nodes_to_save_.empty()) return;
|
||||
auto it = nodes->begin();
|
||||
while (it != nodes->end()) {
|
||||
if (dequant_nodes_to_save_.find(*it) == dequant_nodes_to_save_.end()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
it = nodes->erase(it);
|
||||
}
|
||||
}
|
||||
void RemoveReservedDequantsFromNodes(std::vector<int>* nodes);
|
||||
|
||||
// Remap input tensors of a single 'node' if some of come from a dequant op.
|
||||
// If 'orig_inputs' isn't nullptr, it records original input tensor ids of
|
||||
// this node if any input is remapped.
|
||||
void RemapInputTensors(TfLiteNode* node,
|
||||
std::vector<int>* orig_inputs) const {
|
||||
TfLiteIntArray* inputs = node->inputs;
|
||||
auto inputs_view = TfLiteIntArrayView(inputs);
|
||||
// Prepopulate 'orig_inputs' first and clear it if there's no input from a
|
||||
// dequant op.
|
||||
if (orig_inputs) {
|
||||
orig_inputs->clear();
|
||||
orig_inputs->reserve(inputs->size);
|
||||
for (auto tid : inputs_view) {
|
||||
orig_inputs->push_back(tid);
|
||||
}
|
||||
}
|
||||
// Fix this node's inputs (i.e. prune out the preceding dequantize node) in
|
||||
// order to test if it is supported.
|
||||
bool is_remapped = false;
|
||||
for (int j = 0; j < inputs->size; ++j) {
|
||||
const int input_tid = inputs->data[j];
|
||||
const auto it = dequant_nodes_.find(input_tid);
|
||||
if (it != dequant_nodes_.end()) {
|
||||
inputs->data[j] = it->second;
|
||||
is_remapped = true;
|
||||
}
|
||||
}
|
||||
if (!is_remapped && orig_inputs) orig_inputs->clear();
|
||||
}
|
||||
void RemapInputTensors(TfLiteNode* node, std::vector<int>* orig_inputs) const;
|
||||
|
||||
// A map recording dequantize nodes's input/output tensors of this selected
|
||||
// graph. The key is the output tensor id, and the value is the input tensor
|
||||
@ -222,6 +89,95 @@ class GraphWithDequantPartitionHelper
|
||||
// graph.
|
||||
std::set<int> dequant_nodes_to_save_;
|
||||
};
|
||||
|
||||
absl::Status GetNodeAndRegistration(TfLiteContext* context, int node_id,
|
||||
TfLiteNode** tflite_node,
|
||||
TfLiteRegistration** registration);
|
||||
|
||||
DataType ToDataType(TfLiteType type);
|
||||
|
||||
absl::Status ExtractTensorShape(const TfLiteTensor& tflite_tensor, BHWC* bhwc);
|
||||
|
||||
absl::Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
|
||||
TensorRef<BHWC>* tensor_ref);
|
||||
|
||||
// Populates quantization parameters for non-constant UInt8/Int8 tensors.
|
||||
// This helps the delegate emulate quantized inference with
|
||||
// QuantizeAndDequantize.
|
||||
absl::Status PopulateQuantParams(const TfLiteTensor& tensor,
|
||||
QuantizationParams* quant_params);
|
||||
|
||||
int GetNumberOfRuntimeInputsForNode(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node);
|
||||
|
||||
int GetNumberOfConstInputsForNode(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node);
|
||||
|
||||
int GetNumberOfRuntimeOutputsForNode(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node);
|
||||
|
||||
absl::Status CheckInputsOutputs(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
int runtime_inputs, int outputs);
|
||||
|
||||
absl::Status CheckInputsConstsOutputs(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
int runtime_inputs, int const_inputs,
|
||||
int outputs);
|
||||
|
||||
void ConvertFloat16ToFloat32(size_t num_elements, const uint16_t* src,
|
||||
float* dst);
|
||||
|
||||
template <typename T>
|
||||
void DequantizeConstantTensor(const TfLiteTensor& tensor, const T* source_data,
|
||||
float* dequantized_data) {
|
||||
TfLiteAffineQuantization* quant_params =
|
||||
reinterpret_cast<TfLiteAffineQuantization*>(tensor.quantization.params);
|
||||
if (quant_params->scale->size > 1) {
|
||||
// Tensor is per-channel quantized.
|
||||
PerChannelDequantizationParams op_params;
|
||||
op_params.zero_point = quant_params->zero_point->data;
|
||||
op_params.scale = quant_params->scale->data;
|
||||
op_params.quantized_dimension = quant_params->quantized_dimension;
|
||||
reference_ops::PerChannelDequantize(op_params, GetTensorShape(&tensor),
|
||||
source_data, GetTensorShape(&tensor),
|
||||
dequantized_data);
|
||||
} else {
|
||||
DequantizationParams op_params;
|
||||
op_params.zero_point = tensor.params.zero_point;
|
||||
op_params.scale = tensor.params.scale;
|
||||
reference_ops::Dequantize(op_params, GetTensorShape(&tensor), source_data,
|
||||
GetTensorShape(&tensor), dequantized_data);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::Status CreateVectorCopyData(const TfLiteTensor& tensor, T* tensor_data) {
|
||||
if (tensor.bytes % sizeof(T) != 0) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Input data size ", tensor.bytes,
|
||||
" is not aligned to expected type: ", sizeof(T)));
|
||||
}
|
||||
std::memcpy(tensor_data, tensor.data.uint8, tensor.bytes);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <>
|
||||
absl::Status CreateVectorCopyData<float>(const TfLiteTensor& tensor,
|
||||
float* tensor_data);
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Scalar* shape);
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, Linear* shape);
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HWC* shape);
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, HW* shape);
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, OHWI* shape);
|
||||
|
||||
absl::Status SetAllDimensions(const TfLiteIntArray* dimensions, BHWC* shape);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
|
176
tensorflow/lite/delegates/gpu/common/object_reader.cc
Normal file
176
tensorflow/lite/delegates/gpu/common/object_reader.cc
Normal file
@ -0,0 +1,176 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/object_reader.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/utils.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
absl::Status ObjectReader::ReadNonConstantTensor(
|
||||
TfLiteContext* context,
|
||||
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value,
|
||||
std::unordered_map<int, int>* quant_conversion_map, GraphFloat32* graph,
|
||||
uint32_t tensor_idx, Value<TensorRef<BHWC>>** value) {
|
||||
if (tensor_idx >= context->tensors_size) {
|
||||
return absl::OutOfRangeError(
|
||||
absl::StrCat("ReadNonConstTensor: input tensor index: ", tensor_idx));
|
||||
}
|
||||
|
||||
if (tensor_to_value->find(tensor_idx) == tensor_to_value->end()) {
|
||||
const TfLiteTensor& tflite_tensor = context->tensors[tensor_idx];
|
||||
if (tflite::IsConstantTensor(&tflite_tensor)) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"ReadNonConstantTensor: value is a constant tensor: ", tensor_idx));
|
||||
}
|
||||
|
||||
if ((tflite_tensor.type == kTfLiteInt8 ||
|
||||
tflite_tensor.type == kTfLiteUInt8) &&
|
||||
quant_conversion_map) {
|
||||
// Quantized case
|
||||
if (quant_conversion_map->find(tensor_idx) ==
|
||||
quant_conversion_map->end()) {
|
||||
// Since the original tensor is fixed-point, add a new float tensor to
|
||||
// the TFLite graph to represent the dequantized data.
|
||||
int fp_tensor_index = 0;
|
||||
TfLiteTensor* fp_tflite_tensor;
|
||||
if (delegates::CreateNewTensorWithDifferentType(
|
||||
context, tensor_idx, kTfLiteFloat32, &fp_tflite_tensor,
|
||||
&fp_tensor_index) != kTfLiteOk) {
|
||||
return absl::InternalError("Could not add new tensor to graph");
|
||||
}
|
||||
// Remember this tensor for later.
|
||||
(*quant_conversion_map)[fp_tensor_index] = tensor_idx;
|
||||
(*quant_conversion_map)[tensor_idx] = fp_tensor_index;
|
||||
// Add a new GPU Value for the new dequantized floating-point tensor.
|
||||
Value<TensorRef<BHWC>>* value = graph->NewValue();
|
||||
RETURN_IF_ERROR(
|
||||
ConvertTfLiteTensorToTensorRef(*fp_tflite_tensor, &value->tensor));
|
||||
value->tensor.ref = fp_tensor_index;
|
||||
value->quant_params.emplace();
|
||||
RETURN_IF_ERROR(
|
||||
PopulateQuantParams(tflite_tensor, &value->quant_params.value()));
|
||||
(*tensor_to_value)[fp_tensor_index] = value;
|
||||
}
|
||||
// We do not use the original tensor index as reference for the GPU
|
||||
// Value, instead pointing at the corresponding float version.
|
||||
tensor_idx = quant_conversion_map->at(tensor_idx);
|
||||
} else {
|
||||
// Floating-point case.
|
||||
Value<TensorRef<BHWC>>* value = graph->NewValue();
|
||||
RETURN_IF_ERROR(
|
||||
ConvertTfLiteTensorToTensorRef(tflite_tensor, &value->tensor));
|
||||
value->tensor.ref = tensor_idx;
|
||||
(*tensor_to_value)[tensor_idx] = value;
|
||||
}
|
||||
}
|
||||
|
||||
if (value) {
|
||||
*value = (*tensor_to_value)[tensor_idx];
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ObjectReader::ReadValue(uint32_t idx,
|
||||
Value<TensorRef<BHWC>>** value) {
|
||||
if (idx >= node_->inputs->size) {
|
||||
return absl::OutOfRangeError(
|
||||
absl::StrCat("ReadValue: input tensor index: ", idx));
|
||||
}
|
||||
return ReadValueByTensorIdx(node_->inputs->data[idx], value);
|
||||
}
|
||||
|
||||
absl::Status ObjectReader::ReadValueByTensorIdx(
|
||||
uint32_t tensor_idx, Value<TensorRef<BHWC>>** value) {
|
||||
// Constant tensors should be handled by ReadTensor.
|
||||
return ReadNonConstantTensor(context_, tensor_to_value_,
|
||||
quant_conversion_map_, graph_, tensor_idx,
|
||||
value);
|
||||
}
|
||||
|
||||
int ObjectReader::GetNumberOfRuntimeInputs() const {
|
||||
return GetNumberOfRuntimeInputsForNode(context_, node_);
|
||||
}
|
||||
|
||||
absl::Status ObjectReader::GetTensorDims(uint32_t idx,
|
||||
TfLiteIntArray* dimensions) const {
|
||||
if (idx >= node_->inputs->size) {
|
||||
return absl::OutOfRangeError(absl::StrCat("Input tensor index: ", idx));
|
||||
}
|
||||
const int tensor_idx = node_->inputs->data[idx];
|
||||
if (tensor_idx < 0 || tensor_idx > context_->tensors_size) {
|
||||
return absl::OutOfRangeError(absl::StrCat("Tensor index: ", tensor_idx));
|
||||
}
|
||||
const TfLiteTensor& tflite_tensor = context_->tensors[tensor_idx];
|
||||
*dimensions = *tflite_tensor.dims;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ObjectReader::AddOutput(const Node* node, int id) {
|
||||
if (node_->outputs->size <= id) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Data id ", id, " must be less than tflite node outputs size ",
|
||||
node_->outputs->size));
|
||||
}
|
||||
int output_tensor_idx = node_->outputs->data[id];
|
||||
Value<TensorRef<BHWC>>* value;
|
||||
RETURN_IF_ERROR(ReadValueByTensorIdx(output_tensor_idx, &value));
|
||||
RETURN_IF_ERROR(graph_->SetProducer(node->id, value->id));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ObjectReader::AddOutputs(const Node* node) {
|
||||
for (int i = 0; i < node_->outputs->size; ++i) {
|
||||
RETURN_IF_ERROR(AddOutput(node, i));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ObjectReader::AddInput(const Node* node, uint32_t idx) {
|
||||
Value<TensorRef<BHWC>>* input;
|
||||
RETURN_IF_ERROR(ReadValue(idx, &input));
|
||||
return graph_->AddConsumer(node->id, input->id);
|
||||
}
|
||||
|
||||
TfLiteTensor* ObjectReader::GetInputTensor(int index) const {
|
||||
return index >= 0 && index < node_->inputs->size
|
||||
? context_->tensors + node_->inputs->data[index]
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
TfLiteTensor* ObjectReader::GetOutputTensor(int index) const {
|
||||
return index >= 0 && index < node_->outputs->size
|
||||
? context_->tensors + node_->outputs->data[index]
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
absl::Status ObjectReader::VerifyInputsConstsOutputs(const TfLiteNode* node,
|
||||
int runtime_inputs,
|
||||
int const_inputs,
|
||||
int outputs) {
|
||||
return CheckInputsConstsOutputs(context_, node, runtime_inputs, const_inputs,
|
||||
outputs);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
99
tensorflow/lite/delegates/gpu/common/object_reader.h
Normal file
99
tensorflow/lite/delegates/gpu/common/object_reader.h
Normal file
@ -0,0 +1,99 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OBJECT_READER_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OBJECT_READER_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
|
||||
// If quantized tensors exist in the graph & quant_conversion_map is non-null,
|
||||
// the mapping between the original tensors (fixed-point) & GPU values (fp) is
|
||||
// stored in quant_conversion_map.
|
||||
class ObjectReader {
|
||||
public:
|
||||
static absl::Status ReadNonConstantTensor(
|
||||
TfLiteContext* context,
|
||||
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value,
|
||||
std::unordered_map<int, int>* quant_conversion_map, GraphFloat32* graph,
|
||||
uint32_t tensor_idx, Value<TensorRef<BHWC>>** value = nullptr);
|
||||
|
||||
ObjectReader(
|
||||
GraphFloat32* graph, TfLiteContext* context, const TfLiteNode* node,
|
||||
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value,
|
||||
std::unordered_map<int, int>* quant_conversion_map = nullptr)
|
||||
: graph_(graph),
|
||||
context_(context),
|
||||
node_(node),
|
||||
tensor_to_value_(tensor_to_value),
|
||||
quant_conversion_map_(quant_conversion_map) {}
|
||||
|
||||
absl::Status ReadValue(uint32_t idx, Value<TensorRef<BHWC>>** value);
|
||||
|
||||
absl::Status ReadValueByTensorIdx(uint32_t tensor_idx,
|
||||
Value<TensorRef<BHWC>>** value);
|
||||
|
||||
int GetNumberOfRuntimeInputs() const;
|
||||
|
||||
absl::Status GetTensorDims(uint32_t idx, TfLiteIntArray* dimensions) const;
|
||||
|
||||
template <typename TensorT>
|
||||
absl::Status ReadTensor(uint32_t idx, TensorT* t) const {
|
||||
const int32_t tensor_idx = node_->inputs->data[idx];
|
||||
const TfLiteTensor* tflite_tensor = context_->tensors + tensor_idx;
|
||||
t->data.resize(NumElements(tflite_tensor));
|
||||
RETURN_IF_ERROR(CreateVectorCopyData(*tflite_tensor, &t->data[0]));
|
||||
|
||||
// Axis and data layout depend on operation this tensor is used in. So,
|
||||
// postpone resolutions until operations are parsed.
|
||||
t->id = tensor_idx;
|
||||
return SetAllDimensions(tflite_tensor->dims, &t->shape);
|
||||
}
|
||||
|
||||
absl::Status AddOutput(const Node* node, int id);
|
||||
|
||||
absl::Status AddOutputs(const Node* node);
|
||||
|
||||
absl::Status AddInput(const Node* node, uint32_t idx);
|
||||
|
||||
TfLiteTensor* GetInputTensor(int index) const;
|
||||
|
||||
TfLiteTensor* GetOutputTensor(int index) const;
|
||||
|
||||
absl::Status VerifyInputsConstsOutputs(const TfLiteNode* node,
|
||||
int runtime_inputs, int const_inputs,
|
||||
int outputs);
|
||||
|
||||
private:
|
||||
GraphFloat32* graph_;
|
||||
TfLiteContext* context_;
|
||||
const TfLiteNode* node_;
|
||||
std::unordered_map<int, Value<TensorRef<BHWC>>*>* tensor_to_value_;
|
||||
std::unordered_map<int, int>* quant_conversion_map_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_GPU_COMMON_OBJECT_READER_H_
|
Loading…
Reference in New Issue
Block a user