Merge pull request #42059 from Tessil:toupstream/16x8_batch_matmul
PiperOrigin-RevId: 336667926 Change-Id: I0d33c9daf62372606b59bcad89f29157c61b3fc7
This commit is contained in:
commit
296393e947
@ -314,7 +314,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
// Note that quantized inference requires that all tensors have their
|
// Note that quantized inference requires that all tensors have their
|
||||||
// parameters set. This is usually done during quantized training.
|
// parameters set. This is usually done during quantized training.
|
||||||
if (lhs_data->type == kTfLiteInt8) {
|
if (lhs_data->type == kTfLiteInt8 || lhs_data->type == kTfLiteInt16) {
|
||||||
double real_multiplier = 0.0;
|
double real_multiplier = 0.0;
|
||||||
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
|
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
|
||||||
context, lhs_data, rhs_data, output, &real_multiplier));
|
context, lhs_data, rhs_data, output, &real_multiplier));
|
||||||
@ -322,16 +322,34 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
QuantizeMultiplier(real_multiplier, &op_data->output_multiplier, &exponent);
|
QuantizeMultiplier(real_multiplier, &op_data->output_multiplier, &exponent);
|
||||||
op_data->output_shift = exponent;
|
op_data->output_shift = exponent;
|
||||||
// BatchMatMul has no fused activation functions. Therefore, set
|
// BatchMatMul has no fused activation functions. Therefore, set
|
||||||
// output activation min and max to min and max of int8_t type,
|
// output activation min and max to min and max of int8_t or int16_t
|
||||||
// respecitvely.
|
// type.
|
||||||
op_data->output_activation_min = std::numeric_limits<int8_t>::min();
|
if (lhs_data->type == kTfLiteInt8) {
|
||||||
op_data->output_activation_max = std::numeric_limits<int8_t>::max();
|
op_data->output_activation_min = std::numeric_limits<int8_t>::min();
|
||||||
|
op_data->output_activation_max = std::numeric_limits<int8_t>::max();
|
||||||
|
} else {
|
||||||
|
op_data->output_activation_min = std::numeric_limits<int16_t>::min();
|
||||||
|
op_data->output_activation_max = std::numeric_limits<int16_t>::max();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lhs_data->type == kTfLiteInt16) {
|
||||||
|
TF_LITE_ENSURE_EQ(context, lhs_data->params.zero_point, 0);
|
||||||
|
TF_LITE_ENSURE_EQ(context, rhs_data->params.zero_point, 0);
|
||||||
|
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_LITE_ENSURE(context, lhs_data->type == kTfLiteFloat32 ||
|
TF_LITE_ENSURE(context, lhs_data->type == kTfLiteFloat32 ||
|
||||||
lhs_data->type == kTfLiteInt8);
|
lhs_data->type == kTfLiteInt8 ||
|
||||||
|
lhs_data->type == kTfLiteInt16);
|
||||||
TF_LITE_ENSURE(context, rhs_data->type == kTfLiteFloat32 ||
|
TF_LITE_ENSURE(context, rhs_data->type == kTfLiteFloat32 ||
|
||||||
rhs_data->type == kTfLiteInt8);
|
rhs_data->type == kTfLiteInt8 ||
|
||||||
|
rhs_data->type == kTfLiteInt16);
|
||||||
|
// Either we have a hybrid quantization with a float32 and an int8 input,
|
||||||
|
// otherwise both inputs should be of the same type.
|
||||||
|
TF_LITE_ENSURE(context, (lhs_data->type == kTfLiteFloat32 &&
|
||||||
|
rhs_data->type == kTfLiteInt8) ||
|
||||||
|
lhs_data->type == rhs_data->type);
|
||||||
// Support dimensions between 2 and 4, inclusive.
|
// Support dimensions between 2 and 4, inclusive.
|
||||||
TF_LITE_ENSURE(context, NumDimensions(lhs_data) >= 2);
|
TF_LITE_ENSURE(context, NumDimensions(lhs_data) >= 2);
|
||||||
TF_LITE_ENSURE(context, NumDimensions(lhs_data) <= 4);
|
TF_LITE_ENSURE(context, NumDimensions(lhs_data) <= 4);
|
||||||
@ -402,9 +420,14 @@ TfLiteStatus TransposeRowsColumns(TfLiteContext* context,
|
|||||||
tensor_in, GetTensorData<int8_t>(tensor_in), tensor_out,
|
tensor_in, GetTensorData<int8_t>(tensor_in), tensor_out,
|
||||||
GetTensorData<int8_t>(tensor_out));
|
GetTensorData<int8_t>(tensor_out));
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
|
} else if (tensor_in->type == kTfLiteInt16) {
|
||||||
|
TransposeRowsColumnsImpl<int16_t>(
|
||||||
|
tensor_in, GetTensorData<int16_t>(tensor_in), tensor_out,
|
||||||
|
GetTensorData<int16_t>(tensor_out));
|
||||||
|
return kTfLiteOk;
|
||||||
} else {
|
} else {
|
||||||
TF_LITE_KERNEL_LOG(context,
|
TF_LITE_KERNEL_LOG(
|
||||||
"Can only transpose tensors with float and int8 type.");
|
context, "Can only transpose tensors with float, int8 or int16 type.");
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -501,10 +524,10 @@ TfLiteStatus EvalInt8(TfLiteContext* context, const OpData* data,
|
|||||||
op_params.rhs_cacheable = IsConstantTensor(rhs);
|
op_params.rhs_cacheable = IsConstantTensor(rhs);
|
||||||
|
|
||||||
if (kernel_type == kReference) {
|
if (kernel_type == kReference) {
|
||||||
reference_ops::BatchMatMul(op_params, rhs_shape, GetTensorData<int8_t>(rhs),
|
reference_ops::BatchMatMul<int8_t, int32_t>(
|
||||||
lhs_shape, GetTensorData<int8_t>(lhs),
|
op_params, rhs_shape, GetTensorData<int8_t>(rhs), lhs_shape,
|
||||||
GetTensorShape(output),
|
GetTensorData<int8_t>(lhs), GetTensorShape(output),
|
||||||
GetTensorData<int8_t>(output));
|
GetTensorData<int8_t>(output));
|
||||||
} else {
|
} else {
|
||||||
optimized_ops::BatchMatMul(op_params, rhs_shape, GetTensorData<int8_t>(rhs),
|
optimized_ops::BatchMatMul(op_params, rhs_shape, GetTensorData<int8_t>(rhs),
|
||||||
lhs_shape, GetTensorData<int8_t>(lhs),
|
lhs_shape, GetTensorData<int8_t>(lhs),
|
||||||
@ -515,13 +538,40 @@ TfLiteStatus EvalInt8(TfLiteContext* context, const OpData* data,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <KernelType kernel_type>
|
||||||
|
TfLiteStatus EvalInt16(TfLiteContext* context, const OpData* data,
|
||||||
|
const RuntimeShape& lhs_shape, const TfLiteTensor* lhs,
|
||||||
|
const RuntimeShape& rhs_shape, const TfLiteTensor* rhs,
|
||||||
|
const RuntimeShape& output_shape, TfLiteTensor* output) {
|
||||||
|
// Reuse params struct from FullyConnected Op.
|
||||||
|
FullyConnectedParams op_params;
|
||||||
|
int32_t input_offset = -lhs->params.zero_point;
|
||||||
|
int32_t filter_offset = -rhs->params.zero_point;
|
||||||
|
int32_t output_offset = output->params.zero_point;
|
||||||
|
op_params.input_offset = input_offset;
|
||||||
|
op_params.weights_offset = filter_offset;
|
||||||
|
op_params.output_offset = output_offset;
|
||||||
|
op_params.output_multiplier = data->output_multiplier;
|
||||||
|
op_params.output_shift = data->output_shift;
|
||||||
|
op_params.quantized_activation_min = data->output_activation_min;
|
||||||
|
op_params.quantized_activation_max = data->output_activation_max;
|
||||||
|
|
||||||
|
// optimized_ops not yet implemnted for int16_t, use reference_ops in all
|
||||||
|
// cases.
|
||||||
|
reference_ops::BatchMatMul<int16_t, int64_t>(
|
||||||
|
op_params, rhs_shape, GetTensorData<int16_t>(rhs), lhs_shape,
|
||||||
|
GetTensorData<int16_t>(lhs), GetTensorShape(output),
|
||||||
|
GetTensorData<int16_t>(output));
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
template <KernelType kernel_type>
|
template <KernelType kernel_type>
|
||||||
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||||
OpData* data, const RuntimeShape& lhs_shape,
|
OpData* data, const RuntimeShape& lhs_shape,
|
||||||
const TfLiteTensor* lhs,
|
const TfLiteTensor* lhs,
|
||||||
const RuntimeShape& rhs_shape,
|
const RuntimeShape& rhs_shape,
|
||||||
const TfLiteTensor* rhs, TfLiteTensor* output) {
|
const TfLiteTensor* rhs, TfLiteTensor* output) {
|
||||||
if (lhs->type == kTfLiteFloat32) {
|
if (lhs->type == kTfLiteFloat32 && rhs->type == kTfLiteInt8) {
|
||||||
TfLiteTensor* input_quantized;
|
TfLiteTensor* input_quantized;
|
||||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
|
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
|
||||||
&input_quantized));
|
&input_quantized));
|
||||||
@ -540,12 +590,16 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
return EvalHybrid<kernel_type>(
|
return EvalHybrid<kernel_type>(
|
||||||
context, node, data, lhs_shape, lhs, rhs_shape, rhs, input_quantized,
|
context, node, data, lhs_shape, lhs, rhs_shape, rhs, input_quantized,
|
||||||
scaling_factors, accum_scratch, row_sums, input_offsets, output);
|
scaling_factors, accum_scratch, row_sums, input_offsets, output);
|
||||||
} else if (lhs->type == kTfLiteInt8) {
|
} else if (lhs->type == kTfLiteInt8 && rhs->type == kTfLiteInt8) {
|
||||||
return EvalInt8<kernel_type>(context, data, lhs_shape, lhs, rhs_shape, rhs,
|
return EvalInt8<kernel_type>(context, data, lhs_shape, lhs, rhs_shape, rhs,
|
||||||
GetTensorShape(output), output);
|
GetTensorShape(output), output);
|
||||||
|
} else if (lhs->type == kTfLiteInt16 && rhs->type == kTfLiteInt16) {
|
||||||
|
return EvalInt16<kernel_type>(context, data, lhs_shape, lhs, rhs_shape, rhs,
|
||||||
|
GetTensorShape(output), output);
|
||||||
} else {
|
} else {
|
||||||
TF_LITE_KERNEL_LOG(
|
TF_LITE_KERNEL_LOG(
|
||||||
context, "Currently only hybrid and int8 quantization is supported.\n");
|
context,
|
||||||
|
"Currently only hybrid, int8 and int16 quantization are supported.\n");
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
@ -558,7 +612,7 @@ TfLiteTensor* GetTempRhs(TfLiteContext* context, TfLiteNode* node,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (rhs->type == kTfLiteInt8) {
|
if (rhs->type == kTfLiteInt8 || rhs->type == kTfLiteInt16) {
|
||||||
// Get the quantization params from the RHS tensor.
|
// Get the quantization params from the RHS tensor.
|
||||||
transposed_rhs->params.scale = rhs->params.scale;
|
transposed_rhs->params.scale = rhs->params.scale;
|
||||||
transposed_rhs->params.zero_point = rhs->params.zero_point;
|
transposed_rhs->params.zero_point = rhs->params.zero_point;
|
||||||
@ -573,7 +627,7 @@ TfLiteTensor* GetTempLhs(TfLiteContext* context, TfLiteNode* node,
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lhs->type == kTfLiteInt8) {
|
if (lhs->type == kTfLiteInt8 || lhs->type == kTfLiteInt16) {
|
||||||
// Get the quantization params from the LHS tensor.
|
// Get the quantization params from the LHS tensor.
|
||||||
transposed_lhs->params.scale = lhs->params.scale;
|
transposed_lhs->params.scale = lhs->params.scale;
|
||||||
transposed_lhs->params.zero_point = lhs->params.zero_point;
|
transposed_lhs->params.zero_point = lhs->params.zero_point;
|
||||||
@ -646,6 +700,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case kTfLiteInt8:
|
case kTfLiteInt8:
|
||||||
|
case kTfLiteInt16:
|
||||||
EvalQuantized<kernel_type>(context, node, op_data, lhs_shape, lhs_tensor,
|
EvalQuantized<kernel_type>(context, node, op_data, lhs_shape, lhs_tensor,
|
||||||
rhs_shape, rhs_tensor, output);
|
rhs_shape, rhs_tensor, output);
|
||||||
break;
|
break;
|
||||||
|
@ -483,7 +483,12 @@ class QuantizedBatchMatMulOpModel : public SingleOpModel {
|
|||||||
input_size_ = total_input_size / batches_;
|
input_size_ = total_input_size / batches_;
|
||||||
|
|
||||||
lhs_id_ = AddInput(lhs);
|
lhs_id_ = AddInput(lhs);
|
||||||
rhs_id_ = AddInput({lhs.type, {input_size_, units_}, lhs.min, lhs.max});
|
rhs_id_ = AddInput({lhs.type,
|
||||||
|
{input_size_, units_},
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
GetScale(lhs_id_),
|
||||||
|
GetZeroPoint(lhs_id_)});
|
||||||
|
|
||||||
output_id_ = AddOutput(output);
|
output_id_ = AddOutput(output);
|
||||||
|
|
||||||
@ -553,6 +558,35 @@ TEST_P(QuantizedBatchMatMulOpTest, SimpleTestQuantizedInt8) {
|
|||||||
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(22, 22, 22, 56, 56, 56));
|
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(22, 22, 22, 56, 56, 56));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(QuantizedBatchMatMulOpTest, SimpleTestQuantizedInt16) {
|
||||||
|
const float inputs_scale = 10.0 / std::numeric_limits<int16_t>::max();
|
||||||
|
const float output_scale = 1.0;
|
||||||
|
const int32_t zero_point = 0;
|
||||||
|
|
||||||
|
QuantizedBatchMatMulOpModel m(
|
||||||
|
/*units=*/3, /*batches*/ 2,
|
||||||
|
/*lhs=*/
|
||||||
|
{TensorType_INT16, {2, 10}, 0, 0, inputs_scale, zero_point},
|
||||||
|
/*output=*/
|
||||||
|
{TensorType_INT16, {}, 0, 0, output_scale, zero_point});
|
||||||
|
|
||||||
|
m.SetWeights<int16_t>({
|
||||||
|
1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5,
|
||||||
|
6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10,
|
||||||
|
});
|
||||||
|
|
||||||
|
m.SetInput<int16_t>({
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
|
||||||
|
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
|
||||||
|
});
|
||||||
|
|
||||||
|
m.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({23, 23, 23, 57, 57, 57})));
|
||||||
|
EXPECT_THAT(m.GetOutput<int16_t>(), ElementsAre(23, 23, 23, 57, 57, 57));
|
||||||
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
QuantizedBatchMatMulOpTest, QuantizedBatchMatMulOpTest,
|
QuantizedBatchMatMulOpTest, QuantizedBatchMatMulOpTest,
|
||||||
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
|
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
|
||||||
|
@ -217,10 +217,11 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename AccumT>
|
||||||
inline void BatchMatMul(const FullyConnectedParams& params,
|
inline void BatchMatMul(const FullyConnectedParams& params,
|
||||||
const RuntimeShape& lhs_shape, const int8_t* lhs_data,
|
const RuntimeShape& lhs_shape, const T* lhs_data,
|
||||||
const RuntimeShape& rhs_shape, const int8_t* rhs_data,
|
const RuntimeShape& rhs_shape, const T* rhs_data,
|
||||||
const RuntimeShape& output_shape, int8_t* output_data) {
|
const RuntimeShape& output_shape, T* output_data) {
|
||||||
const RuntimeShape extended_lhs_shape =
|
const RuntimeShape extended_lhs_shape =
|
||||||
RuntimeShape::ExtendedShape(5, lhs_shape);
|
RuntimeShape::ExtendedShape(5, lhs_shape);
|
||||||
const RuntimeShape extended_rhs_shape =
|
const RuntimeShape extended_rhs_shape =
|
||||||
@ -276,33 +277,33 @@ inline void BatchMatMul(const FullyConnectedParams& params,
|
|||||||
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||||
|
|
||||||
for (int b0 = 0; b0 < batch_dim0; ++b0) {
|
for (int b0 = 0; b0 < batch_dim0; ++b0) {
|
||||||
const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
|
const T* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
|
||||||
const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
|
const T* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
|
||||||
for (int b1 = 0; b1 < batch_dim1; ++b1) {
|
for (int b1 = 0; b1 < batch_dim1; ++b1) {
|
||||||
const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
|
const T* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
|
||||||
const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
|
const T* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
|
||||||
for (int b2 = 0; b2 < batch_dim2; ++b2) {
|
for (int b2 = 0; b2 < batch_dim2; ++b2) {
|
||||||
const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
|
const T* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
|
||||||
const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
|
const T* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
|
||||||
int8_t* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
|
T* out_ptr = output_data +
|
||||||
b1 * batch_dim2 + b2) *
|
((b0 * batch_dim1 * batch_dim2) + b1 * batch_dim2 + b2) *
|
||||||
lhs_rows * rhs_cols;
|
lhs_rows * rhs_cols;
|
||||||
|
|
||||||
for (int j = 0; j < rhs_cols; ++j) {
|
for (int j = 0; j < rhs_cols; ++j) {
|
||||||
for (int i = 0; i < lhs_rows; ++i) {
|
for (int i = 0; i < lhs_rows; ++i) {
|
||||||
int32_t total = 0;
|
AccumT total = 0;
|
||||||
for (int k = 0; k < accum_depth; ++k) {
|
for (int k = 0; k < accum_depth; ++k) {
|
||||||
int32_t lhs_val = lhs_ptr2[accum_depth * i + k];
|
AccumT lhs_val = lhs_ptr2[accum_depth * i + k];
|
||||||
int32_t rhs_val = rhs_ptr2[accum_depth * j + k];
|
AccumT rhs_val = rhs_ptr2[accum_depth * j + k];
|
||||||
total += (lhs_val + filter_offset) * (rhs_val + input_offset);
|
total += (lhs_val + filter_offset) * (rhs_val + input_offset);
|
||||||
}
|
}
|
||||||
total = MultiplyByQuantizedMultiplier(total, output_multiplier,
|
int32_t total_scaled = MultiplyByQuantizedMultiplier(
|
||||||
output_shift);
|
total, output_multiplier, output_shift);
|
||||||
total += output_offset;
|
total_scaled += output_offset;
|
||||||
total = std::max(total, output_activation_min);
|
total_scaled = std::max(total_scaled, output_activation_min);
|
||||||
total = std::min(total, output_activation_max);
|
total_scaled = std::min(total_scaled, output_activation_max);
|
||||||
const int idx = lhs_rows * j + i;
|
const int idx = lhs_rows * j + i;
|
||||||
out_ptr[idx] = static_cast<int8_t>(total);
|
out_ptr[idx] = static_cast<T>(total_scaled);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -293,7 +293,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
|
AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
|
||||||
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(),
|
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(),
|
||||||
/* min_version = */ 1,
|
/* min_version = */ 1,
|
||||||
/* max_version = */ 2);
|
/* max_version = */ 3);
|
||||||
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
|
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
|
||||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||||
// custom ops aren't always included by default.
|
// custom ops aren't always included by default.
|
||||||
|
@ -447,7 +447,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
|||||||
AddBuiltin(BuiltinOperator_DENSIFY, Register_DENSIFY());
|
AddBuiltin(BuiltinOperator_DENSIFY, Register_DENSIFY());
|
||||||
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL_REF(),
|
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL_REF(),
|
||||||
/* min_version = */ 1,
|
/* min_version = */ 1,
|
||||||
/* max_version = */ 2);
|
/* max_version = */ 3);
|
||||||
AddCustom("NumericVerify",
|
AddCustom("NumericVerify",
|
||||||
tflite::ops::custom::Register_NUMERIC_VERIFY_REF());
|
tflite::ops::custom::Register_NUMERIC_VERIFY_REF());
|
||||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||||
|
@ -99,6 +99,7 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||||||
property.inputs = {{0, {}}, {1, {}}};
|
property.inputs = {{0, {}}, {1, {}}};
|
||||||
property.outputs = {{0, {}}};
|
property.outputs = {{0, {}}};
|
||||||
property.version = 2;
|
property.version = 2;
|
||||||
|
property.quantize_input_as_activations = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
||||||
|
@ -543,6 +543,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||||||
return 1;
|
return 1;
|
||||||
|
|
||||||
case BuiltinOperator_CONCATENATION:
|
case BuiltinOperator_CONCATENATION:
|
||||||
|
case BuiltinOperator_BATCH_MATMUL:
|
||||||
case BuiltinOperator_SOFTMAX:
|
case BuiltinOperator_SOFTMAX:
|
||||||
case BuiltinOperator_MEAN:
|
case BuiltinOperator_MEAN:
|
||||||
case BuiltinOperator_PAD:
|
case BuiltinOperator_PAD:
|
||||||
@ -585,7 +586,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||||||
case BuiltinOperator_LESS:
|
case BuiltinOperator_LESS:
|
||||||
case BuiltinOperator_LESS_EQUAL:
|
case BuiltinOperator_LESS_EQUAL:
|
||||||
case BuiltinOperator_SELECT:
|
case BuiltinOperator_SELECT:
|
||||||
case BuiltinOperator_BATCH_MATMUL:
|
|
||||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||||
return 2;
|
return 2;
|
||||||
}
|
}
|
||||||
|
@ -59,6 +59,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
|||||||
{{BuiltinOperator_AVERAGE_POOL_2D, 3}, "2.3.0"},
|
{{BuiltinOperator_AVERAGE_POOL_2D, 3}, "2.3.0"},
|
||||||
{{BuiltinOperator_BATCH_MATMUL, 1}, "2.3.0"},
|
{{BuiltinOperator_BATCH_MATMUL, 1}, "2.3.0"},
|
||||||
{{BuiltinOperator_BATCH_MATMUL, 2}, "2.3.0"},
|
{{BuiltinOperator_BATCH_MATMUL, 2}, "2.3.0"},
|
||||||
|
{{BuiltinOperator_BATCH_MATMUL, 3}, kPendingReleaseVersion},
|
||||||
{{BuiltinOperator_CONV_2D, 1}, "1.5.0"},
|
{{BuiltinOperator_CONV_2D, 1}, "1.5.0"},
|
||||||
{{BuiltinOperator_CONV_2D, 2}, "1.14.0"},
|
{{BuiltinOperator_CONV_2D, 2}, "1.14.0"},
|
||||||
{{BuiltinOperator_CONV_2D, 3}, "1.14.0"},
|
{{BuiltinOperator_CONV_2D, 3}, "1.14.0"},
|
||||||
|
Loading…
Reference in New Issue
Block a user