add support for start == limit case for range.

PiperOrigin-RevId: 327587859
Change-Id: I8b8edc23acbd5dba5e6cad95792259623f8342f1
This commit is contained in:
Renjie Liu 2020-08-20 01:39:36 -07:00 committed by TensorFlower Gardener
parent 2303ed4bdb
commit fce766941e
3 changed files with 12 additions and 3 deletions

View File

@ -41,8 +41,8 @@ template <typename T>
TfLiteStatus GetSize(TfLiteContext* context, T start, T limit, T delta,
int* size) {
TF_LITE_ENSURE(context, !std::equal_to<T>()(delta, 0));
TF_LITE_ENSURE(context,
(start > limit && delta < 0) || (start < limit && delta > 0));
TF_LITE_ENSURE(
context, (start >= limit && delta < 0) || (start <= limit && delta > 0));
*size =
(std::is_integral<T>::value
? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))

View File

@ -112,5 +112,14 @@ TEST(RangeOpModel, FloatNegativeDelta) {
EXPECT_THAT(model.GetOutput(), ElementsAre(10, 7, 4));
}
TEST(RangeOpModel, EmptyOutput) {
RangeOpModel<int32_t> model(TensorType_INT32);
model.PopulateTensor<int32_t>(model.start(), {0});
model.PopulateTensor<int32_t>(model.limit(), {0});
model.PopulateTensor<int32_t>(model.delta(), {1});
model.Invoke();
EXPECT_THAT(model.GetOutputShape(), ElementsAre(0));
}
} // namespace
} // namespace tflite

View File

@ -29,7 +29,7 @@ def make_range_tests(options):
test_parameters = [{
"dtype": [tf.int32, tf.float32],
"offset": [10, 100, 1000],
"offset": [10, 100, 1000, 0],
"delta": [1, 2, 3, 4, -1, -2, -3, -4],
}]