Lite: Split Op Axis Validation Added

This commit is contained in:
ANSHUMAN TRIPATHY 2019-03-13 17:09:00 +05:30
parent 98134fb601
commit e7873f5718
2 changed files with 26 additions and 27 deletions

View File

@ -2228,22 +2228,22 @@ void Split(const SplitParams& params, const RuntimeShape& input_shape,
const Scalar* input_data, const RuntimeShape* const* output_shapes, const Scalar* input_data, const RuntimeShape* const* output_shapes,
Scalar* const* output_data) { Scalar* const* output_data) {
gemmlowp::ScopedProfilingLabel label("Split"); gemmlowp::ScopedProfilingLabel label("Split");
const int concat_dimensions = input_shape.DimensionsCount(); const int split_dimensions = input_shape.DimensionsCount();
int axis = params.axis < 0 ? params.axis + concat_dimensions : params.axis; int axis = params.axis < 0 ? params.axis + split_dimensions : params.axis;
int outputs_count = params.num_split; int outputs_count = params.num_split;
TFLITE_DCHECK_LT(axis, concat_dimensions); TFLITE_DCHECK_LT(axis, split_dimensions);
int64_t concat_size = 0; int64_t split_size = 0;
for (int i = 0; i < outputs_count; i++) { for (int i = 0; i < outputs_count; i++) {
TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), concat_dimensions); TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), split_dimensions);
for (int j = 0; j < concat_dimensions; j++) { for (int j = 0; j < split_dimensions; j++) {
if (j != axis) { if (j != axis) {
MatchingDim(*output_shapes[i], j, input_shape, j); MatchingDim(*output_shapes[i], j, input_shape, j);
} }
} }
concat_size += output_shapes[i]->Dims(axis); split_size += output_shapes[i]->Dims(axis);
} }
TFLITE_DCHECK_EQ(concat_size, input_shape.Dims(axis)); TFLITE_DCHECK_EQ(split_size, input_shape.Dims(axis));
int64_t outer_size = 1; int64_t outer_size = 1;
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
outer_size *= input_shape.Dims(i); outer_size *= input_shape.Dims(i);
@ -2251,7 +2251,7 @@ void Split(const SplitParams& params, const RuntimeShape& input_shape,
// For all output arrays, // For all output arrays,
// FlatSize() = outer_size * Dims(axis) * base_inner_size; // FlatSize() = outer_size * Dims(axis) * base_inner_size;
int64_t base_inner_size = 1; int64_t base_inner_size = 1;
for (int i = axis + 1; i < concat_dimensions; ++i) { for (int i = axis + 1; i < split_dimensions; ++i) {
base_inner_size *= input_shape.Dims(i); base_inner_size *= input_shape.Dims(i);
} }

View File

@ -53,6 +53,9 @@ TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
axis_value += NumDimensions(input); axis_value += NumDimensions(input);
} }
TF_LITE_ENSURE(context, axis_value >= 0);
TF_LITE_ENSURE(context, axis_value < NumDimensions(input));
const int input_size = SizeOfDimension(input, axis_value); const int input_size = SizeOfDimension(input, axis_value);
TF_LITE_ENSURE_MSG(context, input_size % num_splits == 0, TF_LITE_ENSURE_MSG(context, input_size % num_splits == 0,
"Not an even split"); "Not an even split");
@ -111,24 +114,22 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
axis_value += NumDimensions(op_context.input); axis_value += NumDimensions(op_context.input);
} }
TF_LITE_ENSURE(context, axis_value >= 0);
TF_LITE_ENSURE(context, axis_value < NumDimensions(op_context.input));
// TODO(ahentz): Our usage of VectorOfTensors could be optimized by // TODO(ahentz): Our usage of VectorOfTensors could be optimized by
// calculating it in Prepare, unless we defer shape calculation. // calculating it in Prepare, unless we defer shape calculation.
// TODO(ahentz): We can improve the optimized_ops version to handle other // TODO(ahentz): We can improve the optimized_ops version to handle other
// cases too. // cases too.
#define TF_LITE_SPLIT(scalar) \ #define TF_LITE_SPLIT(scalar) \
VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \ VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \
tflite::SplitParams op_params; \ tflite::SplitParams op_params; \
op_params.num_split = NumOutputs(node); \ op_params.num_split = NumOutputs(node); \
op_params.axis = axis_value; \ op_params.axis = axis_value; \
if (axis_value == 0) { \ reference_ops::Split(op_params, GetTensorShape(op_context.input), \
optimized_ops::Split(op_params, GetTensorShape(op_context.input), \ GetTensorData<scalar>(op_context.input), \
GetTensorData<scalar>(op_context.input), \ all_outputs.shapes(), all_outputs.data());
all_outputs.shapes(), all_outputs.data()); \
} else { \
reference_ops::Split(op_params, GetTensorShape(op_context.input), \
GetTensorData<scalar>(op_context.input), \
all_outputs.shapes(), all_outputs.data()); \
}
switch (op_context.input->type) { switch (op_context.input->type) {
case kTfLiteFloat32: { case kTfLiteFloat32: {
TF_LITE_SPLIT(float); TF_LITE_SPLIT(float);
@ -151,10 +152,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break; break;
} }
default: default:
context->ReportError(context, context->ReportError(context, "Type %s currently not supported.",
"Only float32, uint8, int8, int16 and int32 are " TfLiteTypeGetName(op_context.input->type));
"currently supported, got %d.",
op_context.input->type);
return kTfLiteError; return kTfLiteError;
} }
#undef TF_LITE_SPLIT #undef TF_LITE_SPLIT