Merge pull request #26655 from Dayananda-V:tflite_hybrid_op

PiperOrigin-RevId: 239124760
This commit is contained in:
TensorFlower Gardener 2019-03-18 22:06:32 -07:00
commit 9201689e42
8 changed files with 15 additions and 28 deletions

View File

@ -95,9 +95,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size_array));
bool is_hybrid =
input->type == kTfLiteFloat32 && (input_weights->type == kTfLiteUInt8 ||
input_weights->type == kTfLiteInt8);
const bool is_hybrid = IsHybridOp(input, input_weights);
// Allocate temporary tensors to store quantized values of input and
// hidden_state tensors.

View File

@ -509,8 +509,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, fw_output, fw_output_size));
// The weights are of consistent type, so it suffices to check one.
const bool is_hybrid_op = (fw_input_to_output_weights->type == kTfLiteUInt8 ||
fw_input_to_output_weights->type == kTfLiteInt8);
const bool is_hybrid_op = IsHybridOp(input, fw_input_to_output_weights);
TfLiteIntArrayFree(node->temporaries);
if (is_hybrid_op) {

View File

@ -168,11 +168,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bw_aux_input_weights->dims->data[1]);
}
const bool is_hybrid_op = ((fw_input_weights->type == kTfLiteUInt8 ||
fw_input_weights->type == kTfLiteInt8) &&
input->type == kTfLiteFloat32);
if (is_hybrid_op) {
if (IsHybridOp(input, fw_input_weights)) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
TfLiteIntArrayFree(node->temporaries);

View File

@ -84,6 +84,13 @@ inline void SetTensorToDynamic(TfLiteTensor* tensor) {
}
}
// Determines whether it is a hybrid op - one that has float inputs and
// quantized weights.
inline bool IsHybridOp(const TfLiteTensor* input, const TfLiteTensor* weight) {
return ((weight->type == kTfLiteUInt8 || weight->type == kTfLiteInt8) &&
input->type == kTfLiteFloat32);
}
// Check dimensionality match and populate OpData for Conv and DepthwiseConv.
TfLiteStatus PopulateConvolutionQuantizationParams(
TfLiteContext* context, const TfLiteTensor* input,

View File

@ -381,10 +381,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, output, output_size));
// The weights are of consistent type, so it suffices to check one.
// TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
const bool is_hybrid_op = ((input_to_output_weights->type == kTfLiteUInt8 ||
input_to_output_weights->type == kTfLiteInt8) &&
input->type == kTfLiteFloat32);
const bool is_hybrid_op = IsHybridOp(input, input_to_output_weights);
TfLiteIntArrayFree(node->temporaries);
if (is_hybrid_op) {

View File

@ -176,9 +176,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, output, output_size_array));
// The weights are of consistent type, so it suffices to check one.
const bool is_hybrid_op = (input->type == kTfLiteFloat32 &&
(weights_feature->type == kTfLiteUInt8 ||
weights_feature->type == kTfLiteInt8));
const bool is_hybrid_op = IsHybridOp(input, weights_feature);
// Resize scratch.
TfLiteIntArrayFree(node->temporaries);

View File

@ -304,14 +304,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
// The weights are of consistent type, so it suffices to check one.
// TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
const bool is_hybrid_op = ((input_to_output_weights->type == kTfLiteUInt8 ||
input_to_output_weights->type == kTfLiteInt8) &&
input->type == kTfLiteFloat32);
TfLiteIntArrayFree(node->temporaries);
if (is_hybrid_op) {
if (IsHybridOp(input, input_to_output_weights)) {
node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
} else {
node->temporaries = TfLiteIntArrayCreate(1);
@ -338,7 +332,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
scratch_buffer_size));
if (is_hybrid_op) {
if (IsHybridOp(input, input_to_output_weights)) {
// Allocate temporary tensors to store quantized values of input,
// activation_state and cell_state tensors.
node->temporaries->data[kInputQuantized] =

View File

@ -96,9 +96,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size_array));
const bool is_hybrid =
input->type == kTfLiteFloat32 && (input_weights->type == kTfLiteUInt8 ||
input_weights->type == kTfLiteInt8);
const bool is_hybrid = IsHybridOp(input, input_weights);
// Allocate temporary tensors to store quantized values of input and
// hidden_state tensors.