Support scalar condition in Select op

PiperOrigin-RevId: 319477642
Change-Id: Ia9866838388d72cb1637cf0f29ae92ec9e486296
This commit is contained in:
Thai Nguyen 2020-07-02 23:52:12 -07:00 committed by TensorFlower Gardener
parent d31913d1e2
commit 64a509a9d2
3 changed files with 55 additions and 15 deletions

View File

@ -249,8 +249,9 @@ inline void ReluX(const tflite::ActivationParams& params,
const T min_value = params.quantized_activation_min;
for (int i = 0; i < flat_size; ++i) {
const T val = input_data[i];
const T clamped =
val > max_value ? max_value : val < min_value ? min_value : val;
const T clamped = val > max_value ? max_value
: val < min_value ? min_value
: val;
output_data[i] = clamped;
}
}
@ -1345,7 +1346,6 @@ inline void LogSoftmax(const SoftmaxParams& params,
}
}
inline void Dequantize(const RuntimeShape& input_shape,
const Eigen::half* input_data,
const RuntimeShape& output_shape, float* output_data) {
@ -2309,11 +2309,16 @@ void RankOneSelect(const RuntimeShape& input_condition_shape,
const RuntimeShape& input_y_shape, const T* input_y_data,
const RuntimeShape& output_shape, T* output_data) {
const int64_t outer_size = input_condition_shape.FlatSize();
TFLITE_DCHECK_EQ(
MatchingDim(input_x_shape, 0, input_y_shape, 0, output_shape, 0),
outer_size);
const int64_t inner_size =
MatchingFlatSizeSkipDim(input_x_shape, 0, input_y_shape, output_shape);
int64_t inner_size;
if (input_condition_shape.DimensionsCount() == 0) {
inner_size = MatchingFlatSize(input_x_shape, input_y_shape, output_shape);
} else {
TFLITE_DCHECK_EQ(
MatchingDim(input_x_shape, 0, input_y_shape, 0, output_shape, 0),
outer_size);
inner_size =
MatchingFlatSizeSkipDim(input_x_shape, 0, input_y_shape, output_shape);
}
int64_t offset = 0;
for (int64_t i = 0; i < outer_size; i++) {
@ -2604,7 +2609,6 @@ void ReverseSequence(const TS* seq_lengths, const int seq_dim,
}
}
template <typename T>
inline void SegmentSum(const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& segment_ids_shape,

View File

@ -38,13 +38,15 @@ enum KernelType {
struct OpData {
bool requires_broadcast;
bool has_rank_one_input_condition;
// True if input condition is scalar or input condition has rank one and
// matches the first dimension of other inputs.
bool has_low_rank_input_condition;
};
void* SelectInit(TfLiteContext* context, const char* buffer, size_t length) {
auto* data = new OpData;
data->requires_broadcast = false;
data->has_rank_one_input_condition = false;
data->has_low_rank_input_condition = false;
return data;
}
@ -76,10 +78,13 @@ TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
if (!same_shape) {
switch (kernel_type) {
case kVersionOne: {
data->has_rank_one_input_condition =
bool is_input_condition_scalar = NumDimensions(input_condition) == 0;
bool has_rank_one_input_condition =
NumDimensions(input_condition) == 1 &&
SizeOfDimension(input_condition, 0) == SizeOfDimension(input_x, 0);
TF_LITE_ENSURE(context, data->has_rank_one_input_condition);
data->has_low_rank_input_condition =
is_input_condition_scalar || has_rank_one_input_condition;
TF_LITE_ENSURE(context, data->has_low_rank_input_condition);
output_size = TfLiteIntArrayCopy(input_x->dims);
@ -151,7 +156,7 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError; \
}
if (data->has_rank_one_input_condition) {
if (data->has_low_rank_input_condition) {
TF_LITE_SWITCH(input_x->type, RankOneSelect);
} else if (data->requires_broadcast) {
TF_LITE_SWITCH(input_x->type, BroadcastSelect4DSlow);
@ -170,7 +175,8 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
// true or the value of 'y' if false. There are valid condition input sizes:
//
// 1. Either the same shape (in which case the select is elementwise), or
// 2. condition must be Rank 1 and match over the first dimension.
// 2. condition must be Rank 1 and match over the first dimension, or
// 3. condition is scalar
TfLiteRegistration* Register_SELECT() {
static TfLiteRegistration r = {select::SelectInit, select::SelectFree,
select::SelectPrepare<select::kVersionOne>,

View File

@ -174,6 +174,36 @@ TEST(SelectOpTest, RankOneSelectInt32) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 2, 1}));
}
TEST(SelectOpTest, ScalarFalseConditionInt32) {
if (SingleOpModel::GetForceUseNnapi()) {
return;
}
SelectOpModel model({}, {2, 1, 2, 1}, {2, 1, 2, 1}, TensorType_INT32);
model.PopulateTensor<bool>(model.input1(), {false});
model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4});
model.PopulateTensor<int32_t>(model.input3(), {5, 6, 7, 8});
model.Invoke();
EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, 6, 7, 8}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 2, 1}));
}
TEST(SelectOpTest, ScalarTrueConditionInt32) {
if (SingleOpModel::GetForceUseNnapi()) {
return;
}
SelectOpModel model({}, {2, 1, 2, 1}, {2, 1, 2, 1}, TensorType_INT32);
model.PopulateTensor<bool>(model.input1(), {true});
model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4});
model.PopulateTensor<int32_t>(model.input3(), {5, 6, 7, 8});
model.Invoke();
EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({1, 2, 3, 4}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 2, 1}));
}
TEST(SelectOpTest, RankZeroSelectInt32) {
SelectOpModel model({1}, {1, 2, 2, 1}, {1, 2, 2, 1}, TensorType_INT32);