Add bool support in TFL unpack op
This bool type support is required for tf.layers.keras.RNN with unroll=True. PiperOrigin-RevId: 286955661 Change-Id: Ie557febc346e44978c4e3e96d828b21098ab3afb
This commit is contained in:
parent
c2e9f671a7
commit
a1aa335328
@ -2359,14 +2359,14 @@ def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect]> {
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TensorOf<[F32, I8, I32, QI8, QUI8]>:$input,
|
TensorOf<[F32, I1, I8, I32, QI8, QUI8]>:$input,
|
||||||
|
|
||||||
I32Attr:$num,
|
I32Attr:$num,
|
||||||
I32Attr:$axis
|
I32Attr:$axis
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
Variadic<TensorOf<[F32, I8, I32, QI8, QUI8]>>:$outputs
|
Variadic<TensorOf<[F32, I1, I8, I32, QI8, QUI8]>>:$outputs
|
||||||
);
|
);
|
||||||
|
|
||||||
let verifier = [{ return Verify(*this); }];
|
let verifier = [{ return Verify(*this); }];
|
||||||
|
@ -247,7 +247,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
|
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
|
||||||
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK(),
|
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK(),
|
||||||
/* min_version */ 1,
|
/* min_version */ 1,
|
||||||
/* max_version */ 2);
|
/* max_version */ 3);
|
||||||
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV(),
|
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV(),
|
||||||
/* min_version */ 1,
|
/* min_version */ 1,
|
||||||
/* max_version */ 2);
|
/* max_version */ 2);
|
||||||
|
@ -43,7 +43,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input));
|
TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input));
|
||||||
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
|
if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
|
||||||
input->type != kTfLiteUInt8 && input->type != kTfLiteInt8) {
|
input->type != kTfLiteUInt8 && input->type != kTfLiteInt8 &&
|
||||||
|
input->type != kTfLiteBool) {
|
||||||
context->ReportError(context, "Type '%s' is not supported by unpack.",
|
context->ReportError(context, "Type '%s' is not supported by unpack.",
|
||||||
TfLiteTypeGetName(input->type));
|
TfLiteTypeGetName(input->type));
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
@ -112,6 +113,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
UnpackImpl<int8_t>(context, node, input, data->num, data->axis);
|
UnpackImpl<int8_t>(context, node, input, data->num, data->axis);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case kTfLiteBool: {
|
||||||
|
UnpackImpl<bool>(context, node, input, data->num, data->axis);
|
||||||
|
break;
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
context->ReportError(context, "Type '%s' is not supported by unpack.",
|
context->ReportError(context, "Type '%s' is not supported by unpack.",
|
||||||
TfLiteTypeGetName(input->type));
|
TfLiteTypeGetName(input->type));
|
||||||
|
@ -87,43 +87,43 @@ void Check(int axis, const std::initializer_list<int>& input_shape,
|
|||||||
TEST(UnpackOpTest, FloatThreeOutputs) {
|
TEST(UnpackOpTest, FloatThreeOutputs) {
|
||||||
Check<float>(/*axis=*/0, /*input_shape=*/{3, 2},
|
Check<float>(/*axis=*/0, /*input_shape=*/{3, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{2}, {2}, {2}},
|
/*exp_output_shape=*/{{2}, {2}, {2}},
|
||||||
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}});
|
/*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, FloatThreeOutputsAxisOne) {
|
TEST(UnpackOpTest, FloatThreeOutputsAxisOne) {
|
||||||
Check<float>(/*axis=*/1, /*input_shape=*/{3, 2},
|
Check<float>(/*axis=*/1, /*input_shape=*/{3, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{3}, {3}},
|
/*exp_output_shape=*/{{3}, {3}},
|
||||||
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}});
|
/*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, FloatThreeOutputsNegativeAxisOne) {
|
TEST(UnpackOpTest, FloatThreeOutputsNegativeAxisOne) {
|
||||||
Check<float>(/*axis=*/-1, /*input_shape=*/{3, 2},
|
Check<float>(/*axis=*/-1, /*input_shape=*/{3, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{3}, {3}},
|
/*exp_output_shape=*/{{3}, {3}},
|
||||||
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}});
|
/*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, FloatThreeOutputsNegativeAxisTwo) {
|
TEST(UnpackOpTest, FloatThreeOutputsNegativeAxisTwo) {
|
||||||
Check<float>(/*axis=*/-2, /*input_shape=*/{3, 2},
|
Check<float>(/*axis=*/-2, /*input_shape=*/{3, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{2}, {2}, {2}},
|
/*exp_output_shape=*/{{2}, {2}, {2}},
|
||||||
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}});
|
/*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, FloatOneOutput) {
|
TEST(UnpackOpTest, FloatOneOutput) {
|
||||||
Check<float>(/*axis=*/0, /*input_shape=*/{1, 6},
|
Check<float>(/*axis=*/0, /*input_shape=*/{1, 6},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{6}},
|
/*exp_output_shape=*/{{6}},
|
||||||
/*expected_output_data=*/{{1, 2, 3, 4, 5, 6}});
|
/*exp_output_data=*/{{1, 2, 3, 4, 5, 6}});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, FloatThreeDimensionsOutputs) {
|
TEST(UnpackOpTest, FloatThreeDimensionsOutputs) {
|
||||||
Check<float>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
|
Check<float>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
|
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
|
||||||
/*expected_output_shape=*/{{2, 2}, {2, 2}},
|
/*exp_output_shape=*/{{2, 2}, {2, 2}},
|
||||||
/*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}});
|
/*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, FloatVectorToScalar) {
|
TEST(UnpackOpTest, FloatVectorToScalar) {
|
||||||
@ -137,32 +137,32 @@ TEST(UnpackOpTest, FloatVectorToScalar) {
|
|||||||
TEST(UnpackOpTest, IntThreeOutputs) {
|
TEST(UnpackOpTest, IntThreeOutputs) {
|
||||||
Check<int32_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
Check<int32_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{2}, {2}, {2}},
|
/*exp_output_shape=*/{{2}, {2}, {2}},
|
||||||
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
/*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
||||||
/*type=*/TensorType_INT32);
|
/*type=*/TensorType_INT32);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, IntThreeOutputsAxisOne) {
|
TEST(UnpackOpTest, IntThreeOutputsAxisOne) {
|
||||||
Check<int32_t>(/*axis=*/1, /*input_shape=*/{3, 2},
|
Check<int32_t>(/*axis=*/1, /*input_shape=*/{3, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{3}, {3}},
|
/*exp_output_shape=*/{{3}, {3}},
|
||||||
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
/*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
||||||
/*type=*/TensorType_INT32);
|
/*type=*/TensorType_INT32);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, IntOneOutput) {
|
TEST(UnpackOpTest, IntOneOutput) {
|
||||||
Check<int32_t>(/*axis=*/0, /*input_shape=*/{1, 6},
|
Check<int32_t>(/*axis=*/0, /*input_shape=*/{1, 6},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{6}},
|
/*exp_output_shape=*/{{6}},
|
||||||
/*expected_output_data=*/{{1, 2, 3, 4, 5, 6}},
|
/*exp_output_data=*/{{1, 2, 3, 4, 5, 6}},
|
||||||
/*type=*/TensorType_INT32);
|
/*type=*/TensorType_INT32);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, IntThreeDimensionsOutputs) {
|
TEST(UnpackOpTest, IntThreeDimensionsOutputs) {
|
||||||
Check<int32_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
|
Check<int32_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
|
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
|
||||||
/*expected_output_shape=*/{{2, 2}, {2, 2}},
|
/*exp_output_shape=*/{{2, 2}, {2, 2}},
|
||||||
/*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
|
/*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
|
||||||
/*type=*/TensorType_INT32);
|
/*type=*/TensorType_INT32);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -178,48 +178,48 @@ TEST(UnpackOpTest, IntVectorToScalar) {
|
|||||||
TEST(UnpackOpTest, Uint8ThreeOutputs) {
|
TEST(UnpackOpTest, Uint8ThreeOutputs) {
|
||||||
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{2}, {2}, {2}},
|
/*exp_output_shape=*/{{2}, {2}, {2}},
|
||||||
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
/*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
||||||
/*type=*/TensorType_UINT8);
|
/*type=*/TensorType_UINT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, Uint8ThreeOutputsAxisOne) {
|
TEST(UnpackOpTest, Uint8ThreeOutputsAxisOne) {
|
||||||
Check<uint8_t>(/*axis=*/1, /*input_shape=*/{3, 2},
|
Check<uint8_t>(/*axis=*/1, /*input_shape=*/{3, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{3}, {3}},
|
/*exp_output_shape=*/{{3}, {3}},
|
||||||
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
/*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
||||||
/*type=*/TensorType_UINT8);
|
/*type=*/TensorType_UINT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisOne) {
|
TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisOne) {
|
||||||
Check<uint8_t>(/*axis=*/-1, /*input_shape=*/{3, 2},
|
Check<uint8_t>(/*axis=*/-1, /*input_shape=*/{3, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{3}, {3}},
|
/*exp_output_shape=*/{{3}, {3}},
|
||||||
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
/*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
||||||
/*type=*/TensorType_UINT8);
|
/*type=*/TensorType_UINT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisTwo) {
|
TEST(UnpackOpTest, Uint8ThreeOutputsNegativeAxisTwo) {
|
||||||
Check<uint8_t>(/*axis=*/-2, /*input_shape=*/{3, 2},
|
Check<uint8_t>(/*axis=*/-2, /*input_shape=*/{3, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{2}, {2}, {2}},
|
/*exp_output_shape=*/{{2}, {2}, {2}},
|
||||||
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
/*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
||||||
/*type=*/TensorType_UINT8);
|
/*type=*/TensorType_UINT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, Uint8OneOutput) {
|
TEST(UnpackOpTest, Uint8OneOutput) {
|
||||||
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{1, 6},
|
Check<uint8_t>(/*axis=*/0, /*input_shape=*/{1, 6},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{6}},
|
/*exp_output_shape=*/{{6}},
|
||||||
/*expected_output_data=*/{{1, 2, 3, 4, 5, 6}},
|
/*exp_output_data=*/{{1, 2, 3, 4, 5, 6}},
|
||||||
/*type=*/TensorType_UINT8);
|
/*type=*/TensorType_UINT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, Uint8ThreeDimensionsOutputs) {
|
TEST(UnpackOpTest, Uint8ThreeDimensionsOutputs) {
|
||||||
Check<uint8_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
|
Check<uint8_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
|
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
|
||||||
/*expected_output_shape=*/{{2, 2}, {2, 2}},
|
/*exp_output_shape=*/{{2, 2}, {2, 2}},
|
||||||
/*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
|
/*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
|
||||||
/*type=*/TensorType_UINT8);
|
/*type=*/TensorType_UINT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,48 +235,48 @@ TEST(UnpackOpTest, Uint8VectorToScalar) {
|
|||||||
TEST(UnpackOpTest, Int8ThreeOutputs) {
|
TEST(UnpackOpTest, Int8ThreeOutputs) {
|
||||||
Check<int8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
Check<int8_t>(/*axis=*/0, /*input_shape=*/{3, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{2}, {2}, {2}},
|
/*exp_output_shape=*/{{2}, {2}, {2}},
|
||||||
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
/*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
||||||
/*type=*/TensorType_INT8);
|
/*type=*/TensorType_INT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, Int8ThreeOutputsAxisOne) {
|
TEST(UnpackOpTest, Int8ThreeOutputsAxisOne) {
|
||||||
Check<int8_t>(/*axis=*/1, /*input_shape=*/{3, 2},
|
Check<int8_t>(/*axis=*/1, /*input_shape=*/{3, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{3}, {3}},
|
/*exp_output_shape=*/{{3}, {3}},
|
||||||
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
/*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
||||||
/*type=*/TensorType_INT8);
|
/*type=*/TensorType_INT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisOne) {
|
TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisOne) {
|
||||||
Check<int8_t>(/*axis=*/-1, /*input_shape=*/{3, 2},
|
Check<int8_t>(/*axis=*/-1, /*input_shape=*/{3, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{3}, {3}},
|
/*exp_output_shape=*/{{3}, {3}},
|
||||||
/*expected_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
/*exp_output_data=*/{{1, 3, 5}, {2, 4, 6}},
|
||||||
/*type=*/TensorType_INT8);
|
/*type=*/TensorType_INT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisTwo) {
|
TEST(UnpackOpTest, Int8ThreeOutputsNegativeAxisTwo) {
|
||||||
Check<int8_t>(/*axis=*/-2, /*input_shape=*/{3, 2},
|
Check<int8_t>(/*axis=*/-2, /*input_shape=*/{3, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{2}, {2}, {2}},
|
/*exp_output_shape=*/{{2}, {2}, {2}},
|
||||||
/*expected_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
/*exp_output_data=*/{{1, 2}, {3, 4}, {5, 6}},
|
||||||
/*type=*/TensorType_INT8);
|
/*type=*/TensorType_INT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, Int8OneOutput) {
|
TEST(UnpackOpTest, Int8OneOutput) {
|
||||||
Check<int8_t>(/*axis=*/0, /*input_shape=*/{1, 6},
|
Check<int8_t>(/*axis=*/0, /*input_shape=*/{1, 6},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
/*input_data=*/{1, 2, 3, 4, 5, 6},
|
||||||
/*expected_output_shape=*/{{6}},
|
/*exp_output_shape=*/{{6}},
|
||||||
/*expected_output_data=*/{{1, 2, 3, 4, 5, 6}},
|
/*exp_output_data=*/{{1, 2, 3, 4, 5, 6}},
|
||||||
/*type=*/TensorType_INT8);
|
/*type=*/TensorType_INT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnpackOpTest, Int8ThreeDimensionsOutputs) {
|
TEST(UnpackOpTest, Int8ThreeDimensionsOutputs) {
|
||||||
Check<int8_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
|
Check<int8_t>(/*axis=*/2, /*input_shape=*/{2, 2, 2},
|
||||||
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
|
/*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8},
|
||||||
/*expected_output_shape=*/{{2, 2}, {2, 2}},
|
/*exp_output_shape=*/{{2, 2}, {2, 2}},
|
||||||
/*expected_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
|
/*exp_output_data=*/{{1, 3, 5, 7}, {2, 4, 6, 8}},
|
||||||
/*type=*/TensorType_INT8);
|
/*type=*/TensorType_INT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -288,5 +288,69 @@ TEST(UnpackOpTest, Int8VectorToScalar) {
|
|||||||
/*type=*/TensorType_INT8);
|
/*type=*/TensorType_INT8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// bool tests.
|
||||||
|
TEST(UnpackOpTest, BoolThreeOutputs) {
|
||||||
|
Check<bool>(
|
||||||
|
/*axis=*/0, /*input_shape=*/{3, 2},
|
||||||
|
/*input_data=*/{true, false, true, false, true, false},
|
||||||
|
/*exp_output_shape=*/{{2}, {2}, {2}},
|
||||||
|
/*exp_output_data=*/{{true, false}, {true, false}, {true, false}},
|
||||||
|
/*type=*/TensorType_BOOL);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, BoolThreeOutputsAxisOne) {
|
||||||
|
Check<bool>(
|
||||||
|
/*axis=*/1, /*input_shape=*/{3, 2},
|
||||||
|
/*input_data=*/{true, false, true, false, true, false},
|
||||||
|
/*exp_output_shape=*/{{3}, {3}},
|
||||||
|
/*exp_output_data=*/{{true, true, true}, {false, false, false}},
|
||||||
|
/*type=*/TensorType_BOOL);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, BoolThreeOutputsNegativeAxisOne) {
|
||||||
|
Check<bool>(
|
||||||
|
/*axis=*/-1, /*input_shape=*/{3, 2},
|
||||||
|
/*input_data=*/{true, false, true, false, true, false},
|
||||||
|
/*exp_output_shape=*/{{3}, {3}},
|
||||||
|
/*exp_output_data=*/{{true, true, true}, {false, false, false}},
|
||||||
|
/*type=*/TensorType_BOOL);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, BoolThreeOutputsNegativeAxisTwo) {
|
||||||
|
Check<bool>(
|
||||||
|
/*axis=*/-2, /*input_shape=*/{3, 2},
|
||||||
|
/*input_data=*/{true, false, true, false, true, false},
|
||||||
|
/*exp_output_shape=*/{{2}, {2}, {2}},
|
||||||
|
/*exp_output_data=*/{{true, false}, {true, false}, {true, false}},
|
||||||
|
/*type=*/TensorType_BOOL);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, BoolOneOutput) {
|
||||||
|
Check<bool>(
|
||||||
|
/*axis=*/0, /*input_shape=*/{1, 6},
|
||||||
|
/*input_data=*/{true, false, true, false, true, false},
|
||||||
|
/*exp_output_shape=*/{{6}},
|
||||||
|
/*exp_output_data=*/{{true, false, true, false, true, false}},
|
||||||
|
/*type=*/TensorType_BOOL);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, BoolThreeDimensionsOutputs) {
|
||||||
|
Check<bool>(
|
||||||
|
/*axis=*/2, /*input_shape=*/{2, 2, 2},
|
||||||
|
/*input_data=*/{true, false, true, false, true, false, true, false},
|
||||||
|
/*exp_output_shape=*/{{2, 2}, {2, 2}},
|
||||||
|
/*exp_output_data=*/
|
||||||
|
{{true, true, true, true}, {false, false, false, false}},
|
||||||
|
/*type=*/TensorType_BOOL);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(UnpackOpTest, BoolVectorToScalar) {
|
||||||
|
Check<bool>(/*axis=*/0, /*input_shape=*/{5},
|
||||||
|
/*input_data=*/{true, false, true, false, true},
|
||||||
|
/*exp_output_shape=*/{{}, {}, {}, {}, {}},
|
||||||
|
/*exp_output_data=*/{{true}, {false}, {true}, {false}, {true}},
|
||||||
|
/*type=*/TensorType_BOOL);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.lite.testing.zip_test_utils import create_tensor_data
|
from tensorflow.lite.testing.zip_test_utils import create_tensor_data
|
||||||
from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
|
from tensorflow.lite.testing.zip_test_utils import make_zip_of_tests
|
||||||
@ -31,6 +30,7 @@ def make_unpack_tests(options):
|
|||||||
test_parameters = [{
|
test_parameters = [{
|
||||||
"base_shape": [[3, 4, 3], [3, 4], [5, 6, 7, 8]],
|
"base_shape": [[3, 4, 3], [3, 4], [5, 6, 7, 8]],
|
||||||
"axis": [0, 1, 2, 3],
|
"axis": [0, 1, 2, 3],
|
||||||
|
"dtype": [tf.int32, tf.bool, tf.float32],
|
||||||
}]
|
}]
|
||||||
|
|
||||||
def get_valid_axis(parameters):
|
def get_valid_axis(parameters):
|
||||||
@ -43,12 +43,15 @@ def make_unpack_tests(options):
|
|||||||
|
|
||||||
def build_graph(parameters):
|
def build_graph(parameters):
|
||||||
input_tensor = tf.compat.v1.placeholder(
|
input_tensor = tf.compat.v1.placeholder(
|
||||||
dtype=tf.float32, name=("input"), shape=parameters["base_shape"])
|
dtype=parameters["dtype"],
|
||||||
|
name=("input"),
|
||||||
|
shape=parameters["base_shape"])
|
||||||
outs = tf.unstack(input_tensor, axis=get_valid_axis(parameters))
|
outs = tf.unstack(input_tensor, axis=get_valid_axis(parameters))
|
||||||
return [input_tensor], [outs[0]]
|
return [input_tensor], [outs[0]]
|
||||||
|
|
||||||
def build_inputs(parameters, sess, inputs, outputs):
|
def build_inputs(parameters, sess, inputs, outputs):
|
||||||
input_value = create_tensor_data(np.float32, shape=parameters["base_shape"])
|
input_value = create_tensor_data(
|
||||||
|
parameters["dtype"], shape=parameters["base_shape"])
|
||||||
return [input_value], sess.run(
|
return [input_value], sess.run(
|
||||||
outputs, feed_dict=dict(zip(inputs, [input_value])))
|
outputs, feed_dict=dict(zip(inputs, [input_value])))
|
||||||
|
|
||||||
|
@ -168,6 +168,8 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
|
|||||||
{{OperatorType::kOneHot, 1}, "1.11.0"},
|
{{OperatorType::kOneHot, 1}, "1.11.0"},
|
||||||
{{OperatorType::kCTCBeamSearchDecoder, 1}, "1.11.0"},
|
{{OperatorType::kCTCBeamSearchDecoder, 1}, "1.11.0"},
|
||||||
{{OperatorType::kUnpack, 1}, "1.11.0"},
|
{{OperatorType::kUnpack, 1}, "1.11.0"},
|
||||||
|
{{OperatorType::kUnpack, 2}, "1.14.0"},
|
||||||
|
{{OperatorType::kUnpack, 3}, kPendingReleaseOpVersion},
|
||||||
{{OperatorType::kLeakyRelu, 1}, "1.13.1"},
|
{{OperatorType::kLeakyRelu, 1}, "1.13.1"},
|
||||||
{{OperatorType::kLogistic, 1}, "1.14.0"},
|
{{OperatorType::kLogistic, 1}, "1.14.0"},
|
||||||
{{OperatorType::kLogistic, 2}, "1.14.0"},
|
{{OperatorType::kLogistic, 2}, "1.14.0"},
|
||||||
|
@ -1349,6 +1349,21 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
|
|||||||
op->num = options.num();
|
op->num = options.num();
|
||||||
op->axis = options.axis();
|
op->axis = options.axis();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int GetVersion(const OperatorSignature& op_signature) const override {
|
||||||
|
const string& input_name = op_signature.op->inputs[0];
|
||||||
|
const Array& input_array = op_signature.model->GetArray(input_name);
|
||||||
|
// If the op take int8/uint8 input, it is version 2.
|
||||||
|
if (input_array.data_type == ArrayDataType::kInt8 ||
|
||||||
|
input_array.data_type == ArrayDataType::kUint8) {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
// If the op take bool input, it is version 3.
|
||||||
|
if (input_array.data_type == ArrayDataType::kBool) {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class LeakyRelu
|
class LeakyRelu
|
||||||
|
@ -219,6 +219,10 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||||||
op_sig.input_types.at(0) == TensorType_UINT8) {
|
op_sig.input_types.at(0) == TensorType_UINT8) {
|
||||||
return 2;
|
return 2;
|
||||||
}
|
}
|
||||||
|
// If the op take bool input, it is version 3.
|
||||||
|
if (op_sig.input_types.at(0) == TensorType_BOOL) {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
return 1;
|
return 1;
|
||||||
|
|
||||||
case BuiltinOperator_DEQUANTIZE:
|
case BuiltinOperator_DEQUANTIZE:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user