Make TFLite BatchToSpaceND op have parity with TF BatchToSpaceND op.
PiperOrigin-RevId: 183687487
This commit is contained in:
parent
1f26c65254
commit
730071d0dc
@ -127,14 +127,6 @@ typedef struct {
|
||||
} TfLiteSpaceToBatchNDParams;
|
||||
|
||||
typedef struct {
|
||||
// Number of spatial dimensions.
|
||||
// For now only NHWC is supported, and the value should always be 2.
|
||||
int num_spatial_dimensions;
|
||||
// TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
|
||||
// For now we will fix the maximum possible number of dimensions.
|
||||
int block_shape[2];
|
||||
int before_crops[2];
|
||||
int after_crops[2];
|
||||
} TfLiteBatchToSpaceNDParams;
|
||||
|
||||
typedef struct {
|
||||
|
@ -35,12 +35,14 @@ enum KernelType {
|
||||
|
||||
struct BatchToSpaceNDContext {
|
||||
BatchToSpaceNDContext(TfLiteContext* context, TfLiteNode* node) {
|
||||
params = reinterpret_cast<TfLiteBatchToSpaceNDParams*>(node->builtin_data);
|
||||
input = GetInput(context, node, 0);
|
||||
block_shape = GetInput(context, node, 1);
|
||||
crops = GetInput(context, node, 2);
|
||||
output = GetOutput(context, node, 0);
|
||||
}
|
||||
TfLiteBatchToSpaceNDParams* params;
|
||||
TfLiteTensor* input;
|
||||
TfLiteTensor* block_shape;
|
||||
TfLiteTensor* crops;
|
||||
TfLiteTensor* output;
|
||||
};
|
||||
|
||||
@ -48,24 +50,22 @@ struct BatchToSpaceNDContext {
|
||||
// The 4D array need to have exactly 2 spatial dimensions.
|
||||
// TODO(ycling): Support arbitrary dimension in BatchToSpaceND.
|
||||
const int kInputDimensionNum = 4;
|
||||
const int kOutputDimensionNum = 4;
|
||||
const int kBlockSizeDimensionNum = 1;
|
||||
const int kSpatialDimensionNum = 2;
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// The 2nd tensor (block_shape) and the 3rd tensor (crops) are ignored now.
|
||||
TF_LITE_ENSURE(context, NumInputs(node) >= 1 && NumInputs(node) <= 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
|
||||
BatchToSpaceNDContext* op_context) {
|
||||
TfLiteIntArray* input_size = op_context->input->dims;
|
||||
const int* block_shape = GetTensorData<int32>(op_context->block_shape);
|
||||
|
||||
BatchToSpaceNDContext op_context(context, node);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input),
|
||||
kInputDimensionNum);
|
||||
TF_LITE_ENSURE_EQ(context, op_context.params->num_spatial_dimensions,
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape),
|
||||
kBlockSizeDimensionNum);
|
||||
TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0],
|
||||
kSpatialDimensionNum);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops),
|
||||
kSpatialDimensionNum);
|
||||
TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
|
||||
|
||||
const TfLiteIntArray* input_size = op_context.input->dims;
|
||||
const int* block_shape = op_context.params->block_shape;
|
||||
|
||||
// TODO(ycling): Add crops as part of calculation.
|
||||
// Number of batch must be multiple of (block_shape[0] * block_shape[1]).
|
||||
TF_LITE_ENSURE_EQ(context,
|
||||
input_size->data[0] % (block_shape[0] * block_shape[1]), 0);
|
||||
@ -76,27 +76,48 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
const int output_width = input_size->data[2] * block_shape[1];
|
||||
const int output_channel_size = input_size->data[3];
|
||||
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCreate(kOutputDimensionNum);
|
||||
TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
|
||||
output_size->data[0] = output_batch_size;
|
||||
output_size->data[1] = output_height;
|
||||
output_size->data[2] = output_width;
|
||||
output_size->data[3] = output_channel_size;
|
||||
|
||||
return context->ResizeTensor(context, op_context.output, output_size);
|
||||
return context->ResizeTensor(context, op_context->output, output_size);
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
|
||||
|
||||
BatchToSpaceNDContext op_context(context, node);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input),
|
||||
kInputDimensionNum);
|
||||
TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
|
||||
|
||||
if (!IsConstantTensor(op_context.block_shape) ||
|
||||
!IsConstantTensor(op_context.crops)) {
|
||||
SetTensorToDynamic(op_context.output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
return ResizeOutputTensor(context, &op_context);
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
BatchToSpaceNDContext op_context(context, node);
|
||||
|
||||
int block_shape_dims_array[1] = {kSpatialDimensionNum};
|
||||
Dims<4> block_shape_dims = GetTensorDims(block_shape_dims_array, 1);
|
||||
// Resize the output tensor if the output tensor is dynamic.
|
||||
if (IsDynamicTensor(op_context.output)) {
|
||||
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
|
||||
TfLiteTensorRealloc(op_context.output->bytes, op_context.output);
|
||||
}
|
||||
|
||||
#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \
|
||||
type::BatchToSpaceND(GetTensorData<scalar>(op_context.input), \
|
||||
GetTensorDims(op_context.input), \
|
||||
op_context.params->block_shape, block_shape_dims, \
|
||||
GetTensorData<scalar>(op_context.output), \
|
||||
#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \
|
||||
type::BatchToSpaceND(GetTensorData<scalar>(op_context.input), \
|
||||
GetTensorDims(op_context.input), \
|
||||
GetTensorData<int32_t>(op_context.block_shape), \
|
||||
GetTensorDims(op_context.block_shape), \
|
||||
GetTensorData<scalar>(op_context.output), \
|
||||
GetTensorDims(op_context.output))
|
||||
switch (op_context.input->type) { // Already know in/out types are same.
|
||||
case kTfLiteFloat32:
|
||||
|
@ -26,36 +26,76 @@ using ::testing::ElementsAreArray;
|
||||
|
||||
class BatchToSpaceNDOpModel : public SingleOpModel {
|
||||
public:
|
||||
BatchToSpaceNDOpModel(std::initializer_list<int> input_shape,
|
||||
std::initializer_list<int> block_shape,
|
||||
std::initializer_list<int> before_crops,
|
||||
std::initializer_list<int> after_crops) {
|
||||
input_ = AddInput(TensorType_FLOAT32);
|
||||
output_ = AddOutput(TensorType_FLOAT32);
|
||||
SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND,
|
||||
BuiltinOptions_BatchToSpaceNDOptions,
|
||||
CreateBatchToSpaceNDOptions(
|
||||
builder_, builder_.CreateVector<int>(block_shape),
|
||||
builder_.CreateVector<int>(before_crops),
|
||||
builder_.CreateVector<int>(after_crops))
|
||||
.Union());
|
||||
BuildInterpreter({input_shape});
|
||||
}
|
||||
|
||||
void SetInput(std::initializer_list<float> data) {
|
||||
PopulateTensor<float>(input_, data);
|
||||
}
|
||||
|
||||
void SetBlockShape(std::initializer_list<int> data) {
|
||||
PopulateTensor<int>(block_shape_, data);
|
||||
}
|
||||
|
||||
void SetCrops(std::initializer_list<int> data) {
|
||||
PopulateTensor<int>(crops_, data);
|
||||
}
|
||||
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||
|
||||
private:
|
||||
protected:
|
||||
int input_;
|
||||
int block_shape_;
|
||||
int crops_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(BatchToSpaceNDOpTest, SimpleTest) {
|
||||
BatchToSpaceNDOpModel m({4, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0});
|
||||
// Tests case where block_shape and crops are const tensors.
|
||||
//
|
||||
// Example usage is as follows:
|
||||
// BatchToSpaceNDOpConstModel m(input_shape, block_shape, crops);
|
||||
// m.SetInput(input_data);
|
||||
// m.Invoke();
|
||||
class BatchToSpaceNDOpConstModel : public BatchToSpaceNDOpModel {
|
||||
public:
|
||||
BatchToSpaceNDOpConstModel(std::initializer_list<int> input_shape,
|
||||
std::initializer_list<int> block_shape,
|
||||
std::initializer_list<int> crops) {
|
||||
input_ = AddInput(TensorType_FLOAT32);
|
||||
block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2});
|
||||
crops_ = AddConstInput(TensorType_INT32, crops, {2, 2});
|
||||
output_ = AddOutput(TensorType_FLOAT32);
|
||||
|
||||
SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND,
|
||||
BuiltinOptions_BatchToSpaceNDOptions,
|
||||
CreateBatchToSpaceNDOptions(builder_).Union());
|
||||
BuildInterpreter({input_shape});
|
||||
}
|
||||
};
|
||||
|
||||
// Tests case where block_shape and crops are non-const tensors.
|
||||
//
|
||||
// Example usage is as follows:
|
||||
// BatchToSpaceNDOpDynamicModel m(input_shape);
|
||||
// m.SetInput(input_data);
|
||||
// m.SetBlockShape(block_shape);
|
||||
// m.SetPaddings(crops);
|
||||
// m.Invoke();
|
||||
class BatchToSpaceNDOpDynamicModel : public BatchToSpaceNDOpModel {
|
||||
public:
|
||||
BatchToSpaceNDOpDynamicModel(std::initializer_list<int> input_shape) {
|
||||
input_ = AddInput(TensorType_FLOAT32);
|
||||
block_shape_ = AddInput(TensorType_INT32);
|
||||
crops_ = AddInput(TensorType_INT32);
|
||||
output_ = AddOutput(TensorType_FLOAT32);
|
||||
|
||||
SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND,
|
||||
BuiltinOptions_BatchToSpaceNDOptions,
|
||||
CreateBatchToSpaceNDOptions(builder_).Union());
|
||||
BuildInterpreter({input_shape, {2}, {2, 2}});
|
||||
}
|
||||
};
|
||||
|
||||
TEST(BatchToSpaceNDOpTest, SimpleConstTest) {
|
||||
BatchToSpaceNDOpConstModel m({4, 2, 2, 1}, {2, 2}, {0, 0, 0, 0});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
@ -63,8 +103,19 @@ TEST(BatchToSpaceNDOpTest, SimpleTest) {
|
||||
4, 8, 11, 15, 12, 16}));
|
||||
}
|
||||
|
||||
TEST(BatchToSpaceNDOpTest, SimpleDynamicTest) {
|
||||
BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1});
|
||||
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
m.SetBlockShape({2, 2});
|
||||
m.SetCrops({0, 0, 0, 0});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7,
|
||||
4, 8, 11, 15, 12, 16}));
|
||||
}
|
||||
|
||||
TEST(BatchToSpaceNDOpTest, InvalidShapeTest) {
|
||||
EXPECT_DEATH(BatchToSpaceNDOpModel({3, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0}),
|
||||
EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 0}),
|
||||
"Cannot allocate tensors");
|
||||
}
|
||||
|
||||
|
@ -536,21 +536,6 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_BATCH_TO_SPACE_ND: {
|
||||
auto* params = MallocPOD<TfLiteBatchToSpaceNDParams>();
|
||||
if (auto* schema_params =
|
||||
op->builtin_options_as_BatchToSpaceNDOptions()) {
|
||||
const auto& block_shape = schema_params->block_shape();
|
||||
FlatBufferIntVectorToArray(sizeof(params->block_shape), block_shape,
|
||||
params->block_shape, error_reporter);
|
||||
const auto& before_crops = schema_params->before_crops();
|
||||
FlatBufferIntVectorToArray(sizeof(params->before_crops), before_crops,
|
||||
params->before_crops, error_reporter);
|
||||
const auto& after_crops = schema_params->after_crops();
|
||||
FlatBufferIntVectorToArray(sizeof(params->after_crops), after_crops,
|
||||
params->after_crops, error_reporter);
|
||||
params->num_spatial_dimensions = block_shape->Length();
|
||||
}
|
||||
builtin_data = reinterpret_cast<void*>(params);
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_TRANSPOSE: {
|
||||
|
@ -295,9 +295,6 @@ table SpaceToBatchNDOptions {
|
||||
}
|
||||
|
||||
table BatchToSpaceNDOptions {
|
||||
block_shape:[int];
|
||||
before_crops:[int];
|
||||
after_crops:[int];
|
||||
}
|
||||
|
||||
table SkipGramOptions {
|
||||
|
@ -2929,33 +2929,14 @@ flatbuffers::Offset<SpaceToBatchNDOptions> CreateSpaceToBatchNDOptions(
|
||||
|
||||
struct BatchToSpaceNDOptionsT : public flatbuffers::NativeTable {
|
||||
typedef BatchToSpaceNDOptions TableType;
|
||||
std::vector<int32_t> block_shape;
|
||||
std::vector<int32_t> before_crops;
|
||||
std::vector<int32_t> after_crops;
|
||||
BatchToSpaceNDOptionsT() {}
|
||||
};
|
||||
|
||||
struct BatchToSpaceNDOptions FLATBUFFERS_FINAL_CLASS
|
||||
: private flatbuffers::Table {
|
||||
typedef BatchToSpaceNDOptionsT NativeTableType;
|
||||
enum { VT_BLOCK_SHAPE = 4, VT_BEFORE_CROPS = 6, VT_AFTER_CROPS = 8 };
|
||||
const flatbuffers::Vector<int32_t> *block_shape() const {
|
||||
return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_BLOCK_SHAPE);
|
||||
}
|
||||
const flatbuffers::Vector<int32_t> *before_crops() const {
|
||||
return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_BEFORE_CROPS);
|
||||
}
|
||||
const flatbuffers::Vector<int32_t> *after_crops() const {
|
||||
return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_AFTER_CROPS);
|
||||
}
|
||||
bool Verify(flatbuffers::Verifier &verifier) const {
|
||||
return VerifyTableStart(verifier) &&
|
||||
VerifyOffset(verifier, VT_BLOCK_SHAPE) &&
|
||||
verifier.Verify(block_shape()) &&
|
||||
VerifyOffset(verifier, VT_BEFORE_CROPS) &&
|
||||
verifier.Verify(before_crops()) &&
|
||||
VerifyOffset(verifier, VT_AFTER_CROPS) &&
|
||||
verifier.Verify(after_crops()) && verifier.EndTable();
|
||||
return VerifyTableStart(verifier) && verifier.EndTable();
|
||||
}
|
||||
BatchToSpaceNDOptionsT *UnPack(
|
||||
const flatbuffers::resolver_function_t *_resolver = nullptr) const;
|
||||
@ -2970,18 +2951,6 @@ struct BatchToSpaceNDOptions FLATBUFFERS_FINAL_CLASS
|
||||
struct BatchToSpaceNDOptionsBuilder {
|
||||
flatbuffers::FlatBufferBuilder &fbb_;
|
||||
flatbuffers::uoffset_t start_;
|
||||
void add_block_shape(
|
||||
flatbuffers::Offset<flatbuffers::Vector<int32_t>> block_shape) {
|
||||
fbb_.AddOffset(BatchToSpaceNDOptions::VT_BLOCK_SHAPE, block_shape);
|
||||
}
|
||||
void add_before_crops(
|
||||
flatbuffers::Offset<flatbuffers::Vector<int32_t>> before_crops) {
|
||||
fbb_.AddOffset(BatchToSpaceNDOptions::VT_BEFORE_CROPS, before_crops);
|
||||
}
|
||||
void add_after_crops(
|
||||
flatbuffers::Offset<flatbuffers::Vector<int32_t>> after_crops) {
|
||||
fbb_.AddOffset(BatchToSpaceNDOptions::VT_AFTER_CROPS, after_crops);
|
||||
}
|
||||
explicit BatchToSpaceNDOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
|
||||
: fbb_(_fbb) {
|
||||
start_ = fbb_.StartTable();
|
||||
@ -2995,29 +2964,11 @@ struct BatchToSpaceNDOptionsBuilder {
|
||||
};
|
||||
|
||||
inline flatbuffers::Offset<BatchToSpaceNDOptions> CreateBatchToSpaceNDOptions(
|
||||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
flatbuffers::Offset<flatbuffers::Vector<int32_t>> block_shape = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<int32_t>> before_crops = 0,
|
||||
flatbuffers::Offset<flatbuffers::Vector<int32_t>> after_crops = 0) {
|
||||
flatbuffers::FlatBufferBuilder &_fbb) {
|
||||
BatchToSpaceNDOptionsBuilder builder_(_fbb);
|
||||
builder_.add_after_crops(after_crops);
|
||||
builder_.add_before_crops(before_crops);
|
||||
builder_.add_block_shape(block_shape);
|
||||
return builder_.Finish();
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<BatchToSpaceNDOptions>
|
||||
CreateBatchToSpaceNDOptionsDirect(
|
||||
flatbuffers::FlatBufferBuilder &_fbb,
|
||||
const std::vector<int32_t> *block_shape = nullptr,
|
||||
const std::vector<int32_t> *before_crops = nullptr,
|
||||
const std::vector<int32_t> *after_crops = nullptr) {
|
||||
return tflite::CreateBatchToSpaceNDOptions(
|
||||
_fbb, block_shape ? _fbb.CreateVector<int32_t>(*block_shape) : 0,
|
||||
before_crops ? _fbb.CreateVector<int32_t>(*before_crops) : 0,
|
||||
after_crops ? _fbb.CreateVector<int32_t>(*after_crops) : 0);
|
||||
}
|
||||
|
||||
flatbuffers::Offset<BatchToSpaceNDOptions> CreateBatchToSpaceNDOptions(
|
||||
flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o,
|
||||
const flatbuffers::rehasher_function_t *_rehasher = nullptr);
|
||||
@ -5774,33 +5725,6 @@ inline void BatchToSpaceNDOptions::UnPackTo(
|
||||
const flatbuffers::resolver_function_t *_resolver) const {
|
||||
(void)_o;
|
||||
(void)_resolver;
|
||||
{
|
||||
auto _e = block_shape();
|
||||
if (_e) {
|
||||
_o->block_shape.resize(_e->size());
|
||||
for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) {
|
||||
_o->block_shape[_i] = _e->Get(_i);
|
||||
}
|
||||
}
|
||||
};
|
||||
{
|
||||
auto _e = before_crops();
|
||||
if (_e) {
|
||||
_o->before_crops.resize(_e->size());
|
||||
for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) {
|
||||
_o->before_crops[_i] = _e->Get(_i);
|
||||
}
|
||||
}
|
||||
};
|
||||
{
|
||||
auto _e = after_crops();
|
||||
if (_e) {
|
||||
_o->after_crops.resize(_e->size());
|
||||
for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) {
|
||||
_o->after_crops[_i] = _e->Get(_i);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
inline flatbuffers::Offset<BatchToSpaceNDOptions> BatchToSpaceNDOptions::Pack(
|
||||
@ -5820,14 +5744,7 @@ inline flatbuffers::Offset<BatchToSpaceNDOptions> CreateBatchToSpaceNDOptions(
|
||||
const flatbuffers::rehasher_function_t *__rehasher;
|
||||
} _va = {&_fbb, _o, _rehasher};
|
||||
(void)_va;
|
||||
auto _block_shape =
|
||||
_o->block_shape.size() ? _fbb.CreateVector(_o->block_shape) : 0;
|
||||
auto _before_crops =
|
||||
_o->before_crops.size() ? _fbb.CreateVector(_o->before_crops) : 0;
|
||||
auto _after_crops =
|
||||
_o->after_crops.size() ? _fbb.CreateVector(_o->after_crops) : 0;
|
||||
return tflite::CreateBatchToSpaceNDOptions(_fbb, _block_shape, _before_crops,
|
||||
_after_crops);
|
||||
return tflite::CreateBatchToSpaceNDOptions(_fbb);
|
||||
}
|
||||
|
||||
inline SkipGramOptionsT *SkipGramOptions::UnPack(
|
||||
|
@ -94,7 +94,8 @@ KNOWN_BUGS = {
|
||||
r"softmax.*input_shape=\[1,3,4,3\]": "67749831",
|
||||
# SpaceToDepth only supports float32.
|
||||
r"space_to_depth.*(float16|int32|uint8|int64)": "68018134",
|
||||
# BatchToSpaceND doesn't support cropping.
|
||||
# BatchToSpaceND doesn't support cropping. This catches test cases with
|
||||
# const tensors as crops.
|
||||
r"batch_to_space_nd.*crops=\[\[1,1\],\[1,1\]\]": "70594634",
|
||||
# BatchToSpaceND only supports 4D tensors.
|
||||
r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733",
|
||||
@ -1361,6 +1362,8 @@ def make_batch_to_space_nd_tests(zip_path):
|
||||
"input_shape": [[12, 2, 2, 1]],
|
||||
"block_shape": [[1, 4], [2, 2], [3, 4]],
|
||||
"crops": [[[0, 0], [0, 0]], [[1, 1], [1, 1]]],
|
||||
"constant_block_shape": [True, False],
|
||||
"constant_crops": [True, False],
|
||||
},
|
||||
# Non-4D use case: 1 bath dimension, 3 spatial dimensions, 2 others.
|
||||
{
|
||||
@ -1368,23 +1371,47 @@ def make_batch_to_space_nd_tests(zip_path):
|
||||
"input_shape": [[8, 2, 2, 2, 1, 1]],
|
||||
"block_shape": [[2, 2, 2]],
|
||||
"crops": [[[0, 0], [0, 0], [0, 0]]],
|
||||
"constant_block_shape": [True, False],
|
||||
"constant_crops": [True, False],
|
||||
},
|
||||
]
|
||||
|
||||
def build_graph(parameters):
|
||||
"""Build a batch_to_space graph given `parameters`."""
|
||||
input_tensor = tf.placeholder(
|
||||
dtype=parameters["dtype"],
|
||||
name="input",
|
||||
shape=parameters["input_shape"])
|
||||
out = tf.batch_to_space_nd(input_tensor, parameters["block_shape"],
|
||||
parameters["crops"])
|
||||
return [input_tensor], [out]
|
||||
input_tensors = [input_tensor]
|
||||
|
||||
# Get block_shape either as a const or as a placeholder (tensor).
|
||||
if parameters["constant_block_shape"]:
|
||||
block_shape = parameters["block_shape"]
|
||||
else:
|
||||
shape = [len(parameters["block_shape"])]
|
||||
block_shape = tf.placeholder(dtype=tf.int32, name="shape", shape=shape)
|
||||
input_tensors.append(block_shape)
|
||||
|
||||
# Get crops either as a const or as a placeholder (tensor).
|
||||
if parameters["constant_crops"]:
|
||||
crops = parameters["crops"]
|
||||
else:
|
||||
shape = [len(parameters["crops"]), 2]
|
||||
crops = tf.placeholder(dtype=tf.int32, name="crops", shape=shape)
|
||||
input_tensors.append(crops)
|
||||
|
||||
out = tf.batch_to_space_nd(input_tensor, block_shape, crops)
|
||||
return input_tensors, [out]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_values = create_tensor_data(parameters["dtype"],
|
||||
parameters["input_shape"])
|
||||
return [input_values], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_values])))
|
||||
values = [
|
||||
create_tensor_data(parameters["dtype"], parameters["input_shape"])
|
||||
]
|
||||
if not parameters["constant_block_shape"]:
|
||||
values.append(np.array(parameters["block_shape"]))
|
||||
if not parameters["constant_crops"]:
|
||||
values.append(np.array(parameters["crops"]))
|
||||
return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
|
||||
|
||||
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
|
||||
|
||||
|
@ -67,7 +67,11 @@ std::map<string, string> kBrokenTests = {
|
||||
// L2Norm only supports tensors with 4D or fewer.
|
||||
{R"(^\/l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"},
|
||||
|
||||
// SpaceToBatch only supports 4D tensors.
|
||||
// BatchToSpaceND doesn't support cropping. This catches test cases with
|
||||
// non-const tensors as crops.
|
||||
{R"(^\/batch_to_space_nd.*crops=\[\[1,1\],\[1,1\]\])", "70594634"},
|
||||
|
||||
// SpaceToBatchND only supports 4D tensors.
|
||||
{R"(^\/space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"},
|
||||
|
||||
// L2Norm only works for dim=-1.
|
||||
|
@ -211,24 +211,11 @@ class BatchToSpaceND
|
||||
flatbuffers::Offset<TfLiteOptions> WriteOptions(
|
||||
const TocoOperator& op,
|
||||
flatbuffers::FlatBufferBuilder* builder) const override {
|
||||
auto block_shape = builder->CreateVector(op.block_shape);
|
||||
auto before_crops = builder->CreateVector(op.before_crops);
|
||||
auto after_crops = builder->CreateVector(op.after_crops);
|
||||
return ::tflite::CreateBatchToSpaceNDOptions(*builder, block_shape,
|
||||
before_crops, after_crops);
|
||||
return ::tflite::CreateBatchToSpaceNDOptions(*builder);
|
||||
}
|
||||
|
||||
void ReadOptions(const TfLiteOptions& options,
|
||||
TocoOperator* op) const override {
|
||||
op->block_shape.insert(op->block_shape.end(),
|
||||
options.block_shape()->begin(),
|
||||
options.block_shape()->end());
|
||||
op->before_crops.insert(op->before_crops.end(),
|
||||
options.before_crops()->begin(),
|
||||
options.before_crops()->end());
|
||||
op->after_crops.insert(op->after_crops.end(),
|
||||
options.after_crops()->begin(),
|
||||
options.after_crops()->end());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -132,19 +132,6 @@ TEST_F(OperatorTest, BuiltinSpaceToBatchND) {
|
||||
EXPECT_EQ(op.after_paddings, output_toco_op->after_paddings);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, BuiltinBatchToSpaceND) {
|
||||
BatchToSpaceNDOperator op;
|
||||
op.block_shape = {2, 2};
|
||||
op.before_crops = {1, 2};
|
||||
op.after_crops = {3, 4};
|
||||
|
||||
auto output_toco_op = SerializeAndDeserialize(
|
||||
GetOperator("BATCH_TO_SPACE_ND", OperatorType::kBatchToSpaceND), op);
|
||||
EXPECT_EQ(op.block_shape, output_toco_op->block_shape);
|
||||
EXPECT_EQ(op.before_crops, output_toco_op->before_crops);
|
||||
EXPECT_EQ(op.after_crops, output_toco_op->after_crops);
|
||||
}
|
||||
|
||||
TEST_F(OperatorTest, BuiltinMean) {
|
||||
MeanOperator op;
|
||||
op.axis = {1, 2};
|
||||
|
Loading…
x
Reference in New Issue
Block a user