Adds support for half_pixel_centers in TFLite's resize nearest neighbor op.

PiperOrigin-RevId: 310035959
Change-Id: I83238d568e2b4ebe0d844847d72c274e89651faf
This commit is contained in:
Sachin Joglekar 2020-05-05 15:43:48 -07:00 committed by TensorFlower Gardener
parent 24ef7cae6d
commit 48baba71cf
24 changed files with 205 additions and 37 deletions

View File

@ -2999,7 +2999,8 @@ def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor",
let arguments = (ins
TFL_TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$input,
TFL_TensorOf<[I32]>:$size,
BoolAttr:$align_corners
BoolAttr:$align_corners,
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
);
let results = (outs

View File

@ -1213,15 +1213,14 @@ func @resize_nearest_neighbor(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi3
%0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: resize_nearest_neighbor
// CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
// CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) {align_corners = true, half_pixel_centers = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
}
// Note: half_pixel_centers isn't supported by TFLite, so it's not legalized.
func @resize_nearest_neighbor_with_half_pixel_centers(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
%0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
%0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = false, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
// CHECK-LABEL: resize_nearest_neighbor_with_half_pixel_centers
// CHECK: "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true, half_pixel_centers = true}
// CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) {align_corners = false, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
}
func @sparse_to_dense_with_scalar_sparse_indices(%arg0: tensor<i32>, %arg1: tensor<3xi32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<?x?x?xf32> {

View File

@ -192,7 +192,7 @@ func @testSquare(tensor<? x f32>) -> tensor<? x f32> {
func @testQuantizedResizeNearestNeighbor(tensor<? x !quant.uniform<u8:f32, 0.1>>, tensor<? x i32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> {
^bb0(%arg0: tensor<? x !quant.uniform<u8:f32, 0.1>>, %arg1: tensor<? x i32>):
%0 = "tfl.resize_nearest_neighbor"(%arg0, %arg1) { align_corners = false } : (tensor<? x !quant.uniform<u8:f32, 0.1>>, tensor<? x i32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>>
%0 = "tfl.resize_nearest_neighbor"(%arg0, %arg1) { align_corners = false, half_pixel_centers = false } : (tensor<? x !quant.uniform<u8:f32, 0.1>>, tensor<? x i32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>>
return %0 : tensor<? x !quant.uniform<u8:f32, 0.1>>
}

View File

@ -302,7 +302,7 @@ def : Pat<(TF_DepthToSpaceOp $input, $block_size, IsDataFormatNHWC:$data_format)
(TFL_DepthToSpaceOp $input, (convertIntAttrTo32Bit $block_size))>;
def : Pat<(TF_ResizeBilinearOp $images, $size, $align_corners, $half_pixel_centers), (TFL_ResizeBilinearOp $images, $size, $align_corners, $half_pixel_centers)>;
def : Pat<(TF_ResizeNearestNeighborOp $images, $size, $align_corners, ConstBoolAttrFalse:$half_pixel_centers), (TFL_ResizeNearestNeighborOp $images, $size, $align_corners)>;
def : Pat<(TF_ResizeNearestNeighborOp $images, $size, $align_corners, $half_pixel_centers), (TFL_ResizeNearestNeighborOp $images, $size, $align_corners, $half_pixel_centers)>;
def : Pat<(TF_MirrorPadOp $arg0, $arg1, $cst), (TFL_MirrorPadOp $arg0, $arg1, $cst)>;

View File

@ -326,6 +326,7 @@ def generated_test_models():
"relu6",
"reshape",
"resize_bilinear",
"resize_nearest_neighbor",
"resolve_constant_strided_slice",
"reverse_sequence",
"reverse_v2",

View File

@ -297,6 +297,7 @@ typedef struct {
typedef struct {
bool align_corners;
bool half_pixel_centers;
} TfLiteResizeNearestNeighborParams;
typedef struct {

View File

@ -536,6 +536,10 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
if (const auto* schema_params =
op->builtin_options_as_ResizeNearestNeighborOptions()) {
params->align_corners = schema_params->align_corners();
params->half_pixel_centers = schema_params->half_pixel_centers();
} else {
params->align_corners = false;
params->half_pixel_centers = false;
}
*builtin_data = params.release();
return kTfLiteOk;

View File

@ -304,6 +304,9 @@ VariedShapeSpec/ReshapeOpTest/WithStretchDimension/1
ResizeBilinearOpTest/ResizeBilinearOpTest/.+/0,29
# resize_nearest_neighbor_test
// align_corners & half_pixel_centers are not implemented in NNAPI.
-ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest.+AlignCorners.*,29
-ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest.+HalfPixelCenters.*,29
// Only models with constant size tensor are accelerated
ResizeNearestNeighborOpTest/ResizeNearestNeighborOpTest/.+/0,29

View File

@ -1668,9 +1668,14 @@ bool NNAPIDelegateKernel::Validate(
ExpectIsFloatOrQuant8Operator(context, node, &val_ctx);
auto builtin = reinterpret_cast<TfLiteResizeNearestNeighborParams*>(
node->builtin_data);
// TODO(b/149823713): Update when NNAPI delegate can support align_corners
// & half_pixel_centers.
Expect(!builtin->align_corners,
NNAPIValidationFailureType::kUnsupportedOperandValue,
"NNAPI does not support align_corners == true.", &val_ctx);
Expect(!builtin->half_pixel_centers,
NNAPIValidationFailureType::kUnsupportedOperandValue,
"NNAPI does not support half_pixel_centers == true.", &val_ctx);
} break;
case kTfLiteBuiltinSqueeze: {
ExpectOpVersion(version, 1, &val_ctx);

View File

@ -121,7 +121,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
Register_RESIZE_NEAREST_NEIGHBOR(),
/* min_version = */ 1,
/* max_version = */ 2);
/* max_version = */ 3);
AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM());
AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH(),
/* min_version = */ 1,

View File

@ -89,7 +89,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::ResizeNearestNeighborParams op_params;
op_params.align_corners = params->align_corners;
op_params.half_pixel_centers = false;
op_params.half_pixel_centers = params->half_pixel_centers;
if (output->type == kTfLiteFloat32) {
reference_ops::ResizeNearestNeighbor(

View File

@ -33,7 +33,9 @@ class ResizeNearestNeighborOpModel : public SingleOpModel {
public:
explicit ResizeNearestNeighborOpModel(const TensorData& input,
std::initializer_list<int> size_data,
TestType test_type) {
TestType test_type,
bool align_corners = false,
bool half_pixel_centers = false) {
bool const_size = (test_type == TestType::kConst);
input_ = AddInput(input);
@ -45,7 +47,10 @@ class ResizeNearestNeighborOpModel : public SingleOpModel {
output_ = AddOutput(input.type);
SetBuiltinOp(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
BuiltinOptions_ResizeNearestNeighborOptions,
CreateResizeNearestNeighborOptions(builder_).Union());
CreateResizeNearestNeighborOptions(
builder_, /*align_corners*/ align_corners,
/*half_pixel_centers*/ half_pixel_centers)
.Union());
if (const_size) {
BuildInterpreter({GetShape(input_)});
} else {
@ -182,6 +187,47 @@ TEST_P(ResizeNearestNeighborOpTest, TwoDimensionalResizeWithTwoBatches) {
10, 10, 16, //
})));
}
TEST_P(ResizeNearestNeighborOpTest,
TwoDimensionalResizeWithTwoBatches_AlignCorners) {
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3},
GetParam(), /**align_corners**/ true);
m.SetInput<float>({
3, 6, //
9, 12, //
4, 10, //
10, 16 //
});
m.Invoke();
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
3, 6, 6, //
9, 12, 12, //
9, 12, 12, //
4, 10, 10, //
10, 16, 16, //
10, 16, 16, //
})));
}
TEST_P(ResizeNearestNeighborOpTest,
TwoDimensionalResizeWithTwoBatches_HalfPixelCenters) {
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3},
GetParam(), /**align_corners**/ false,
/**half_pixel_centers**/ true);
m.SetInput<float>({
3, 6, //
9, 12, //
4, 10, //
10, 16 //
});
m.Invoke();
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
3, 6, 6, //
9, 12, 12, //
9, 12, 12, //
4, 10, 10, //
10, 16, 16, //
10, 16, 16, //
})));
}
TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResize) {
ResizeNearestNeighborOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3},
GetParam());
@ -248,6 +294,36 @@ TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResizeUInt8) {
10, 12, 10, 12, 14, 16, //
})));
}
TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResizeUInt8_AlignCorners) {
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3},
GetParam(), /**align_corners**/ true);
m.SetInput<uint8>({
3, 4, 6, 10, //
10, 12, 14, 16, //
});
m.Invoke();
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
3, 4, 6, 10, 6, 10, //
10, 12, 14, 16, 14, 16, //
10, 12, 14, 16, 14, 16, //
})));
}
TEST_P(ResizeNearestNeighborOpTest,
ThreeDimensionalResizeUInt8_HalfPixelCenters) {
ResizeNearestNeighborOpModel m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3},
GetParam(), /**align_corners**/ false,
/**half_pixel_centers**/ true);
m.SetInput<uint8>({
3, 4, 6, 10, //
10, 12, 14, 16, //
});
m.Invoke();
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
3, 4, 6, 10, 6, 10, //
10, 12, 14, 16, 14, 16, //
10, 12, 14, 16, 14, 16, //
})));
}
TEST_P(ResizeNearestNeighborOpTest, ThreeDimensionalResizeInt8) {
ResizeNearestNeighborOpModel m({TensorType_INT8, {1, 2, 2, 2}}, {3, 3},
GetParam());

View File

@ -662,6 +662,7 @@ table ResizeBilinearOptions {
table ResizeNearestNeighborOptions {
align_corners: bool;
half_pixel_centers: bool;
}
// A call operation options

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
// automatically generated by the FlatBuffers compiler, do not modify
#ifndef FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_
#define FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_
@ -5375,22 +5376,29 @@ flatbuffers::Offset<ResizeBilinearOptions> CreateResizeBilinearOptions(flatbuffe
struct ResizeNearestNeighborOptionsT : public flatbuffers::NativeTable {
typedef ResizeNearestNeighborOptions TableType;
bool align_corners;
bool half_pixel_centers;
ResizeNearestNeighborOptionsT()
: align_corners(false) {
: align_corners(false),
half_pixel_centers(false) {
}
};
struct ResizeNearestNeighborOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef ResizeNearestNeighborOptionsT NativeTableType;
enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
VT_ALIGN_CORNERS = 4
VT_ALIGN_CORNERS = 4,
VT_HALF_PIXEL_CENTERS = 6
};
bool align_corners() const {
return GetField<uint8_t>(VT_ALIGN_CORNERS, 0) != 0;
}
bool half_pixel_centers() const {
return GetField<uint8_t>(VT_HALF_PIXEL_CENTERS, 0) != 0;
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<uint8_t>(verifier, VT_ALIGN_CORNERS) &&
VerifyField<uint8_t>(verifier, VT_HALF_PIXEL_CENTERS) &&
verifier.EndTable();
}
ResizeNearestNeighborOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@ -5404,6 +5412,9 @@ struct ResizeNearestNeighborOptionsBuilder {
void add_align_corners(bool align_corners) {
fbb_.AddElement<uint8_t>(ResizeNearestNeighborOptions::VT_ALIGN_CORNERS, static_cast<uint8_t>(align_corners), 0);
}
void add_half_pixel_centers(bool half_pixel_centers) {
fbb_.AddElement<uint8_t>(ResizeNearestNeighborOptions::VT_HALF_PIXEL_CENTERS, static_cast<uint8_t>(half_pixel_centers), 0);
}
explicit ResizeNearestNeighborOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@ -5418,8 +5429,10 @@ struct ResizeNearestNeighborOptionsBuilder {
inline flatbuffers::Offset<ResizeNearestNeighborOptions> CreateResizeNearestNeighborOptions(
flatbuffers::FlatBufferBuilder &_fbb,
bool align_corners = false) {
bool align_corners = false,
bool half_pixel_centers = false) {
ResizeNearestNeighborOptionsBuilder builder_(_fbb);
builder_.add_half_pixel_centers(half_pixel_centers);
builder_.add_align_corners(align_corners);
return builder_.Finish();
}
@ -11634,6 +11647,7 @@ inline void ResizeNearestNeighborOptions::UnPackTo(ResizeNearestNeighborOptionsT
(void)_o;
(void)_resolver;
{ auto _e = align_corners(); _o->align_corners = _e; }
{ auto _e = half_pixel_centers(); _o->half_pixel_centers = _e; }
}
inline flatbuffers::Offset<ResizeNearestNeighborOptions> ResizeNearestNeighborOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResizeNearestNeighborOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@ -11645,9 +11659,11 @@ inline flatbuffers::Offset<ResizeNearestNeighborOptions> CreateResizeNearestNeig
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ResizeNearestNeighborOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _align_corners = _o->align_corners;
auto _half_pixel_centers = _o->half_pixel_centers;
return tflite::CreateResizeNearestNeighborOptions(
_fbb,
_align_corners);
_align_corners,
_half_pixel_centers);
}
inline CallOptionsT *CallOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {

View File

@ -31,35 +31,35 @@ def make_resize_bilinear_tests(options):
"dtype": [tf.float32, tf.int32],
"input_shape": [[1, 3, 4, 3], [1, 10, 2, 1]],
"size": [[1, 1], [4, 3], [2, 2], [5, 6]],
"align_corners": [None, True, False],
"align_corners": [True, False],
"half_pixel_centers": [False],
"fully_quantize": [False]
}, {
"dtype": [tf.float32],
"input_shape": [[1, 3, 4, 3], [1, 10, 2, 1]],
"size": [[1, 1], [4, 3], [2, 2], [5, 6]],
"align_corners": [None, True, False],
"align_corners": [True, False],
"half_pixel_centers": [False],
"fully_quantize": [True]
}, {
"dtype": [tf.float32],
"input_shape": [[1, 16, 24, 3], [1, 12, 18, 3]],
"size": [[8, 12], [12, 18]],
"align_corners": [None, True, False],
"align_corners": [True, False],
"half_pixel_centers": [False],
"fully_quantize": [True]
}, {
"dtype": [tf.float32],
"input_shape": [[1, 16, 24, 3], [1, 12, 18, 3]],
"size": [[8, 12]],
"align_corners": [None, False],
"align_corners": [False],
"half_pixel_centers": [True],
"fully_quantize": [True]
}, {
"dtype": [tf.float32, tf.int32],
"input_shape": [[1, 3, 4, 3], [1, 10, 2, 1]],
"size": [[1, 1], [4, 3], [2, 2], [5, 6]],
"align_corners": [None, False],
"align_corners": [False],
"half_pixel_centers": [True],
"fully_quantize": [False]
}]

View File

@ -28,23 +28,26 @@ def make_resize_nearest_neighbor_tests(options):
"""Make a set of tests to do resize_nearest_neighbor."""
test_parameters = [{
"dtype": [tf.float32, tf.int32],
"input_shape": [[1, 3, 4, 3], [1, 10, 2, 1]],
"size": [[1, 1], [4, 3], [2, 2], [5, 6]],
"align_corners": [False],
"fully_quantize": [False],
}, {
"dtype": [tf.float32],
"input_shape": [[1, 3, 4, 3], [1, 10, 2, 1]],
"size": [[1, 1], [4, 3], [2, 2], [5, 6]],
"align_corners": [False],
"fully_quantize": [True],
"half_pixel_centers": [False],
"fully_quantize": [True, False],
}, {
"dtype": [tf.float32],
"input_shape": [[1, 16, 24, 3], [1, 12, 18, 3]],
"size": [[8, 12], [12, 18]],
"align_corners": [None, True, False],
"fully_quantize": [True]
"align_corners": [True],
"half_pixel_centers": [False],
"fully_quantize": [True, False]
}, {
"dtype": [tf.float32],
"input_shape": [[1, 16, 24, 3], [1, 12, 18, 3]],
"size": [[8, 12], [12, 18]],
"align_corners": [False],
"half_pixel_centers": [True],
"fully_quantize": [True, False]
}]
def build_graph(parameters):
@ -55,7 +58,8 @@ def make_resize_nearest_neighbor_tests(options):
out = tf.image.resize_nearest_neighbor(
input_tensor,
size=parameters["size"],
align_corners=parameters["align_corners"])
align_corners=parameters["align_corners"],
half_pixel_centers=parameters["half_pixel_centers"])
return [input_tensor], [out]
def build_inputs(parameters, sess, inputs, outputs):

View File

@ -1417,6 +1417,21 @@ void ConvertResizeBilinearOperator(const Model& model,
src_op.half_pixel_centers);
}
void ConvertResizeNearestNeighborOperator(
const Model& model, const ResizeNearestNeighborOperator& src_op,
GraphDef* tensorflow_graph) {
tensorflow::NodeDef* resize_op = tensorflow_graph->add_node();
resize_op->set_op("ResizeNearestNeighbor");
resize_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
*resize_op->add_input() = src_op.inputs[0];
*resize_op->add_input() = src_op.inputs[1];
(*resize_op->mutable_attr())["T"].set_type(DT_FLOAT);
(*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
(*resize_op->mutable_attr())["half_pixel_centers"].set_b(
src_op.half_pixel_centers);
}
void ConvertOneHotOperator(const Model& model, const OneHotOperator& src_op,
GraphDef* tensorflow_graph) {
tensorflow::NodeDef* onehot_op = tensorflow_graph->add_node();
@ -2227,6 +2242,10 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertResizeBilinearOperator(
model, static_cast<const ResizeBilinearOperator&>(src_op),
tensorflow_graph);
} else if (src_op.type == OperatorType::kResizeNearestNeighbor) {
ConvertResizeNearestNeighborOperator(
model, static_cast<const ResizeNearestNeighborOperator&>(src_op),
tensorflow_graph);
} else if (src_op.type == OperatorType::kSpaceToBatchND) {
ConvertSpaceToBatchNDOperator(
model, static_cast<const SpaceToBatchNDOperator&>(src_op),

View File

@ -1731,9 +1731,13 @@ tensorflow::Status ConvertResizeNearestNeighborOperator(
auto* op = new ResizeNearestNeighborOperator;
op->align_corners = false;
op->half_pixel_centers = false;
if (HasAttr(node, "align_corners")) {
op->align_corners = GetBoolAttr(node, "align_corners");
}
if (HasAttr(node, "half_pixel_centers")) {
op->half_pixel_centers = GetBoolAttr(node, "half_pixel_centers");
}
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));

View File

@ -1861,6 +1861,7 @@ struct ResizeNearestNeighborOperator : Operator {
: Operator(OperatorType::kResizeNearestNeighbor) {}
bool align_corners = false;
bool half_pixel_centers = false;
};
// SpaceToBatchND operator. It divides spatial dimensions into a grid of

View File

@ -133,6 +133,7 @@ string GetMinimumRuntimeVersionForModel(const Model& model) {
{{OperatorType::kResizeBilinear, 3}, "2.2.0"},
{{OperatorType::kResizeNearestNeighbor, 1}, "1.13.1"},
{{OperatorType::kResizeNearestNeighbor, 2}, "1.14.0"},
{{OperatorType::kResizeNearestNeighbor, 3}, kPendingReleaseOpVersion},
{{OperatorType::kSqueeze, 1}, "1.6.0"},
{{OperatorType::kSplit, 1}, "1.5.0"},
{{OperatorType::kSplit, 2}, "1.14.0"},

View File

@ -1116,7 +1116,7 @@ class ResizeBilinear
static_cast<const ResizeBilinearOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.resize_bilinear.half_pixel_centers =
op_sig.options.resize.half_pixel_centers =
resize_bilinear_op.half_pixel_centers;
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
@ -1131,13 +1131,23 @@ class ResizeNearestNeighbor
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
return ::tflite::CreateResizeNearestNeighborOptions(*builder,
op.align_corners);
return ::tflite::CreateResizeNearestNeighborOptions(
*builder, op.align_corners, op.half_pixel_centers);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->align_corners = options.align_corners();
op->half_pixel_centers = options.half_pixel_centers();
}
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& resize_nn_op =
static_cast<const ResizeNearestNeighborOperator&>(*op_signature.op);
::tflite::OpSignature op_sig =
GetVersioningOpSig(builtin_op(), op_signature);
op_sig.options.resize.half_pixel_centers = resize_nn_op.half_pixel_centers;
return ::tflite::GetBuiltinOperatorVersion(op_sig);
}
};

View File

@ -436,11 +436,25 @@ TEST_F(OperatorTest, ResizeBilinear_HalfPixelCenters) {
TEST_F(OperatorTest, ResizeNearestNeighbor) {
ResizeNearestNeighborOperator op;
op.align_corners = true;
op.half_pixel_centers = false;
auto output_toco_op =
SerializeAndDeserialize(GetOperator("RESIZE_NEAREST_NEIGHBOR",
OperatorType::kResizeNearestNeighbor),
op);
EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
EXPECT_EQ(op.half_pixel_centers, output_toco_op->half_pixel_centers);
}
TEST_F(OperatorTest, ResizeNearestNeighbor_HalfPixelCenters) {
ResizeNearestNeighborOperator op;
op.align_corners = true;
op.half_pixel_centers = true;
auto output_toco_op =
SerializeAndDeserialize(GetOperator("RESIZE_NEAREST_NEIGHBOR",
OperatorType::kResizeNearestNeighbor),
op);
EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
EXPECT_EQ(op.half_pixel_centers, output_toco_op->half_pixel_centers);
}
TEST_F(OperatorTest, Svdf) {

View File

@ -332,7 +332,8 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
}
return 1;
case BuiltinOperator_RESIZE_BILINEAR:
if (op_sig.options.resize_bilinear.half_pixel_centers) {
case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
if (op_sig.options.resize.half_pixel_centers) {
return 3;
} else if (op_sig.input_types.at(0) == TensorType_INT8) {
return 2;
@ -438,7 +439,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
case BuiltinOperator_REDUCE_MAX:
case BuiltinOperator_REDUCE_MIN:
case BuiltinOperator_RELU6:
case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
case BuiltinOperator_TANH:
case BuiltinOperator_LOGISTIC:
case BuiltinOperator_LOG_SOFTMAX:
@ -554,10 +554,18 @@ OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
auto resize_bilinear_option =
op->builtin_options_as_ResizeBilinearOptions();
if (resize_bilinear_option) {
op_sig.options.resize_bilinear.half_pixel_centers =
op_sig.options.resize.half_pixel_centers =
resize_bilinear_option->half_pixel_centers();
}
} break;
case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: {
auto resize_nn_option =
op->builtin_options_as_ResizeNearestNeighborOptions();
if (resize_nn_option) {
op_sig.options.resize.half_pixel_centers =
resize_nn_option->half_pixel_centers();
}
} break;
// TODO(b/150176627): Add tests for GetOpSignature.
case BuiltinOperator_STRIDED_SLICE:
case BuiltinOperator_SPACE_TO_BATCH_ND:

View File

@ -48,7 +48,7 @@ typedef struct {
} lstm;
struct {
bool half_pixel_centers;
} resize_bilinear;
} resize;
struct {
int32_t num_dims;
} single_input_op;