Create int8 space to depth.

PiperOrigin-RevId: 233633513
This commit is contained in:
Jian Li 2019-02-12 10:34:53 -08:00 committed by TensorFlower Gardener
parent 7141c42808
commit 370f10f3b9
5 changed files with 30 additions and 2 deletions

View File

@ -241,7 +241,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM());
AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH());
AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH(),
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_GATHER, Register_GATHER(),
/* min_version */ 1,
/* max_version */ 2);

View File

@ -50,7 +50,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto data_type = output->type;
TF_LITE_ENSURE(context,
data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 ||
data_type == kTfLiteInt32 || data_type == kTfLiteInt64);
data_type == kTfLiteInt8 || data_type == kTfLiteInt32 ||
data_type == kTfLiteInt64);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
const int block_size = params->block_size;
@ -100,6 +101,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_SPACE_TO_DEPTH(optimized_ops, uint8_t);
}
break;
case kTfLiteInt8:
if (kernel_type == kReference) {
TF_LITE_SPACE_TO_DEPTH(reference_ops, int8_t);
} else {
TF_LITE_SPACE_TO_DEPTH(optimized_ops, int8_t);
}
break;
case kTfLiteInt32:
if (kernel_type == kReference) {
TF_LITE_SPACE_TO_DEPTH(reference_ops, int32_t);

View File

@ -74,6 +74,14 @@ TEST(SpaceToDepthOpModel, Uint8) {
EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(SpaceToDepthOpModel, int8) {
SpaceToDepthOpModel m({TensorType_INT8, {1, 2, 2, 1}}, 2);
m.SetInput<int8_t>({1, 2, 3, 4});
m.Invoke();
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({1, 2, 3, 4}));
EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(SpaceToDepthOpModel, Int32) {
SpaceToDepthOpModel m({TensorType_INT32, {1, 2, 2, 3}}, 2);
m.SetInput<int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});

View File

@ -841,6 +841,12 @@ class SpaceToDepth
}
int GetVersion(const OperatorSignature& op_signature) const override {
const string& input_name = op_signature.op->inputs[0];
const Array& input_array = op_signature.model->GetArray(input_name);
// If the op take int8 input, it is version 2.
if (input_array.data_type == ArrayDataType::kInt8) {
return 2;
}
return 1;
}
};

View File

@ -816,6 +816,10 @@ TEST_F(OperatorTest, VersioningStridedSliceTest) {
SimpleVersioningTest<StridedSliceOperator>();
}
TEST_F(OperatorTest, VersioningSpaceToDepthTest) {
SimpleVersioningTest<SpaceToDepthOperator>();
}
TEST_F(OperatorTest, VersioningSliceTest) {
SimpleVersioningTest<SliceOperator>();
}