From 56942dddcc6a16f94c8f69451cba6482bf350f6a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Aug 2020 00:22:36 -0700 Subject: [PATCH] fixed dimension zero edge case. when one of the dimension is zero, where function should output empty output PiperOrigin-RevId: 327177867 Change-Id: Ib2848d1d02605a162534e0290ca20a262317f231 --- .../internal/reference/reference_ops.h | 4 ++++ tensorflow/lite/kernels/where.cc | 7 ++++++ tensorflow/lite/kernels/where_test.cc | 24 +++++++++++++++++++ tensorflow/lite/testing/op_tests/where.py | 5 ++++ 4 files changed, 40 insertions(+) diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h index b9434c5cfae..df771bcca27 100644 --- a/tensorflow/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h @@ -2384,6 +2384,10 @@ template void SelectTrueCoords(const RuntimeShape& input_condition_shape, const D* input_condition_data, T* output_data) { const size_t size = input_condition_shape.FlatSize(); + if (size == 0) { + // Dimension is zero, in which case we don't need to output. + return; + } const size_t cond_rank = input_condition_shape.DimensionsCount(); std::vector dims_to_count(cond_rank, 0); diff --git a/tensorflow/lite/kernels/where.cc b/tensorflow/lite/kernels/where.cc index a20efa8baaa..8eb09bf2798 100644 --- a/tensorflow/lite/kernels/where.cc +++ b/tensorflow/lite/kernels/where.cc @@ -90,6 +90,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { ResizeOutputTensor(context, cond_tensor, output)); } + TfLiteIntArray* dims = cond_tensor->dims; + if (dims->size == 0) { + // Scalar tensors are not supported. + TF_LITE_KERNEL_LOG(context, "Where op requires condition w/ rank > 0"); + return kTfLiteError; + } + reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor), GetTensorData(cond_tensor), GetTensorData(output)); diff --git a/tensorflow/lite/kernels/where_test.cc b/tensorflow/lite/kernels/where_test.cc index ba93bed6e74..4a77470e89f 100644 --- a/tensorflow/lite/kernels/where_test.cc +++ b/tensorflow/lite/kernels/where_test.cc @@ -51,6 +51,30 @@ class IntegerWhereOpModel : public BaseWhereOpModel { std::vector GetOutput() { return ExtractVector(output_); } }; +template +class ConstInputWhereOpModel : public SingleOpModel { + public: + ConstInputWhereOpModel(T1 constant_values, const TensorData& output) { + input_ = AddConstInput(GetTensorType(), {constant_values}, {}); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_WHERE, BuiltinOptions_WhereOptions, + CreateWhereOptions(builder_).Union()); + BuildInterpreter({{}}); + } + + int input() { return input_; } + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int output_; +}; + +TEST(WhereOpTest, ScalarValueFail) { + ConstInputWhereOpModel m(false, {TensorType_INT64, {}}); + EXPECT_EQ(m.InvokeUnchecked(), kTfLiteError); +} + TEST(WhereOpTest, SelectFromVectorNoResult) { IntegerWhereOpModel m({TensorType_BOOL, {3}}, {TensorType_INT64, {}}); m.PopulateTensor(m.input(), {false, false, false}); diff --git a/tensorflow/lite/testing/op_tests/where.py b/tensorflow/lite/testing/op_tests/where.py index 49802422e3f..90db8d56f25 100644 --- a/tensorflow/lite/testing/op_tests/where.py +++ b/tensorflow/lite/testing/op_tests/where.py @@ -33,6 +33,11 @@ def make_where_tests(options): "input_shape_set": [([1, 2, 3, 4], [1, 2, 3, 4]),], "use_where_v2": [False, True], }, + { + "input_dtype": [tf.float32, tf.int32], + "input_shape_set": [([], []),], + "use_where_v2": [], + }, ] def build_graph(parameters):