diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index 944c676d0a9..ebe96d5dbd6 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -671,32 +671,55 @@ Status GraphExecutionState::OptimizeGraph( if (!(options.callable_options.feed().empty() && options.callable_options.tensor_connection().empty())) { - std::unordered_set feeds; + std::vector feeds; + for (const string& feed : options.callable_options.feed()) { - TensorId id = ParseTensorName(feed); - if (id.second != 0) { - return errors::InvalidArgument("Unsupported feed: ", feed); - } - feeds.emplace(id.first); + feeds.emplace_back(ParseTensorName(feed)); } for (const TensorConnection& tensor_connection : options.callable_options.tensor_connection()) { - TensorId id = ParseTensorName(tensor_connection.to_tensor()); - if (id.second != 0) { - return errors::InvalidArgument("Unsupported feed: ", - tensor_connection.to_tensor()); - } - feeds.emplace(id.first); + feeds.emplace_back(ParseTensorName(tensor_connection.to_tensor())); } - for (const Node* node : graph_->nodes()) { - if (feeds.find(node->name()) == feeds.end()) { + + // For feeds with tensor index 0 we try to find the corresponding node in + // the graph to infer feed data type and shape. + std::unordered_set feed_nodes; + + // For feeds with tensor index larger than 0, we can't infer data type or + // shape from the graph. Currently we only support type and shape + // inference from a small set of node types: Placeholder, Const, etc... + for (const SafeTensorId& feed : feeds) { + if (feed.index() > 0) { + VLOG(3) << "Add undefined feed for: " << feed.ToString(); + Tensor fake_input(DT_INVALID, {0}); + item.feed.emplace_back(feed.ToString(), fake_input); + } else { + VLOG(3) << "Add node for feed inference: " << feed.ToString(); + feed_nodes.insert(feed.node()); continue; } - // Get the type and shape of the feed node. + } + + // For feeds with tensor index == 0 we try to infer data type and tensor + // shape from the graph, by looking at the fed node attributes. + for (const Node* node : graph_->nodes()) { + if (feed_nodes.find(node->name()) == feed_nodes.end()) continue; + + // Try to get the type and shape of the feed node. PartialTensorShape partial_shape; DataType type; - TF_RETURN_IF_ERROR(GetFeedShapeAndTypeFromAttribute( - node->def(), &partial_shape, &type)); + Status st = GetFeedShapeAndTypeFromAttribute(node->def(), + &partial_shape, &type); + + // Failed to get type and shape of the feed node. + if (!st.ok()) { + VLOG(3) << "Failed to infer feed node type and shape." + << " Add undefined feed for: " << node->name(); + Tensor fake_input(DT_INVALID, {0}); + item.feed.emplace_back(node->name(), fake_input); + continue; + } + // If the shape of the placeholder is only partially known, we are free // to set unknown dimensions of its shape to any value we desire. We // choose 0 to minimize the memory impact. Note that this only matters @@ -717,6 +740,8 @@ Status GraphExecutionState::OptimizeGraph( } } + VLOG(3) << "Add feed for: " << node->name() << "; type: " << type + << "; shape: " << shape; Tensor fake_input(type, shape); item.feed.emplace_back(node->name(), fake_input); } diff --git a/tensorflow/core/grappler/costs/graph_memory_test.cc b/tensorflow/core/grappler/costs/graph_memory_test.cc index 95170ba49b7..bcb20098575 100644 --- a/tensorflow/core/grappler/costs/graph_memory_test.cc +++ b/tensorflow/core/grappler/costs/graph_memory_test.cc @@ -59,13 +59,13 @@ TEST_F(GraphMemoryTest, Basic) { for (const auto& t : mem_usage.live_tensors) { tensors.insert(strings::StrCat(t.node, ":", t.output_id)); } - // When the execution of the 'Square' node completes, TF can start executing - // 'Square_1' and release the memory used by 'x'. Since we can't be sure of + // When the execution of the 'Sign' node completes, TF can start executing + // 'Sign_1' and release the memory used by 'x'. Since we can't be sure of // the order in which this takes place, in the worst case the 3 tensors are in // memory. std::set expected; - expected.insert("Square:0"); - expected.insert("Square_1:0"); + expected.insert("Sign:0"); + expected.insert("Sign_1:0"); expected.insert("x:0"); EXPECT_EQ(expected, tensors); } @@ -91,7 +91,7 @@ TEST_F(GraphMemoryTest, UnknownBatchSize) { } std::set expected; expected.insert("Const/Const:0"); - expected.insert("Square:0"); + expected.insert("Sign:0"); expected.insert("x:0"); EXPECT_EQ(expected, tensors); } @@ -114,8 +114,8 @@ TEST_F(GraphMemoryTest, MultiDevice) { cpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id)); } std::set cpu_expected; - cpu_expected.insert("Recv_Square_1_0_on_/CPU_0:0"); - cpu_expected.insert("Square:0"); + cpu_expected.insert("Recv_Sign_1_0_on_/CPU_0:0"); + cpu_expected.insert("Sign:0"); cpu_expected.insert("x:0"); cpu_expected.insert("AddN:0"); EXPECT_EQ(cpu_expected, cpu_tensors); @@ -128,7 +128,7 @@ TEST_F(GraphMemoryTest, MultiDevice) { } std::set gpu_expected; gpu_expected.insert("Recv_AddN_0_on_/GPU_0:0"); - gpu_expected.insert("Square_1:0"); + gpu_expected.insert("Sign_1:0"); gpu_expected.insert("AddN_1:0"); gpu_expected.insert("AddN_3:0"); EXPECT_EQ(gpu_expected, gpu_tensors); @@ -154,8 +154,8 @@ TEST_F(GraphMemoryTest, GpuSwapping) { gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id)); } std::set gpu_expected; - gpu_expected.insert("Square:0"); - gpu_expected.insert("Square_1:0"); + gpu_expected.insert("Sign:0"); + gpu_expected.insert("Sign_1:0"); gpu_expected.insert("AddN:0"); gpu_expected.insert("AddN_1:0"); gpu_expected.insert("AddN_2:0"); diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc index 7be98dc43b4..fd59b7a167a 100644 --- a/tensorflow/core/grappler/graph_view_test.cc +++ b/tensorflow/core/grappler/graph_view_test.cc @@ -134,14 +134,14 @@ TEST_F(GraphViewTest, BasicGraph) { EXPECT_EQ(input.node->name(), "AddN"); EXPECT_EQ(input.port_id, 0); GraphView::OutputPort fanin = graph.GetRegularFanin(input); - EXPECT_EQ(fanin.node->name(), "Square"); + EXPECT_EQ(fanin.node->name(), "Sign"); EXPECT_EQ(fanin.port_id, 0); input = graph.GetInputPort("AddN", 1); EXPECT_EQ(input.node->name(), "AddN"); EXPECT_EQ(input.port_id, 1); fanin = graph.GetRegularFanin(input); - EXPECT_EQ(fanin.node->name(), "Square_1"); + EXPECT_EQ(fanin.node->name(), "Sign_1"); EXPECT_EQ(fanin.port_id, 0); GraphView::OutputPort output = graph.GetOutputPort("AddN", 0); @@ -169,7 +169,7 @@ TEST_F(GraphViewTest, BasicGraph) { EXPECT_EQ(fanouts, expected_fanouts); absl::flat_hash_set fanins; - absl::flat_hash_set expected_fanins = {"Square_1:0", "Square:0"}; + absl::flat_hash_set expected_fanins = {"Sign_1:0", "Sign:0"}; for (const auto& fi : graph.GetFanins(*add_node, false)) { fanins.insert(absl::StrCat(fi.node->name(), ":", fi.port_id)); } diff --git a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc index ec54bd5c759..9ce0284369a 100644 --- a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc +++ b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc @@ -47,11 +47,11 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size, std::vector this_stage; for (int j = 0; j < width; j++) { if (last_stage.size() == 1) { - Output unary_op = Square( - s.WithDevice( - device_names[use_multiple_devices ? j % device_names.size() - : 0]), - last_stage[0]); + Output unary_op = + Sign(s.WithDevice( + device_names[use_multiple_devices ? j % device_names.size() + : 0]), + last_stage[0]); this_stage.push_back(unary_op); } else { Output combine = diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 8be53aa08e3..520346b0166 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2253,7 +2253,7 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage { ~FoldTransposeIntoMatMul() override = default; bool IsSupported(const NodeDef* node) const override { - return IsAnyMatMul(*node); + return IsAnyMatMul(*node) && !IsInPreserveSet(*node); } Status TrySimplify(NodeDef* node, string* simplified_node_name) override { diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 477c284d44c..8b403b17841 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -736,7 +736,7 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) { auto identity = ops::Identity(s.WithOpName("identity"), matmul); GrapplerItem item; - item.fetch = {"matmul"}; + item.fetch = {"identity"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); @@ -795,9 +795,10 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) { Output trans_a = ops::ConjugateTranspose(s.WithOpName("trans_a"), a, perm); Output trans_b = ops::ConjugateTranspose(s.WithOpName("trans_b"), b, perm); Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b); + Output identity = ops::Identity(s.WithOpName("identity"), matmul); GrapplerItem item; - item.fetch = {"matmul"}; + item.fetch = {"identity"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); @@ -808,7 +809,7 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) { OptimizeTwice(&optimizer, &item, &output); NodeMap node_map(&output); - EXPECT_EQ(output.node_size(), 11); + EXPECT_EQ(output.node_size(), 12); const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul"; const string optimized_name = absl::StrCat(p, "_", "matmul"); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index b4ebc888c30..a0ec3714070 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -2336,6 +2336,13 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) { node_map_->NodeExists(axis_node_name)) { return false; } + + // It's unsafe to add a control dependency on the feed node, because it might + // have been never executed otherwiwise. + if (feed_nodes_.find(NodeName(node->input(0))) != feed_nodes_.end()) { + return false; + } + // Create constant axis node. Tensor axis_t(DT_INT32, TensorShape({})); const int axis = diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index 840d0d17068..b0478525d39 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -280,6 +280,13 @@ class FunctionOptimizerContext { const GraphView& graph_view() const { return graph_view_; } + bool IsFeedNode(const string& node_name) const { + return absl::c_any_of( + item_->feed, [&](const std::pair& feed) { + return ParseTensorName(feed.first).node() == node_name; + }); + } + bool IsFetchNode(const string& node_name) const { return absl::c_any_of(item_->fetch, [&](const string& fetch) { return ParseTensorName(fetch).node() == node_name; @@ -1445,9 +1452,9 @@ Status FunctionOptimizer::RunFunctionOptimizerPass( // Do not specialize if function has custom gradient or marked nospecialize. const string grad_func = ctx.function_library().FindGradient(func_name); - const bool no_specialize = !grad_func.empty() || - MarkedNoSpecialize(*func) || - MarkedForXlaCompilation(node); + const bool no_specialize = + !grad_func.empty() || ctx.IsFeedNode(node.name()) || + MarkedNoSpecialize(*func) || MarkedForXlaCompilation(node); if (specialization_worthy && !no_specialize) { // TODO(ezhulenev): Specialize function call if input has a known shape. diff --git a/tensorflow/core/grappler/optimizers/static_schedule_test.cc b/tensorflow/core/grappler/optimizers/static_schedule_test.cc index d766cfdeee3..e8ee7db16a6 100644 --- a/tensorflow/core/grappler/optimizers/static_schedule_test.cc +++ b/tensorflow/core/grappler/optimizers/static_schedule_test.cc @@ -111,8 +111,8 @@ TEST_F(StaticScheduleTest, BasicGraph) { std::vector ordered_node_names = GetOrderedNodeNames(completion_times); EXPECT_EQ(ordered_node_names, - (std::vector{"Const/Const", "x", "Square", "Square_1", - "Square_2", "Square_3", "y"})); + (std::vector{"Const/Const", "x", "Sign", "Sign_1", + "Sign_2", "Sign_3", "y"})); } TEST_F(StaticScheduleTest, BasicGraphWithCtrlDependencies) { @@ -192,8 +192,8 @@ TEST_F(StaticScheduleTest, RequiredTimes) { std::vector ordered_node_names = GetOrderedNodeNames(required_times); EXPECT_EQ(ordered_node_names, - (std::vector{"Const/Const", "x", "Square", "Square_1", - "Square_2", "Square_3", "y"})); + (std::vector{"Const/Const", "x", "Sign", "Sign_1", + "Sign_2", "Sign_3", "y"})); } } // namespace