[Grappler] Migrate FrameView to use utils::GraphView/utils::MutableGraphView.
PiperOrigin-RevId: 251659253
This commit is contained in:
parent
cd09510f92
commit
d15c612f77
tensorflow/core/grappler
@ -498,6 +498,7 @@ cc_library(
|
||||
"//tensorflow/core/grappler:devices",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core/grappler/costs:virtual_placer",
|
||||
@ -703,6 +704,7 @@ cc_library(
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core/grappler/utils:frame",
|
||||
"//tensorflow/core/grappler/utils:traversal",
|
||||
@ -724,6 +726,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
"//tensorflow/core/grappler/utils:graph_view",
|
||||
"//tensorflow/core/grappler/utils:grappler_test",
|
||||
],
|
||||
)
|
||||
@ -883,6 +886,7 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core/grappler/utils:frame",
|
||||
],
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
|
||||
|
||||
#include <deque>
|
||||
#include <unordered_set>
|
||||
|
||||
@ -28,7 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/devices.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/frame.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
|
@ -17,8 +17,10 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/frame.h"
|
||||
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||
|
||||
|
@ -14,12 +14,14 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
|
||||
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/graph_view.h"
|
||||
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -104,26 +106,42 @@ TEST_F(LoopOptimizerTest, Basic) {
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
{ // Original graph.
|
||||
GraphView view(&graph);
|
||||
Status status;
|
||||
utils::GraphView view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
FrameView frames;
|
||||
TF_EXPECT_OK(frames.InferFromGraphView(view));
|
||||
|
||||
EXPECT_EQ(frames.num_frames(), 1);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).back(), 0);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("VariantAdd")).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("VariantAdd")).back(), 0);
|
||||
const auto* invariant_add_node = view.GetNode("InvariantAdd");
|
||||
ASSERT_NE(invariant_add_node, nullptr);
|
||||
const auto* invariant_add_node_def = invariant_add_node->node();
|
||||
ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
|
||||
const auto* variant_add_node = view.GetNode("VariantAdd");
|
||||
ASSERT_NE(variant_add_node, nullptr);
|
||||
const auto* variant_add_node_def = variant_add_node->node();
|
||||
ASSERT_EQ(frames.Frames(*variant_add_node_def).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*variant_add_node_def).back(), 0);
|
||||
}
|
||||
|
||||
{ // Optimized graph.
|
||||
GraphView view(&output);
|
||||
Status status;
|
||||
utils::GraphView view(&output, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
FrameView frames;
|
||||
TF_EXPECT_OK(frames.InferFromGraphView(view));
|
||||
|
||||
EXPECT_EQ(frames.num_frames(), 1);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).size(), 0);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("VariantAdd")).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("VariantAdd")).back(), 0);
|
||||
const auto* invariant_add_node = view.GetNode("InvariantAdd");
|
||||
ASSERT_NE(invariant_add_node, nullptr);
|
||||
const auto* invariant_add_node_def = invariant_add_node->node();
|
||||
ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 0);
|
||||
const auto* variant_add_node = view.GetNode("VariantAdd");
|
||||
ASSERT_NE(variant_add_node, nullptr);
|
||||
const auto* variant_add_node_def = variant_add_node->node();
|
||||
ASSERT_EQ(frames.Frames(*variant_add_node_def).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*variant_add_node_def).back(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -155,25 +173,41 @@ TEST_F(LoopOptimizerTest, Const) {
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
{ // Original graph.
|
||||
GraphView view(&graph);
|
||||
Status status;
|
||||
utils::GraphView view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
FrameView frames;
|
||||
TF_EXPECT_OK(frames.InferFromGraphView(view));
|
||||
|
||||
EXPECT_EQ(frames.num_frames(), 1);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).back(), 0);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("Const")).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("Const")).back(), 0);
|
||||
const auto* invariant_add_node = view.GetNode("InvariantAdd");
|
||||
ASSERT_NE(invariant_add_node, nullptr);
|
||||
const auto* invariant_add_node_def = invariant_add_node->node();
|
||||
ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
|
||||
const auto* const_node = view.GetNode("Const");
|
||||
ASSERT_NE(const_node, nullptr);
|
||||
const auto* const_node_node_def = const_node->node();
|
||||
ASSERT_EQ(frames.Frames(*const_node_node_def).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*const_node_node_def).back(), 0);
|
||||
}
|
||||
|
||||
{ // Optimized graph.
|
||||
GraphView view(&output);
|
||||
Status status;
|
||||
utils::GraphView view(&output, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
FrameView frames;
|
||||
TF_EXPECT_OK(frames.InferFromGraphView(view));
|
||||
|
||||
EXPECT_EQ(frames.num_frames(), 1);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).size(), 0);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("Const")).size(), 0);
|
||||
const auto* invariant_add_node = view.GetNode("InvariantAdd");
|
||||
ASSERT_NE(invariant_add_node, nullptr);
|
||||
const auto* invariant_add_node_def = invariant_add_node->node();
|
||||
ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 0);
|
||||
const auto* const_node = view.GetNode("Const");
|
||||
ASSERT_NE(const_node, nullptr);
|
||||
const auto* const_node_node_def = const_node->node();
|
||||
ASSERT_EQ(frames.Frames(*const_node_node_def).size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -206,23 +240,33 @@ TEST_F(LoopOptimizerTest, ControlOutput) {
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
{ // Original graph.
|
||||
GraphView view(&graph);
|
||||
Status status;
|
||||
utils::GraphView view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
FrameView frames;
|
||||
TF_EXPECT_OK(frames.InferFromGraphView(view));
|
||||
|
||||
EXPECT_EQ(frames.num_frames(), 1);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).back(), 0);
|
||||
const auto* invariant_add_node = view.GetNode("InvariantAdd");
|
||||
ASSERT_NE(invariant_add_node, nullptr);
|
||||
const auto* invariant_add_node_def = invariant_add_node->node();
|
||||
ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
|
||||
}
|
||||
|
||||
{ // Optimized graph.
|
||||
GraphView view(&output);
|
||||
Status status;
|
||||
utils::GraphView view(&output, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
FrameView frames;
|
||||
TF_EXPECT_OK(frames.InferFromGraphView(view));
|
||||
|
||||
EXPECT_EQ(frames.num_frames(), 1);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).back(), 0);
|
||||
const auto* invariant_add_node = view.GetNode("InvariantAdd");
|
||||
ASSERT_NE(invariant_add_node, nullptr);
|
||||
const auto* invariant_add_node_def = invariant_add_node->node();
|
||||
ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -270,30 +314,52 @@ TEST_F(LoopOptimizerTest, NestedLoop1) {
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
{ // Original graph.
|
||||
GraphView view(&graph);
|
||||
Status status;
|
||||
utils::GraphView view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
FrameView frames;
|
||||
TF_EXPECT_OK(frames.InferFromGraphView(view));
|
||||
|
||||
EXPECT_EQ(frames.num_frames(), 2);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).back(), 1);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("VariantAdd2")).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("VariantAdd2")).back(), 1);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).back(), 0);
|
||||
const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
|
||||
ASSERT_NE(invariant_add_2_node, nullptr);
|
||||
const auto* invariant_add_2_node_def = invariant_add_2_node->node();
|
||||
ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 1);
|
||||
const auto* variant_add_2_node = view.GetNode("VariantAdd2");
|
||||
ASSERT_NE(variant_add_2_node, nullptr);
|
||||
const auto* variant_add_2_node_def = variant_add_2_node->node();
|
||||
ASSERT_EQ(frames.Frames(*variant_add_2_node_def).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*variant_add_2_node_def).back(), 1);
|
||||
const auto* invariant_add_node = view.GetNode("InvariantAdd");
|
||||
ASSERT_NE(invariant_add_node, nullptr);
|
||||
const auto* invariant_add_node_def = invariant_add_node->node();
|
||||
ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
|
||||
}
|
||||
|
||||
{ // Optimized graph.
|
||||
GraphView view(&output);
|
||||
Status status;
|
||||
utils::GraphView view(&output, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
FrameView frames;
|
||||
TF_EXPECT_OK(frames.InferFromGraphView(view));
|
||||
|
||||
EXPECT_EQ(frames.num_frames(), 2);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).back(), 0);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("VariantAdd2")).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("VariantAdd2")).back(), 1);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).size(), 0);
|
||||
const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
|
||||
ASSERT_NE(invariant_add_2_node, nullptr);
|
||||
const auto* invariant_add_2_node_def = invariant_add_2_node->node();
|
||||
ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 0);
|
||||
const auto* variant_add_2_node = view.GetNode("VariantAdd2");
|
||||
ASSERT_NE(variant_add_2_node, nullptr);
|
||||
const auto* variant_add_2_node_def = variant_add_2_node->node();
|
||||
ASSERT_EQ(frames.Frames(*variant_add_2_node_def).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*variant_add_2_node_def).back(), 1);
|
||||
const auto* invariant_add_node = view.GetNode("InvariantAdd");
|
||||
ASSERT_NE(invariant_add_node, nullptr);
|
||||
const auto* invariant_add_node_def = invariant_add_node->node();
|
||||
ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -341,26 +407,42 @@ TEST_F(LoopOptimizerTest, NestedLoop2) {
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
{ // Original graph.
|
||||
GraphView view(&graph);
|
||||
Status status;
|
||||
utils::GraphView view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
FrameView frames;
|
||||
TF_EXPECT_OK(frames.InferFromGraphView(view));
|
||||
|
||||
EXPECT_EQ(frames.num_frames(), 2);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).back(), 1);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("VariantAdd2")).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("VariantAdd2")).back(), 1);
|
||||
const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
|
||||
ASSERT_NE(invariant_add_2_node, nullptr);
|
||||
const auto* invariant_add_2_node_def = invariant_add_2_node->node();
|
||||
ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 1);
|
||||
const auto* variant_add_2_node = view.GetNode("VariantAdd2");
|
||||
ASSERT_NE(variant_add_2_node, nullptr);
|
||||
const auto* variant_add_2_node_def = variant_add_2_node->node();
|
||||
ASSERT_EQ(frames.Frames(*variant_add_2_node_def).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*variant_add_2_node_def).back(), 1);
|
||||
}
|
||||
|
||||
{ // Optimized graph.
|
||||
GraphView view(&output);
|
||||
Status status;
|
||||
utils::GraphView view(&output, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
FrameView frames;
|
||||
TF_EXPECT_OK(frames.InferFromGraphView(view));
|
||||
|
||||
EXPECT_EQ(frames.num_frames(), 2);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).size(), 0);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("VariantAdd2")).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("VariantAdd2")).back(), 1);
|
||||
const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
|
||||
ASSERT_NE(invariant_add_2_node, nullptr);
|
||||
const auto* invariant_add_2_node_def = invariant_add_2_node->node();
|
||||
ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 0);
|
||||
const auto* variant_add_2_node = view.GetNode("VariantAdd2");
|
||||
ASSERT_NE(variant_add_2_node, nullptr);
|
||||
const auto* variant_add_2_node_def = variant_add_2_node->node();
|
||||
ASSERT_EQ(frames.Frames(*variant_add_2_node_def).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*variant_add_2_node_def).back(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
@ -408,27 +490,43 @@ TEST_F(LoopOptimizerTest, NestedLoopConst1) {
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
{ // Original graph.
|
||||
GraphView view(&graph);
|
||||
Status status;
|
||||
utils::GraphView view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
FrameView frames;
|
||||
TF_EXPECT_OK(frames.InferFromGraphView(view));
|
||||
|
||||
EXPECT_EQ(frames.num_frames(), 2);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).back(), 1);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("Const2")).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("Const2")).back(), 1);
|
||||
const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
|
||||
ASSERT_NE(invariant_add_2_node, nullptr);
|
||||
const auto* invariant_add_2_node_def = invariant_add_2_node->node();
|
||||
ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 1);
|
||||
const auto* const_2_node = view.GetNode("Const2");
|
||||
ASSERT_NE(const_2_node, nullptr);
|
||||
const auto* const_2_node_def = const_2_node->node();
|
||||
ASSERT_EQ(frames.Frames(*const_2_node_def).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*const_2_node_def).back(), 1);
|
||||
}
|
||||
|
||||
{ // Optimized graph.
|
||||
GraphView view(&output);
|
||||
Status status;
|
||||
utils::GraphView view(&output, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
FrameView frames;
|
||||
TF_EXPECT_OK(frames.InferFromGraphView(view));
|
||||
|
||||
EXPECT_EQ(frames.num_frames(), 2);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).back(), 0);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("Const2")).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("Const2")).back(), 0);
|
||||
const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
|
||||
ASSERT_NE(invariant_add_2_node, nullptr);
|
||||
const auto* invariant_add_2_node_def = invariant_add_2_node->node();
|
||||
ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 0);
|
||||
const auto* const_2_node = view.GetNode("Const2");
|
||||
ASSERT_NE(const_2_node, nullptr);
|
||||
const auto* const_2_node_def = const_2_node->node();
|
||||
ASSERT_EQ(frames.Frames(*const_2_node_def).size(), 1);
|
||||
EXPECT_EQ(frames.Frames(*const_2_node_def).back(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -476,25 +574,41 @@ TEST_F(LoopOptimizerTest, NestedLoopConst2) {
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
{ // Original graph.
|
||||
GraphView view(&graph);
|
||||
Status status;
|
||||
utils::GraphView view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
FrameView frames;
|
||||
TF_EXPECT_OK(frames.InferFromGraphView(view));
|
||||
|
||||
EXPECT_EQ(frames.num_frames(), 2);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).back(), 1);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("Const2")).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*view.GetNode("Const2")).back(), 1);
|
||||
const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
|
||||
ASSERT_NE(invariant_add_2_node, nullptr);
|
||||
const auto* invariant_add_2_node_def = invariant_add_2_node->node();
|
||||
ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 1);
|
||||
const auto* const_2_node = view.GetNode("Const2");
|
||||
ASSERT_NE(const_2_node, nullptr);
|
||||
const auto* const_2_node_def = const_2_node->node();
|
||||
ASSERT_EQ(frames.Frames(*const_2_node_def).size(), 2);
|
||||
EXPECT_EQ(frames.Frames(*const_2_node_def).back(), 1);
|
||||
}
|
||||
|
||||
{ // Optimized graph.
|
||||
GraphView view(&output);
|
||||
Status status;
|
||||
utils::GraphView view(&output, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
FrameView frames;
|
||||
TF_EXPECT_OK(frames.InferFromGraphView(view));
|
||||
|
||||
EXPECT_EQ(frames.num_frames(), 2);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).size(), 0);
|
||||
ASSERT_EQ(frames.Frames(*view.GetNode("Const2")).size(), 0);
|
||||
const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
|
||||
ASSERT_NE(invariant_add_2_node, nullptr);
|
||||
const auto* invariant_add_2_node_def = invariant_add_2_node->node();
|
||||
ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 0);
|
||||
const auto* const_2_node = view.GetNode("Const2");
|
||||
ASSERT_NE(const_2_node, nullptr);
|
||||
const auto* const_2_node_def = const_2_node->node();
|
||||
ASSERT_EQ(frames.Frames(*const_2_node_def).size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -78,10 +78,10 @@ cc_library(
|
||||
hdrs = ["frame.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_view",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:graph_view",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
@ -93,6 +93,8 @@ tf_cc_test(
|
||||
srcs = ["frame_test.cc"],
|
||||
deps = [
|
||||
":frame",
|
||||
":graph_view",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -14,10 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/utils/frame.h"
|
||||
|
||||
#include <deque>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
@ -26,101 +27,134 @@ namespace grappler {
|
||||
|
||||
namespace {} // namespace
|
||||
|
||||
Status FrameView::InferFromGraphView(const GraphView& graph_view) {
|
||||
template <typename GraphViewT>
|
||||
inline Status FrameView::InferFromGraphViewT(const GraphViewT& graph_view) {
|
||||
if (is_inferred_) {
|
||||
return errors::Internal("FrameView was already inferred from the graph");
|
||||
}
|
||||
is_inferred_ = true;
|
||||
|
||||
std::deque<const NodeDef*> ready_nodes;
|
||||
std::deque<int> ready_node_indices;
|
||||
|
||||
// All nodes without inputs are automatically added to the ready queue.
|
||||
for (const NodeDef& node : graph_view.graph()->node()) {
|
||||
if (node.input_size() == 0) {
|
||||
ready_nodes.push_back(&node);
|
||||
node_to_frames_[&node] = node_has_no_frames_;
|
||||
for (const auto& node : graph_view.GetNodes()) {
|
||||
if (node.NumRegularFanins() + node.NumControllingFanins() == 0) {
|
||||
ready_node_indices.push_back(node.node_index());
|
||||
node_to_frames_[node.node()] = node_has_no_frames_;
|
||||
}
|
||||
}
|
||||
|
||||
const auto* graph = graph_view.graph();
|
||||
|
||||
// We assign unique int id to each frame, and use this map to track what
|
||||
// frames we've already seen in the graph.
|
||||
absl::flat_hash_map<string, int> frame_name_to_id;
|
||||
|
||||
while (!ready_nodes.empty()) {
|
||||
const NodeDef* ready_node = ready_nodes.front();
|
||||
auto process_fanout = [this, graph](
|
||||
absl::flat_hash_map<string, int>* frame_name_to_id,
|
||||
std::deque<int>* ready_node_indices,
|
||||
const NodeDef* ready_node, int fanout_node_index) {
|
||||
const NodeDef* fanout_node = &graph->node(fanout_node_index);
|
||||
if (!node_to_frames_.contains(fanout_node)) {
|
||||
// If we have never seen this node before, we add all frames from the
|
||||
// incoming node (and pop/push frames if coming from Exit/Enter nodes).
|
||||
std::vector<int> frame_ids = node_to_frames_[ready_node];
|
||||
|
||||
absl::flat_hash_set<GraphView::InputPort> fanouts =
|
||||
graph_view.GetFanouts(*ready_node, /*include_controlled_nodes=*/true);
|
||||
if (IsExit(*ready_node)) {
|
||||
frame_ids.pop_back();
|
||||
}
|
||||
|
||||
for (const GraphView::InputPort& fanout : fanouts) {
|
||||
if (node_to_frames_.count(fanout.node) < 1) {
|
||||
// If we have never seen this node before, we add all frames from the
|
||||
// incoming node (and pop/push frames if coming from Exit/Enter nodes).
|
||||
std::vector<int> frame_ids = node_to_frames_[ready_node];
|
||||
if (IsEnter(*fanout_node)) {
|
||||
const AttrValue* frame_name_attr =
|
||||
AttrSlice(*fanout_node).Find("frame_name");
|
||||
|
||||
if (IsExit(*ready_node)) {
|
||||
frame_ids.pop_back();
|
||||
}
|
||||
|
||||
if (IsEnter(*fanout.node)) {
|
||||
const AttrValue* frame_name_attr =
|
||||
AttrSlice(*fanout.node).Find("frame_name");
|
||||
|
||||
if (!frame_name_attr) {
|
||||
return errors::InvalidArgument(
|
||||
"Missing frame name for the Enter node: ",
|
||||
SummarizeNodeDef(*fanout.node));
|
||||
}
|
||||
|
||||
absl::string_view frame_name = frame_name_attr->s();
|
||||
int frame_id;
|
||||
|
||||
if (frame_name_to_id.count(frame_name)) {
|
||||
frame_id = frame_name_to_id[frame_name];
|
||||
} else {
|
||||
frame_id = static_cast<int>(frame_name_to_id.size());
|
||||
frame_name_to_id[frame_name] = frame_id;
|
||||
}
|
||||
|
||||
frame_ids.push_back(frame_id);
|
||||
}
|
||||
|
||||
ready_nodes.push_back(fanout.node);
|
||||
node_to_frames_[fanout.node] = std::move(frame_ids);
|
||||
|
||||
} else {
|
||||
// If we've already seen this node before, we need to make sure that
|
||||
// graph is correct and same nodes doesn't have incoming edges with
|
||||
// conflicting frames (all inputs must be produces in the same frame).
|
||||
|
||||
std::vector<int> frame_ids_fanout = node_to_frames_[fanout.node];
|
||||
std::vector<int> frame_ids_node = node_to_frames_[ready_node];
|
||||
|
||||
if (IsEnter(*fanout.node)) {
|
||||
frame_ids_fanout.pop_back();
|
||||
}
|
||||
if (IsExit(*ready_node)) {
|
||||
frame_ids_node.pop_back();
|
||||
}
|
||||
|
||||
if (frame_ids_node != frame_ids_fanout) {
|
||||
if (!frame_name_attr) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid graph: Frame ids for node ", ready_node->name(),
|
||||
" does not match frame ids for it's fanout ",
|
||||
fanout.node->name());
|
||||
"Missing frame name for the Enter node: ",
|
||||
SummarizeNodeDef(*fanout_node));
|
||||
}
|
||||
|
||||
const string& frame_name = frame_name_attr->s();
|
||||
int frame_id;
|
||||
|
||||
if (frame_name_to_id->contains(frame_name)) {
|
||||
frame_id = (*frame_name_to_id)[frame_name];
|
||||
} else {
|
||||
frame_id = static_cast<int>(frame_name_to_id->size());
|
||||
(*frame_name_to_id)[frame_name] = frame_id;
|
||||
}
|
||||
|
||||
frame_ids.push_back(frame_id);
|
||||
}
|
||||
|
||||
ready_node_indices->push_back(fanout_node_index);
|
||||
node_to_frames_[fanout_node] = std::move(frame_ids);
|
||||
|
||||
} else {
|
||||
// If we've already seen this node before, we need to make sure that graph
|
||||
// is correct and same nodes doesn't have incoming edges with conflicting
|
||||
// frames (all inputs must be produces in the same frame).
|
||||
|
||||
std::vector<int> frame_ids_fanout = node_to_frames_[fanout_node];
|
||||
std::vector<int> frame_ids_node = node_to_frames_[ready_node];
|
||||
|
||||
if (IsEnter(*fanout_node)) {
|
||||
frame_ids_fanout.pop_back();
|
||||
}
|
||||
if (IsExit(*ready_node)) {
|
||||
frame_ids_node.pop_back();
|
||||
}
|
||||
|
||||
if (frame_ids_node != frame_ids_fanout) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid graph: Frame ids for node ", ready_node->name(),
|
||||
" does not match frame ids for it's fanout ", fanout_node->name());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
while (!ready_node_indices.empty()) {
|
||||
const int ready_node_index = ready_node_indices.front();
|
||||
ready_node_indices.pop_front();
|
||||
const auto* ready_node_view = graph_view.GetNode(ready_node_index);
|
||||
const NodeDef* ready_node_def = ready_node_view->node();
|
||||
|
||||
for (const auto& regular_fanouts_port_i :
|
||||
ready_node_view->GetRegularFanouts()) {
|
||||
for (const auto& regular_fanout : regular_fanouts_port_i) {
|
||||
TF_RETURN_IF_ERROR(process_fanout(&frame_name_to_id,
|
||||
&ready_node_indices, ready_node_def,
|
||||
regular_fanout.node_index()));
|
||||
}
|
||||
}
|
||||
|
||||
ready_nodes.pop_front();
|
||||
for (const auto& controlled_fanout :
|
||||
ready_node_view->GetControlledFanouts()) {
|
||||
TF_RETURN_IF_ERROR(process_fanout(&frame_name_to_id, &ready_node_indices,
|
||||
ready_node_def,
|
||||
controlled_fanout.node_index()));
|
||||
}
|
||||
}
|
||||
|
||||
num_frames_ = static_cast<int>(frame_name_to_id.size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FrameView::InferFromGraphView(const utils::GraphView& graph_view) {
|
||||
return InferFromGraphViewT(graph_view);
|
||||
}
|
||||
|
||||
Status FrameView::InferFromGraphView(
|
||||
const utils::MutableGraphView& graph_view) {
|
||||
return InferFromGraphViewT(graph_view);
|
||||
}
|
||||
|
||||
Status FrameView::InferFromGraph(const GraphDef& graph) {
|
||||
return InferFromGraphView(GraphView(&graph));
|
||||
Status status;
|
||||
utils::GraphView graph_view(&graph, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
return InferFromGraphViewT(graph_view);
|
||||
}
|
||||
|
||||
const std::vector<int>& FrameView::Frames(const NodeDef& node) const {
|
||||
|
@ -16,10 +16,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/grappler/utils/graph_view.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -40,7 +39,10 @@ class FrameView {
|
||||
|
||||
// Infers nodes execution frames from the GraphView. Returns an error if
|
||||
// called multiple times.
|
||||
Status InferFromGraphView(const GraphView& graph_view);
|
||||
Status InferFromGraphView(const utils::GraphView& graph_view);
|
||||
// Infers nodes execution frames from the MutableGraphView. Returns an error
|
||||
// if called multiple times.
|
||||
Status InferFromGraphView(const utils::MutableGraphView& graph_view);
|
||||
// Infers nodes execution by constructing temporary GraphView and passing it
|
||||
// to InferFromGraphView.
|
||||
Status InferFromGraph(const GraphDef& graph);
|
||||
@ -56,6 +58,9 @@ class FrameView {
|
||||
bool is_inferred() const { return is_inferred_; }
|
||||
|
||||
private:
|
||||
template <typename GraphViewT>
|
||||
inline Status InferFromGraphViewT(const GraphViewT& graph_view);
|
||||
|
||||
bool is_inferred_; // true if it was inferred from the graph
|
||||
int num_frames_; // number of frames present in a graph
|
||||
absl::flat_hash_map<const NodeDef*, std::vector<int>> node_to_frames_;
|
||||
|
@ -14,8 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/utils/frame.h"
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/utils/graph_view.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -23,19 +26,23 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
using GraphTypes =
|
||||
::testing::Types<GraphDef, utils::GraphView, utils::MutableGraphView>;
|
||||
|
||||
template <typename T>
|
||||
class FrameViewTest : public ::testing::Test {
|
||||
protected:
|
||||
static NodeDef CreateNode(const string& name,
|
||||
const std::vector<string>& inputs) {
|
||||
NodeDef CreateNode(const string& name, const std::vector<string>& inputs) {
|
||||
return CreateNode(name, "", "", inputs);
|
||||
}
|
||||
static NodeDef CreateNode(const string& name, const string& op,
|
||||
const std::vector<string>& inputs) {
|
||||
|
||||
NodeDef CreateNode(const string& name, const string& op,
|
||||
const std::vector<string>& inputs) {
|
||||
return CreateNode(name, op, "", inputs);
|
||||
}
|
||||
static NodeDef CreateNode(const string& name, const string& op,
|
||||
const string& frame,
|
||||
const std::vector<string>& inputs) {
|
||||
|
||||
NodeDef CreateNode(const string& name, const string& op, const string& frame,
|
||||
const std::vector<string>& inputs) {
|
||||
NodeDef node;
|
||||
node.set_name(name);
|
||||
if (!op.empty()) {
|
||||
@ -53,30 +60,56 @@ class FrameViewTest : public ::testing::Test {
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(FrameViewTest, NestedLoop) {
|
||||
TYPED_TEST_SUITE(FrameViewTest, GraphTypes);
|
||||
|
||||
template <typename T>
|
||||
void InferFromGraph(FrameView* frame_view, GraphDef* graph, bool valid) {
|
||||
Status status;
|
||||
T graph_view(graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
status = frame_view->InferFromGraphView(graph_view);
|
||||
if (valid) {
|
||||
TF_ASSERT_OK(status);
|
||||
} else {
|
||||
ASSERT_FALSE(status.ok());
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void InferFromGraph<GraphDef>(FrameView* frame_view, GraphDef* graph,
|
||||
bool valid) {
|
||||
Status status = frame_view->InferFromGraph(*graph);
|
||||
if (valid) {
|
||||
TF_ASSERT_OK(status);
|
||||
} else {
|
||||
ASSERT_FALSE(status.ok());
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(FrameViewTest, NestedLoop) {
|
||||
GraphDef graph;
|
||||
// Create a two-level nested loop
|
||||
*graph.add_node() = CreateNode("0", {});
|
||||
*graph.add_node() = CreateNode("1", "Enter", "while/context1", {"0"});
|
||||
*graph.add_node() = CreateNode("2", {"1"});
|
||||
*graph.add_node() = CreateNode("3", "Merge", {"2", "14"});
|
||||
*graph.add_node() = CreateNode("4", {"3"});
|
||||
*graph.add_node() = CreateNode("5", "Switch", {"4"});
|
||||
*graph.add_node() = CreateNode("6", {"5"});
|
||||
*graph.add_node() = CreateNode("7", "Enter", "while/context2", {"6"});
|
||||
*graph.add_node() = CreateNode("8", {"7"});
|
||||
*graph.add_node() = CreateNode("9", "Merge", {"8", "12"});
|
||||
*graph.add_node() = CreateNode("10", {"9"});
|
||||
*graph.add_node() = CreateNode("11", "Switch", {"10"});
|
||||
*graph.add_node() = CreateNode("12", "NextIteration", {"11"});
|
||||
*graph.add_node() = CreateNode("13", "Exit", {"11"});
|
||||
*graph.add_node() = CreateNode("14", "NextIteration", {"13"});
|
||||
*graph.add_node() = CreateNode("15", {"5"});
|
||||
*graph.add_node() = CreateNode("16", "Exit", {"15"});
|
||||
*graph.add_node() = CreateNode("17", {"16"});
|
||||
*graph.add_node() = this->CreateNode("0", {});
|
||||
*graph.add_node() = this->CreateNode("1", "Enter", "while/context1", {"0"});
|
||||
*graph.add_node() = this->CreateNode("2", {"1"});
|
||||
*graph.add_node() = this->CreateNode("3", "Merge", {"2", "14"});
|
||||
*graph.add_node() = this->CreateNode("4", {"3"});
|
||||
*graph.add_node() = this->CreateNode("5", "Switch", {"4"});
|
||||
*graph.add_node() = this->CreateNode("6", {"5"});
|
||||
*graph.add_node() = this->CreateNode("7", "Enter", "while/context2", {"6"});
|
||||
*graph.add_node() = this->CreateNode("8", {"7"});
|
||||
*graph.add_node() = this->CreateNode("9", "Merge", {"8", "12"});
|
||||
*graph.add_node() = this->CreateNode("10", {"9"});
|
||||
*graph.add_node() = this->CreateNode("11", "Switch", {"10"});
|
||||
*graph.add_node() = this->CreateNode("12", "NextIteration", {"11"});
|
||||
*graph.add_node() = this->CreateNode("13", "Exit", {"11"});
|
||||
*graph.add_node() = this->CreateNode("14", "NextIteration", {"13"});
|
||||
*graph.add_node() = this->CreateNode("15", {"5"});
|
||||
*graph.add_node() = this->CreateNode("16", "Exit", {"15"});
|
||||
*graph.add_node() = this->CreateNode("17", {"16"});
|
||||
|
||||
FrameView frame_view;
|
||||
ASSERT_TRUE(frame_view.InferFromGraph(graph).ok());
|
||||
InferFromGraph<TypeParam>(&frame_view, &graph, /*valid=*/true);
|
||||
|
||||
std::unordered_map<string, std::vector<int>> expected = {
|
||||
{"0", {}}, {"1", {0}}, {"2", {0}}, {"3", {0}},
|
||||
@ -93,15 +126,16 @@ TEST_F(FrameViewTest, NestedLoop) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FrameViewTest, MultipleInputsToEnter) {
|
||||
TYPED_TEST(FrameViewTest, MultipleInputsToEnter) {
|
||||
GraphDef graph;
|
||||
*graph.add_node() = CreateNode("0", {});
|
||||
*graph.add_node() = CreateNode("1", {});
|
||||
*graph.add_node() = CreateNode("2", "Enter", "while/context", {"0", "1"});
|
||||
*graph.add_node() = CreateNode("3", "Exit", {"2"});
|
||||
*graph.add_node() = this->CreateNode("0", {});
|
||||
*graph.add_node() = this->CreateNode("1", {});
|
||||
*graph.add_node() =
|
||||
this->CreateNode("2", "Enter", "while/context", {"0", "1"});
|
||||
*graph.add_node() = this->CreateNode("3", "Exit", {"2"});
|
||||
|
||||
FrameView frame_view;
|
||||
ASSERT_TRUE(frame_view.InferFromGraph(graph).ok());
|
||||
InferFromGraph<TypeParam>(&frame_view, &graph, /*valid=*/true);
|
||||
|
||||
std::unordered_map<string, std::vector<int>> expected = {
|
||||
{"0", {}}, {"1", {}}, {"2", {0}}, {"3", {0}}};
|
||||
@ -114,16 +148,16 @@ TEST_F(FrameViewTest, MultipleInputsToEnter) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FrameViewTest, ExitOutput) {
|
||||
TYPED_TEST(FrameViewTest, ExitOutput) {
|
||||
GraphDef graph;
|
||||
*graph.add_node() = CreateNode("0", {});
|
||||
*graph.add_node() = CreateNode("1", "Enter", "while/context", {"0"});
|
||||
*graph.add_node() = CreateNode("2", "Exit", {"1"});
|
||||
*graph.add_node() = CreateNode("3", {});
|
||||
*graph.add_node() = CreateNode("4", {"2", "3"});
|
||||
*graph.add_node() = this->CreateNode("0", {});
|
||||
*graph.add_node() = this->CreateNode("1", "Enter", "while/context", {"0"});
|
||||
*graph.add_node() = this->CreateNode("2", "Exit", {"1"});
|
||||
*graph.add_node() = this->CreateNode("3", {});
|
||||
*graph.add_node() = this->CreateNode("4", {"2", "3"});
|
||||
|
||||
FrameView frame_view;
|
||||
ASSERT_TRUE(frame_view.InferFromGraph(graph).ok());
|
||||
InferFromGraph<TypeParam>(&frame_view, &graph, /*valid=*/true);
|
||||
|
||||
std::unordered_map<string, std::vector<int>> expected = {
|
||||
{"0", {}}, {"1", {0}}, {"2", {0}}, {"3", {}}, {"4", {}}};
|
||||
@ -136,21 +170,21 @@ TEST_F(FrameViewTest, ExitOutput) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FrameViewTest, MultipleEnterNodes) {
|
||||
TYPED_TEST(FrameViewTest, MultipleEnterNodes) {
|
||||
GraphDef graph;
|
||||
*graph.add_node() = CreateNode("0", {});
|
||||
*graph.add_node() = CreateNode("1", "Enter", "while/context", {"0"});
|
||||
*graph.add_node() = CreateNode("2", {"1"});
|
||||
*graph.add_node() = CreateNode("5", {});
|
||||
*graph.add_node() = CreateNode("4", "Enter", "while/context", {"5"});
|
||||
*graph.add_node() = CreateNode("3", {"4", "2"});
|
||||
*graph.add_node() = CreateNode("6", "Merge", {"3", "8"});
|
||||
*graph.add_node() = CreateNode("7", "Switch", {"6"});
|
||||
*graph.add_node() = CreateNode("8", "NextIteration", {"7"});
|
||||
*graph.add_node() = CreateNode("9", "Exit", {"7"});
|
||||
*graph.add_node() = this->CreateNode("0", {});
|
||||
*graph.add_node() = this->CreateNode("1", "Enter", "while/context", {"0"});
|
||||
*graph.add_node() = this->CreateNode("2", {"1"});
|
||||
*graph.add_node() = this->CreateNode("5", {});
|
||||
*graph.add_node() = this->CreateNode("4", "Enter", "while/context", {"5"});
|
||||
*graph.add_node() = this->CreateNode("3", {"4", "2"});
|
||||
*graph.add_node() = this->CreateNode("6", "Merge", {"3", "8"});
|
||||
*graph.add_node() = this->CreateNode("7", "Switch", {"6"});
|
||||
*graph.add_node() = this->CreateNode("8", "NextIteration", {"7"});
|
||||
*graph.add_node() = this->CreateNode("9", "Exit", {"7"});
|
||||
|
||||
FrameView frame_view;
|
||||
ASSERT_TRUE(frame_view.InferFromGraph(graph).ok());
|
||||
InferFromGraph<TypeParam>(&frame_view, &graph, /*valid=*/true);
|
||||
|
||||
std::unordered_map<string, std::vector<int>> expected = {
|
||||
{"0", {}}, {"1", {0}}, {"2", {0}}, {"3", {0}}, {"4", {0}},
|
||||
@ -164,15 +198,15 @@ TEST_F(FrameViewTest, MultipleEnterNodes) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FrameViewTest, ConflictingFrames) {
|
||||
TYPED_TEST(FrameViewTest, ConflictingFrames) {
|
||||
GraphDef graph;
|
||||
*graph.add_node() = CreateNode("0", {});
|
||||
*graph.add_node() = CreateNode("1", "Enter", "while/context1", {"0"});
|
||||
*graph.add_node() = CreateNode("2", "Enter", "while/context2", {"1"});
|
||||
*graph.add_node() = CreateNode("3", {"1", "2"});
|
||||
*graph.add_node() = this->CreateNode("0", {});
|
||||
*graph.add_node() = this->CreateNode("1", "Enter", "while/context1", {"0"});
|
||||
*graph.add_node() = this->CreateNode("2", "Enter", "while/context2", {"1"});
|
||||
*graph.add_node() = this->CreateNode("3", {"1", "2"});
|
||||
|
||||
FrameView frame_view;
|
||||
ASSERT_FALSE(frame_view.InferFromGraph(graph).ok());
|
||||
InferFromGraph<TypeParam>(&frame_view, &graph, /*valid=*/false);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user