fixed dimension zero edge case. when one of the dimension is zero, where function should output empty output
PiperOrigin-RevId: 327177867 Change-Id: Ib2848d1d02605a162534e0290ca20a262317f231
This commit is contained in:
parent
0e23f5d4ab
commit
56942dddcc
@ -2384,6 +2384,10 @@ template <typename D, typename T>
|
||||
void SelectTrueCoords(const RuntimeShape& input_condition_shape,
|
||||
const D* input_condition_data, T* output_data) {
|
||||
const size_t size = input_condition_shape.FlatSize();
|
||||
if (size == 0) {
|
||||
// Dimension is zero, in which case we don't need to output.
|
||||
return;
|
||||
}
|
||||
const size_t cond_rank = input_condition_shape.DimensionsCount();
|
||||
|
||||
std::vector<int> dims_to_count(cond_rank, 0);
|
||||
|
@ -90,6 +90,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
ResizeOutputTensor(context, cond_tensor, output));
|
||||
}
|
||||
|
||||
TfLiteIntArray* dims = cond_tensor->dims;
|
||||
if (dims->size == 0) {
|
||||
// Scalar tensors are not supported.
|
||||
TF_LITE_KERNEL_LOG(context, "Where op requires condition w/ rank > 0");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
reference_ops::SelectTrueCoords(GetTensorShape(cond_tensor),
|
||||
GetTensorData<bool>(cond_tensor),
|
||||
GetTensorData<int64_t>(output));
|
||||
|
@ -51,6 +51,30 @@ class IntegerWhereOpModel : public BaseWhereOpModel {
|
||||
std::vector<int64_t> GetOutput() { return ExtractVector<int64_t>(output_); }
|
||||
};
|
||||
|
||||
template <typename T1>
|
||||
class ConstInputWhereOpModel : public SingleOpModel {
|
||||
public:
|
||||
ConstInputWhereOpModel(T1 constant_values, const TensorData& output) {
|
||||
input_ = AddConstInput(GetTensorType<T1>(), {constant_values}, {});
|
||||
output_ = AddOutput(output);
|
||||
SetBuiltinOp(BuiltinOperator_WHERE, BuiltinOptions_WhereOptions,
|
||||
CreateWhereOptions(builder_).Union());
|
||||
BuildInterpreter({{}});
|
||||
}
|
||||
|
||||
int input() { return input_; }
|
||||
std::vector<int64_t> GetOutput() { return ExtractVector<int64_t>(output_); }
|
||||
|
||||
protected:
|
||||
int input_;
|
||||
int output_;
|
||||
};
|
||||
|
||||
TEST(WhereOpTest, ScalarValueFail) {
|
||||
ConstInputWhereOpModel<bool> m(false, {TensorType_INT64, {}});
|
||||
EXPECT_EQ(m.InvokeUnchecked(), kTfLiteError);
|
||||
}
|
||||
|
||||
TEST(WhereOpTest, SelectFromVectorNoResult) {
|
||||
IntegerWhereOpModel m({TensorType_BOOL, {3}}, {TensorType_INT64, {}});
|
||||
m.PopulateTensor<bool>(m.input(), {false, false, false});
|
||||
|
@ -33,6 +33,11 @@ def make_where_tests(options):
|
||||
"input_shape_set": [([1, 2, 3, 4], [1, 2, 3, 4]),],
|
||||
"use_where_v2": [False, True],
|
||||
},
|
||||
{
|
||||
"input_dtype": [tf.float32, tf.int32],
|
||||
"input_shape_set": [([], []),],
|
||||
"use_where_v2": [],
|
||||
},
|
||||
]
|
||||
|
||||
def build_graph(parameters):
|
||||
|
Loading…
x
Reference in New Issue
Block a user