diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 8966e5c2b60..6dd9cb39d25 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -127,14 +127,6 @@ typedef struct { } TfLiteSpaceToBatchNDParams; typedef struct { - // Number of spatial dimensions. - // For now only NHWC is supported, and the value should always be 2. - int num_spatial_dimensions; - // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. - // For now we will fix the maximum possible number of dimensions. - int block_shape[2]; - int before_crops[2]; - int after_crops[2]; } TfLiteBatchToSpaceNDParams; typedef struct { diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc index 0eed680fdcc..d84a77039bf 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc @@ -35,12 +35,14 @@ enum KernelType { struct BatchToSpaceNDContext { BatchToSpaceNDContext(TfLiteContext* context, TfLiteNode* node) { - params = reinterpret_cast(node->builtin_data); input = GetInput(context, node, 0); + block_shape = GetInput(context, node, 1); + crops = GetInput(context, node, 2); output = GetOutput(context, node, 0); } - TfLiteBatchToSpaceNDParams* params; TfLiteTensor* input; + TfLiteTensor* block_shape; + TfLiteTensor* crops; TfLiteTensor* output; }; @@ -48,24 +50,22 @@ struct BatchToSpaceNDContext { // The 4D array need to have exactly 2 spatial dimensions. // TODO(ycling): Support arbitrary dimension in BatchToSpaceND. const int kInputDimensionNum = 4; -const int kOutputDimensionNum = 4; +const int kBlockSizeDimensionNum = 1; const int kSpatialDimensionNum = 2; -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - // The 2nd tensor (block_shape) and the 3rd tensor (crops) are ignored now. - TF_LITE_ENSURE(context, NumInputs(node) >= 1 && NumInputs(node) <= 3); - TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + BatchToSpaceNDContext* op_context) { + TfLiteIntArray* input_size = op_context->input->dims; + const int* block_shape = GetTensorData(op_context->block_shape); - BatchToSpaceNDContext op_context(context, node); - TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), - kInputDimensionNum); - TF_LITE_ENSURE_EQ(context, op_context.params->num_spatial_dimensions, + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), + kBlockSizeDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0], + kSpatialDimensionNum); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops), kSpatialDimensionNum); - TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); - - const TfLiteIntArray* input_size = op_context.input->dims; - const int* block_shape = op_context.params->block_shape; + // TODO(ycling): Add crops as part of calculation. // Number of batch must be multiple of (block_shape[0] * block_shape[1]). TF_LITE_ENSURE_EQ(context, input_size->data[0] % (block_shape[0] * block_shape[1]), 0); @@ -76,27 +76,48 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int output_width = input_size->data[2] * block_shape[1]; const int output_channel_size = input_size->data[3]; - TfLiteIntArray* output_size = TfLiteIntArrayCreate(kOutputDimensionNum); + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); output_size->data[0] = output_batch_size; output_size->data[1] = output_height; output_size->data[2] = output_width; output_size->data[3] = output_channel_size; - return context->ResizeTensor(context, op_context.output, output_size); + return context->ResizeTensor(context, op_context->output, output_size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + BatchToSpaceNDContext op_context(context, node); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), + kInputDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + if (!IsConstantTensor(op_context.block_shape) || + !IsConstantTensor(op_context.crops)) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, &op_context); } template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { BatchToSpaceNDContext op_context(context, node); - int block_shape_dims_array[1] = {kSpatialDimensionNum}; - Dims<4> block_shape_dims = GetTensorDims(block_shape_dims_array, 1); + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + TfLiteTensorRealloc(op_context.output->bytes, op_context.output); + } -#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \ - type::BatchToSpaceND(GetTensorData(op_context.input), \ - GetTensorDims(op_context.input), \ - op_context.params->block_shape, block_shape_dims, \ - GetTensorData(op_context.output), \ +#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \ + type::BatchToSpaceND(GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), \ + GetTensorData(op_context.block_shape), \ + GetTensorDims(op_context.block_shape), \ + GetTensorData(op_context.output), \ GetTensorDims(op_context.output)) switch (op_context.input->type) { // Already know in/out types are same. case kTfLiteFloat32: diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc index 3ec4efbebce..c9152bf9672 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc @@ -26,36 +26,76 @@ using ::testing::ElementsAreArray; class BatchToSpaceNDOpModel : public SingleOpModel { public: - BatchToSpaceNDOpModel(std::initializer_list input_shape, - std::initializer_list block_shape, - std::initializer_list before_crops, - std::initializer_list after_crops) { - input_ = AddInput(TensorType_FLOAT32); - output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, - BuiltinOptions_BatchToSpaceNDOptions, - CreateBatchToSpaceNDOptions( - builder_, builder_.CreateVector(block_shape), - builder_.CreateVector(before_crops), - builder_.CreateVector(after_crops)) - .Union()); - BuildInterpreter({input_shape}); - } - void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } + void SetBlockShape(std::initializer_list data) { + PopulateTensor(block_shape_, data); + } + + void SetCrops(std::initializer_list data) { + PopulateTensor(crops_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } - private: + protected: int input_; + int block_shape_; + int crops_; int output_; }; -TEST(BatchToSpaceNDOpTest, SimpleTest) { - BatchToSpaceNDOpModel m({4, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0}); +// Tests case where block_shape and crops are const tensors. +// +// Example usage is as follows: +// BatchToSpaceNDOpConstModel m(input_shape, block_shape, crops); +// m.SetInput(input_data); +// m.Invoke(); +class BatchToSpaceNDOpConstModel : public BatchToSpaceNDOpModel { + public: + BatchToSpaceNDOpConstModel(std::initializer_list input_shape, + std::initializer_list block_shape, + std::initializer_list crops) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2}); + crops_ = AddConstInput(TensorType_INT32, crops, {2, 2}); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, + BuiltinOptions_BatchToSpaceNDOptions, + CreateBatchToSpaceNDOptions(builder_).Union()); + BuildInterpreter({input_shape}); + } +}; + +// Tests case where block_shape and crops are non-const tensors. +// +// Example usage is as follows: +// BatchToSpaceNDOpDynamicModel m(input_shape); +// m.SetInput(input_data); +// m.SetBlockShape(block_shape); +// m.SetPaddings(crops); +// m.Invoke(); +class BatchToSpaceNDOpDynamicModel : public BatchToSpaceNDOpModel { + public: + BatchToSpaceNDOpDynamicModel(std::initializer_list input_shape) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddInput(TensorType_INT32); + crops_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, + BuiltinOptions_BatchToSpaceNDOptions, + CreateBatchToSpaceNDOptions(builder_).Union()); + BuildInterpreter({input_shape, {2}, {2, 2}}); + } +}; + +TEST(BatchToSpaceNDOpTest, SimpleConstTest) { + BatchToSpaceNDOpConstModel m({4, 2, 2, 1}, {2, 2}, {0, 0, 0, 0}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.Invoke(); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); @@ -63,8 +103,19 @@ TEST(BatchToSpaceNDOpTest, SimpleTest) { 4, 8, 11, 15, 12, 16})); } +TEST(BatchToSpaceNDOpTest, SimpleDynamicTest) { + BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBlockShape({2, 2}); + m.SetCrops({0, 0, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7, + 4, 8, 11, 15, 12, 16})); +} + TEST(BatchToSpaceNDOpTest, InvalidShapeTest) { - EXPECT_DEATH(BatchToSpaceNDOpModel({3, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0}), + EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 0}), "Cannot allocate tensors"); } diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index ec4d6e3487a..64b8f550971 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -536,21 +536,6 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, break; } case BuiltinOperator_BATCH_TO_SPACE_ND: { - auto* params = MallocPOD(); - if (auto* schema_params = - op->builtin_options_as_BatchToSpaceNDOptions()) { - const auto& block_shape = schema_params->block_shape(); - FlatBufferIntVectorToArray(sizeof(params->block_shape), block_shape, - params->block_shape, error_reporter); - const auto& before_crops = schema_params->before_crops(); - FlatBufferIntVectorToArray(sizeof(params->before_crops), before_crops, - params->before_crops, error_reporter); - const auto& after_crops = schema_params->after_crops(); - FlatBufferIntVectorToArray(sizeof(params->after_crops), after_crops, - params->after_crops, error_reporter); - params->num_spatial_dimensions = block_shape->Length(); - } - builtin_data = reinterpret_cast(params); break; } case BuiltinOperator_TRANSPOSE: { diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 50709344eac..f6217567f82 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -295,9 +295,6 @@ table SpaceToBatchNDOptions { } table BatchToSpaceNDOptions { - block_shape:[int]; - before_crops:[int]; - after_crops:[int]; } table SkipGramOptions { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index f1ee925df23..ad758529a37 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -2929,33 +2929,14 @@ flatbuffers::Offset CreateSpaceToBatchNDOptions( struct BatchToSpaceNDOptionsT : public flatbuffers::NativeTable { typedef BatchToSpaceNDOptions TableType; - std::vector block_shape; - std::vector before_crops; - std::vector after_crops; BatchToSpaceNDOptionsT() {} }; struct BatchToSpaceNDOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef BatchToSpaceNDOptionsT NativeTableType; - enum { VT_BLOCK_SHAPE = 4, VT_BEFORE_CROPS = 6, VT_AFTER_CROPS = 8 }; - const flatbuffers::Vector *block_shape() const { - return GetPointer *>(VT_BLOCK_SHAPE); - } - const flatbuffers::Vector *before_crops() const { - return GetPointer *>(VT_BEFORE_CROPS); - } - const flatbuffers::Vector *after_crops() const { - return GetPointer *>(VT_AFTER_CROPS); - } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && - VerifyOffset(verifier, VT_BLOCK_SHAPE) && - verifier.Verify(block_shape()) && - VerifyOffset(verifier, VT_BEFORE_CROPS) && - verifier.Verify(before_crops()) && - VerifyOffset(verifier, VT_AFTER_CROPS) && - verifier.Verify(after_crops()) && verifier.EndTable(); + return VerifyTableStart(verifier) && verifier.EndTable(); } BatchToSpaceNDOptionsT *UnPack( const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -2970,18 +2951,6 @@ struct BatchToSpaceNDOptions FLATBUFFERS_FINAL_CLASS struct BatchToSpaceNDOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_block_shape( - flatbuffers::Offset> block_shape) { - fbb_.AddOffset(BatchToSpaceNDOptions::VT_BLOCK_SHAPE, block_shape); - } - void add_before_crops( - flatbuffers::Offset> before_crops) { - fbb_.AddOffset(BatchToSpaceNDOptions::VT_BEFORE_CROPS, before_crops); - } - void add_after_crops( - flatbuffers::Offset> after_crops) { - fbb_.AddOffset(BatchToSpaceNDOptions::VT_AFTER_CROPS, after_crops); - } explicit BatchToSpaceNDOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2995,29 +2964,11 @@ struct BatchToSpaceNDOptionsBuilder { }; inline flatbuffers::Offset CreateBatchToSpaceNDOptions( - flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset> block_shape = 0, - flatbuffers::Offset> before_crops = 0, - flatbuffers::Offset> after_crops = 0) { + flatbuffers::FlatBufferBuilder &_fbb) { BatchToSpaceNDOptionsBuilder builder_(_fbb); - builder_.add_after_crops(after_crops); - builder_.add_before_crops(before_crops); - builder_.add_block_shape(block_shape); return builder_.Finish(); } -inline flatbuffers::Offset -CreateBatchToSpaceNDOptionsDirect( - flatbuffers::FlatBufferBuilder &_fbb, - const std::vector *block_shape = nullptr, - const std::vector *before_crops = nullptr, - const std::vector *after_crops = nullptr) { - return tflite::CreateBatchToSpaceNDOptions( - _fbb, block_shape ? _fbb.CreateVector(*block_shape) : 0, - before_crops ? _fbb.CreateVector(*before_crops) : 0, - after_crops ? _fbb.CreateVector(*after_crops) : 0); -} - flatbuffers::Offset CreateBatchToSpaceNDOptions( flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -5774,33 +5725,6 @@ inline void BatchToSpaceNDOptions::UnPackTo( const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = block_shape(); - if (_e) { - _o->block_shape.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->block_shape[_i] = _e->Get(_i); - } - } - }; - { - auto _e = before_crops(); - if (_e) { - _o->before_crops.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->before_crops[_i] = _e->Get(_i); - } - } - }; - { - auto _e = after_crops(); - if (_e) { - _o->after_crops.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->after_crops[_i] = _e->Get(_i); - } - } - }; } inline flatbuffers::Offset BatchToSpaceNDOptions::Pack( @@ -5820,14 +5744,7 @@ inline flatbuffers::Offset CreateBatchToSpaceNDOptions( const flatbuffers::rehasher_function_t *__rehasher; } _va = {&_fbb, _o, _rehasher}; (void)_va; - auto _block_shape = - _o->block_shape.size() ? _fbb.CreateVector(_o->block_shape) : 0; - auto _before_crops = - _o->before_crops.size() ? _fbb.CreateVector(_o->before_crops) : 0; - auto _after_crops = - _o->after_crops.size() ? _fbb.CreateVector(_o->after_crops) : 0; - return tflite::CreateBatchToSpaceNDOptions(_fbb, _block_shape, _before_crops, - _after_crops); + return tflite::CreateBatchToSpaceNDOptions(_fbb); } inline SkipGramOptionsT *SkipGramOptions::UnPack( diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index bc9b23aeb42..4ae6ccb765d 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -94,7 +94,8 @@ KNOWN_BUGS = { r"softmax.*input_shape=\[1,3,4,3\]": "67749831", # SpaceToDepth only supports float32. r"space_to_depth.*(float16|int32|uint8|int64)": "68018134", - # BatchToSpaceND doesn't support cropping. + # BatchToSpaceND doesn't support cropping. This catches test cases with + # const tensors as crops. r"batch_to_space_nd.*crops=\[\[1,1\],\[1,1\]\]": "70594634", # BatchToSpaceND only supports 4D tensors. r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733", @@ -1361,6 +1362,8 @@ def make_batch_to_space_nd_tests(zip_path): "input_shape": [[12, 2, 2, 1]], "block_shape": [[1, 4], [2, 2], [3, 4]], "crops": [[[0, 0], [0, 0]], [[1, 1], [1, 1]]], + "constant_block_shape": [True, False], + "constant_crops": [True, False], }, # Non-4D use case: 1 bath dimension, 3 spatial dimensions, 2 others. { @@ -1368,23 +1371,47 @@ def make_batch_to_space_nd_tests(zip_path): "input_shape": [[8, 2, 2, 2, 1, 1]], "block_shape": [[2, 2, 2]], "crops": [[[0, 0], [0, 0], [0, 0]]], + "constant_block_shape": [True, False], + "constant_crops": [True, False], }, ] def build_graph(parameters): + """Build a batch_to_space graph given `parameters`.""" input_tensor = tf.placeholder( dtype=parameters["dtype"], name="input", shape=parameters["input_shape"]) - out = tf.batch_to_space_nd(input_tensor, parameters["block_shape"], - parameters["crops"]) - return [input_tensor], [out] + input_tensors = [input_tensor] + + # Get block_shape either as a const or as a placeholder (tensor). + if parameters["constant_block_shape"]: + block_shape = parameters["block_shape"] + else: + shape = [len(parameters["block_shape"])] + block_shape = tf.placeholder(dtype=tf.int32, name="shape", shape=shape) + input_tensors.append(block_shape) + + # Get crops either as a const or as a placeholder (tensor). + if parameters["constant_crops"]: + crops = parameters["crops"] + else: + shape = [len(parameters["crops"]), 2] + crops = tf.placeholder(dtype=tf.int32, name="crops", shape=shape) + input_tensors.append(crops) + + out = tf.batch_to_space_nd(input_tensor, block_shape, crops) + return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): - input_values = create_tensor_data(parameters["dtype"], - parameters["input_shape"]) - return [input_values], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_values]))) + values = [ + create_tensor_data(parameters["dtype"], parameters["input_shape"]) + ] + if not parameters["constant_block_shape"]: + values.append(np.array(parameters["block_shape"])) + if not parameters["constant_crops"]: + values.append(np.array(parameters["crops"])) + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 2bbfe77a123..e6b782472a8 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -67,7 +67,11 @@ std::map kBrokenTests = { // L2Norm only supports tensors with 4D or fewer. {R"(^\/l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, - // SpaceToBatch only supports 4D tensors. + // BatchToSpaceND doesn't support cropping. This catches test cases with + // non-const tensors as crops. + {R"(^\/batch_to_space_nd.*crops=\[\[1,1\],\[1,1\]\])", "70594634"}, + + // SpaceToBatchND only supports 4D tensors. {R"(^\/space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"}, // L2Norm only works for dim=-1. diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 2d6bccce2b8..853a9f46c25 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -211,24 +211,11 @@ class BatchToSpaceND flatbuffers::Offset WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - auto block_shape = builder->CreateVector(op.block_shape); - auto before_crops = builder->CreateVector(op.before_crops); - auto after_crops = builder->CreateVector(op.after_crops); - return ::tflite::CreateBatchToSpaceNDOptions(*builder, block_shape, - before_crops, after_crops); + return ::tflite::CreateBatchToSpaceNDOptions(*builder); } void ReadOptions(const TfLiteOptions& options, TocoOperator* op) const override { - op->block_shape.insert(op->block_shape.end(), - options.block_shape()->begin(), - options.block_shape()->end()); - op->before_crops.insert(op->before_crops.end(), - options.before_crops()->begin(), - options.before_crops()->end()); - op->after_crops.insert(op->after_crops.end(), - options.after_crops()->begin(), - options.after_crops()->end()); } }; diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 78af3a767d3..4df90710950 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -132,19 +132,6 @@ TEST_F(OperatorTest, BuiltinSpaceToBatchND) { EXPECT_EQ(op.after_paddings, output_toco_op->after_paddings); } -TEST_F(OperatorTest, BuiltinBatchToSpaceND) { - BatchToSpaceNDOperator op; - op.block_shape = {2, 2}; - op.before_crops = {1, 2}; - op.after_crops = {3, 4}; - - auto output_toco_op = SerializeAndDeserialize( - GetOperator("BATCH_TO_SPACE_ND", OperatorType::kBatchToSpaceND), op); - EXPECT_EQ(op.block_shape, output_toco_op->block_shape); - EXPECT_EQ(op.before_crops, output_toco_op->before_crops); - EXPECT_EQ(op.after_crops, output_toco_op->after_crops); -} - TEST_F(OperatorTest, BuiltinMean) { MeanOperator op; op.axis = {1, 2};