Support scalar condition in Select op
PiperOrigin-RevId: 319477642 Change-Id: Ia9866838388d72cb1637cf0f29ae92ec9e486296
This commit is contained in:
parent
d31913d1e2
commit
64a509a9d2
@ -249,8 +249,9 @@ inline void ReluX(const tflite::ActivationParams& params,
|
||||
const T min_value = params.quantized_activation_min;
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
const T val = input_data[i];
|
||||
const T clamped =
|
||||
val > max_value ? max_value : val < min_value ? min_value : val;
|
||||
const T clamped = val > max_value ? max_value
|
||||
: val < min_value ? min_value
|
||||
: val;
|
||||
output_data[i] = clamped;
|
||||
}
|
||||
}
|
||||
@ -1345,7 +1346,6 @@ inline void LogSoftmax(const SoftmaxParams& params,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline void Dequantize(const RuntimeShape& input_shape,
|
||||
const Eigen::half* input_data,
|
||||
const RuntimeShape& output_shape, float* output_data) {
|
||||
@ -2309,11 +2309,16 @@ void RankOneSelect(const RuntimeShape& input_condition_shape,
|
||||
const RuntimeShape& input_y_shape, const T* input_y_data,
|
||||
const RuntimeShape& output_shape, T* output_data) {
|
||||
const int64_t outer_size = input_condition_shape.FlatSize();
|
||||
TFLITE_DCHECK_EQ(
|
||||
MatchingDim(input_x_shape, 0, input_y_shape, 0, output_shape, 0),
|
||||
outer_size);
|
||||
const int64_t inner_size =
|
||||
MatchingFlatSizeSkipDim(input_x_shape, 0, input_y_shape, output_shape);
|
||||
int64_t inner_size;
|
||||
if (input_condition_shape.DimensionsCount() == 0) {
|
||||
inner_size = MatchingFlatSize(input_x_shape, input_y_shape, output_shape);
|
||||
} else {
|
||||
TFLITE_DCHECK_EQ(
|
||||
MatchingDim(input_x_shape, 0, input_y_shape, 0, output_shape, 0),
|
||||
outer_size);
|
||||
inner_size =
|
||||
MatchingFlatSizeSkipDim(input_x_shape, 0, input_y_shape, output_shape);
|
||||
}
|
||||
|
||||
int64_t offset = 0;
|
||||
for (int64_t i = 0; i < outer_size; i++) {
|
||||
@ -2604,7 +2609,6 @@ void ReverseSequence(const TS* seq_lengths, const int seq_dim,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
inline void SegmentSum(const RuntimeShape& input_shape, const T* input_data,
|
||||
const RuntimeShape& segment_ids_shape,
|
||||
|
@ -38,13 +38,15 @@ enum KernelType {
|
||||
|
||||
struct OpData {
|
||||
bool requires_broadcast;
|
||||
bool has_rank_one_input_condition;
|
||||
// True if input condition is scalar or input condition has rank one and
|
||||
// matches the first dimension of other inputs.
|
||||
bool has_low_rank_input_condition;
|
||||
};
|
||||
|
||||
void* SelectInit(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
auto* data = new OpData;
|
||||
data->requires_broadcast = false;
|
||||
data->has_rank_one_input_condition = false;
|
||||
data->has_low_rank_input_condition = false;
|
||||
return data;
|
||||
}
|
||||
|
||||
@ -76,10 +78,13 @@ TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
if (!same_shape) {
|
||||
switch (kernel_type) {
|
||||
case kVersionOne: {
|
||||
data->has_rank_one_input_condition =
|
||||
bool is_input_condition_scalar = NumDimensions(input_condition) == 0;
|
||||
bool has_rank_one_input_condition =
|
||||
NumDimensions(input_condition) == 1 &&
|
||||
SizeOfDimension(input_condition, 0) == SizeOfDimension(input_x, 0);
|
||||
TF_LITE_ENSURE(context, data->has_rank_one_input_condition);
|
||||
data->has_low_rank_input_condition =
|
||||
is_input_condition_scalar || has_rank_one_input_condition;
|
||||
TF_LITE_ENSURE(context, data->has_low_rank_input_condition);
|
||||
|
||||
output_size = TfLiteIntArrayCopy(input_x->dims);
|
||||
|
||||
@ -151,7 +156,7 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return kTfLiteError; \
|
||||
}
|
||||
|
||||
if (data->has_rank_one_input_condition) {
|
||||
if (data->has_low_rank_input_condition) {
|
||||
TF_LITE_SWITCH(input_x->type, RankOneSelect);
|
||||
} else if (data->requires_broadcast) {
|
||||
TF_LITE_SWITCH(input_x->type, BroadcastSelect4DSlow);
|
||||
@ -170,7 +175,8 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// true or the value of 'y' if false. There are valid condition input sizes:
|
||||
//
|
||||
// 1. Either the same shape (in which case the select is elementwise), or
|
||||
// 2. condition must be Rank 1 and match over the first dimension.
|
||||
// 2. condition must be Rank 1 and match over the first dimension, or
|
||||
// 3. condition is scalar
|
||||
TfLiteRegistration* Register_SELECT() {
|
||||
static TfLiteRegistration r = {select::SelectInit, select::SelectFree,
|
||||
select::SelectPrepare<select::kVersionOne>,
|
||||
|
@ -174,6 +174,36 @@ TEST(SelectOpTest, RankOneSelectInt32) {
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 2, 1}));
|
||||
}
|
||||
|
||||
TEST(SelectOpTest, ScalarFalseConditionInt32) {
|
||||
if (SingleOpModel::GetForceUseNnapi()) {
|
||||
return;
|
||||
}
|
||||
SelectOpModel model({}, {2, 1, 2, 1}, {2, 1, 2, 1}, TensorType_INT32);
|
||||
|
||||
model.PopulateTensor<bool>(model.input1(), {false});
|
||||
model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4});
|
||||
model.PopulateTensor<int32_t>(model.input3(), {5, 6, 7, 8});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, 6, 7, 8}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 2, 1}));
|
||||
}
|
||||
|
||||
TEST(SelectOpTest, ScalarTrueConditionInt32) {
|
||||
if (SingleOpModel::GetForceUseNnapi()) {
|
||||
return;
|
||||
}
|
||||
SelectOpModel model({}, {2, 1, 2, 1}, {2, 1, 2, 1}, TensorType_INT32);
|
||||
|
||||
model.PopulateTensor<bool>(model.input1(), {true});
|
||||
model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4});
|
||||
model.PopulateTensor<int32_t>(model.input3(), {5, 6, 7, 8});
|
||||
model.Invoke();
|
||||
|
||||
EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({1, 2, 3, 4}));
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 2, 1}));
|
||||
}
|
||||
|
||||
TEST(SelectOpTest, RankZeroSelectInt32) {
|
||||
SelectOpModel model({1}, {1, 2, 2, 1}, {1, 2, 2, 1}, TensorType_INT32);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user