Create int8 batch_to_space_nd.
PiperOrigin-RevId: 233532684
This commit is contained in:
parent
07cea8cc4c
commit
078228441f
@ -148,6 +148,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, uint8_t);
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt8:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int8_t);
|
||||
} else {
|
||||
TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int8_t);
|
||||
}
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int32_t);
|
||||
|
@ -26,8 +26,9 @@ using ::testing::ElementsAreArray;
|
||||
|
||||
class BatchToSpaceNDOpModel : public SingleOpModel {
|
||||
public:
|
||||
void SetInput(std::initializer_list<float> data) {
|
||||
PopulateTensor<float>(input_, data);
|
||||
template <typename T>
|
||||
void SetInput(std::initializer_list<T> data) {
|
||||
PopulateTensor<T>(input_, data);
|
||||
}
|
||||
|
||||
void SetBlockShape(std::initializer_list<int> data) {
|
||||
@ -38,7 +39,10 @@ class BatchToSpaceNDOpModel : public SingleOpModel {
|
||||
PopulateTensor<int>(crops_, data);
|
||||
}
|
||||
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
template <typename T>
|
||||
std::vector<T> GetOutput() {
|
||||
return ExtractVector<T>(output_);
|
||||
}
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
protected:
|
||||
@ -58,11 +62,12 @@ class BatchToSpaceNDOpConstModel : public BatchToSpaceNDOpModel {
|
||||
public:
|
||||
BatchToSpaceNDOpConstModel(std::initializer_list<int> input_shape,
|
||||
std::initializer_list<int> block_shape,
|
||||
std::initializer_list<int> crops) {
|
||||
input_ = AddInput(TensorType_FLOAT32);
|
||||
std::initializer_list<int> crops,
|
||||
const TensorType& type = TensorType_FLOAT32) {
|
||||
input_ = AddInput(type);
|
||||
block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2});
|
||||
crops_ = AddConstInput(TensorType_INT32, crops, {2, 2});
|
||||
output_ = AddOutput(TensorType_FLOAT32);
|
||||
output_ = AddOutput(type);
|
||||
|
||||
SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND,
|
||||
BuiltinOptions_BatchToSpaceNDOptions,
|
||||
@ -81,11 +86,12 @@ class BatchToSpaceNDOpConstModel : public BatchToSpaceNDOpModel {
|
||||
// m.Invoke();
|
||||
class BatchToSpaceNDOpDynamicModel : public BatchToSpaceNDOpModel {
|
||||
public:
|
||||
BatchToSpaceNDOpDynamicModel(std::initializer_list<int> input_shape) {
|
||||
input_ = AddInput(TensorType_FLOAT32);
|
||||
BatchToSpaceNDOpDynamicModel(std::initializer_list<int> input_shape,
|
||||
const TensorType& type = TensorType_FLOAT32) {
|
||||
input_ = AddInput(type);
|
||||
block_shape_ = AddInput(TensorType_INT32);
|
||||
crops_ = AddInput(TensorType_INT32);
|
||||
output_ = AddOutput(TensorType_FLOAT32);
|
||||
output_ = AddOutput(type);
|
||||
|
||||
SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND,
|
||||
BuiltinOptions_BatchToSpaceNDOptions,
|
||||
@ -96,22 +102,47 @@ class BatchToSpaceNDOpDynamicModel : public BatchToSpaceNDOpModel {
|
||||
|
||||
TEST(BatchToSpaceNDOpTest, SimpleConstTest) {
|
||||
BatchToSpaceNDOpConstModel m({4, 2, 2, 1}, {2, 2}, {0, 0, 0, 0});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
m.SetInput<float>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7,
|
||||
4, 8, 11, 15, 12, 16}));
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray(
|
||||
{1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16}));
|
||||
}
|
||||
|
||||
TEST(BatchToSpaceNDOpTest, SimpleConstTestInt8) {
|
||||
BatchToSpaceNDOpConstModel m({4, 2, 2, 1}, {2, 2}, {0, 0, 0, 0},
|
||||
TensorType_INT8);
|
||||
m.SetInput<int8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
EXPECT_THAT(m.GetOutput<int8_t>(),
|
||||
ElementsAreArray(
|
||||
{1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16}));
|
||||
}
|
||||
|
||||
TEST(BatchToSpaceNDOpTest, SimpleDynamicTest) {
|
||||
BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
m.SetInput<float>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
m.SetBlockShape({2, 2});
|
||||
m.SetCrops({0, 0, 0, 0});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7,
|
||||
4, 8, 11, 15, 12, 16}));
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray(
|
||||
{1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16}));
|
||||
}
|
||||
|
||||
TEST(BatchToSpaceNDOpTest, SimpleDynamicTestInt8) {
|
||||
BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}, TensorType_INT8);
|
||||
m.SetInput<int8_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
m.SetBlockShape({2, 2});
|
||||
m.SetCrops({0, 0, 0, 0});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
EXPECT_THAT(m.GetOutput<int8_t>(),
|
||||
ElementsAreArray(
|
||||
{1, 5, 2, 6, 9, 13, 10, 14, 3, 7, 4, 8, 11, 15, 12, 16}));
|
||||
}
|
||||
|
||||
#ifdef GTEST_HAS_DEATH_TEST
|
||||
@ -127,7 +158,7 @@ TEST(BatchToSpaceNDOpTest, InvalidCropsConstTest) {
|
||||
|
||||
TEST(BatchToSpaceNDOpTest, InvalidCropsDynamicTest) {
|
||||
BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
m.SetInput<float>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
m.SetBlockShape({2, 2});
|
||||
m.SetCrops({0, 0, -1, 0});
|
||||
EXPECT_DEATH(m.Invoke(), "crops.2. >= 0 was not true.");
|
||||
|
@ -215,7 +215,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
AddBuiltin(BuiltinOperator_SPACE_TO_BATCH_ND, Register_SPACE_TO_BATCH_ND(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, Register_BATCH_TO_SPACE_ND());
|
||||
AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, Register_BATCH_TO_SPACE_ND(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_MUL, Register_MUL());
|
||||
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION());
|
||||
AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
|
||||
|
@ -314,6 +314,12 @@ class BatchToSpaceND
|
||||
TocoOperator* op) const override {}
|
||||
|
||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||
const string& input_name = op_signature.op->inputs[0];
|
||||
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||
// If the op take int8 input, it is version 2.
|
||||
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
@ -808,6 +808,10 @@ TEST_F(OperatorTest, VersioningPackTest) {
|
||||
SimpleVersioningTest<PackOperator>();
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningBatchToSpaceNDTest) {
|
||||
SimpleVersioningTest<BatchToSpaceNDOperator>();
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, VersioningStridedSliceTest) {
|
||||
SimpleVersioningTest<StridedSliceOperator>();
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user