Fix batch value matching for the empty vector + clean up few typos.

PiperOrigin-RevId: 332545377
Change-Id: I81106152a42c7ae5c5d930faee45a574c909f7f7
This commit is contained in:
A. Unique TensorFlower 2020-09-18 16:03:49 -07:00 committed by TensorFlower Gardener
parent 0462de5b54
commit 6bdae6145a
4 changed files with 28 additions and 3 deletions
tensorflow/lite/delegates/gpu

View File

@ -501,6 +501,7 @@ absl::Status ConnectTwoNodes(GraphFloat32* graph, const Node* from_node,
}
bool IsBatchMatchesForAllValues(const GraphFloat32& model) {
if (model.values().empty()) return true;
const int32_t b = model.values()[0]->tensor.shape.b;
for (auto value : model.values()) {
if (value->tensor.shape.b != b) {

View File

@ -257,7 +257,7 @@ absl::Status AddOutput(GraphFloat32* graph, const Node* from_node,
absl::Status ConnectTwoNodes(GraphFloat32* graph, const Node* from_node,
const Node* to_node, Value** output);
// @return true if all tensors have same batch value.
// @return true if all tensors have same batch value or if model has no values.
bool IsBatchMatchesForAllValues(const GraphFloat32& model);
} // namespace gpu

View File

@ -510,6 +510,29 @@ TEST(Model, InsertNodeAfter) {
EXPECT_THAT(graph.nodes(), ElementsAre(node1, new_node1, node2, new_node2));
}
TEST(BatchMatchingTest, EmptyGraph) {
GraphFloat32 graph;
ASSERT_TRUE(IsBatchMatchesForAllValues(graph));
}
TEST(BatchMatchingTest, AllMatch) {
GraphFloat32 graph;
Value* a = graph.NewValue();
Value* b = graph.NewValue();
a->tensor.shape = BHWC(1, 1, 1, 1);
b->tensor.shape = BHWC(1, 1, 1, 1);
ASSERT_TRUE(IsBatchMatchesForAllValues(graph));
}
TEST(BatchMatchingTest, NotAllMatch) {
GraphFloat32 graph;
Value* a = graph.NewValue();
Value* b = graph.NewValue();
a->tensor.shape = BHWC(1, 1, 1, 1);
b->tensor.shape = BHWC(2, 1, 1, 1);
ASSERT_FALSE(IsBatchMatchesForAllValues(graph));
}
} // namespace
} // namespace gpu
} // namespace tflite

View File

@ -225,12 +225,13 @@ std::unique_ptr<NodeShader> NewElementwiseNodeShader(
OperationType operation_type) {
switch (operation_type) {
case OperationType::ABS:
case OperationType::COPY:
case OperationType::COS:
case OperationType::COPY:
case OperationType::ELU:
case OperationType::EXP:
case OperationType::LOG:
case OperationType::HARD_SWISH:
case OperationType::LOG:
case OperationType::NEG:
case OperationType::RSQRT:
case OperationType::SIGMOID:
case OperationType::SIN: