Add cache for const nodes in hexagon delegate.

On few test models this reduces the size of const nodes by half, which will reduce graph preparation time.
Bug fix for sometimes wrong casting.
Remove some redundant const nodes.

PiperOrigin-RevId: 326046789
Change-Id: I462dd6702e0e02953c43ab47dd53589a653b3531
This commit is contained in:
Karim Nosir 2020-08-11 10:18:01 -07:00 committed by TensorFlower Gardener
parent aa9d2d80f4
commit 0d4f0584ef
11 changed files with 152 additions and 67 deletions

View File

@ -85,6 +85,7 @@ cc_library(
"//tensorflow/lite/kernels:padding",
"//tensorflow/lite/kernels/internal:optimized_base",
"//tensorflow/lite/kernels/internal:tensor",
"@farmhash_archive//:farmhash",
"@hexagon_nn//:hexagon_nn_ops",
],
)

View File

@ -267,13 +267,13 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
auto* conv_op = graph_builder_->AddNode(GetTFLiteNodeID());
conv_op->SetOpType(OP_DepthwiseSupernode_8x8p32to8);
conv_op->AddInput(space_to_batch_op_out);
conv_op->AddInput(TensorID(weights_data_node_->GetID(), 0));
conv_op->AddInput(graph_builder_->GetHexagonTensorId(inputs->data[1]));
conv_op->AddInput(TensorID(data_min_const->GetID(), 0));
conv_op->AddInput(TensorID(data_max_const->GetID(), 0));
conv_op->AddInput(TensorID(weights_min_node_->GetID(), 0));
conv_op->AddInput(TensorID(weights_max_node_->GetID(), 0));
conv_op->AddInput(TensorID(stride_node->GetID(), 0));
conv_op->AddInput(TensorID(bias_data_node_->GetID(), 0));
conv_op->AddInput(graph_builder_->GetHexagonTensorId(inputs->data[2]));
conv_op->AddInput(TensorID(bias_min_node_->GetID(), 0));
conv_op->AddInput(TensorID(bias_max_node_->GetID(), 0));
conv_op->AddInput(TensorID(conv_output_min_const->GetID(), 0));
@ -330,13 +330,13 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
}
// Inputs
AddInput(graph_builder_->GetHexagonTensorId(inputs->data[0]));
AddInput(TensorID(weights_data_node_->GetID(), 0));
AddInput(graph_builder_->GetHexagonTensorId(inputs->data[1]));
AddInput(TensorID(data_min_const->GetID(), 0));
AddInput(TensorID(data_max_const->GetID(), 0));
AddInput(TensorID(weights_min_node_->GetID(), 0));
AddInput(TensorID(weights_max_node_->GetID(), 0));
AddInput(TensorID(stride_node->GetID(), 0));
AddInput(TensorID(bias_data_node_->GetID(), 0));
AddInput(graph_builder_->GetHexagonTensorId(inputs->data[2]));
AddInput(TensorID(bias_min_node_->GetID(), 0));
AddInput(TensorID(bias_max_node_->GetID(), 0));
AddInput(TensorID(conv_output_min_const->GetID(), 0));

View File

@ -62,10 +62,8 @@ class Conv2dOpBuilder : public OpBuilder {
std::vector<float> transposed_weights_;
std::vector<int> stride_shape_;
std::vector<int> weight_shape_;
OpBuilder* weights_data_node_ = nullptr;
OpBuilder* weights_min_node_ = nullptr;
OpBuilder* weights_max_node_ = nullptr;
OpBuilder* bias_data_node_ = nullptr;
OpBuilder* bias_min_node_ = nullptr;
OpBuilder* bias_max_node_ = nullptr;

View File

@ -106,6 +106,7 @@ TfLiteStatus Conv2dOpBuilder::InitializeWeightsNodes(
const bool is_per_channel_quant = weights_quant_params->scale->size > 1;
// WEIGHTS DATA.
OpBuilder* weights_data_node = nullptr;
if (op_node_.op_type == OP_Supernode_8x8p32to8) {
// Hexagon lib expects the weight tensor in HWCN, TFLite uses NHWC.
// Transpose NHWC -> HWCN
@ -137,7 +138,7 @@ TfLiteStatus Conv2dOpBuilder::InitializeWeightsNodes(
weights_tensor.data.uint8, hwcn_shape,
hwcn.data());
}
weights_data_node_ = graph_builder_->AddConstNodeWithData(
weights_data_node = graph_builder_->AddConstNodeWithData(
weight_shape_.data(), reinterpret_cast<char*>(hwcn.data()),
hwcn.size() * sizeof(hwcn[0]));
} else if (op_node_.op_type == OP_DepthwiseSupernode_8x8p32to8) {
@ -156,17 +157,17 @@ TfLiteStatus Conv2dOpBuilder::InitializeWeightsNodes(
for (int i = 0; i < converted_data.size(); ++i) {
converted_data[i] = weights_tensor.data.int8[i] ^ k8BitSignFlipConstant;
}
weights_data_node_ = graph_builder_->AddConstNodeWithData(
weights_data_node = graph_builder_->AddConstNodeWithData(
weight_shape_.data(), reinterpret_cast<char*>(converted_data.data()),
converted_data.size() * sizeof(converted_data[0]));
} else {
weights_data_node_ = graph_builder_->AddConstNodeWithData(
weights_data_node = graph_builder_->AddConstNodeWithData(
weight_shape_.data(), weights_tensor.data.raw,
NumElements(&weights_tensor) * sizeof(weights_tensor.data.uint8[0]));
}
}
graph_builder_->AddTensorWithID(inputs->data[1], weights_data_node_->GetID(),
0);
graph_builder_->AddTensorWithID(inputs->data[1], weights_data_node->GetID(),
0, /*overwrite=*/true);
// WEIGHTS QUANTIZATION.
float weights_min = 0;
@ -229,9 +230,11 @@ TfLiteStatus Conv2dOpBuilder::ProcessPerChannelQuantizedBias(
}
// Add nodes for bias.
const std::vector<int> bias_shape = {1, 1, 1, bias_size};
bias_data_node_ = graph_builder_->AddConstNodeWithData(
auto* bias_data_node = graph_builder_->AddConstNodeWithData(
bias_shape.data(), reinterpret_cast<char*>(preprocessed_bias_data.data()),
preprocessed_bias_data.size() * sizeof(preprocessed_bias_data[0]));
graph_builder_->AddTensorWithID(inputs->data[2], bias_data_node->GetID(), 0,
/*overwrite=*/true);
return kTfLiteOk;
}
@ -248,8 +251,10 @@ TfLiteStatus Conv2dOpBuilder::InitializeBiasNodes(const TfLiteIntArray* inputs,
ProcessPerChannelQuantizedBias(inputs, outputs, context, &bias_min,
&bias_max);
} else {
bias_data_node_ =
auto* bias_data_node =
graph_builder_->AddConstNodeWithData(inputs->data[2], bias_tensor);
graph_builder_->AddTensorWithID(inputs->data[2], bias_data_node->GetID(), 0,
/*overwrite=*/true);
TF_LITE_ENSURE_STATUS(
ComputeMinAndMaxQuantValues(bias_tensor, &bias_min, &bias_max));
}

View File

@ -27,10 +27,6 @@ TfLiteStatus MinMaxOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
int b_tensor_id = inputs->data[1];
const auto& a_tensor = context->tensors[a_tensor_id];
const auto& b_tensor = context->tensors[b_tensor_id];
if (a_tensor.allocation_type == kTfLiteMmapRo)
graph_builder_->AddConstNodeWithData(a_tensor_id, a_tensor);
if (b_tensor.allocation_type == kTfLiteMmapRo)
graph_builder_->AddConstNodeWithData(b_tensor_id, b_tensor);
AddInput(graph_builder_->GetHexagonTensorId(a_tensor_id));
AddInput(graph_builder_->GetHexagonTensorId(b_tensor_id));

View File

@ -18,10 +18,59 @@ limitations under the License.
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/hexagon/builders/op_factory.h"
#include <farmhash.h>
namespace tflite {
namespace delegates {
namespace hexagon {
namespace {
// Farmhash Fingerprint
inline uint64_t CombineFingerprints(uint64_t l, uint64_t h) {
// Murmur-inspired hashing.
const uint64_t kMul = 0x9ddfea08eb382d69ULL;
uint64_t a = (l ^ h) * kMul;
a ^= (a >> 47);
uint64_t b = (h ^ a) * kMul;
b ^= (b >> 44);
b *= kMul;
b ^= (b >> 41);
b *= kMul;
return b;
}
inline uint64_t ComputeHash(const int shape[], const char* data,
const int data_len) {
return CombineFingerprints(
::util::Fingerprint64(data, data_len),
::util::Fingerprint64(reinterpret_cast<const char*>(shape),
sizeof(shape[0]) * 4));
}
inline uint64_t ComputeHash(const TfLiteTensor& tensor, const int shape[],
int int8_to_uint8) {
auto data_hash = ComputeHash(shape, tensor.data.raw_const, tensor.bytes);
auto int8_to_uint8_hash = ::util::Fingerprint64(
reinterpret_cast<char*>(&int8_to_uint8), sizeof(int8_to_uint8));
return CombineFingerprints(data_hash, int8_to_uint8_hash);
}
int GetElementSize(TfLiteType type) {
switch (type) {
case kTfLiteFloat32:
return sizeof(float);
case kTfLiteBool:
return sizeof(bool);
case kTfLiteInt32:
return sizeof(int32_t);
case kTfLiteInt8:
return sizeof(int8_t);
case kTfLiteUInt8:
return sizeof(uint8_t);
default:
return sizeof(int8_t);
}
}
} // namespace
OpBuilder* GraphBuilder::CreateOpBuilderFromTfLiteOp(int op_type,
TfLiteNode* node) {
@ -116,8 +165,20 @@ OpBuilder* GraphBuilder::CreateOpBuilderFromTfLiteOp(int op_type,
}
}
OpBuilder* GraphBuilder::LookupConstData(uint64_t cache_key) {
auto lookup_result = cache_.find(cache_key);
if (lookup_result != cache_.end()) return lookup_result->second;
return nullptr;
}
void GraphBuilder::AddToCache(uint64_t cache_key, OpBuilder* value) {
cache_[cache_key] = value;
}
OpBuilder* GraphBuilder::AddConstNodeWithData(const int shape[], char* data,
int data_size) {
auto cache_key = ComputeHash(shape, data, data_size);
if (auto lookup_result = LookupConstData(cache_key)) return lookup_result;
builders_.emplace_back(new OpBuilder(this, OP_Const));
builders_.back()->SetConstNode();
builders_.back()->SetNodeId(builders_.size());
@ -125,22 +186,36 @@ OpBuilder* GraphBuilder::AddConstNodeWithData(const int shape[], char* data,
graph_id_, builders_.size(), shape[0], shape[1], shape[2], shape[3],
reinterpret_cast<const uint8_t*>(data), data_size);
if (error != 0) {
context_->ReportError(context_, "Error adding const node with shape id: %d",
(int)builders_.size());
TF_LITE_KERNEL_LOG(context_, "Error adding const node with shape id: %d",
static_cast<int>(builders_.size()));
return nullptr;
}
AddToCache(cache_key, builders_.back().get());
return builders_.back().get();
}
OpBuilder* GraphBuilder::AddConstNodeWithData(int tensor_id,
const TfLiteTensor& tensor,
bool int8_to_uint8) {
// Fetch shape of tensor and pad 1's so it is always 4D.
int batch_size, height_size, width_size, depth_size;
GetDims(&batch_size, &height_size, &width_size, &depth_size, tensor.dims);
const int shape[] = {batch_size, height_size, width_size, depth_size};
auto cache_key = ComputeHash(tensor, shape, int8_to_uint8 ? 1 : 0);
if (auto lookup_result = LookupConstData(cache_key)) {
// If tensor is cached but with no id, that can happen when the same
// data is added from a constant value (not tensor). We can cache the data
// and reuse it.
// We assign the tensor to this cached const node before returning.
if (!HasTensor(tensor_id))
AddTensorWithID(tensor_id, lookup_result->GetID(), 0);
return lookup_result;
}
builders_.emplace_back(new OpBuilder(this, OP_Const));
const int node_id = builders_.size();
builders_.back()->SetConstNode();
builders_.back()->SetNodeId(node_id);
int batch_size, height_size, width_size, depth_size;
GetDims(&batch_size, &height_size, &width_size, &depth_size, tensor.dims);
int error = hexagon_nn_->hexagon_nn_append_const_node(
graph_id_, node_id, batch_size, height_size, width_size, depth_size,
reinterpret_cast<const uint8_t*>(tensor.data.raw), tensor.bytes);
@ -150,19 +225,26 @@ OpBuilder* GraphBuilder::AddConstNodeWithData(int tensor_id,
return nullptr;
}
AddTensorWithID(tensor_id, node_id, 0);
// We need to return the builder with result, so we can't rely
// on builders_.back() as it can change while casting, so we hold pointer
// and update with value from casting if needed.
OpBuilder* result_builder = builders_.back().get();
// Cast int8 to uint8 if requested.
// This will add cast op to uint8 and update tensor map to point
// to the casted tensor.
if (int8_to_uint8 && tensor.type == kTfLiteInt8) {
AddCastOp(context_, OP_Quantized_CastInt8ToUInt8, tensor_id);
AddCastOp(context_, OP_Quantized_CastInt8ToUInt8, tensor_id,
&result_builder);
}
return builders_.back().get();
AddToCache(cache_key, result_builder);
return result_builder;
}
// TODO(b/154604279): Support these casting ops in Hexagon op profiling (which
// seems to key tensors on a single op, which may not be the case now).
TfLiteStatus GraphBuilder::AddCastOp(TfLiteContext* context, int op_type,
int tensor_id) {
int tensor_id,
OpBuilder** cast_op_builder) {
// Create a new OpBuilder for casting the tensor.
OpBuilder* cast_builder = CreateCastBuilder(this, op_type);
builders_.emplace_back(cast_builder);
@ -177,6 +259,7 @@ TfLiteStatus GraphBuilder::AddCastOp(TfLiteContext* context, int op_type,
TF_LITE_ENSURE_STATUS(cast_builder->RegisterOutputs(tensor_data, context));
TfLiteIntArrayFree(tensor_data);
if (cast_op_builder != nullptr) *cast_op_builder = cast_builder;
return kTfLiteOk;
}
@ -192,12 +275,12 @@ TfLiteStatus GraphBuilder::AddInputTensors(const TfLiteIntArray* input_tensors,
const int tensor_id = input_tensors->data[i];
const auto& tensor = context->tensors[tensor_id];
if (tensor.allocation_type == kTfLiteMmapRo) continue;
input_op->AddOutput(tensor.dims);
input_op->AddOutput(tensor.dims, GetElementSize(tensor.type));
AddTensorWithID(tensor_id, input_op->GetID(), num_inputs);
// If tensor is of type int8, add an op to cast it to uint8.
if (tensor.type == kTfLiteInt8) {
TF_LITE_ENSURE_STATUS(
AddCastOp(context, OP_Quantized_CastInt8ToUInt8, tensor_id));
TF_LITE_ENSURE_STATUS(AddCastOp(context, OP_Quantized_CastInt8ToUInt8,
tensor_id, /*cast_op_builder=*/nullptr));
}
++num_inputs;
}
@ -215,8 +298,8 @@ TfLiteStatus GraphBuilder::AddOutputTensors(
const auto& tensor = context->tensors[tensor_id];
// If tensor is of type int8, add an op to cast it to uint8.
if (tensor.type == kTfLiteInt8) {
TF_LITE_ENSURE_STATUS(
AddCastOp(context, OP_Quantized_CastUInt8ToInt8, tensor_id));
TF_LITE_ENSURE_STATUS(AddCastOp(context, OP_Quantized_CastUInt8ToInt8,
tensor_id, /*cast_op_builder=*/nullptr));
}
hexagon_output_ids.push_back(GetHexagonTensorId(tensor_id));
}
@ -231,9 +314,10 @@ TfLiteStatus GraphBuilder::AddOutputTensors(
return kTfLiteOk;
}
OpBuilder::TensorID OpBuilder::AddOutput(const TfLiteIntArray* dims) {
OpBuilder::TensorID OpBuilder::AddOutput(const TfLiteIntArray* dims,
int element_size) {
op_node_.outputs.push_back(hexagon_nn_output());
op_node_.outputs.back().elementsize = sizeof(uint8_t);
op_node_.outputs.back().elementsize = element_size;
op_node_.outputs.back().rank = 4;
// TODO(karimnosseir): What is a good to estimate the max size ?
int batch_size, height_size, width_size, depth_size;

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_LITE_DELEGATES_HEXAGON_BUILDERS_OP_BUILDER_H_
#include <limits>
#include <map>
#include <memory>
#include <string>
#include <utility>
@ -131,9 +132,9 @@ class OpBuilder {
void AddInput(const TensorID& tensor_id) { input_ids_.push_back(tensor_id); }
// Adds Output to the current node, the output has shape defined in 'dims'.
// This assumes the data type is uint8.
// The size of each element is defined using 'element_size'.
// Returns the TensorID identifying this output in the graph.
TensorID AddOutput(const TfLiteIntArray* dims);
TensorID AddOutput(const TfLiteIntArray* dims, int element_size);
// Adds Output to the current node, each element in the output has
// size 'elementsize' and rank 'rank' and for each dimension in the output
@ -316,11 +317,22 @@ class GraphBuilder {
bool AddTensorWithID(int tflite_tensor_id, int hexagon_node_id,
int hexagon_node_output_id, bool overwrite = false) {
if (!overwrite && HasTensor(tflite_tensor_id)) {
TF_LITE_KERNEL_LOG(
context_,
"Trying to add duplicate tensor without overwrite, tflite_tensor_id "
"%d, hexagon_node_id %d, hexagon_node_output_id %d",
tflite_tensor_id, hexagon_node_id, hexagon_node_output_id);
return false;
}
if (tensors_.size() <= tflite_tensor_id) {
tensors_.resize(tflite_tensor_id + 1);
}
if (hexagon_node_id == -1 || hexagon_node_output_id == -1)
TF_LITE_KERNEL_LOG(context_,
"Trying to add invalid id, tflite_tensor_id "
"%d, hexagon_node_id %d, hexagon_node_output_id %d",
tflite_tensor_id, hexagon_node_id,
hexagon_node_output_id);
tensors_[tflite_tensor_id] =
OpBuilder::TensorID(hexagon_node_id, hexagon_node_output_id);
return true;
@ -348,6 +360,14 @@ class GraphBuilder {
int GetMaxBatchSize() const { return max_size_for_batch_; }
private:
// Lookup in cache if data with key 'cache_key' is present.
// Return OpBuilder* for the data if found, nullptr otherwise.
OpBuilder* LookupConstData(uint64_t cache_key);
// Inserts 'value' in cache, with key equals 'cache_key'.
// If data in cache with same key then it will be overwritten.
void AddToCache(uint64_t cache_key, OpBuilder* value);
// Helper method to fetch dimensions.
// TODO(karimnosseir): Move this method to shared place.
void GetDims(int* batch_size, int* height_size, int* width_size,
@ -360,7 +380,10 @@ class GraphBuilder {
}
// Adds a Cast op to convert a tensor from int8 to uint8 (or vice versa).
TfLiteStatus AddCastOp(TfLiteContext* context, int op_type, int tensor_id);
// The builder which has the casting operator is filled in 'cast_op_builder'
// if not nullptr.
TfLiteStatus AddCastOp(TfLiteContext* context, int op_type, int tensor_id,
OpBuilder** cast_op_builder);
const HexagonNN* hexagon_nn_ = nullptr;
TfLiteContext* context_ = nullptr;
@ -373,6 +396,11 @@ class GraphBuilder {
// If the graph being built supports dynamic batch, this represents
// the maximum value for batch.
int max_size_for_batch_ = -1;
// Cache for const data in the graph.
// Key is hash of the data, value is pointer to the OpBuilder* for the added
// data.
std::map<uint64_t, OpBuilder*> cache_;
};
} // namespace hexagon

View File

@ -29,15 +29,7 @@ TfLiteStatus TransposeOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
const auto& input_tensor = context->tensors[tensor_id];
AddInput(graph_builder_->GetHexagonTensorId(tensor_id));
// permutation tensor.
tensor_id = inputs->data[1];
const auto& control_tensor = context->tensors[tensor_id];
if (control_tensor.allocation_type == kTfLiteMmapRo) {
auto* const_control_tensor_node =
graph_builder_->AddConstNodeWithData(tensor_id, control_tensor);
AddInput(TensorID(const_control_tensor_node->GetID(), 0));
} else {
AddInput(graph_builder_->GetHexagonTensorId(tensor_id));
}
AddInput(graph_builder_->GetHexagonTensorId(inputs->data[1]));
TF_LITE_ENSURE_STATUS(ComputeAndAddMinAndMax(context, input_tensor));

View File

@ -97,8 +97,6 @@ TfLiteStatus TransposeConv2dOpBuilder::PopulateSubGraph(
filter_depth_size;
GetDims(&filter_batch_size, &filter_height_size, &filter_width_size,
&filter_depth_size, weights_tensor.dims);
weight_shape_ = {filter_batch_size, filter_height_size, filter_width_size,
filter_depth_size};
// Weights tensor could be int8 even for per-tensor quantization.
// Therefore, we look at the number of scale values to check if it is
// per-channel quantized.
@ -106,25 +104,7 @@ TfLiteStatus TransposeConv2dOpBuilder::PopulateSubGraph(
reinterpret_cast<TfLiteAffineQuantization*>(
weights_tensor.quantization.params);
const bool is_per_channel_quant = weights_quant_params->scale->size > 1;
OpBuilder* const_weights_node;
if (weights_tensor.type == kTfLiteInt8) {
std::vector<uint8_t> weights_data(NumElements(&weights_tensor));
const int8_t* original_data = weights_tensor.data.int8;
// Flip bits on the weight values so that the int8 values are treated
// as uint8.
for (int i = 0; i < NumElements(&weights_tensor); ++i) {
weights_data[i] = original_data[i] ^ k8BitSignFlipConstant;
}
const_weights_node = graph_builder_->AddConstNodeWithData(
weight_shape_.data(), reinterpret_cast<char*>(weights_data.data()),
weights_data.size() * sizeof(weights_data[0]));
} else {
const_weights_node = graph_builder_->AddConstNodeWithData(
weight_shape_.data(), weights_tensor.data.raw, weights_tensor.bytes);
}
graph_builder_->AddTensorWithID(tensor_id, const_weights_node->GetID(), 0);
AddInput(TensorID(const_weights_node->GetID(), 0));
AddInput(graph_builder_->GetHexagonTensorId(tensor_id));
// Handle weights quantization.
float weights_min = 0;

View File

@ -47,7 +47,7 @@ class TransposeConv2dOpBuilder : public OpBuilder {
TensorID node_output_;
std::vector<float> transposed_weights_;
std::vector<int> stride_shape_;
std::vector<int> weight_shape_, bias_shape_;
std::vector<int> bias_shape_;
std::vector<int> bias_data_;
// Non-null only if node has per-channel quantized weights/biases.

View File

@ -264,8 +264,9 @@ TfLiteStatus HexagonDelegateKernel::BuildGraph(
if (tensor_id == -1) continue;
const auto& input_tensor = context->tensors[tensor_id];
if (input_tensor.allocation_type == kTfLiteMmapRo) {
builder_->AddConstNodeWithData(tensor_id, input_tensor,
/*int8_to_uint8*/ true);
builder_->AddConstNodeWithData(
tensor_id, input_tensor,
/*int8_to_uint8*/ (input_tensor.type == kTfLiteInt8));
}
}
auto* op_builder =