Merge pull request #46224 from MohamedNourArm:toupstream/broadcast_to
PiperOrigin-RevId: 351674255 Change-Id: Ib17e4343f1e838b291b03b8fbff38e42cad672ca
This commit is contained in:
commit
1d2540ce2d
@ -308,7 +308,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
// of builtin op code shortage problem.
|
// of builtin op code shortage problem.
|
||||||
AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO(),
|
AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO(),
|
||||||
/* min_version = */ 2,
|
/* min_version = */ 2,
|
||||||
/* max_version = */ 2);
|
/* max_version = */ 3);
|
||||||
AddBuiltin(BuiltinOperator_CALL_ONCE,
|
AddBuiltin(BuiltinOperator_CALL_ONCE,
|
||||||
tflite::ops::builtin::Register_CALL_ONCE());
|
tflite::ops::builtin::Register_CALL_ONCE());
|
||||||
AddBuiltin(BuiltinOperator_RFFT2D, Register_RFFT2D());
|
AddBuiltin(BuiltinOperator_RFFT2D, Register_RFFT2D());
|
||||||
|
@ -268,7 +268,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
|||||||
// of builtin op code shortage problem.
|
// of builtin op code shortage problem.
|
||||||
AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO(),
|
AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO(),
|
||||||
/* min_version = */ 2,
|
/* min_version = */ 2,
|
||||||
/* max_version = */ 2);
|
/* max_version = */ 3);
|
||||||
AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
|
AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
|
||||||
Register_LOCAL_RESPONSE_NORM_REF());
|
Register_LOCAL_RESPONSE_NORM_REF());
|
||||||
AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1,
|
AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1,
|
||||||
|
@ -319,6 +319,7 @@ tf_cc_test(
|
|||||||
data = [
|
data = [
|
||||||
"//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin",
|
"//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin",
|
||||||
"//tensorflow/lite/tools/optimize:testdata/argmax.bin",
|
"//tensorflow/lite/tools/optimize:testdata/argmax.bin",
|
||||||
|
"//tensorflow/lite/tools/optimize:testdata/broadcast_to.bin",
|
||||||
"//tensorflow/lite/tools/optimize:testdata/concat.bin",
|
"//tensorflow/lite/tools/optimize:testdata/concat.bin",
|
||||||
"//tensorflow/lite/tools/optimize:testdata/fc.bin",
|
"//tensorflow/lite/tools/optimize:testdata/fc.bin",
|
||||||
"//tensorflow/lite/tools/optimize:testdata/fc_qat.bin",
|
"//tensorflow/lite/tools/optimize:testdata/fc_qat.bin",
|
||||||
|
@ -106,6 +106,12 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) {
|
|||||||
property.version = 2;
|
property.version = 2;
|
||||||
property.quantizable_int16 = false;
|
property.quantizable_int16 = false;
|
||||||
break;
|
break;
|
||||||
|
case BuiltinOperator_BROADCAST_TO:
|
||||||
|
property.inputs = {{0, {}}};
|
||||||
|
property.outputs = {{0, {}}};
|
||||||
|
property.restrict_same_input_output_scale = true;
|
||||||
|
property.version = 3;
|
||||||
|
break;
|
||||||
case BuiltinOperator_DEPTH_TO_SPACE:
|
case BuiltinOperator_DEPTH_TO_SPACE:
|
||||||
property.inputs = {{0, {}}};
|
property.inputs = {{0, {}}};
|
||||||
property.outputs = {{0, {}}};
|
property.outputs = {{0, {}}};
|
||||||
|
@ -1713,6 +1713,73 @@ TEST_F(QuantizeQatTest, VerifySingleQuantize) {
|
|||||||
EXPECT_EQ(model_.operator_codes[2]->version, 4);
|
EXPECT_EQ(model_.operator_codes[2]->version, 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class QuantizeBroadcastToModelTest
|
||||||
|
: public QuantizeModelTest,
|
||||||
|
public testing::WithParamInterface<TensorType> {
|
||||||
|
protected:
|
||||||
|
QuantizeBroadcastToModelTest() {
|
||||||
|
tensor_type_ = GetParam();
|
||||||
|
input_model_ = ReadModel(internal::kModelWithBroadcastToOp);
|
||||||
|
readonly_model_ = input_model_->GetModel();
|
||||||
|
readonly_model_->UnPackTo(&model_);
|
||||||
|
}
|
||||||
|
TensorType tensor_type_;
|
||||||
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(QuantizeBroadcastToModelTestInst,
|
||||||
|
QuantizeBroadcastToModelTest,
|
||||||
|
testing::ValuesIn({TensorType_INT8,
|
||||||
|
TensorType_INT16}));
|
||||||
|
|
||||||
|
TEST_P(QuantizeBroadcastToModelTest, VerifyBroadcastToQuantization) {
|
||||||
|
auto status =
|
||||||
|
QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
|
||||||
|
false, tensor_type_, &error_reporter_);
|
||||||
|
EXPECT_EQ(status, kTfLiteOk);
|
||||||
|
|
||||||
|
// There is only one subgraph.
|
||||||
|
const int32_t subgraph_idx = 0;
|
||||||
|
const auto& subgraph = model_.subgraphs[subgraph_idx];
|
||||||
|
const auto& readonly_subgraph =
|
||||||
|
readonly_model_->subgraphs()->Get(subgraph_idx);
|
||||||
|
|
||||||
|
// There should be a single broadcast_to op.
|
||||||
|
EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
|
||||||
|
EXPECT_EQ(subgraph->operators.size(), 1);
|
||||||
|
const auto& broadcast_to = subgraph->operators[0];
|
||||||
|
EXPECT_EQ(model_.operator_codes[broadcast_to->opcode_index]->builtin_code,
|
||||||
|
BuiltinOperator_BROADCAST_TO);
|
||||||
|
|
||||||
|
// There should be 3 tensors: input, output, and BroadcastTo/shape.
|
||||||
|
EXPECT_EQ(subgraph->tensors.size(), 3);
|
||||||
|
|
||||||
|
// Input Tensor
|
||||||
|
EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
|
||||||
|
EXPECT_EQ(subgraph->tensors[0]->name, "input_1");
|
||||||
|
EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
|
||||||
|
EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
|
||||||
|
|
||||||
|
// Output Tensor. The name given in the generated
|
||||||
|
// .bin test file is 'Identity' and should be preserved
|
||||||
|
EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
|
||||||
|
EXPECT_EQ(subgraph->tensors[2]->name, "Identity");
|
||||||
|
EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
|
||||||
|
EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
|
||||||
|
|
||||||
|
// The BroadCastTo shape is of type INT32 and should not be quantized
|
||||||
|
EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT32);
|
||||||
|
EXPECT_EQ(subgraph->tensors[1]->name,
|
||||||
|
"model/tf.broadcast_to/BroadcastTo/shape");
|
||||||
|
EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 0);
|
||||||
|
EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 0);
|
||||||
|
|
||||||
|
// check op and versioning.
|
||||||
|
EXPECT_EQ(model_.operator_codes.size(), 1);
|
||||||
|
EXPECT_EQ(model_.operator_codes[0]->builtin_code,
|
||||||
|
BuiltinOperator_BROADCAST_TO);
|
||||||
|
EXPECT_EQ(model_.operator_codes[0]->version, 3);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace optimize
|
} // namespace optimize
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -41,6 +41,8 @@ const char* kConstInputAddModel = "add_with_const_input.bin";
|
|||||||
|
|
||||||
const char* kFloatConcatMax5Max10Max10 = "concat.bin";
|
const char* kFloatConcatMax5Max10Max10 = "concat.bin";
|
||||||
|
|
||||||
|
const char* kModelWithBroadcastToOp = "broadcast_to.bin";
|
||||||
|
|
||||||
const char* kModelWithCustomOp = "custom_op.bin";
|
const char* kModelWithCustomOp = "custom_op.bin";
|
||||||
|
|
||||||
const char* kModelWithArgMaxOp = "argmax.bin";
|
const char* kModelWithArgMaxOp = "argmax.bin";
|
||||||
|
@ -63,6 +63,9 @@ extern const char* kConstInputAddModel;
|
|||||||
// 10] as output.
|
// 10] as output.
|
||||||
extern const char* kFloatConcatMax5Max10Max10;
|
extern const char* kFloatConcatMax5Max10Max10;
|
||||||
|
|
||||||
|
// Test model with broadcast_to op.
|
||||||
|
extern const char* kModelWithBroadcastToOp;
|
||||||
|
|
||||||
// Test model with a custom op.
|
// Test model with a custom op.
|
||||||
extern const char* kModelWithCustomOp;
|
extern const char* kModelWithCustomOp;
|
||||||
|
|
||||||
|
BIN
tensorflow/lite/tools/optimize/testdata/broadcast_to.bin
vendored
Normal file
BIN
tensorflow/lite/tools/optimize/testdata/broadcast_to.bin
vendored
Normal file
Binary file not shown.
@ -635,7 +635,12 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||||||
// The version one of broadcast to op won't be not supported since the
|
// The version one of broadcast to op won't be not supported since the
|
||||||
// version one was rollbacked and the builtin op code number has been
|
// version one was rollbacked and the builtin op code number has been
|
||||||
// changed because of builtin op code shortage problem.
|
// changed because of builtin op code shortage problem.
|
||||||
|
// Quantized broadcast_to is version 3
|
||||||
case BuiltinOperator_BROADCAST_TO:
|
case BuiltinOperator_BROADCAST_TO:
|
||||||
|
if (op_sig.input_types.at(0) == TensorType_INT8 ||
|
||||||
|
op_sig.input_types.at(0) == TensorType_INT16) {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
return 2;
|
return 2;
|
||||||
default:
|
default:
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -894,4 +894,27 @@ TEST(OpVersionTest, VersioningRsqrtTest) {
|
|||||||
};
|
};
|
||||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||||
}
|
}
|
||||||
|
TEST(OpVersionTest, VersioningBroadcastToTest) {
|
||||||
|
OpSignature fake_op_sig = {
|
||||||
|
.op = BuiltinOperator_BROADCAST_TO,
|
||||||
|
.input_types = std::vector<TensorType>{TensorType_FLOAT32},
|
||||||
|
.output_types = std::vector<TensorType>{TensorType_FLOAT32},
|
||||||
|
};
|
||||||
|
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
|
||||||
|
|
||||||
|
// Quantized broadcast_to op is version 3.
|
||||||
|
fake_op_sig = {
|
||||||
|
.op = BuiltinOperator_BROADCAST_TO,
|
||||||
|
.input_types = std::vector<TensorType>{TensorType_INT8},
|
||||||
|
.output_types = std::vector<TensorType>{TensorType_INT8},
|
||||||
|
};
|
||||||
|
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||||
|
|
||||||
|
fake_op_sig = {
|
||||||
|
.op = BuiltinOperator_BROADCAST_TO,
|
||||||
|
.input_types = std::vector<TensorType>{TensorType_INT16},
|
||||||
|
.output_types = std::vector<TensorType>{TensorType_INT16},
|
||||||
|
};
|
||||||
|
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
|
||||||
|
}
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -66,6 +66,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
|||||||
// the version one was rollbacked and the builtin op code number
|
// the version one was rollbacked and the builtin op code number
|
||||||
// has been changed because of builtin op code shortage problem.
|
// has been changed because of builtin op code shortage problem.
|
||||||
{{BuiltinOperator_BROADCAST_TO, 2}, kPendingReleaseVersion},
|
{{BuiltinOperator_BROADCAST_TO, 2}, kPendingReleaseVersion},
|
||||||
|
{{BuiltinOperator_BROADCAST_TO, 3}, kPendingReleaseVersion},
|
||||||
{{BuiltinOperator_CONV_2D, 1}, "1.5.0"},
|
{{BuiltinOperator_CONV_2D, 1}, "1.5.0"},
|
||||||
{{BuiltinOperator_CONV_2D, 2}, "1.14.0"},
|
{{BuiltinOperator_CONV_2D, 2}, "1.14.0"},
|
||||||
{{BuiltinOperator_CONV_2D, 3}, "1.14.0"},
|
{{BuiltinOperator_CONV_2D, 3}, "1.14.0"},
|
||||||
|
Loading…
Reference in New Issue
Block a user