diff --git a/tensorflow/lite/delegates/hexagon/builders/BUILD b/tensorflow/lite/delegates/hexagon/builders/BUILD index 63ff274c7b7..ef4b0e957c1 100644 --- a/tensorflow/lite/delegates/hexagon/builders/BUILD +++ b/tensorflow/lite/delegates/hexagon/builders/BUILD @@ -85,6 +85,7 @@ cc_library( "//tensorflow/lite/kernels:padding", "//tensorflow/lite/kernels/internal:optimized_base", "//tensorflow/lite/kernels/internal:tensor", + "@farmhash_archive//:farmhash", "@hexagon_nn//:hexagon_nn_ops", ], ) diff --git a/tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.cc b/tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.cc index cfddd2c2b97..c6d20004227 100644 --- a/tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.cc @@ -267,13 +267,13 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, auto* conv_op = graph_builder_->AddNode(GetTFLiteNodeID()); conv_op->SetOpType(OP_DepthwiseSupernode_8x8p32to8); conv_op->AddInput(space_to_batch_op_out); - conv_op->AddInput(TensorID(weights_data_node_->GetID(), 0)); + conv_op->AddInput(graph_builder_->GetHexagonTensorId(inputs->data[1])); conv_op->AddInput(TensorID(data_min_const->GetID(), 0)); conv_op->AddInput(TensorID(data_max_const->GetID(), 0)); conv_op->AddInput(TensorID(weights_min_node_->GetID(), 0)); conv_op->AddInput(TensorID(weights_max_node_->GetID(), 0)); conv_op->AddInput(TensorID(stride_node->GetID(), 0)); - conv_op->AddInput(TensorID(bias_data_node_->GetID(), 0)); + conv_op->AddInput(graph_builder_->GetHexagonTensorId(inputs->data[2])); conv_op->AddInput(TensorID(bias_min_node_->GetID(), 0)); conv_op->AddInput(TensorID(bias_max_node_->GetID(), 0)); conv_op->AddInput(TensorID(conv_output_min_const->GetID(), 0)); @@ -330,13 +330,13 @@ TfLiteStatus Conv2dOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, } // Inputs AddInput(graph_builder_->GetHexagonTensorId(inputs->data[0])); - AddInput(TensorID(weights_data_node_->GetID(), 0)); + AddInput(graph_builder_->GetHexagonTensorId(inputs->data[1])); AddInput(TensorID(data_min_const->GetID(), 0)); AddInput(TensorID(data_max_const->GetID(), 0)); AddInput(TensorID(weights_min_node_->GetID(), 0)); AddInput(TensorID(weights_max_node_->GetID(), 0)); AddInput(TensorID(stride_node->GetID(), 0)); - AddInput(TensorID(bias_data_node_->GetID(), 0)); + AddInput(graph_builder_->GetHexagonTensorId(inputs->data[2])); AddInput(TensorID(bias_min_node_->GetID(), 0)); AddInput(TensorID(bias_max_node_->GetID(), 0)); AddInput(TensorID(conv_output_min_const->GetID(), 0)); diff --git a/tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.h b/tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.h index 4980b294481..1407f06154b 100644 --- a/tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.h +++ b/tensorflow/lite/delegates/hexagon/builders/conv_2d_builder.h @@ -62,10 +62,8 @@ class Conv2dOpBuilder : public OpBuilder { std::vector transposed_weights_; std::vector stride_shape_; std::vector weight_shape_; - OpBuilder* weights_data_node_ = nullptr; OpBuilder* weights_min_node_ = nullptr; OpBuilder* weights_max_node_ = nullptr; - OpBuilder* bias_data_node_ = nullptr; OpBuilder* bias_min_node_ = nullptr; OpBuilder* bias_max_node_ = nullptr; diff --git a/tensorflow/lite/delegates/hexagon/builders/conv_2d_helpers.cc b/tensorflow/lite/delegates/hexagon/builders/conv_2d_helpers.cc index bf68bbe5a25..b33e28f4e71 100644 --- a/tensorflow/lite/delegates/hexagon/builders/conv_2d_helpers.cc +++ b/tensorflow/lite/delegates/hexagon/builders/conv_2d_helpers.cc @@ -106,6 +106,7 @@ TfLiteStatus Conv2dOpBuilder::InitializeWeightsNodes( const bool is_per_channel_quant = weights_quant_params->scale->size > 1; // WEIGHTS DATA. + OpBuilder* weights_data_node = nullptr; if (op_node_.op_type == OP_Supernode_8x8p32to8) { // Hexagon lib expects the weight tensor in HWCN, TFLite uses NHWC. // Transpose NHWC -> HWCN @@ -137,7 +138,7 @@ TfLiteStatus Conv2dOpBuilder::InitializeWeightsNodes( weights_tensor.data.uint8, hwcn_shape, hwcn.data()); } - weights_data_node_ = graph_builder_->AddConstNodeWithData( + weights_data_node = graph_builder_->AddConstNodeWithData( weight_shape_.data(), reinterpret_cast(hwcn.data()), hwcn.size() * sizeof(hwcn[0])); } else if (op_node_.op_type == OP_DepthwiseSupernode_8x8p32to8) { @@ -156,17 +157,17 @@ TfLiteStatus Conv2dOpBuilder::InitializeWeightsNodes( for (int i = 0; i < converted_data.size(); ++i) { converted_data[i] = weights_tensor.data.int8[i] ^ k8BitSignFlipConstant; } - weights_data_node_ = graph_builder_->AddConstNodeWithData( + weights_data_node = graph_builder_->AddConstNodeWithData( weight_shape_.data(), reinterpret_cast(converted_data.data()), converted_data.size() * sizeof(converted_data[0])); } else { - weights_data_node_ = graph_builder_->AddConstNodeWithData( + weights_data_node = graph_builder_->AddConstNodeWithData( weight_shape_.data(), weights_tensor.data.raw, NumElements(&weights_tensor) * sizeof(weights_tensor.data.uint8[0])); } } - graph_builder_->AddTensorWithID(inputs->data[1], weights_data_node_->GetID(), - 0); + graph_builder_->AddTensorWithID(inputs->data[1], weights_data_node->GetID(), + 0, /*overwrite=*/true); // WEIGHTS QUANTIZATION. float weights_min = 0; @@ -229,9 +230,11 @@ TfLiteStatus Conv2dOpBuilder::ProcessPerChannelQuantizedBias( } // Add nodes for bias. const std::vector bias_shape = {1, 1, 1, bias_size}; - bias_data_node_ = graph_builder_->AddConstNodeWithData( + auto* bias_data_node = graph_builder_->AddConstNodeWithData( bias_shape.data(), reinterpret_cast(preprocessed_bias_data.data()), preprocessed_bias_data.size() * sizeof(preprocessed_bias_data[0])); + graph_builder_->AddTensorWithID(inputs->data[2], bias_data_node->GetID(), 0, + /*overwrite=*/true); return kTfLiteOk; } @@ -248,8 +251,10 @@ TfLiteStatus Conv2dOpBuilder::InitializeBiasNodes(const TfLiteIntArray* inputs, ProcessPerChannelQuantizedBias(inputs, outputs, context, &bias_min, &bias_max); } else { - bias_data_node_ = + auto* bias_data_node = graph_builder_->AddConstNodeWithData(inputs->data[2], bias_tensor); + graph_builder_->AddTensorWithID(inputs->data[2], bias_data_node->GetID(), 0, + /*overwrite=*/true); TF_LITE_ENSURE_STATUS( ComputeMinAndMaxQuantValues(bias_tensor, &bias_min, &bias_max)); } diff --git a/tensorflow/lite/delegates/hexagon/builders/min_max_builder.cc b/tensorflow/lite/delegates/hexagon/builders/min_max_builder.cc index bcfae6032c8..0c6dea2096d 100644 --- a/tensorflow/lite/delegates/hexagon/builders/min_max_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/min_max_builder.cc @@ -27,10 +27,6 @@ TfLiteStatus MinMaxOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, int b_tensor_id = inputs->data[1]; const auto& a_tensor = context->tensors[a_tensor_id]; const auto& b_tensor = context->tensors[b_tensor_id]; - if (a_tensor.allocation_type == kTfLiteMmapRo) - graph_builder_->AddConstNodeWithData(a_tensor_id, a_tensor); - if (b_tensor.allocation_type == kTfLiteMmapRo) - graph_builder_->AddConstNodeWithData(b_tensor_id, b_tensor); AddInput(graph_builder_->GetHexagonTensorId(a_tensor_id)); AddInput(graph_builder_->GetHexagonTensorId(b_tensor_id)); diff --git a/tensorflow/lite/delegates/hexagon/builders/op_builder.cc b/tensorflow/lite/delegates/hexagon/builders/op_builder.cc index 0f32a4de6e1..80aa4c8155c 100644 --- a/tensorflow/lite/delegates/hexagon/builders/op_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/op_builder.cc @@ -18,10 +18,59 @@ limitations under the License. #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/delegates/hexagon/builders/op_factory.h" +#include namespace tflite { namespace delegates { namespace hexagon { +namespace { +// Farmhash Fingerprint +inline uint64_t CombineFingerprints(uint64_t l, uint64_t h) { + // Murmur-inspired hashing. + const uint64_t kMul = 0x9ddfea08eb382d69ULL; + uint64_t a = (l ^ h) * kMul; + a ^= (a >> 47); + uint64_t b = (h ^ a) * kMul; + b ^= (b >> 44); + b *= kMul; + b ^= (b >> 41); + b *= kMul; + return b; +} + +inline uint64_t ComputeHash(const int shape[], const char* data, + const int data_len) { + return CombineFingerprints( + ::util::Fingerprint64(data, data_len), + ::util::Fingerprint64(reinterpret_cast(shape), + sizeof(shape[0]) * 4)); +} + +inline uint64_t ComputeHash(const TfLiteTensor& tensor, const int shape[], + int int8_to_uint8) { + auto data_hash = ComputeHash(shape, tensor.data.raw_const, tensor.bytes); + auto int8_to_uint8_hash = ::util::Fingerprint64( + reinterpret_cast(&int8_to_uint8), sizeof(int8_to_uint8)); + return CombineFingerprints(data_hash, int8_to_uint8_hash); +} + +int GetElementSize(TfLiteType type) { + switch (type) { + case kTfLiteFloat32: + return sizeof(float); + case kTfLiteBool: + return sizeof(bool); + case kTfLiteInt32: + return sizeof(int32_t); + case kTfLiteInt8: + return sizeof(int8_t); + case kTfLiteUInt8: + return sizeof(uint8_t); + default: + return sizeof(int8_t); + } +} +} // namespace OpBuilder* GraphBuilder::CreateOpBuilderFromTfLiteOp(int op_type, TfLiteNode* node) { @@ -116,8 +165,20 @@ OpBuilder* GraphBuilder::CreateOpBuilderFromTfLiteOp(int op_type, } } +OpBuilder* GraphBuilder::LookupConstData(uint64_t cache_key) { + auto lookup_result = cache_.find(cache_key); + if (lookup_result != cache_.end()) return lookup_result->second; + return nullptr; +} + +void GraphBuilder::AddToCache(uint64_t cache_key, OpBuilder* value) { + cache_[cache_key] = value; +} + OpBuilder* GraphBuilder::AddConstNodeWithData(const int shape[], char* data, int data_size) { + auto cache_key = ComputeHash(shape, data, data_size); + if (auto lookup_result = LookupConstData(cache_key)) return lookup_result; builders_.emplace_back(new OpBuilder(this, OP_Const)); builders_.back()->SetConstNode(); builders_.back()->SetNodeId(builders_.size()); @@ -125,22 +186,36 @@ OpBuilder* GraphBuilder::AddConstNodeWithData(const int shape[], char* data, graph_id_, builders_.size(), shape[0], shape[1], shape[2], shape[3], reinterpret_cast(data), data_size); if (error != 0) { - context_->ReportError(context_, "Error adding const node with shape id: %d", - (int)builders_.size()); + TF_LITE_KERNEL_LOG(context_, "Error adding const node with shape id: %d", + static_cast(builders_.size())); return nullptr; } + AddToCache(cache_key, builders_.back().get()); return builders_.back().get(); } OpBuilder* GraphBuilder::AddConstNodeWithData(int tensor_id, const TfLiteTensor& tensor, bool int8_to_uint8) { + // Fetch shape of tensor and pad 1's so it is always 4D. + int batch_size, height_size, width_size, depth_size; + GetDims(&batch_size, &height_size, &width_size, &depth_size, tensor.dims); + const int shape[] = {batch_size, height_size, width_size, depth_size}; + + auto cache_key = ComputeHash(tensor, shape, int8_to_uint8 ? 1 : 0); + if (auto lookup_result = LookupConstData(cache_key)) { + // If tensor is cached but with no id, that can happen when the same + // data is added from a constant value (not tensor). We can cache the data + // and reuse it. + // We assign the tensor to this cached const node before returning. + if (!HasTensor(tensor_id)) + AddTensorWithID(tensor_id, lookup_result->GetID(), 0); + return lookup_result; + } builders_.emplace_back(new OpBuilder(this, OP_Const)); const int node_id = builders_.size(); builders_.back()->SetConstNode(); builders_.back()->SetNodeId(node_id); - int batch_size, height_size, width_size, depth_size; - GetDims(&batch_size, &height_size, &width_size, &depth_size, tensor.dims); int error = hexagon_nn_->hexagon_nn_append_const_node( graph_id_, node_id, batch_size, height_size, width_size, depth_size, reinterpret_cast(tensor.data.raw), tensor.bytes); @@ -150,19 +225,26 @@ OpBuilder* GraphBuilder::AddConstNodeWithData(int tensor_id, return nullptr; } AddTensorWithID(tensor_id, node_id, 0); + // We need to return the builder with result, so we can't rely + // on builders_.back() as it can change while casting, so we hold pointer + // and update with value from casting if needed. + OpBuilder* result_builder = builders_.back().get(); // Cast int8 to uint8 if requested. // This will add cast op to uint8 and update tensor map to point // to the casted tensor. if (int8_to_uint8 && tensor.type == kTfLiteInt8) { - AddCastOp(context_, OP_Quantized_CastInt8ToUInt8, tensor_id); + AddCastOp(context_, OP_Quantized_CastInt8ToUInt8, tensor_id, + &result_builder); } - return builders_.back().get(); + AddToCache(cache_key, result_builder); + return result_builder; } // TODO(b/154604279): Support these casting ops in Hexagon op profiling (which // seems to key tensors on a single op, which may not be the case now). TfLiteStatus GraphBuilder::AddCastOp(TfLiteContext* context, int op_type, - int tensor_id) { + int tensor_id, + OpBuilder** cast_op_builder) { // Create a new OpBuilder for casting the tensor. OpBuilder* cast_builder = CreateCastBuilder(this, op_type); builders_.emplace_back(cast_builder); @@ -177,6 +259,7 @@ TfLiteStatus GraphBuilder::AddCastOp(TfLiteContext* context, int op_type, TF_LITE_ENSURE_STATUS(cast_builder->RegisterOutputs(tensor_data, context)); TfLiteIntArrayFree(tensor_data); + if (cast_op_builder != nullptr) *cast_op_builder = cast_builder; return kTfLiteOk; } @@ -192,12 +275,12 @@ TfLiteStatus GraphBuilder::AddInputTensors(const TfLiteIntArray* input_tensors, const int tensor_id = input_tensors->data[i]; const auto& tensor = context->tensors[tensor_id]; if (tensor.allocation_type == kTfLiteMmapRo) continue; - input_op->AddOutput(tensor.dims); + input_op->AddOutput(tensor.dims, GetElementSize(tensor.type)); AddTensorWithID(tensor_id, input_op->GetID(), num_inputs); // If tensor is of type int8, add an op to cast it to uint8. if (tensor.type == kTfLiteInt8) { - TF_LITE_ENSURE_STATUS( - AddCastOp(context, OP_Quantized_CastInt8ToUInt8, tensor_id)); + TF_LITE_ENSURE_STATUS(AddCastOp(context, OP_Quantized_CastInt8ToUInt8, + tensor_id, /*cast_op_builder=*/nullptr)); } ++num_inputs; } @@ -215,8 +298,8 @@ TfLiteStatus GraphBuilder::AddOutputTensors( const auto& tensor = context->tensors[tensor_id]; // If tensor is of type int8, add an op to cast it to uint8. if (tensor.type == kTfLiteInt8) { - TF_LITE_ENSURE_STATUS( - AddCastOp(context, OP_Quantized_CastUInt8ToInt8, tensor_id)); + TF_LITE_ENSURE_STATUS(AddCastOp(context, OP_Quantized_CastUInt8ToInt8, + tensor_id, /*cast_op_builder=*/nullptr)); } hexagon_output_ids.push_back(GetHexagonTensorId(tensor_id)); } @@ -231,9 +314,10 @@ TfLiteStatus GraphBuilder::AddOutputTensors( return kTfLiteOk; } -OpBuilder::TensorID OpBuilder::AddOutput(const TfLiteIntArray* dims) { +OpBuilder::TensorID OpBuilder::AddOutput(const TfLiteIntArray* dims, + int element_size) { op_node_.outputs.push_back(hexagon_nn_output()); - op_node_.outputs.back().elementsize = sizeof(uint8_t); + op_node_.outputs.back().elementsize = element_size; op_node_.outputs.back().rank = 4; // TODO(karimnosseir): What is a good to estimate the max size ? int batch_size, height_size, width_size, depth_size; diff --git a/tensorflow/lite/delegates/hexagon/builders/op_builder.h b/tensorflow/lite/delegates/hexagon/builders/op_builder.h index 52b130c756f..c2a2889b142 100644 --- a/tensorflow/lite/delegates/hexagon/builders/op_builder.h +++ b/tensorflow/lite/delegates/hexagon/builders/op_builder.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_HEXAGON_BUILDERS_OP_BUILDER_H_ #include +#include #include #include #include @@ -131,9 +132,9 @@ class OpBuilder { void AddInput(const TensorID& tensor_id) { input_ids_.push_back(tensor_id); } // Adds Output to the current node, the output has shape defined in 'dims'. - // This assumes the data type is uint8. + // The size of each element is defined using 'element_size'. // Returns the TensorID identifying this output in the graph. - TensorID AddOutput(const TfLiteIntArray* dims); + TensorID AddOutput(const TfLiteIntArray* dims, int element_size); // Adds Output to the current node, each element in the output has // size 'elementsize' and rank 'rank' and for each dimension in the output @@ -316,11 +317,22 @@ class GraphBuilder { bool AddTensorWithID(int tflite_tensor_id, int hexagon_node_id, int hexagon_node_output_id, bool overwrite = false) { if (!overwrite && HasTensor(tflite_tensor_id)) { + TF_LITE_KERNEL_LOG( + context_, + "Trying to add duplicate tensor without overwrite, tflite_tensor_id " + "%d, hexagon_node_id %d, hexagon_node_output_id %d", + tflite_tensor_id, hexagon_node_id, hexagon_node_output_id); return false; } if (tensors_.size() <= tflite_tensor_id) { tensors_.resize(tflite_tensor_id + 1); } + if (hexagon_node_id == -1 || hexagon_node_output_id == -1) + TF_LITE_KERNEL_LOG(context_, + "Trying to add invalid id, tflite_tensor_id " + "%d, hexagon_node_id %d, hexagon_node_output_id %d", + tflite_tensor_id, hexagon_node_id, + hexagon_node_output_id); tensors_[tflite_tensor_id] = OpBuilder::TensorID(hexagon_node_id, hexagon_node_output_id); return true; @@ -348,6 +360,14 @@ class GraphBuilder { int GetMaxBatchSize() const { return max_size_for_batch_; } private: + // Lookup in cache if data with key 'cache_key' is present. + // Return OpBuilder* for the data if found, nullptr otherwise. + OpBuilder* LookupConstData(uint64_t cache_key); + + // Inserts 'value' in cache, with key equals 'cache_key'. + // If data in cache with same key then it will be overwritten. + void AddToCache(uint64_t cache_key, OpBuilder* value); + // Helper method to fetch dimensions. // TODO(karimnosseir): Move this method to shared place. void GetDims(int* batch_size, int* height_size, int* width_size, @@ -360,7 +380,10 @@ class GraphBuilder { } // Adds a Cast op to convert a tensor from int8 to uint8 (or vice versa). - TfLiteStatus AddCastOp(TfLiteContext* context, int op_type, int tensor_id); + // The builder which has the casting operator is filled in 'cast_op_builder' + // if not nullptr. + TfLiteStatus AddCastOp(TfLiteContext* context, int op_type, int tensor_id, + OpBuilder** cast_op_builder); const HexagonNN* hexagon_nn_ = nullptr; TfLiteContext* context_ = nullptr; @@ -373,6 +396,11 @@ class GraphBuilder { // If the graph being built supports dynamic batch, this represents // the maximum value for batch. int max_size_for_batch_ = -1; + + // Cache for const data in the graph. + // Key is hash of the data, value is pointer to the OpBuilder* for the added + // data. + std::map cache_; }; } // namespace hexagon diff --git a/tensorflow/lite/delegates/hexagon/builders/transpose_builder.cc b/tensorflow/lite/delegates/hexagon/builders/transpose_builder.cc index 4a7304d011e..eb0c2668edc 100644 --- a/tensorflow/lite/delegates/hexagon/builders/transpose_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/transpose_builder.cc @@ -29,15 +29,7 @@ TfLiteStatus TransposeOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs, const auto& input_tensor = context->tensors[tensor_id]; AddInput(graph_builder_->GetHexagonTensorId(tensor_id)); // permutation tensor. - tensor_id = inputs->data[1]; - const auto& control_tensor = context->tensors[tensor_id]; - if (control_tensor.allocation_type == kTfLiteMmapRo) { - auto* const_control_tensor_node = - graph_builder_->AddConstNodeWithData(tensor_id, control_tensor); - AddInput(TensorID(const_control_tensor_node->GetID(), 0)); - } else { - AddInput(graph_builder_->GetHexagonTensorId(tensor_id)); - } + AddInput(graph_builder_->GetHexagonTensorId(inputs->data[1])); TF_LITE_ENSURE_STATUS(ComputeAndAddMinAndMax(context, input_tensor)); diff --git a/tensorflow/lite/delegates/hexagon/builders/transpose_conv_2d_builder.cc b/tensorflow/lite/delegates/hexagon/builders/transpose_conv_2d_builder.cc index d2620f71007..3e852533394 100644 --- a/tensorflow/lite/delegates/hexagon/builders/transpose_conv_2d_builder.cc +++ b/tensorflow/lite/delegates/hexagon/builders/transpose_conv_2d_builder.cc @@ -97,8 +97,6 @@ TfLiteStatus TransposeConv2dOpBuilder::PopulateSubGraph( filter_depth_size; GetDims(&filter_batch_size, &filter_height_size, &filter_width_size, &filter_depth_size, weights_tensor.dims); - weight_shape_ = {filter_batch_size, filter_height_size, filter_width_size, - filter_depth_size}; // Weights tensor could be int8 even for per-tensor quantization. // Therefore, we look at the number of scale values to check if it is // per-channel quantized. @@ -106,25 +104,7 @@ TfLiteStatus TransposeConv2dOpBuilder::PopulateSubGraph( reinterpret_cast( weights_tensor.quantization.params); const bool is_per_channel_quant = weights_quant_params->scale->size > 1; - - OpBuilder* const_weights_node; - if (weights_tensor.type == kTfLiteInt8) { - std::vector weights_data(NumElements(&weights_tensor)); - const int8_t* original_data = weights_tensor.data.int8; - // Flip bits on the weight values so that the int8 values are treated - // as uint8. - for (int i = 0; i < NumElements(&weights_tensor); ++i) { - weights_data[i] = original_data[i] ^ k8BitSignFlipConstant; - } - const_weights_node = graph_builder_->AddConstNodeWithData( - weight_shape_.data(), reinterpret_cast(weights_data.data()), - weights_data.size() * sizeof(weights_data[0])); - } else { - const_weights_node = graph_builder_->AddConstNodeWithData( - weight_shape_.data(), weights_tensor.data.raw, weights_tensor.bytes); - } - graph_builder_->AddTensorWithID(tensor_id, const_weights_node->GetID(), 0); - AddInput(TensorID(const_weights_node->GetID(), 0)); + AddInput(graph_builder_->GetHexagonTensorId(tensor_id)); // Handle weights quantization. float weights_min = 0; diff --git a/tensorflow/lite/delegates/hexagon/builders/transpose_conv_2d_builder.h b/tensorflow/lite/delegates/hexagon/builders/transpose_conv_2d_builder.h index 0a6a90a0297..4afab9894f0 100644 --- a/tensorflow/lite/delegates/hexagon/builders/transpose_conv_2d_builder.h +++ b/tensorflow/lite/delegates/hexagon/builders/transpose_conv_2d_builder.h @@ -47,7 +47,7 @@ class TransposeConv2dOpBuilder : public OpBuilder { TensorID node_output_; std::vector transposed_weights_; std::vector stride_shape_; - std::vector weight_shape_, bias_shape_; + std::vector bias_shape_; std::vector bias_data_; // Non-null only if node has per-channel quantized weights/biases. diff --git a/tensorflow/lite/delegates/hexagon/hexagon_delegate_kernel.cc b/tensorflow/lite/delegates/hexagon/hexagon_delegate_kernel.cc index cdf6b555929..83ebc15510e 100644 --- a/tensorflow/lite/delegates/hexagon/hexagon_delegate_kernel.cc +++ b/tensorflow/lite/delegates/hexagon/hexagon_delegate_kernel.cc @@ -264,8 +264,9 @@ TfLiteStatus HexagonDelegateKernel::BuildGraph( if (tensor_id == -1) continue; const auto& input_tensor = context->tensors[tensor_id]; if (input_tensor.allocation_type == kTfLiteMmapRo) { - builder_->AddConstNodeWithData(tensor_id, input_tensor, - /*int8_to_uint8*/ true); + builder_->AddConstNodeWithData( + tensor_id, input_tensor, + /*int8_to_uint8*/ (input_tensor.type == kTfLiteInt8)); } } auto* op_builder =