Lite: Split Op Axis Validation Added

This commit is contained in:
ANSHUMAN TRIPATHY 2019-03-13 17:09:00 +05:30
parent 98134fb601
commit e7873f5718
2 changed files with 26 additions and 27 deletions

View File

@ -2228,22 +2228,22 @@ void Split(const SplitParams& params, const RuntimeShape& input_shape,
const Scalar* input_data, const RuntimeShape* const* output_shapes,
Scalar* const* output_data) {
gemmlowp::ScopedProfilingLabel label("Split");
const int concat_dimensions = input_shape.DimensionsCount();
int axis = params.axis < 0 ? params.axis + concat_dimensions : params.axis;
const int split_dimensions = input_shape.DimensionsCount();
int axis = params.axis < 0 ? params.axis + split_dimensions : params.axis;
int outputs_count = params.num_split;
TFLITE_DCHECK_LT(axis, concat_dimensions);
TFLITE_DCHECK_LT(axis, split_dimensions);
int64_t concat_size = 0;
int64_t split_size = 0;
for (int i = 0; i < outputs_count; i++) {
TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), concat_dimensions);
for (int j = 0; j < concat_dimensions; j++) {
TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), split_dimensions);
for (int j = 0; j < split_dimensions; j++) {
if (j != axis) {
MatchingDim(*output_shapes[i], j, input_shape, j);
}
}
concat_size += output_shapes[i]->Dims(axis);
split_size += output_shapes[i]->Dims(axis);
}
TFLITE_DCHECK_EQ(concat_size, input_shape.Dims(axis));
TFLITE_DCHECK_EQ(split_size, input_shape.Dims(axis));
int64_t outer_size = 1;
for (int i = 0; i < axis; ++i) {
outer_size *= input_shape.Dims(i);
@ -2251,7 +2251,7 @@ void Split(const SplitParams& params, const RuntimeShape& input_shape,
// For all output arrays,
// FlatSize() = outer_size * Dims(axis) * base_inner_size;
int64_t base_inner_size = 1;
for (int i = axis + 1; i < concat_dimensions; ++i) {
for (int i = axis + 1; i < split_dimensions; ++i) {
base_inner_size *= input_shape.Dims(i);
}

View File

@ -53,6 +53,9 @@ TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
axis_value += NumDimensions(input);
}
TF_LITE_ENSURE(context, axis_value >= 0);
TF_LITE_ENSURE(context, axis_value < NumDimensions(input));
const int input_size = SizeOfDimension(input, axis_value);
TF_LITE_ENSURE_MSG(context, input_size % num_splits == 0,
"Not an even split");
@ -111,24 +114,22 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
axis_value += NumDimensions(op_context.input);
}
TF_LITE_ENSURE(context, axis_value >= 0);
TF_LITE_ENSURE(context, axis_value < NumDimensions(op_context.input));
// TODO(ahentz): Our usage of VectorOfTensors could be optimized by
// calculating it in Prepare, unless we defer shape calculation.
// TODO(ahentz): We can improve the optimized_ops version to handle other
// cases too.
#define TF_LITE_SPLIT(scalar) \
VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \
tflite::SplitParams op_params; \
op_params.num_split = NumOutputs(node); \
op_params.axis = axis_value; \
if (axis_value == 0) { \
optimized_ops::Split(op_params, GetTensorShape(op_context.input), \
GetTensorData<scalar>(op_context.input), \
all_outputs.shapes(), all_outputs.data()); \
} else { \
reference_ops::Split(op_params, GetTensorShape(op_context.input), \
GetTensorData<scalar>(op_context.input), \
all_outputs.shapes(), all_outputs.data()); \
}
#define TF_LITE_SPLIT(scalar) \
VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \
tflite::SplitParams op_params; \
op_params.num_split = NumOutputs(node); \
op_params.axis = axis_value; \
reference_ops::Split(op_params, GetTensorShape(op_context.input), \
GetTensorData<scalar>(op_context.input), \
all_outputs.shapes(), all_outputs.data());
switch (op_context.input->type) {
case kTfLiteFloat32: {
TF_LITE_SPLIT(float);
@ -151,10 +152,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
}
default:
context->ReportError(context,
"Only float32, uint8, int8, int16 and int32 are "
"currently supported, got %d.",
op_context.input->type);
context->ReportError(context, "Type %s currently not supported.",
TfLiteTypeGetName(op_context.input->type));
return kTfLiteError;
}
#undef TF_LITE_SPLIT