From 370f10f3b9711bd03367f32c4426b7d73a5c0d4c Mon Sep 17 00:00:00 2001 From: Jian Li Date: Tue, 12 Feb 2019 10:34:53 -0800 Subject: [PATCH] Create int8 space to depth. PiperOrigin-RevId: 233633513 --- tensorflow/lite/kernels/register.cc | 4 +++- tensorflow/lite/kernels/space_to_depth.cc | 10 +++++++++- tensorflow/lite/kernels/space_to_depth_test.cc | 8 ++++++++ tensorflow/lite/toco/tflite/operator.cc | 6 ++++++ tensorflow/lite/toco/tflite/operator_test.cc | 4 ++++ 5 files changed, 30 insertions(+), 2 deletions(-) diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 2290ed477cf..f7ab8edc5c5 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -241,7 +241,9 @@ BuiltinOpResolver::BuiltinOpResolver() { /* min_version */ 1, /* max_version */ 2); AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM()); - AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH()); + AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_GATHER, Register_GATHER(), /* min_version */ 1, /* max_version */ 2); diff --git a/tensorflow/lite/kernels/space_to_depth.cc b/tensorflow/lite/kernels/space_to_depth.cc index 79e28bf47d9..cf6b0bd4d3d 100644 --- a/tensorflow/lite/kernels/space_to_depth.cc +++ b/tensorflow/lite/kernels/space_to_depth.cc @@ -50,7 +50,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { auto data_type = output->type; TF_LITE_ENSURE(context, data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 || - data_type == kTfLiteInt32 || data_type == kTfLiteInt64); + data_type == kTfLiteInt8 || data_type == kTfLiteInt32 || + data_type == kTfLiteInt64); TF_LITE_ENSURE_EQ(context, input->type, output->type); const int block_size = params->block_size; @@ -100,6 +101,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_SPACE_TO_DEPTH(optimized_ops, uint8_t); } break; + case kTfLiteInt8: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, int8_t); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, int8_t); + } + break; case kTfLiteInt32: if (kernel_type == kReference) { TF_LITE_SPACE_TO_DEPTH(reference_ops, int32_t); diff --git a/tensorflow/lite/kernels/space_to_depth_test.cc b/tensorflow/lite/kernels/space_to_depth_test.cc index 3fa8d86348e..58665fc9d83 100644 --- a/tensorflow/lite/kernels/space_to_depth_test.cc +++ b/tensorflow/lite/kernels/space_to_depth_test.cc @@ -74,6 +74,14 @@ TEST(SpaceToDepthOpModel, Uint8) { EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } +TEST(SpaceToDepthOpModel, int8) { + SpaceToDepthOpModel m({TensorType_INT8, {1, 2, 2, 1}}, 2); + m.SetInput({1, 2, 3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + TEST(SpaceToDepthOpModel, Int32) { SpaceToDepthOpModel m({TensorType_INT32, {1, 2, 2, 3}}, 2); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index 961bda28648..adb7f504f17 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -841,6 +841,12 @@ class SpaceToDepth } int GetVersion(const OperatorSignature& op_signature) const override { + const string& input_name = op_signature.op->inputs[0]; + const Array& input_array = op_signature.model->GetArray(input_name); + // If the op take int8 input, it is version 2. + if (input_array.data_type == ArrayDataType::kInt8) { + return 2; + } return 1; } }; diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc index b977270bda2..b4f7c48d75d 100644 --- a/tensorflow/lite/toco/tflite/operator_test.cc +++ b/tensorflow/lite/toco/tflite/operator_test.cc @@ -816,6 +816,10 @@ TEST_F(OperatorTest, VersioningStridedSliceTest) { SimpleVersioningTest(); } +TEST_F(OperatorTest, VersioningSpaceToDepthTest) { + SimpleVersioningTest(); +} + TEST_F(OperatorTest, VersioningSliceTest) { SimpleVersioningTest(); }