Fix batch value matching for the empty vector + clean up few typos.
PiperOrigin-RevId: 332545377 Change-Id: I81106152a42c7ae5c5d930faee45a574c909f7f7
This commit is contained in:
parent
0462de5b54
commit
6bdae6145a
tensorflow/lite/delegates/gpu
@ -501,6 +501,7 @@ absl::Status ConnectTwoNodes(GraphFloat32* graph, const Node* from_node,
|
||||
}
|
||||
|
||||
bool IsBatchMatchesForAllValues(const GraphFloat32& model) {
|
||||
if (model.values().empty()) return true;
|
||||
const int32_t b = model.values()[0]->tensor.shape.b;
|
||||
for (auto value : model.values()) {
|
||||
if (value->tensor.shape.b != b) {
|
||||
|
@ -257,7 +257,7 @@ absl::Status AddOutput(GraphFloat32* graph, const Node* from_node,
|
||||
absl::Status ConnectTwoNodes(GraphFloat32* graph, const Node* from_node,
|
||||
const Node* to_node, Value** output);
|
||||
|
||||
// @return true if all tensors have same batch value.
|
||||
// @return true if all tensors have same batch value or if model has no values.
|
||||
bool IsBatchMatchesForAllValues(const GraphFloat32& model);
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -510,6 +510,29 @@ TEST(Model, InsertNodeAfter) {
|
||||
EXPECT_THAT(graph.nodes(), ElementsAre(node1, new_node1, node2, new_node2));
|
||||
}
|
||||
|
||||
TEST(BatchMatchingTest, EmptyGraph) {
|
||||
GraphFloat32 graph;
|
||||
ASSERT_TRUE(IsBatchMatchesForAllValues(graph));
|
||||
}
|
||||
|
||||
TEST(BatchMatchingTest, AllMatch) {
|
||||
GraphFloat32 graph;
|
||||
Value* a = graph.NewValue();
|
||||
Value* b = graph.NewValue();
|
||||
a->tensor.shape = BHWC(1, 1, 1, 1);
|
||||
b->tensor.shape = BHWC(1, 1, 1, 1);
|
||||
ASSERT_TRUE(IsBatchMatchesForAllValues(graph));
|
||||
}
|
||||
|
||||
TEST(BatchMatchingTest, NotAllMatch) {
|
||||
GraphFloat32 graph;
|
||||
Value* a = graph.NewValue();
|
||||
Value* b = graph.NewValue();
|
||||
a->tensor.shape = BHWC(1, 1, 1, 1);
|
||||
b->tensor.shape = BHWC(2, 1, 1, 1);
|
||||
ASSERT_FALSE(IsBatchMatchesForAllValues(graph));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
@ -225,12 +225,13 @@ std::unique_ptr<NodeShader> NewElementwiseNodeShader(
|
||||
OperationType operation_type) {
|
||||
switch (operation_type) {
|
||||
case OperationType::ABS:
|
||||
case OperationType::COPY:
|
||||
case OperationType::COS:
|
||||
case OperationType::COPY:
|
||||
case OperationType::ELU:
|
||||
case OperationType::EXP:
|
||||
case OperationType::LOG:
|
||||
case OperationType::HARD_SWISH:
|
||||
case OperationType::LOG:
|
||||
case OperationType::NEG:
|
||||
case OperationType::RSQRT:
|
||||
case OperationType::SIGMOID:
|
||||
case OperationType::SIN:
|
||||
|
Loading…
Reference in New Issue
Block a user