Lite: Split Op Axis Validation Added
This commit is contained in:
parent
98134fb601
commit
e7873f5718
@ -2228,22 +2228,22 @@ void Split(const SplitParams& params, const RuntimeShape& input_shape,
|
||||
const Scalar* input_data, const RuntimeShape* const* output_shapes,
|
||||
Scalar* const* output_data) {
|
||||
gemmlowp::ScopedProfilingLabel label("Split");
|
||||
const int concat_dimensions = input_shape.DimensionsCount();
|
||||
int axis = params.axis < 0 ? params.axis + concat_dimensions : params.axis;
|
||||
const int split_dimensions = input_shape.DimensionsCount();
|
||||
int axis = params.axis < 0 ? params.axis + split_dimensions : params.axis;
|
||||
int outputs_count = params.num_split;
|
||||
TFLITE_DCHECK_LT(axis, concat_dimensions);
|
||||
TFLITE_DCHECK_LT(axis, split_dimensions);
|
||||
|
||||
int64_t concat_size = 0;
|
||||
int64_t split_size = 0;
|
||||
for (int i = 0; i < outputs_count; i++) {
|
||||
TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), concat_dimensions);
|
||||
for (int j = 0; j < concat_dimensions; j++) {
|
||||
TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), split_dimensions);
|
||||
for (int j = 0; j < split_dimensions; j++) {
|
||||
if (j != axis) {
|
||||
MatchingDim(*output_shapes[i], j, input_shape, j);
|
||||
}
|
||||
}
|
||||
concat_size += output_shapes[i]->Dims(axis);
|
||||
split_size += output_shapes[i]->Dims(axis);
|
||||
}
|
||||
TFLITE_DCHECK_EQ(concat_size, input_shape.Dims(axis));
|
||||
TFLITE_DCHECK_EQ(split_size, input_shape.Dims(axis));
|
||||
int64_t outer_size = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
outer_size *= input_shape.Dims(i);
|
||||
@ -2251,7 +2251,7 @@ void Split(const SplitParams& params, const RuntimeShape& input_shape,
|
||||
// For all output arrays,
|
||||
// FlatSize() = outer_size * Dims(axis) * base_inner_size;
|
||||
int64_t base_inner_size = 1;
|
||||
for (int i = axis + 1; i < concat_dimensions; ++i) {
|
||||
for (int i = axis + 1; i < split_dimensions; ++i) {
|
||||
base_inner_size *= input_shape.Dims(i);
|
||||
}
|
||||
|
||||
|
@ -53,6 +53,9 @@ TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
|
||||
axis_value += NumDimensions(input);
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE(context, axis_value >= 0);
|
||||
TF_LITE_ENSURE(context, axis_value < NumDimensions(input));
|
||||
|
||||
const int input_size = SizeOfDimension(input, axis_value);
|
||||
TF_LITE_ENSURE_MSG(context, input_size % num_splits == 0,
|
||||
"Not an even split");
|
||||
@ -111,24 +114,22 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
axis_value += NumDimensions(op_context.input);
|
||||
}
|
||||
|
||||
TF_LITE_ENSURE(context, axis_value >= 0);
|
||||
TF_LITE_ENSURE(context, axis_value < NumDimensions(op_context.input));
|
||||
|
||||
// TODO(ahentz): Our usage of VectorOfTensors could be optimized by
|
||||
// calculating it in Prepare, unless we defer shape calculation.
|
||||
// TODO(ahentz): We can improve the optimized_ops version to handle other
|
||||
// cases too.
|
||||
#define TF_LITE_SPLIT(scalar) \
|
||||
VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \
|
||||
tflite::SplitParams op_params; \
|
||||
op_params.num_split = NumOutputs(node); \
|
||||
op_params.axis = axis_value; \
|
||||
if (axis_value == 0) { \
|
||||
optimized_ops::Split(op_params, GetTensorShape(op_context.input), \
|
||||
GetTensorData<scalar>(op_context.input), \
|
||||
all_outputs.shapes(), all_outputs.data()); \
|
||||
} else { \
|
||||
reference_ops::Split(op_params, GetTensorShape(op_context.input), \
|
||||
GetTensorData<scalar>(op_context.input), \
|
||||
all_outputs.shapes(), all_outputs.data()); \
|
||||
}
|
||||
#define TF_LITE_SPLIT(scalar) \
|
||||
VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \
|
||||
tflite::SplitParams op_params; \
|
||||
op_params.num_split = NumOutputs(node); \
|
||||
op_params.axis = axis_value; \
|
||||
reference_ops::Split(op_params, GetTensorShape(op_context.input), \
|
||||
GetTensorData<scalar>(op_context.input), \
|
||||
all_outputs.shapes(), all_outputs.data());
|
||||
|
||||
switch (op_context.input->type) {
|
||||
case kTfLiteFloat32: {
|
||||
TF_LITE_SPLIT(float);
|
||||
@ -151,10 +152,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
break;
|
||||
}
|
||||
default:
|
||||
context->ReportError(context,
|
||||
"Only float32, uint8, int8, int16 and int32 are "
|
||||
"currently supported, got %d.",
|
||||
op_context.input->type);
|
||||
context->ReportError(context, "Type %s currently not supported.",
|
||||
TfLiteTypeGetName(op_context.input->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
#undef TF_LITE_SPLIT
|
||||
|
Loading…
x
Reference in New Issue
Block a user