Create int8 space to depth.
PiperOrigin-RevId: 233633513
This commit is contained in:
parent
7141c42808
commit
370f10f3b9
@ -241,7 +241,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
/* min_version */ 1,
|
/* min_version */ 1,
|
||||||
/* max_version */ 2);
|
/* max_version */ 2);
|
||||||
AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM());
|
AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM());
|
||||||
AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH());
|
AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH(),
|
||||||
|
/* min_version */ 1,
|
||||||
|
/* max_version */ 2);
|
||||||
AddBuiltin(BuiltinOperator_GATHER, Register_GATHER(),
|
AddBuiltin(BuiltinOperator_GATHER, Register_GATHER(),
|
||||||
/* min_version */ 1,
|
/* min_version */ 1,
|
||||||
/* max_version */ 2);
|
/* max_version */ 2);
|
||||||
|
@ -50,7 +50,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
auto data_type = output->type;
|
auto data_type = output->type;
|
||||||
TF_LITE_ENSURE(context,
|
TF_LITE_ENSURE(context,
|
||||||
data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 ||
|
data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 ||
|
||||||
data_type == kTfLiteInt32 || data_type == kTfLiteInt64);
|
data_type == kTfLiteInt8 || data_type == kTfLiteInt32 ||
|
||||||
|
data_type == kTfLiteInt64);
|
||||||
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
TF_LITE_ENSURE_EQ(context, input->type, output->type);
|
||||||
|
|
||||||
const int block_size = params->block_size;
|
const int block_size = params->block_size;
|
||||||
@ -100,6 +101,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_SPACE_TO_DEPTH(optimized_ops, uint8_t);
|
TF_LITE_SPACE_TO_DEPTH(optimized_ops, uint8_t);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case kTfLiteInt8:
|
||||||
|
if (kernel_type == kReference) {
|
||||||
|
TF_LITE_SPACE_TO_DEPTH(reference_ops, int8_t);
|
||||||
|
} else {
|
||||||
|
TF_LITE_SPACE_TO_DEPTH(optimized_ops, int8_t);
|
||||||
|
}
|
||||||
|
break;
|
||||||
case kTfLiteInt32:
|
case kTfLiteInt32:
|
||||||
if (kernel_type == kReference) {
|
if (kernel_type == kReference) {
|
||||||
TF_LITE_SPACE_TO_DEPTH(reference_ops, int32_t);
|
TF_LITE_SPACE_TO_DEPTH(reference_ops, int32_t);
|
||||||
|
@ -74,6 +74,14 @@ TEST(SpaceToDepthOpModel, Uint8) {
|
|||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SpaceToDepthOpModel, int8) {
|
||||||
|
SpaceToDepthOpModel m({TensorType_INT8, {1, 2, 2, 1}}, 2);
|
||||||
|
m.SetInput<int8_t>({1, 2, 3, 4});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({1, 2, 3, 4}));
|
||||||
|
EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 4));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(SpaceToDepthOpModel, Int32) {
|
TEST(SpaceToDepthOpModel, Int32) {
|
||||||
SpaceToDepthOpModel m({TensorType_INT32, {1, 2, 2, 3}}, 2);
|
SpaceToDepthOpModel m({TensorType_INT32, {1, 2, 2, 3}}, 2);
|
||||||
m.SetInput<int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
m.SetInput<int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
|
@ -841,6 +841,12 @@ class SpaceToDepth
|
|||||||
}
|
}
|
||||||
|
|
||||||
int GetVersion(const OperatorSignature& op_signature) const override {
|
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||||
|
const string& input_name = op_signature.op->inputs[0];
|
||||||
|
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||||
|
// If the op take int8 input, it is version 2.
|
||||||
|
if (input_array.data_type == ArrayDataType::kInt8) {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -816,6 +816,10 @@ TEST_F(OperatorTest, VersioningStridedSliceTest) {
|
|||||||
SimpleVersioningTest<StridedSliceOperator>();
|
SimpleVersioningTest<StridedSliceOperator>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(OperatorTest, VersioningSpaceToDepthTest) {
|
||||||
|
SimpleVersioningTest<SpaceToDepthOperator>();
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(OperatorTest, VersioningSliceTest) {
|
TEST_F(OperatorTest, VersioningSliceTest) {
|
||||||
SimpleVersioningTest<SliceOperator>();
|
SimpleVersioningTest<SliceOperator>();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user