add support for start == limit case for range.
PiperOrigin-RevId: 327587859 Change-Id: I8b8edc23acbd5dba5e6cad95792259623f8342f1
This commit is contained in:
parent
2303ed4bdb
commit
fce766941e
@ -41,8 +41,8 @@ template <typename T>
|
||||
TfLiteStatus GetSize(TfLiteContext* context, T start, T limit, T delta,
|
||||
int* size) {
|
||||
TF_LITE_ENSURE(context, !std::equal_to<T>()(delta, 0));
|
||||
TF_LITE_ENSURE(context,
|
||||
(start > limit && delta < 0) || (start < limit && delta > 0));
|
||||
TF_LITE_ENSURE(
|
||||
context, (start >= limit && delta < 0) || (start <= limit && delta > 0));
|
||||
*size =
|
||||
(std::is_integral<T>::value
|
||||
? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
|
||||
|
||||
@ -112,5 +112,14 @@ TEST(RangeOpModel, FloatNegativeDelta) {
|
||||
EXPECT_THAT(model.GetOutput(), ElementsAre(10, 7, 4));
|
||||
}
|
||||
|
||||
TEST(RangeOpModel, EmptyOutput) {
|
||||
RangeOpModel<int32_t> model(TensorType_INT32);
|
||||
model.PopulateTensor<int32_t>(model.start(), {0});
|
||||
model.PopulateTensor<int32_t>(model.limit(), {0});
|
||||
model.PopulateTensor<int32_t>(model.delta(), {1});
|
||||
model.Invoke();
|
||||
EXPECT_THAT(model.GetOutputShape(), ElementsAre(0));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
@ -29,7 +29,7 @@ def make_range_tests(options):
|
||||
|
||||
test_parameters = [{
|
||||
"dtype": [tf.int32, tf.float32],
|
||||
"offset": [10, 100, 1000],
|
||||
"offset": [10, 100, 1000, 0],
|
||||
"delta": [1, 2, 3, 4, -1, -2, -3, -4],
|
||||
}]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user