fixed dimension zero edge case. when one of the dimension is zero, where function should output empty output

PiperOrigin-RevId: 327177867
Change-Id: Ib2848d1d02605a162534e0290ca20a262317f231
This commit is contained in:
A. Unique TensorFlower 2020-08-18 00:22:36 -07:00 committed by TensorFlower Gardener
parent 0e23f5d4ab
commit 56942dddcc
4 changed files with 40 additions and 0 deletions

View File

@ -2384,6 +2384,10 @@ template <typename D, typename T>
void SelectTrueCoords(const RuntimeShape& input_condition_shape,
const D* input_condition_data, T* output_data) {
const size_t size = input_condition_shape.FlatSize();
if (size == 0) {
// Dimension is zero, in which case we don't need to output.
return;
}
const size_t cond_rank = input_condition_shape.DimensionsCount();
std::vector<int> dims_to_count(cond_rank, 0);

View File

@ -90,6 +90,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
ResizeOutputTensor(context, cond_tensor, output));
}
TfLiteIntArray* dims = cond_tensor->dims;
if (dims->size == 0) {
// Scalar tensors are not supported.
TF_LITE_KERNEL_LOG(context, "Where op requires condition w/ rank > 0");
return kTfLiteError;
}
reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
GetTensorData<bool>(cond_tensor),
GetTensorData<int64_t>(output));

View File

@ -51,6 +51,30 @@ class IntegerWhereOpModel : public BaseWhereOpModel {
std::vector<int64_t> GetOutput() { return ExtractVector<int64_t>(output_); }
};
template <typename T1>
class ConstInputWhereOpModel : public SingleOpModel {
public:
ConstInputWhereOpModel(T1 constant_values, const TensorData& output) {
input_ = AddConstInput(GetTensorType<T1>(), {constant_values}, {});
output_ = AddOutput(output);
SetBuiltinOp(BuiltinOperator_WHERE, BuiltinOptions_WhereOptions,
CreateWhereOptions(builder_).Union());
BuildInterpreter({{}});
}
int input() { return input_; }
std::vector<int64_t> GetOutput() { return ExtractVector<int64_t>(output_); }
protected:
int input_;
int output_;
};
TEST(WhereOpTest, ScalarValueFail) {
ConstInputWhereOpModel<bool> m(false, {TensorType_INT64, {}});
EXPECT_EQ(m.InvokeUnchecked(), kTfLiteError);
}
TEST(WhereOpTest, SelectFromVectorNoResult) {
IntegerWhereOpModel m({TensorType_BOOL, {3}}, {TensorType_INT64, {}});
m.PopulateTensor<bool>(m.input(), {false, false, false});

View File

@ -33,6 +33,11 @@ def make_where_tests(options):
"input_shape_set": [([1, 2, 3, 4], [1, 2, 3, 4]),],
"use_where_v2": [False, True],
},
{
"input_dtype": [tf.float32, tf.int32],
"input_shape_set": [([], []),],
"use_where_v2": [],
},
]
def build_graph(parameters):