Create int8 space to batch ND.

PiperOrigin-RevId: 233460573
This commit is contained in:
Jian Li 2019-02-11 12:52:15 -08:00 committed by TensorFlower Gardener
parent 049848467b
commit 8367a99757
5 changed files with 82 additions and 22 deletions

View File

@ -212,7 +212,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version */ 2);
AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION());
AddBuiltin(BuiltinOperator_ADD, Register_ADD());
AddBuiltin(BuiltinOperator_SPACE_TO_BATCH_ND, Register_SPACE_TO_BATCH_ND());
AddBuiltin(BuiltinOperator_SPACE_TO_BATCH_ND, Register_SPACE_TO_BATCH_ND(),
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, Register_BATCH_TO_SPACE_ND());
AddBuiltin(BuiltinOperator_MUL, Register_MUL());
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION());

View File

@ -141,6 +141,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
op_context.output->params.zero_point);
}
break;
case kTfLiteInt8:
if (kernel_type == kReference) {
TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int8_t,
op_context.output->params.zero_point);
} else {
TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int8_t,
op_context.output->params.zero_point);
}
break;
case kTfLiteInt32:
if (kernel_type == kReference) {
TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t, 0);

View File

@ -31,8 +31,9 @@ class SpaceToBatchNDOpModel : public SingleOpModel {
PopulateTensor<float>(input_, data);
}
template <typename T>
void SetQuantizedInput(std::initializer_list<float> data) {
QuantizeAndPopulate<uint8_t>(input_, data);
QuantizeAndPopulate<T>(input_, data);
}
void SetBlockShape(std::initializer_list<int> data) {
@ -46,9 +47,10 @@ class SpaceToBatchNDOpModel : public SingleOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
template <typename T>
std::vector<float> GetDequantizedOutput() {
return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
GetScale(output_), GetZeroPoint(output_));
return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
GetZeroPoint(output_));
}
protected:
@ -233,29 +235,62 @@ TEST_F(QuantizedSpaceToBatchNDOpTest, ZeroNotInQuantizationRange) {
}
#endif
TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingConstTest) {
TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingConstTestUint8) {
SpaceToBatchNDOpConstModel m({TensorType_UINT8, {1, 5, 2, 1}, -1.0, 1.0},
{3, 2}, {1, 0, 2, 0},
{TensorType_UINT8, {}, -1.0, 1.0});
m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1});
m.SetQuantizedInput<uint8_t>(
{-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1}));
EXPECT_THAT(m.GetDequantizedOutput(),
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(DequantizedArrayNear(
{0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7,
0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1},
-1.0, 1.0)));
}
TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingDynamicTest) {
TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingConstTestInt8) {
SpaceToBatchNDOpConstModel m({TensorType_INT8, {1, 5, 2, 1}, -1.0, 1.0},
{3, 2}, {1, 0, 2, 0},
{TensorType_INT8, {}, -1.0, 1.0});
m.SetQuantizedInput<int8_t>(
{-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1}));
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
ElementsAreArray(DequantizedArrayNear(
{0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7,
0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1},
-1.0, 1.0)));
}
TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingDynamicTestUint8) {
SpaceToBatchNDOpDynamicModel m({TensorType_UINT8, {1, 5, 2, 1}, -1.0, 1.0},
{TensorType_UINT8, {}, -1.0, 1.0});
m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1});
m.SetQuantizedInput<uint8_t>(
{-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1});
m.SetBlockShape({3, 2});
m.SetPaddings({1, 0, 2, 0});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1}));
EXPECT_THAT(m.GetDequantizedOutput(),
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(DequantizedArrayNear(
{0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7,
0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1},
-1.0, 1.0)));
}
TEST_F(QuantizedSpaceToBatchNDOpTest, SimplePaddingDynamicTestInt8) {
SpaceToBatchNDOpDynamicModel m({TensorType_INT8, {1, 5, 2, 1}, -1.0, 1.0},
{TensorType_INT8, {}, -1.0, 1.0});
m.SetQuantizedInput<int8_t>(
{-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 0.1});
m.SetBlockShape({3, 2});
m.SetPaddings({1, 0, 2, 0});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1}));
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
ElementsAreArray(DequantizedArrayNear(
{0, 0, 0, -0.5, 0, 0, 0, 0.6, 0, -0.1, 0, -0.7,
0, 0.2, 0, 0.8, 0, -0.3, 0, -0.9, 0, 0.4, 0, 0.1},
@ -266,10 +301,10 @@ TEST_F(QuantizedSpaceToBatchNDOpTest, ComplexPaddingConstTest) {
SpaceToBatchNDOpConstModel m({TensorType_UINT8, {1, 4, 2, 1}, -1.0, 1.0},
{3, 2}, {1, 1, 2, 4},
{TensorType_UINT8, {}, -1.0, 1.0});
m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8});
m.SetQuantizedInput<uint8_t>({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1}));
EXPECT_THAT(m.GetDequantizedOutput(),
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(DequantizedArrayNear(
{
0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0.6, 0, 0,
@ -282,12 +317,12 @@ TEST_F(QuantizedSpaceToBatchNDOpTest, ComplexPaddingConstTest) {
TEST_F(QuantizedSpaceToBatchNDOpTest, ComplexPaddingDynamicTest) {
SpaceToBatchNDOpDynamicModel m({TensorType_UINT8, {1, 4, 2, 1}, -1.0, 1.0},
{TensorType_UINT8, {}, -1.0, 1.0});
m.SetQuantizedInput({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8});
m.SetQuantizedInput<uint8_t>({-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8});
m.SetBlockShape({3, 2});
m.SetPaddings({1, 1, 2, 4});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1}));
EXPECT_THAT(m.GetDequantizedOutput(),
EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(DequantizedArrayNear(
{
0, 0, 0, 0, 0, -0.5, 0, 0, 0, 0, 0, 0, 0, 0.6, 0, 0,

View File

@ -239,6 +239,12 @@ class SpaceToBatchND
TocoOperator* op) const override {}
int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
// If the op take int8 input, it is version 2.
if (input_array.data_type == ArrayDataType::kInt8) {
return 2;
}
return 1;
}
};

View File

@ -754,8 +754,10 @@ TEST_F(OperatorTest, BuiltinUnique) {
EXPECT_EQ(output_toco_op->idx_out_type, op.idx_out_type);
}
// Test version for a simple Op with 2 versions and the input type controls the
// version.
template <typename Op>
void VersioningTest() {
void SimpleVersioningTest() {
Op op;
op.inputs = {"input1"};
auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
@ -775,30 +777,36 @@ void VersioningTest() {
}
TEST_F(OperatorTest, VersioningEqualTest) {
VersioningTest<TensorFlowEqualOperator>();
SimpleVersioningTest<TensorFlowEqualOperator>();
}
TEST_F(OperatorTest, VersioningNotEqualTest) {
VersioningTest<TensorFlowNotEqualOperator>();
SimpleVersioningTest<TensorFlowNotEqualOperator>();
}
TEST_F(OperatorTest, VersioningLessTest) {
VersioningTest<TensorFlowLessOperator>();
SimpleVersioningTest<TensorFlowLessOperator>();
}
TEST_F(OperatorTest, VersioningLessEqualTest) {
VersioningTest<TensorFlowLessEqualOperator>();
SimpleVersioningTest<TensorFlowLessEqualOperator>();
}
TEST_F(OperatorTest, VersioningGreaterTest) {
VersioningTest<TensorFlowGreaterOperator>();
SimpleVersioningTest<TensorFlowGreaterOperator>();
}
TEST_F(OperatorTest, VersioningGreaterEqualTest) {
VersioningTest<TensorFlowGreaterEqualOperator>();
SimpleVersioningTest<TensorFlowGreaterEqualOperator>();
}
TEST_F(OperatorTest, VersioningPackTest) { VersioningTest<PackOperator>(); }
TEST_F(OperatorTest, VersioningSpaceToBatchNDTest) {
SimpleVersioningTest<SpaceToBatchNDOperator>();
}
TEST_F(OperatorTest, VersioningPackTest) {
SimpleVersioningTest<PackOperator>();
}
TEST_F(OperatorTest, VersioningSelectTest) {
SelectOperator select_op;