Merge pull request from MohamedNourArm:toupstream/broadcast_to

PiperOrigin-RevId: 351674255
Change-Id: Ib17e4343f1e838b291b03b8fbff38e42cad672ca
This commit is contained in:
TensorFlower Gardener 2021-01-13 15:15:05 -08:00
commit 1d2540ce2d
11 changed files with 110 additions and 2 deletions

View File

@ -308,7 +308,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
// of builtin op code shortage problem. // of builtin op code shortage problem.
AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO(), AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO(),
/* min_version = */ 2, /* min_version = */ 2,
/* max_version = */ 2); /* max_version = */ 3);
AddBuiltin(BuiltinOperator_CALL_ONCE, AddBuiltin(BuiltinOperator_CALL_ONCE,
tflite::ops::builtin::Register_CALL_ONCE()); tflite::ops::builtin::Register_CALL_ONCE());
AddBuiltin(BuiltinOperator_RFFT2D, Register_RFFT2D()); AddBuiltin(BuiltinOperator_RFFT2D, Register_RFFT2D());

View File

@ -268,7 +268,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
// of builtin op code shortage problem. // of builtin op code shortage problem.
AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO(), AddBuiltin(BuiltinOperator_BROADCAST_TO, Register_BROADCAST_TO(),
/* min_version = */ 2, /* min_version = */ 2,
/* max_version = */ 2); /* max_version = */ 3);
AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
Register_LOCAL_RESPONSE_NORM_REF()); Register_LOCAL_RESPONSE_NORM_REF());
AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1, AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1,

View File

@ -319,6 +319,7 @@ tf_cc_test(
data = [ data = [
"//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin", "//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin",
"//tensorflow/lite/tools/optimize:testdata/argmax.bin", "//tensorflow/lite/tools/optimize:testdata/argmax.bin",
"//tensorflow/lite/tools/optimize:testdata/broadcast_to.bin",
"//tensorflow/lite/tools/optimize:testdata/concat.bin", "//tensorflow/lite/tools/optimize:testdata/concat.bin",
"//tensorflow/lite/tools/optimize:testdata/fc.bin", "//tensorflow/lite/tools/optimize:testdata/fc.bin",
"//tensorflow/lite/tools/optimize:testdata/fc_qat.bin", "//tensorflow/lite/tools/optimize:testdata/fc_qat.bin",

View File

@ -106,6 +106,12 @@ OperatorProperty GetOperatorProperty(OpVariant op_variant) {
property.version = 2; property.version = 2;
property.quantizable_int16 = false; property.quantizable_int16 = false;
break; break;
case BuiltinOperator_BROADCAST_TO:
property.inputs = {{0, {}}};
property.outputs = {{0, {}}};
property.restrict_same_input_output_scale = true;
property.version = 3;
break;
case BuiltinOperator_DEPTH_TO_SPACE: case BuiltinOperator_DEPTH_TO_SPACE:
property.inputs = {{0, {}}}; property.inputs = {{0, {}}};
property.outputs = {{0, {}}}; property.outputs = {{0, {}}};

View File

@ -1713,6 +1713,73 @@ TEST_F(QuantizeQatTest, VerifySingleQuantize) {
EXPECT_EQ(model_.operator_codes[2]->version, 4); EXPECT_EQ(model_.operator_codes[2]->version, 4);
} }
class QuantizeBroadcastToModelTest
: public QuantizeModelTest,
public testing::WithParamInterface<TensorType> {
protected:
QuantizeBroadcastToModelTest() {
tensor_type_ = GetParam();
input_model_ = ReadModel(internal::kModelWithBroadcastToOp);
readonly_model_ = input_model_->GetModel();
readonly_model_->UnPackTo(&model_);
}
TensorType tensor_type_;
};
INSTANTIATE_TEST_SUITE_P(QuantizeBroadcastToModelTestInst,
QuantizeBroadcastToModelTest,
testing::ValuesIn({TensorType_INT8,
TensorType_INT16}));
TEST_P(QuantizeBroadcastToModelTest, VerifyBroadcastToQuantization) {
auto status =
QuantizeModelAllOperators(&builder_, &model_, tensor_type_, tensor_type_,
false, tensor_type_, &error_reporter_);
EXPECT_EQ(status, kTfLiteOk);
// There is only one subgraph.
const int32_t subgraph_idx = 0;
const auto& subgraph = model_.subgraphs[subgraph_idx];
const auto& readonly_subgraph =
readonly_model_->subgraphs()->Get(subgraph_idx);
// There should be a single broadcast_to op.
EXPECT_EQ(readonly_subgraph->operators()->size(), 1);
EXPECT_EQ(subgraph->operators.size(), 1);
const auto& broadcast_to = subgraph->operators[0];
EXPECT_EQ(model_.operator_codes[broadcast_to->opcode_index]->builtin_code,
BuiltinOperator_BROADCAST_TO);
// There should be 3 tensors: input, output, and BroadcastTo/shape.
EXPECT_EQ(subgraph->tensors.size(), 3);
// Input Tensor
EXPECT_EQ(subgraph->tensors[0]->type, tensor_type_);
EXPECT_EQ(subgraph->tensors[0]->name, "input_1");
EXPECT_EQ(subgraph->tensors[0]->quantization->scale.size(), 1);
EXPECT_EQ(subgraph->tensors[0]->quantization->zero_point.size(), 1);
// Output Tensor. The name given in the generated
// .bin test file is 'Identity' and should be preserved
EXPECT_EQ(subgraph->tensors[2]->type, tensor_type_);
EXPECT_EQ(subgraph->tensors[2]->name, "Identity");
EXPECT_EQ(subgraph->tensors[2]->quantization->scale.size(), 1);
EXPECT_EQ(subgraph->tensors[2]->quantization->zero_point.size(), 1);
// The BroadCastTo shape is of type INT32 and should not be quantized
EXPECT_EQ(subgraph->tensors[1]->type, TensorType_INT32);
EXPECT_EQ(subgraph->tensors[1]->name,
"model/tf.broadcast_to/BroadcastTo/shape");
EXPECT_EQ(subgraph->tensors[1]->quantization->scale.size(), 0);
EXPECT_EQ(subgraph->tensors[1]->quantization->zero_point.size(), 0);
// check op and versioning.
EXPECT_EQ(model_.operator_codes.size(), 1);
EXPECT_EQ(model_.operator_codes[0]->builtin_code,
BuiltinOperator_BROADCAST_TO);
EXPECT_EQ(model_.operator_codes[0]->version, 3);
}
} // namespace } // namespace
} // namespace optimize } // namespace optimize
} // namespace tflite } // namespace tflite

View File

@ -41,6 +41,8 @@ const char* kConstInputAddModel = "add_with_const_input.bin";
const char* kFloatConcatMax5Max10Max10 = "concat.bin"; const char* kFloatConcatMax5Max10Max10 = "concat.bin";
const char* kModelWithBroadcastToOp = "broadcast_to.bin";
const char* kModelWithCustomOp = "custom_op.bin"; const char* kModelWithCustomOp = "custom_op.bin";
const char* kModelWithArgMaxOp = "argmax.bin"; const char* kModelWithArgMaxOp = "argmax.bin";

View File

@ -63,6 +63,9 @@ extern const char* kConstInputAddModel;
// 10] as output. // 10] as output.
extern const char* kFloatConcatMax5Max10Max10; extern const char* kFloatConcatMax5Max10Max10;
// Test model with broadcast_to op.
extern const char* kModelWithBroadcastToOp;
// Test model with a custom op. // Test model with a custom op.
extern const char* kModelWithCustomOp; extern const char* kModelWithCustomOp;

Binary file not shown.

View File

@ -635,7 +635,12 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
// The version one of broadcast to op won't be not supported since the // The version one of broadcast to op won't be not supported since the
// version one was rollbacked and the builtin op code number has been // version one was rollbacked and the builtin op code number has been
// changed because of builtin op code shortage problem. // changed because of builtin op code shortage problem.
// Quantized broadcast_to is version 3
case BuiltinOperator_BROADCAST_TO: case BuiltinOperator_BROADCAST_TO:
if (op_sig.input_types.at(0) == TensorType_INT8 ||
op_sig.input_types.at(0) == TensorType_INT16) {
return 3;
}
return 2; return 2;
default: default:
return 1; return 1;

View File

@ -894,4 +894,27 @@ TEST(OpVersionTest, VersioningRsqrtTest) {
}; };
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2); EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
} }
TEST(OpVersionTest, VersioningBroadcastToTest) {
OpSignature fake_op_sig = {
.op = BuiltinOperator_BROADCAST_TO,
.input_types = std::vector<TensorType>{TensorType_FLOAT32},
.output_types = std::vector<TensorType>{TensorType_FLOAT32},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 2);
// Quantized broadcast_to op is version 3.
fake_op_sig = {
.op = BuiltinOperator_BROADCAST_TO,
.input_types = std::vector<TensorType>{TensorType_INT8},
.output_types = std::vector<TensorType>{TensorType_INT8},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fake_op_sig = {
.op = BuiltinOperator_BROADCAST_TO,
.input_types = std::vector<TensorType>{TensorType_INT16},
.output_types = std::vector<TensorType>{TensorType_INT16},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
}
} // namespace tflite } // namespace tflite

View File

@ -66,6 +66,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
// the version one was rollbacked and the builtin op code number // the version one was rollbacked and the builtin op code number
// has been changed because of builtin op code shortage problem. // has been changed because of builtin op code shortage problem.
{{BuiltinOperator_BROADCAST_TO, 2}, kPendingReleaseVersion}, {{BuiltinOperator_BROADCAST_TO, 2}, kPendingReleaseVersion},
{{BuiltinOperator_BROADCAST_TO, 3}, kPendingReleaseVersion},
{{BuiltinOperator_CONV_2D, 1}, "1.5.0"}, {{BuiltinOperator_CONV_2D, 1}, "1.5.0"},
{{BuiltinOperator_CONV_2D, 2}, "1.14.0"}, {{BuiltinOperator_CONV_2D, 2}, "1.14.0"},
{{BuiltinOperator_CONV_2D, 3}, "1.14.0"}, {{BuiltinOperator_CONV_2D, 3}, "1.14.0"},