Merge pull request #46224 from MohamedNourArm:toupstream/broadcast_to
PiperOrigin-RevId: 351674255 Change-Id: Ib17e4343f1e838b291b03b8fbff38e42cad672ca
This commit is contained in:
commit
1d2540ce2d
tensorflow/lite
@ -308,7 +308,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
// of builtin op code shortage problem.
|
||||
AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO(),
|
||||
/* min_version = */ 2,
|
||||
/* max_version = */ 2);
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_CALL_ONCE,
|
||||
tflite::ops::builtin::Register_CALL_ONCE());
|
||||
AddBuiltin(BuiltinOperator_RFFT2D, Register_RFFT2D());
|
||||
|
@ -268,7 +268,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
||||
// of builtin op code shortage problem.
|
||||
AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO(),
|
||||
/* min_version = */ 2,
|
||||
/* max_version = */ 2);
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
|
||||
Register_LOCAL_RESPONSE_NORM_REF());
|
||||
AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1,
|
||||
|
@ -319,6 +319,7 @@ tf_cc_test(
|
||||
data = [
|
||||
"//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/argmax.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/broadcast_to.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/concat.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/fc.bin",
|
||||
"//tensorflow/lite/tools/optimize:testdata/fc_qat.bin",
|
||||
|
@ -106,6 +106,12 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) {
|
||||
property.version = 2;
|
||||
property.quantizable_int16 = false;
|
||||
break;
|
||||
case BuiltinOperator_BROADCAST_TO:
|
||||
property.inputs = {{0, {}}};
|
||||
property.outputs = {{0, {}}};
|
||||
property.restrict_same_input_output_scale = true;
|
||||
property.version = 3;
|
||||
break;
|
||||
case BuiltinOperator_DEPTH_TO_SPACE:
|
||||
property.inputs = {{0, {}}};
|
||||
property.outputs = {{0, {}}};
|
||||
|
@ -1713,6 +1713,73 @@ TEST_F(QuantizeQatTest, VerifySingleQuantize) {
|
||||
EXPECT_EQ(model_.operator_codes[2]->version, 4);
|
||||
}
|
||||
|
||||
class QuantizeBroadcastToModelTest
|
||||
: public QuantizeModelTest,
|
||||
public testing::WithParamInterface<TensorType> {
|
||||
protected:
|
||||
QuantizeBroadcastToModelTest() {
|
||||
tensor_type_ = GetParam();
|
||||
input_model_ = ReadModel(internal::kModelWithBroadcastToOp);
|
||||
readonly_model_ = input_model_->GetModel();
|
||||
readonly_model_->UnPackTo(&model_);
|
||||
}
|
||||
TensorType tensor_type_;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(QuantizeBroadcastToModelTestInst,
|
||||
QuantizeBroadcastToModelTest,
|
||||
testing::ValuesIn({TensorType_INT8,
|
||||
TensorType_INT16}));
|
||||
|
||||
TEST_P(QuantizeBroadcastToModelTest, VerifyBroadcastToQuantization) {
|
||||
auto status =
|
||||
QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
|
||||
false, tensor_type_, &error_reporter_);
|
||||
EXPECT_EQ(status, kTfLiteOk);
|
||||
|
||||
// There is only one subgraph.
|
||||
const int32_t subgraph_idx = 0;
|
||||
const auto& subgraph = model_.subgraphs[subgraph_idx];
|
||||
const auto& readonly_subgraph =
|
||||
readonly_model_->subgraphs()->Get(subgraph_idx);
|
||||
|
||||
// There should be a single broadcast_to op.
|
||||
EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
|
||||
EXPECT_EQ(subgraph->operators.size(), 1);
|
||||
const auto& broadcast_to = subgraph->operators[0];
|
||||
EXPECT_EQ(model_.operator_codes[broadcast_to->opcode_index]->builtin_code,
|
||||
BuiltinOperator_BROADCAST_TO);
|
||||
|
||||
// There should be 3 tensors: input, output, and BroadcastTo/shape.
|
||||
EXPECT_EQ(subgraph->tensors.size(), 3);
|
||||
|
||||
// Input Tensor
|
||||
EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
|
||||
EXPECT_EQ(subgraph->tensors[0]->name, "input_1");
|
||||
EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
|
||||
EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
|
||||
|
||||
// Output Tensor. The name given in the generated
|
||||
// .bin test file is 'Identity' and should be preserved
|
||||
EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
|
||||
EXPECT_EQ(subgraph->tensors[2]->name, "Identity");
|
||||
EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
|
||||
EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
|
||||
|
||||
// The BroadCastTo shape is of type INT32 and should not be quantized
|
||||
EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT32);
|
||||
EXPECT_EQ(subgraph->tensors[1]->name,
|
||||
"model/tf.broadcast_to/BroadcastTo/shape");
|
||||
EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 0);
|
||||
EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 0);
|
||||
|
||||
// check op and versioning.
|
||||
EXPECT_EQ(model_.operator_codes.size(), 1);
|
||||
EXPECT_EQ(model_.operator_codes[0]->builtin_code,
|
||||
BuiltinOperator_BROADCAST_TO);
|
||||
EXPECT_EQ(model_.operator_codes[0]->version, 3);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace optimize
|
||||
} // namespace tflite
|
||||
|
@ -41,6 +41,8 @@ const char* kConstInputAddModel = "add_with_const_input.bin";
|
||||
|
||||
const char* kFloatConcatMax5Max10Max10 = "concat.bin";
|
||||
|
||||
const char* kModelWithBroadcastToOp = "broadcast_to.bin";
|
||||
|
||||
const char* kModelWithCustomOp = "custom_op.bin";
|
||||
|
||||
const char* kModelWithArgMaxOp = "argmax.bin";
|
||||
|
@ -63,6 +63,9 @@ extern const char* kConstInputAddModel;
|
||||
// 10] as output.
|
||||
extern const char* kFloatConcatMax5Max10Max10;
|
||||
|
||||
// Test model with broadcast_to op.
|
||||
extern const char* kModelWithBroadcastToOp;
|
||||
|
||||
// Test model with a custom op.
|
||||
extern const char* kModelWithCustomOp;
|
||||
|
||||
|
BIN
tensorflow/lite/tools/optimize/testdata/broadcast_to.bin
vendored
Normal file
BIN
tensorflow/lite/tools/optimize/testdata/broadcast_to.bin
vendored
Normal file
Binary file not shown.
@ -635,7 +635,12 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
// The version one of broadcast to op won't be not supported since the
|
||||
// version one was rollbacked and the builtin op code number has been
|
||||
// changed because of builtin op code shortage problem.
|
||||
// Quantized broadcast_to is version 3
|
||||
case BuiltinOperator_BROADCAST_TO:
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8 ||
|
||||
op_sig.input_types.at(0) == TensorType_INT16) {
|
||||
return 3;
|
||||
}
|
||||
return 2;
|
||||
default:
|
||||
return 1;
|
||||
|
@ -894,4 +894,27 @@ TEST(OpVersionTest, VersioningRsqrtTest) {
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
}
|
||||
TEST(OpVersionTest, VersioningBroadcastToTest) {
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_BROADCAST_TO,
|
||||
.input_types = std::vector<TensorType>{TensorType_FLOAT32},
|
||||
.output_types = std::vector<TensorType>{TensorType_FLOAT32},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||
|
||||
// Quantized broadcast_to op is version 3.
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_BROADCAST_TO,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT8},
|
||||
.output_types = std::vector<TensorType>{TensorType_INT8},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
|
||||
fake_op_sig = {
|
||||
.op = BuiltinOperator_BROADCAST_TO,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT16},
|
||||
.output_types = std::vector<TensorType>{TensorType_INT16},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||
}
|
||||
} // namespace tflite
|
||||
|
@ -66,6 +66,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
||||
// the version one was rollbacked and the builtin op code number
|
||||
// has been changed because of builtin op code shortage problem.
|
||||
{{BuiltinOperator_BROADCAST_TO, 2}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_BROADCAST_TO, 3}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_CONV_2D, 1}, "1.5.0"},
|
||||
{{BuiltinOperator_CONV_2D, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_CONV_2D, 3}, "1.14.0"},
|
||||
|
Loading…
Reference in New Issue
Block a user