[Grappler] Allow feeding non-zero index tensors to Grappler optimizer
PiperOrigin-RevId: 308665214 Change-Id: I3b449e1dd06f5134de358466227734a8b17673f4
This commit is contained in:
parent
c18a2ad398
commit
f976ccd74d
@ -671,32 +671,55 @@ Status GraphExecutionState::OptimizeGraph(
|
||||
|
||||
if (!(options.callable_options.feed().empty() &&
|
||||
options.callable_options.tensor_connection().empty())) {
|
||||
std::unordered_set<string> feeds;
|
||||
std::vector<SafeTensorId> feeds;
|
||||
|
||||
for (const string& feed : options.callable_options.feed()) {
|
||||
TensorId id = ParseTensorName(feed);
|
||||
if (id.second != 0) {
|
||||
return errors::InvalidArgument("Unsupported feed: ", feed);
|
||||
}
|
||||
feeds.emplace(id.first);
|
||||
feeds.emplace_back(ParseTensorName(feed));
|
||||
}
|
||||
for (const TensorConnection& tensor_connection :
|
||||
options.callable_options.tensor_connection()) {
|
||||
TensorId id = ParseTensorName(tensor_connection.to_tensor());
|
||||
if (id.second != 0) {
|
||||
return errors::InvalidArgument("Unsupported feed: ",
|
||||
tensor_connection.to_tensor());
|
||||
}
|
||||
feeds.emplace(id.first);
|
||||
feeds.emplace_back(ParseTensorName(tensor_connection.to_tensor()));
|
||||
}
|
||||
for (const Node* node : graph_->nodes()) {
|
||||
if (feeds.find(node->name()) == feeds.end()) {
|
||||
|
||||
// For feeds with tensor index 0 we try to find the corresponding node in
|
||||
// the graph to infer feed data type and shape.
|
||||
std::unordered_set<std::string> feed_nodes;
|
||||
|
||||
// For feeds with tensor index larger than 0, we can't infer data type or
|
||||
// shape from the graph. Currently we only support type and shape
|
||||
// inference from a small set of node types: Placeholder, Const, etc...
|
||||
for (const SafeTensorId& feed : feeds) {
|
||||
if (feed.index() > 0) {
|
||||
VLOG(3) << "Add undefined feed for: " << feed.ToString();
|
||||
Tensor fake_input(DT_INVALID, {0});
|
||||
item.feed.emplace_back(feed.ToString(), fake_input);
|
||||
} else {
|
||||
VLOG(3) << "Add node for feed inference: " << feed.ToString();
|
||||
feed_nodes.insert(feed.node());
|
||||
continue;
|
||||
}
|
||||
// Get the type and shape of the feed node.
|
||||
}
|
||||
|
||||
// For feeds with tensor index == 0 we try to infer data type and tensor
|
||||
// shape from the graph, by looking at the fed node attributes.
|
||||
for (const Node* node : graph_->nodes()) {
|
||||
if (feed_nodes.find(node->name()) == feed_nodes.end()) continue;
|
||||
|
||||
// Try to get the type and shape of the feed node.
|
||||
PartialTensorShape partial_shape;
|
||||
DataType type;
|
||||
TF_RETURN_IF_ERROR(GetFeedShapeAndTypeFromAttribute(
|
||||
node->def(), &partial_shape, &type));
|
||||
Status st = GetFeedShapeAndTypeFromAttribute(node->def(),
|
||||
&partial_shape, &type);
|
||||
|
||||
// Failed to get type and shape of the feed node.
|
||||
if (!st.ok()) {
|
||||
VLOG(3) << "Failed to infer feed node type and shape."
|
||||
<< " Add undefined feed for: " << node->name();
|
||||
Tensor fake_input(DT_INVALID, {0});
|
||||
item.feed.emplace_back(node->name(), fake_input);
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the shape of the placeholder is only partially known, we are free
|
||||
// to set unknown dimensions of its shape to any value we desire. We
|
||||
// choose 0 to minimize the memory impact. Note that this only matters
|
||||
@ -717,6 +740,8 @@ Status GraphExecutionState::OptimizeGraph(
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(3) << "Add feed for: " << node->name() << "; type: " << type
|
||||
<< "; shape: " << shape;
|
||||
Tensor fake_input(type, shape);
|
||||
item.feed.emplace_back(node->name(), fake_input);
|
||||
}
|
||||
|
@ -59,13 +59,13 @@ TEST_F(GraphMemoryTest, Basic) {
|
||||
for (const auto& t : mem_usage.live_tensors) {
|
||||
tensors.insert(strings::StrCat(t.node, ":", t.output_id));
|
||||
}
|
||||
// When the execution of the 'Square' node completes, TF can start executing
|
||||
// 'Square_1' and release the memory used by 'x'. Since we can't be sure of
|
||||
// When the execution of the 'Sign' node completes, TF can start executing
|
||||
// 'Sign_1' and release the memory used by 'x'. Since we can't be sure of
|
||||
// the order in which this takes place, in the worst case the 3 tensors are in
|
||||
// memory.
|
||||
std::set<string> expected;
|
||||
expected.insert("Square:0");
|
||||
expected.insert("Square_1:0");
|
||||
expected.insert("Sign:0");
|
||||
expected.insert("Sign_1:0");
|
||||
expected.insert("x:0");
|
||||
EXPECT_EQ(expected, tensors);
|
||||
}
|
||||
@ -91,7 +91,7 @@ TEST_F(GraphMemoryTest, UnknownBatchSize) {
|
||||
}
|
||||
std::set<string> expected;
|
||||
expected.insert("Const/Const:0");
|
||||
expected.insert("Square:0");
|
||||
expected.insert("Sign:0");
|
||||
expected.insert("x:0");
|
||||
EXPECT_EQ(expected, tensors);
|
||||
}
|
||||
@ -114,8 +114,8 @@ TEST_F(GraphMemoryTest, MultiDevice) {
|
||||
cpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
|
||||
}
|
||||
std::set<string> cpu_expected;
|
||||
cpu_expected.insert("Recv_Square_1_0_on_/CPU_0:0");
|
||||
cpu_expected.insert("Square:0");
|
||||
cpu_expected.insert("Recv_Sign_1_0_on_/CPU_0:0");
|
||||
cpu_expected.insert("Sign:0");
|
||||
cpu_expected.insert("x:0");
|
||||
cpu_expected.insert("AddN:0");
|
||||
EXPECT_EQ(cpu_expected, cpu_tensors);
|
||||
@ -128,7 +128,7 @@ TEST_F(GraphMemoryTest, MultiDevice) {
|
||||
}
|
||||
std::set<string> gpu_expected;
|
||||
gpu_expected.insert("Recv_AddN_0_on_/GPU_0:0");
|
||||
gpu_expected.insert("Square_1:0");
|
||||
gpu_expected.insert("Sign_1:0");
|
||||
gpu_expected.insert("AddN_1:0");
|
||||
gpu_expected.insert("AddN_3:0");
|
||||
EXPECT_EQ(gpu_expected, gpu_tensors);
|
||||
@ -154,8 +154,8 @@ TEST_F(GraphMemoryTest, GpuSwapping) {
|
||||
gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
|
||||
}
|
||||
std::set<string> gpu_expected;
|
||||
gpu_expected.insert("Square:0");
|
||||
gpu_expected.insert("Square_1:0");
|
||||
gpu_expected.insert("Sign:0");
|
||||
gpu_expected.insert("Sign_1:0");
|
||||
gpu_expected.insert("AddN:0");
|
||||
gpu_expected.insert("AddN_1:0");
|
||||
gpu_expected.insert("AddN_2:0");
|
||||
|
@ -134,14 +134,14 @@ TEST_F(GraphViewTest, BasicGraph) {
|
||||
EXPECT_EQ(input.node->name(), "AddN");
|
||||
EXPECT_EQ(input.port_id, 0);
|
||||
GraphView::OutputPort fanin = graph.GetRegularFanin(input);
|
||||
EXPECT_EQ(fanin.node->name(), "Square");
|
||||
EXPECT_EQ(fanin.node->name(), "Sign");
|
||||
EXPECT_EQ(fanin.port_id, 0);
|
||||
|
||||
input = graph.GetInputPort("AddN", 1);
|
||||
EXPECT_EQ(input.node->name(), "AddN");
|
||||
EXPECT_EQ(input.port_id, 1);
|
||||
fanin = graph.GetRegularFanin(input);
|
||||
EXPECT_EQ(fanin.node->name(), "Square_1");
|
||||
EXPECT_EQ(fanin.node->name(), "Sign_1");
|
||||
EXPECT_EQ(fanin.port_id, 0);
|
||||
|
||||
GraphView::OutputPort output = graph.GetOutputPort("AddN", 0);
|
||||
@ -169,7 +169,7 @@ TEST_F(GraphViewTest, BasicGraph) {
|
||||
EXPECT_EQ(fanouts, expected_fanouts);
|
||||
|
||||
absl::flat_hash_set<string> fanins;
|
||||
absl::flat_hash_set<string> expected_fanins = {"Square_1:0", "Square:0"};
|
||||
absl::flat_hash_set<string> expected_fanins = {"Sign_1:0", "Sign:0"};
|
||||
for (const auto& fi : graph.GetFanins(*add_node, false)) {
|
||||
fanins.insert(absl::StrCat(fi.node->name(), ":", fi.port_id));
|
||||
}
|
||||
|
@ -47,11 +47,11 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
|
||||
std::vector<Output> this_stage;
|
||||
for (int j = 0; j < width; j++) {
|
||||
if (last_stage.size() == 1) {
|
||||
Output unary_op = Square(
|
||||
s.WithDevice(
|
||||
device_names[use_multiple_devices ? j % device_names.size()
|
||||
: 0]),
|
||||
last_stage[0]);
|
||||
Output unary_op =
|
||||
Sign(s.WithDevice(
|
||||
device_names[use_multiple_devices ? j % device_names.size()
|
||||
: 0]),
|
||||
last_stage[0]);
|
||||
this_stage.push_back(unary_op);
|
||||
} else {
|
||||
Output combine =
|
||||
|
@ -2253,7 +2253,7 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
|
||||
~FoldTransposeIntoMatMul() override = default;
|
||||
|
||||
bool IsSupported(const NodeDef* node) const override {
|
||||
return IsAnyMatMul(*node);
|
||||
return IsAnyMatMul(*node) && !IsInPreserveSet(*node);
|
||||
}
|
||||
|
||||
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
|
||||
|
@ -736,7 +736,7 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
|
||||
auto identity = ops::Identity(s.WithOpName("identity"), matmul);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"matmul"};
|
||||
item.fetch = {"identity"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
@ -795,9 +795,10 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
|
||||
Output trans_a = ops::ConjugateTranspose(s.WithOpName("trans_a"), a, perm);
|
||||
Output trans_b = ops::ConjugateTranspose(s.WithOpName("trans_b"), b, perm);
|
||||
Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
|
||||
Output identity = ops::Identity(s.WithOpName("identity"), matmul);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"matmul"};
|
||||
item.fetch = {"identity"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
@ -808,7 +809,7 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
|
||||
OptimizeTwice(&optimizer, &item, &output);
|
||||
|
||||
NodeMap node_map(&output);
|
||||
EXPECT_EQ(output.node_size(), 11);
|
||||
EXPECT_EQ(output.node_size(), 12);
|
||||
|
||||
const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
|
||||
const string optimized_name = absl::StrCat(p, "_", "matmul");
|
||||
|
@ -2336,6 +2336,13 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
|
||||
node_map_->NodeExists(axis_node_name)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// It's unsafe to add a control dependency on the feed node, because it might
|
||||
// have been never executed otherwiwise.
|
||||
if (feed_nodes_.find(NodeName(node->input(0))) != feed_nodes_.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Create constant axis node.
|
||||
Tensor axis_t(DT_INT32, TensorShape({}));
|
||||
const int axis =
|
||||
|
@ -280,6 +280,13 @@ class FunctionOptimizerContext {
|
||||
|
||||
const GraphView& graph_view() const { return graph_view_; }
|
||||
|
||||
bool IsFeedNode(const string& node_name) const {
|
||||
return absl::c_any_of(
|
||||
item_->feed, [&](const std::pair<std::string, Tensor>& feed) {
|
||||
return ParseTensorName(feed.first).node() == node_name;
|
||||
});
|
||||
}
|
||||
|
||||
bool IsFetchNode(const string& node_name) const {
|
||||
return absl::c_any_of(item_->fetch, [&](const string& fetch) {
|
||||
return ParseTensorName(fetch).node() == node_name;
|
||||
@ -1445,9 +1452,9 @@ Status FunctionOptimizer::RunFunctionOptimizerPass(
|
||||
|
||||
// Do not specialize if function has custom gradient or marked nospecialize.
|
||||
const string grad_func = ctx.function_library().FindGradient(func_name);
|
||||
const bool no_specialize = !grad_func.empty() ||
|
||||
MarkedNoSpecialize(*func) ||
|
||||
MarkedForXlaCompilation(node);
|
||||
const bool no_specialize =
|
||||
!grad_func.empty() || ctx.IsFeedNode(node.name()) ||
|
||||
MarkedNoSpecialize(*func) || MarkedForXlaCompilation(node);
|
||||
|
||||
if (specialization_worthy && !no_specialize) {
|
||||
// TODO(ezhulenev): Specialize function call if input has a known shape.
|
||||
|
@ -111,8 +111,8 @@ TEST_F(StaticScheduleTest, BasicGraph) {
|
||||
std::vector<std::string> ordered_node_names =
|
||||
GetOrderedNodeNames(completion_times);
|
||||
EXPECT_EQ(ordered_node_names,
|
||||
(std::vector<std::string>{"Const/Const", "x", "Square", "Square_1",
|
||||
"Square_2", "Square_3", "y"}));
|
||||
(std::vector<std::string>{"Const/Const", "x", "Sign", "Sign_1",
|
||||
"Sign_2", "Sign_3", "y"}));
|
||||
}
|
||||
|
||||
TEST_F(StaticScheduleTest, BasicGraphWithCtrlDependencies) {
|
||||
@ -192,8 +192,8 @@ TEST_F(StaticScheduleTest, RequiredTimes) {
|
||||
std::vector<std::string> ordered_node_names =
|
||||
GetOrderedNodeNames(required_times);
|
||||
EXPECT_EQ(ordered_node_names,
|
||||
(std::vector<std::string>{"Const/Const", "x", "Square", "Square_1",
|
||||
"Square_2", "Square_3", "y"}));
|
||||
(std::vector<std::string>{"Const/Const", "x", "Sign", "Sign_1",
|
||||
"Sign_2", "Sign_3", "y"}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user