Create int8 batch_to_space_nd.

PiperOrigin-RevId: 233532684
This commit is contained in:
Jian Li 2019-02-11 20:37:26 -08:00 committed by TensorFlower Gardener
parent 07cea8cc4c
commit 078228441f
5 changed files with 67 additions and 17 deletions

View File

@ -148,6 +148,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, uint8_t);
}
break;
case kTfLiteInt8:
if (kernel_type == kReference) {
TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int8_t);
} else {
TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int8_t);
}
break;
case kTfLiteInt32:
if (kernel_type == kReference) {
TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int32_t);

View File

@ -26,8 +26,9 @@ using ::testing::ElementsAreArray;
class BatchToSpaceNDOpModel : public SingleOpModel {
public:
void SetInput(std::initializer_list<float> data) {
PopulateTensor<float>(input_, data);
template <typename T>
void SetInput(std::initializer_list<T> data) {
PopulateTensor<T>(input_, data);
}
void SetBlockShape(std::initializer_list<int> data) {
@ -38,7 +39,10 @@ class BatchToSpaceNDOpModel : public SingleOpModel {
PopulateTensor<int>(crops_, data);
}
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
protected:
@ -58,11 +62,12 @@ class BatchToSpaceNDOpConstModel : public BatchToSpaceNDOpModel {
public:
BatchToSpaceNDOpConstModel(std::initializer_list<int> input_shape,
std::initializer_list<int> block_shape,
std::initializer_list<int> crops) {
input_ = AddInput(TensorType_FLOAT32);
std::initializer_list<int> crops,
const TensorType& type = TensorType_FLOAT32) {
input_ = AddInput(type);
block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2});
crops_ = AddConstInput(TensorType_INT32, crops, {2, 2});
output_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(type);
SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND,
BuiltinOptions_BatchToSpaceNDOptions,
@ -81,11 +86,12 @@ class BatchToSpaceNDOpConstModel : public BatchToSpaceNDOpModel {
// m.Invoke();
class BatchToSpaceNDOpDynamicModel : public BatchToSpaceNDOpModel {
public:
BatchToSpaceNDOpDynamicModel(std::initializer_list<int> input_shape) {
input_ = AddInput(TensorType_FLOAT32);
BatchToSpaceNDOpDynamicModel(std::initializer_list<int> input_shape,
const TensorType& type = TensorType_FLOAT32) {
input_ = AddInput(type);
block_shape_ = AddInput(TensorType_INT32);
crops_ = AddInput(TensorType_INT32);
output_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(type);
SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND,
BuiltinOptions_BatchToSpaceNDOptions,
@ -96,22 +102,47 @@ class BatchToSpaceNDOpDynamicModel : public BatchToSpaceNDOpModel {
TEST(BatchToSpaceNDOpTest, SimpleConstTest) {
BatchToSpaceNDOpConstModel m({4, 2, 2, 1}, {2, 2}, {0, 0, 0, 0});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.SetInput<float>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7,
4, 8, 11, 15, 12, 16}));
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(
{1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16}));
}
TEST(BatchToSpaceNDOpTest, SimpleConstTestInt8) {
BatchToSpaceNDOpConstModel m({4, 2, 2, 1}, {2, 2}, {0, 0, 0, 0},
TensorType_INT8);
m.SetInput<int8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
EXPECT_THAT(m.GetOutput<int8_t>(),
ElementsAreArray(
{1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16}));
}
TEST(BatchToSpaceNDOpTest, SimpleDynamicTest) {
BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.SetInput<float>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.SetBlockShape({2, 2});
m.SetCrops({0, 0, 0, 0});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7,
4, 8, 11, 15, 12, 16}));
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(
{1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16}));
}
TEST(BatchToSpaceNDOpTest, SimpleDynamicTestInt8) {
BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}, TensorType_INT8);
m.SetInput<int8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.SetBlockShape({2, 2});
m.SetCrops({0, 0, 0, 0});
m.Invoke();
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
EXPECT_THAT(m.GetOutput<int8_t>(),
ElementsAreArray(
{1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16}));
}
#ifdef GTEST_HAS_DEATH_TEST
@ -127,7 +158,7 @@ TEST(BatchToSpaceNDOpTest, InvalidCropsConstTest) {
TEST(BatchToSpaceNDOpTest, InvalidCropsDynamicTest) {
BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1});
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.SetInput<float>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.SetBlockShape({2, 2});
m.SetCrops({0, 0, -1, 0});
EXPECT_DEATH(m.Invoke(), "crops.2. >= 0 was not true.");

View File

@ -215,7 +215,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SPACE_TO_BATCH_ND, Register_SPACE_TO_BATCH_ND(),
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, Register_BATCH_TO_SPACE_ND());
AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, Register_BATCH_TO_SPACE_ND(),
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_MUL, Register_MUL());
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION());
AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,

View File

@ -314,6 +314,12 @@ class BatchToSpaceND
TocoOperator* op) const override {}
int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
// If the op take int8 input, it is version 2.
if (input_array.data_type == ArrayDataType::kInt8) {
return 2;
}
return 1;
}
};

View File

@ -808,6 +808,10 @@ TEST_F(OperatorTest, VersioningPackTest) {
SimpleVersioningTest<PackOperator>();
}
TEST_F(OperatorTest, VersioningBatchToSpaceNDTest) {
SimpleVersioningTest<BatchToSpaceNDOperator>();
}
TEST_F(OperatorTest, VersioningStridedSliceTest) {
SimpleVersioningTest<StridedSliceOperator>();
}