[Grappler] Allow feeding non-zero index tensors to Grappler optimizer

PiperOrigin-RevId: 308665214
Change-Id: I3b449e1dd06f5134de358466227734a8b17673f4
This commit is contained in:
Eugene Zhulenev 2020-04-27 11:42:19 -07:00 committed by TensorFlower Gardener
parent c18a2ad398
commit f976ccd74d
9 changed files with 86 additions and 46 deletions

View File

@ -671,32 +671,55 @@ Status GraphExecutionState::OptimizeGraph(
if (!(options.callable_options.feed().empty() &&
options.callable_options.tensor_connection().empty())) {
std::unordered_set<string> feeds;
std::vector<SafeTensorId> feeds;
for (const string& feed : options.callable_options.feed()) {
TensorId id = ParseTensorName(feed);
if (id.second != 0) {
return errors::InvalidArgument("Unsupported feed: ", feed);
}
feeds.emplace(id.first);
feeds.emplace_back(ParseTensorName(feed));
}
for (const TensorConnection& tensor_connection :
options.callable_options.tensor_connection()) {
TensorId id = ParseTensorName(tensor_connection.to_tensor());
if (id.second != 0) {
return errors::InvalidArgument("Unsupported feed: ",
tensor_connection.to_tensor());
}
feeds.emplace(id.first);
feeds.emplace_back(ParseTensorName(tensor_connection.to_tensor()));
}
for (const Node* node : graph_->nodes()) {
if (feeds.find(node->name()) == feeds.end()) {
// For feeds with tensor index 0 we try to find the corresponding node in
// the graph to infer feed data type and shape.
std::unordered_set<std::string> feed_nodes;
// For feeds with tensor index larger than 0, we can't infer data type or
// shape from the graph. Currently we only support type and shape
// inference from a small set of node types: Placeholder, Const, etc...
for (const SafeTensorId& feed : feeds) {
if (feed.index() > 0) {
VLOG(3) << "Add undefined feed for: " << feed.ToString();
Tensor fake_input(DT_INVALID, {0});
item.feed.emplace_back(feed.ToString(), fake_input);
} else {
VLOG(3) << "Add node for feed inference: " << feed.ToString();
feed_nodes.insert(feed.node());
continue;
}
// Get the type and shape of the feed node.
}
// For feeds with tensor index == 0 we try to infer data type and tensor
// shape from the graph, by looking at the fed node attributes.
for (const Node* node : graph_->nodes()) {
if (feed_nodes.find(node->name()) == feed_nodes.end()) continue;
// Try to get the type and shape of the feed node.
PartialTensorShape partial_shape;
DataType type;
TF_RETURN_IF_ERROR(GetFeedShapeAndTypeFromAttribute(
node->def(), &partial_shape, &type));
Status st = GetFeedShapeAndTypeFromAttribute(node->def(),
&partial_shape, &type);
// Failed to get type and shape of the feed node.
if (!st.ok()) {
VLOG(3) << "Failed to infer feed node type and shape."
<< " Add undefined feed for: " << node->name();
Tensor fake_input(DT_INVALID, {0});
item.feed.emplace_back(node->name(), fake_input);
continue;
}
// If the shape of the placeholder is only partially known, we are free
// to set unknown dimensions of its shape to any value we desire. We
// choose 0 to minimize the memory impact. Note that this only matters
@ -717,6 +740,8 @@ Status GraphExecutionState::OptimizeGraph(
}
}
VLOG(3) << "Add feed for: " << node->name() << "; type: " << type
<< "; shape: " << shape;
Tensor fake_input(type, shape);
item.feed.emplace_back(node->name(), fake_input);
}

View File

@ -59,13 +59,13 @@ TEST_F(GraphMemoryTest, Basic) {
for (const auto& t : mem_usage.live_tensors) {
tensors.insert(strings::StrCat(t.node, ":", t.output_id));
}
// When the execution of the 'Square' node completes, TF can start executing
// 'Square_1' and release the memory used by 'x'. Since we can't be sure of
// When the execution of the 'Sign' node completes, TF can start executing
// 'Sign_1' and release the memory used by 'x'. Since we can't be sure of
// the order in which this takes place, in the worst case the 3 tensors are in
// memory.
std::set<string> expected;
expected.insert("Square:0");
expected.insert("Square_1:0");
expected.insert("Sign:0");
expected.insert("Sign_1:0");
expected.insert("x:0");
EXPECT_EQ(expected, tensors);
}
@ -91,7 +91,7 @@ TEST_F(GraphMemoryTest, UnknownBatchSize) {
}
std::set<string> expected;
expected.insert("Const/Const:0");
expected.insert("Square:0");
expected.insert("Sign:0");
expected.insert("x:0");
EXPECT_EQ(expected, tensors);
}
@ -114,8 +114,8 @@ TEST_F(GraphMemoryTest, MultiDevice) {
cpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
}
std::set<string> cpu_expected;
cpu_expected.insert("Recv_Square_1_0_on_/CPU_0:0");
cpu_expected.insert("Square:0");
cpu_expected.insert("Recv_Sign_1_0_on_/CPU_0:0");
cpu_expected.insert("Sign:0");
cpu_expected.insert("x:0");
cpu_expected.insert("AddN:0");
EXPECT_EQ(cpu_expected, cpu_tensors);
@ -128,7 +128,7 @@ TEST_F(GraphMemoryTest, MultiDevice) {
}
std::set<string> gpu_expected;
gpu_expected.insert("Recv_AddN_0_on_/GPU_0:0");
gpu_expected.insert("Square_1:0");
gpu_expected.insert("Sign_1:0");
gpu_expected.insert("AddN_1:0");
gpu_expected.insert("AddN_3:0");
EXPECT_EQ(gpu_expected, gpu_tensors);
@ -154,8 +154,8 @@ TEST_F(GraphMemoryTest, GpuSwapping) {
gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
}
std::set<string> gpu_expected;
gpu_expected.insert("Square:0");
gpu_expected.insert("Square_1:0");
gpu_expected.insert("Sign:0");
gpu_expected.insert("Sign_1:0");
gpu_expected.insert("AddN:0");
gpu_expected.insert("AddN_1:0");
gpu_expected.insert("AddN_2:0");

View File

@ -134,14 +134,14 @@ TEST_F(GraphViewTest, BasicGraph) {
EXPECT_EQ(input.node->name(), "AddN");
EXPECT_EQ(input.port_id, 0);
GraphView::OutputPort fanin = graph.GetRegularFanin(input);
EXPECT_EQ(fanin.node->name(), "Square");
EXPECT_EQ(fanin.node->name(), "Sign");
EXPECT_EQ(fanin.port_id, 0);
input = graph.GetInputPort("AddN", 1);
EXPECT_EQ(input.node->name(), "AddN");
EXPECT_EQ(input.port_id, 1);
fanin = graph.GetRegularFanin(input);
EXPECT_EQ(fanin.node->name(), "Square_1");
EXPECT_EQ(fanin.node->name(), "Sign_1");
EXPECT_EQ(fanin.port_id, 0);
GraphView::OutputPort output = graph.GetOutputPort("AddN", 0);
@ -169,7 +169,7 @@ TEST_F(GraphViewTest, BasicGraph) {
EXPECT_EQ(fanouts, expected_fanouts);
absl::flat_hash_set<string> fanins;
absl::flat_hash_set<string> expected_fanins = {"Square_1:0", "Square:0"};
absl::flat_hash_set<string> expected_fanins = {"Sign_1:0", "Sign:0"};
for (const auto& fi : graph.GetFanins(*add_node, false)) {
fanins.insert(absl::StrCat(fi.node->name(), ":", fi.port_id));
}

View File

@ -47,11 +47,11 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
std::vector<Output> this_stage;
for (int j = 0; j < width; j++) {
if (last_stage.size() == 1) {
Output unary_op = Square(
s.WithDevice(
device_names[use_multiple_devices ? j % device_names.size()
: 0]),
last_stage[0]);
Output unary_op =
Sign(s.WithDevice(
device_names[use_multiple_devices ? j % device_names.size()
: 0]),
last_stage[0]);
this_stage.push_back(unary_op);
} else {
Output combine =

View File

@ -2253,7 +2253,7 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
~FoldTransposeIntoMatMul() override = default;
bool IsSupported(const NodeDef* node) const override {
return IsAnyMatMul(*node);
return IsAnyMatMul(*node) && !IsInPreserveSet(*node);
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {

View File

@ -736,7 +736,7 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
auto identity = ops::Identity(s.WithOpName("identity"), matmul);
GrapplerItem item;
item.fetch = {"matmul"};
item.fetch = {"identity"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
@ -795,9 +795,10 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
Output trans_a = ops::ConjugateTranspose(s.WithOpName("trans_a"), a, perm);
Output trans_b = ops::ConjugateTranspose(s.WithOpName("trans_b"), b, perm);
Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
Output identity = ops::Identity(s.WithOpName("identity"), matmul);
GrapplerItem item;
item.fetch = {"matmul"};
item.fetch = {"identity"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
@ -808,7 +809,7 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
OptimizeTwice(&optimizer, &item, &output);
NodeMap node_map(&output);
EXPECT_EQ(output.node_size(), 11);
EXPECT_EQ(output.node_size(), 12);
const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
const string optimized_name = absl::StrCat(p, "_", "matmul");

View File

@ -2336,6 +2336,13 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
node_map_->NodeExists(axis_node_name)) {
return false;
}
// It's unsafe to add a control dependency on the feed node, because it might
// have been never executed otherwiwise.
if (feed_nodes_.find(NodeName(node->input(0))) != feed_nodes_.end()) {
return false;
}
// Create constant axis node.
Tensor axis_t(DT_INT32, TensorShape({}));
const int axis =

View File

@ -280,6 +280,13 @@ class FunctionOptimizerContext {
const GraphView& graph_view() const { return graph_view_; }
bool IsFeedNode(const string& node_name) const {
return absl::c_any_of(
item_->feed, [&](const std::pair<std::string, Tensor>& feed) {
return ParseTensorName(feed.first).node() == node_name;
});
}
bool IsFetchNode(const string& node_name) const {
return absl::c_any_of(item_->fetch, [&](const string& fetch) {
return ParseTensorName(fetch).node() == node_name;
@ -1445,9 +1452,9 @@ Status FunctionOptimizer::RunFunctionOptimizerPass(
// Do not specialize if function has custom gradient or marked nospecialize.
const string grad_func = ctx.function_library().FindGradient(func_name);
const bool no_specialize = !grad_func.empty() ||
MarkedNoSpecialize(*func) ||
MarkedForXlaCompilation(node);
const bool no_specialize =
!grad_func.empty() || ctx.IsFeedNode(node.name()) ||
MarkedNoSpecialize(*func) || MarkedForXlaCompilation(node);
if (specialization_worthy && !no_specialize) {
// TODO(ezhulenev): Specialize function call if input has a known shape.

View File

@ -111,8 +111,8 @@ TEST_F(StaticScheduleTest, BasicGraph) {
std::vector<std::string> ordered_node_names =
GetOrderedNodeNames(completion_times);
EXPECT_EQ(ordered_node_names,
(std::vector<std::string>{"Const/Const", "x", "Square", "Square_1",
"Square_2", "Square_3", "y"}));
(std::vector<std::string>{"Const/Const", "x", "Sign", "Sign_1",
"Sign_2", "Sign_3", "y"}));
}
TEST_F(StaticScheduleTest, BasicGraphWithCtrlDependencies) {
@ -192,8 +192,8 @@ TEST_F(StaticScheduleTest, RequiredTimes) {
std::vector<std::string> ordered_node_names =
GetOrderedNodeNames(required_times);
EXPECT_EQ(ordered_node_names,
(std::vector<std::string>{"Const/Const", "x", "Square", "Square_1",
"Square_2", "Square_3", "y"}));
(std::vector<std::string>{"Const/Const", "x", "Sign", "Sign_1",
"Sign_2", "Sign_3", "y"}));
}
} // namespace