[Grappler] Allow feeding non-zero index tensors to Grappler optimizer
PiperOrigin-RevId: 308665214 Change-Id: I3b449e1dd06f5134de358466227734a8b17673f4
This commit is contained in:
parent
c18a2ad398
commit
f976ccd74d
@ -671,32 +671,55 @@ Status GraphExecutionState::OptimizeGraph(
|
|||||||
|
|
||||||
if (!(options.callable_options.feed().empty() &&
|
if (!(options.callable_options.feed().empty() &&
|
||||||
options.callable_options.tensor_connection().empty())) {
|
options.callable_options.tensor_connection().empty())) {
|
||||||
std::unordered_set<string> feeds;
|
std::vector<SafeTensorId> feeds;
|
||||||
|
|
||||||
for (const string& feed : options.callable_options.feed()) {
|
for (const string& feed : options.callable_options.feed()) {
|
||||||
TensorId id = ParseTensorName(feed);
|
feeds.emplace_back(ParseTensorName(feed));
|
||||||
if (id.second != 0) {
|
|
||||||
return errors::InvalidArgument("Unsupported feed: ", feed);
|
|
||||||
}
|
|
||||||
feeds.emplace(id.first);
|
|
||||||
}
|
}
|
||||||
for (const TensorConnection& tensor_connection :
|
for (const TensorConnection& tensor_connection :
|
||||||
options.callable_options.tensor_connection()) {
|
options.callable_options.tensor_connection()) {
|
||||||
TensorId id = ParseTensorName(tensor_connection.to_tensor());
|
feeds.emplace_back(ParseTensorName(tensor_connection.to_tensor()));
|
||||||
if (id.second != 0) {
|
|
||||||
return errors::InvalidArgument("Unsupported feed: ",
|
|
||||||
tensor_connection.to_tensor());
|
|
||||||
}
|
}
|
||||||
feeds.emplace(id.first);
|
|
||||||
}
|
// For feeds with tensor index 0 we try to find the corresponding node in
|
||||||
for (const Node* node : graph_->nodes()) {
|
// the graph to infer feed data type and shape.
|
||||||
if (feeds.find(node->name()) == feeds.end()) {
|
std::unordered_set<std::string> feed_nodes;
|
||||||
|
|
||||||
|
// For feeds with tensor index larger than 0, we can't infer data type or
|
||||||
|
// shape from the graph. Currently we only support type and shape
|
||||||
|
// inference from a small set of node types: Placeholder, Const, etc...
|
||||||
|
for (const SafeTensorId& feed : feeds) {
|
||||||
|
if (feed.index() > 0) {
|
||||||
|
VLOG(3) << "Add undefined feed for: " << feed.ToString();
|
||||||
|
Tensor fake_input(DT_INVALID, {0});
|
||||||
|
item.feed.emplace_back(feed.ToString(), fake_input);
|
||||||
|
} else {
|
||||||
|
VLOG(3) << "Add node for feed inference: " << feed.ToString();
|
||||||
|
feed_nodes.insert(feed.node());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// Get the type and shape of the feed node.
|
}
|
||||||
|
|
||||||
|
// For feeds with tensor index == 0 we try to infer data type and tensor
|
||||||
|
// shape from the graph, by looking at the fed node attributes.
|
||||||
|
for (const Node* node : graph_->nodes()) {
|
||||||
|
if (feed_nodes.find(node->name()) == feed_nodes.end()) continue;
|
||||||
|
|
||||||
|
// Try to get the type and shape of the feed node.
|
||||||
PartialTensorShape partial_shape;
|
PartialTensorShape partial_shape;
|
||||||
DataType type;
|
DataType type;
|
||||||
TF_RETURN_IF_ERROR(GetFeedShapeAndTypeFromAttribute(
|
Status st = GetFeedShapeAndTypeFromAttribute(node->def(),
|
||||||
node->def(), &partial_shape, &type));
|
&partial_shape, &type);
|
||||||
|
|
||||||
|
// Failed to get type and shape of the feed node.
|
||||||
|
if (!st.ok()) {
|
||||||
|
VLOG(3) << "Failed to infer feed node type and shape."
|
||||||
|
<< " Add undefined feed for: " << node->name();
|
||||||
|
Tensor fake_input(DT_INVALID, {0});
|
||||||
|
item.feed.emplace_back(node->name(), fake_input);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// If the shape of the placeholder is only partially known, we are free
|
// If the shape of the placeholder is only partially known, we are free
|
||||||
// to set unknown dimensions of its shape to any value we desire. We
|
// to set unknown dimensions of its shape to any value we desire. We
|
||||||
// choose 0 to minimize the memory impact. Note that this only matters
|
// choose 0 to minimize the memory impact. Note that this only matters
|
||||||
@ -717,6 +740,8 @@ Status GraphExecutionState::OptimizeGraph(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
VLOG(3) << "Add feed for: " << node->name() << "; type: " << type
|
||||||
|
<< "; shape: " << shape;
|
||||||
Tensor fake_input(type, shape);
|
Tensor fake_input(type, shape);
|
||||||
item.feed.emplace_back(node->name(), fake_input);
|
item.feed.emplace_back(node->name(), fake_input);
|
||||||
}
|
}
|
||||||
|
@ -59,13 +59,13 @@ TEST_F(GraphMemoryTest, Basic) {
|
|||||||
for (const auto& t : mem_usage.live_tensors) {
|
for (const auto& t : mem_usage.live_tensors) {
|
||||||
tensors.insert(strings::StrCat(t.node, ":", t.output_id));
|
tensors.insert(strings::StrCat(t.node, ":", t.output_id));
|
||||||
}
|
}
|
||||||
// When the execution of the 'Square' node completes, TF can start executing
|
// When the execution of the 'Sign' node completes, TF can start executing
|
||||||
// 'Square_1' and release the memory used by 'x'. Since we can't be sure of
|
// 'Sign_1' and release the memory used by 'x'. Since we can't be sure of
|
||||||
// the order in which this takes place, in the worst case the 3 tensors are in
|
// the order in which this takes place, in the worst case the 3 tensors are in
|
||||||
// memory.
|
// memory.
|
||||||
std::set<string> expected;
|
std::set<string> expected;
|
||||||
expected.insert("Square:0");
|
expected.insert("Sign:0");
|
||||||
expected.insert("Square_1:0");
|
expected.insert("Sign_1:0");
|
||||||
expected.insert("x:0");
|
expected.insert("x:0");
|
||||||
EXPECT_EQ(expected, tensors);
|
EXPECT_EQ(expected, tensors);
|
||||||
}
|
}
|
||||||
@ -91,7 +91,7 @@ TEST_F(GraphMemoryTest, UnknownBatchSize) {
|
|||||||
}
|
}
|
||||||
std::set<string> expected;
|
std::set<string> expected;
|
||||||
expected.insert("Const/Const:0");
|
expected.insert("Const/Const:0");
|
||||||
expected.insert("Square:0");
|
expected.insert("Sign:0");
|
||||||
expected.insert("x:0");
|
expected.insert("x:0");
|
||||||
EXPECT_EQ(expected, tensors);
|
EXPECT_EQ(expected, tensors);
|
||||||
}
|
}
|
||||||
@ -114,8 +114,8 @@ TEST_F(GraphMemoryTest, MultiDevice) {
|
|||||||
cpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
|
cpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
|
||||||
}
|
}
|
||||||
std::set<string> cpu_expected;
|
std::set<string> cpu_expected;
|
||||||
cpu_expected.insert("Recv_Square_1_0_on_/CPU_0:0");
|
cpu_expected.insert("Recv_Sign_1_0_on_/CPU_0:0");
|
||||||
cpu_expected.insert("Square:0");
|
cpu_expected.insert("Sign:0");
|
||||||
cpu_expected.insert("x:0");
|
cpu_expected.insert("x:0");
|
||||||
cpu_expected.insert("AddN:0");
|
cpu_expected.insert("AddN:0");
|
||||||
EXPECT_EQ(cpu_expected, cpu_tensors);
|
EXPECT_EQ(cpu_expected, cpu_tensors);
|
||||||
@ -128,7 +128,7 @@ TEST_F(GraphMemoryTest, MultiDevice) {
|
|||||||
}
|
}
|
||||||
std::set<string> gpu_expected;
|
std::set<string> gpu_expected;
|
||||||
gpu_expected.insert("Recv_AddN_0_on_/GPU_0:0");
|
gpu_expected.insert("Recv_AddN_0_on_/GPU_0:0");
|
||||||
gpu_expected.insert("Square_1:0");
|
gpu_expected.insert("Sign_1:0");
|
||||||
gpu_expected.insert("AddN_1:0");
|
gpu_expected.insert("AddN_1:0");
|
||||||
gpu_expected.insert("AddN_3:0");
|
gpu_expected.insert("AddN_3:0");
|
||||||
EXPECT_EQ(gpu_expected, gpu_tensors);
|
EXPECT_EQ(gpu_expected, gpu_tensors);
|
||||||
@ -154,8 +154,8 @@ TEST_F(GraphMemoryTest, GpuSwapping) {
|
|||||||
gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
|
gpu_tensors.insert(strings::StrCat(t.node, ":", t.output_id));
|
||||||
}
|
}
|
||||||
std::set<string> gpu_expected;
|
std::set<string> gpu_expected;
|
||||||
gpu_expected.insert("Square:0");
|
gpu_expected.insert("Sign:0");
|
||||||
gpu_expected.insert("Square_1:0");
|
gpu_expected.insert("Sign_1:0");
|
||||||
gpu_expected.insert("AddN:0");
|
gpu_expected.insert("AddN:0");
|
||||||
gpu_expected.insert("AddN_1:0");
|
gpu_expected.insert("AddN_1:0");
|
||||||
gpu_expected.insert("AddN_2:0");
|
gpu_expected.insert("AddN_2:0");
|
||||||
|
@ -134,14 +134,14 @@ TEST_F(GraphViewTest, BasicGraph) {
|
|||||||
EXPECT_EQ(input.node->name(), "AddN");
|
EXPECT_EQ(input.node->name(), "AddN");
|
||||||
EXPECT_EQ(input.port_id, 0);
|
EXPECT_EQ(input.port_id, 0);
|
||||||
GraphView::OutputPort fanin = graph.GetRegularFanin(input);
|
GraphView::OutputPort fanin = graph.GetRegularFanin(input);
|
||||||
EXPECT_EQ(fanin.node->name(), "Square");
|
EXPECT_EQ(fanin.node->name(), "Sign");
|
||||||
EXPECT_EQ(fanin.port_id, 0);
|
EXPECT_EQ(fanin.port_id, 0);
|
||||||
|
|
||||||
input = graph.GetInputPort("AddN", 1);
|
input = graph.GetInputPort("AddN", 1);
|
||||||
EXPECT_EQ(input.node->name(), "AddN");
|
EXPECT_EQ(input.node->name(), "AddN");
|
||||||
EXPECT_EQ(input.port_id, 1);
|
EXPECT_EQ(input.port_id, 1);
|
||||||
fanin = graph.GetRegularFanin(input);
|
fanin = graph.GetRegularFanin(input);
|
||||||
EXPECT_EQ(fanin.node->name(), "Square_1");
|
EXPECT_EQ(fanin.node->name(), "Sign_1");
|
||||||
EXPECT_EQ(fanin.port_id, 0);
|
EXPECT_EQ(fanin.port_id, 0);
|
||||||
|
|
||||||
GraphView::OutputPort output = graph.GetOutputPort("AddN", 0);
|
GraphView::OutputPort output = graph.GetOutputPort("AddN", 0);
|
||||||
@ -169,7 +169,7 @@ TEST_F(GraphViewTest, BasicGraph) {
|
|||||||
EXPECT_EQ(fanouts, expected_fanouts);
|
EXPECT_EQ(fanouts, expected_fanouts);
|
||||||
|
|
||||||
absl::flat_hash_set<string> fanins;
|
absl::flat_hash_set<string> fanins;
|
||||||
absl::flat_hash_set<string> expected_fanins = {"Square_1:0", "Square:0"};
|
absl::flat_hash_set<string> expected_fanins = {"Sign_1:0", "Sign:0"};
|
||||||
for (const auto& fi : graph.GetFanins(*add_node, false)) {
|
for (const auto& fi : graph.GetFanins(*add_node, false)) {
|
||||||
fanins.insert(absl::StrCat(fi.node->name(), ":", fi.port_id));
|
fanins.insert(absl::StrCat(fi.node->name(), ":", fi.port_id));
|
||||||
}
|
}
|
||||||
|
@ -47,8 +47,8 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
|
|||||||
std::vector<Output> this_stage;
|
std::vector<Output> this_stage;
|
||||||
for (int j = 0; j < width; j++) {
|
for (int j = 0; j < width; j++) {
|
||||||
if (last_stage.size() == 1) {
|
if (last_stage.size() == 1) {
|
||||||
Output unary_op = Square(
|
Output unary_op =
|
||||||
s.WithDevice(
|
Sign(s.WithDevice(
|
||||||
device_names[use_multiple_devices ? j % device_names.size()
|
device_names[use_multiple_devices ? j % device_names.size()
|
||||||
: 0]),
|
: 0]),
|
||||||
last_stage[0]);
|
last_stage[0]);
|
||||||
|
@ -2253,7 +2253,7 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
|
|||||||
~FoldTransposeIntoMatMul() override = default;
|
~FoldTransposeIntoMatMul() override = default;
|
||||||
|
|
||||||
bool IsSupported(const NodeDef* node) const override {
|
bool IsSupported(const NodeDef* node) const override {
|
||||||
return IsAnyMatMul(*node);
|
return IsAnyMatMul(*node) && !IsInPreserveSet(*node);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
|
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
|
||||||
|
@ -736,7 +736,7 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
|
|||||||
auto identity = ops::Identity(s.WithOpName("identity"), matmul);
|
auto identity = ops::Identity(s.WithOpName("identity"), matmul);
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.fetch = {"matmul"};
|
item.fetch = {"identity"};
|
||||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||||
@ -795,9 +795,10 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
|
|||||||
Output trans_a = ops::ConjugateTranspose(s.WithOpName("trans_a"), a, perm);
|
Output trans_a = ops::ConjugateTranspose(s.WithOpName("trans_a"), a, perm);
|
||||||
Output trans_b = ops::ConjugateTranspose(s.WithOpName("trans_b"), b, perm);
|
Output trans_b = ops::ConjugateTranspose(s.WithOpName("trans_b"), b, perm);
|
||||||
Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
|
Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
|
||||||
|
Output identity = ops::Identity(s.WithOpName("identity"), matmul);
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.fetch = {"matmul"};
|
item.fetch = {"identity"};
|
||||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||||
@ -808,7 +809,7 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
|
|||||||
OptimizeTwice(&optimizer, &item, &output);
|
OptimizeTwice(&optimizer, &item, &output);
|
||||||
|
|
||||||
NodeMap node_map(&output);
|
NodeMap node_map(&output);
|
||||||
EXPECT_EQ(output.node_size(), 11);
|
EXPECT_EQ(output.node_size(), 12);
|
||||||
|
|
||||||
const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
|
const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
|
||||||
const string optimized_name = absl::StrCat(p, "_", "matmul");
|
const string optimized_name = absl::StrCat(p, "_", "matmul");
|
||||||
|
@ -2336,6 +2336,13 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
|
|||||||
node_map_->NodeExists(axis_node_name)) {
|
node_map_->NodeExists(axis_node_name)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// It's unsafe to add a control dependency on the feed node, because it might
|
||||||
|
// have been never executed otherwiwise.
|
||||||
|
if (feed_nodes_.find(NodeName(node->input(0))) != feed_nodes_.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Create constant axis node.
|
// Create constant axis node.
|
||||||
Tensor axis_t(DT_INT32, TensorShape({}));
|
Tensor axis_t(DT_INT32, TensorShape({}));
|
||||||
const int axis =
|
const int axis =
|
||||||
|
@ -280,6 +280,13 @@ class FunctionOptimizerContext {
|
|||||||
|
|
||||||
const GraphView& graph_view() const { return graph_view_; }
|
const GraphView& graph_view() const { return graph_view_; }
|
||||||
|
|
||||||
|
bool IsFeedNode(const string& node_name) const {
|
||||||
|
return absl::c_any_of(
|
||||||
|
item_->feed, [&](const std::pair<std::string, Tensor>& feed) {
|
||||||
|
return ParseTensorName(feed.first).node() == node_name;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
bool IsFetchNode(const string& node_name) const {
|
bool IsFetchNode(const string& node_name) const {
|
||||||
return absl::c_any_of(item_->fetch, [&](const string& fetch) {
|
return absl::c_any_of(item_->fetch, [&](const string& fetch) {
|
||||||
return ParseTensorName(fetch).node() == node_name;
|
return ParseTensorName(fetch).node() == node_name;
|
||||||
@ -1445,9 +1452,9 @@ Status FunctionOptimizer::RunFunctionOptimizerPass(
|
|||||||
|
|
||||||
// Do not specialize if function has custom gradient or marked nospecialize.
|
// Do not specialize if function has custom gradient or marked nospecialize.
|
||||||
const string grad_func = ctx.function_library().FindGradient(func_name);
|
const string grad_func = ctx.function_library().FindGradient(func_name);
|
||||||
const bool no_specialize = !grad_func.empty() ||
|
const bool no_specialize =
|
||||||
MarkedNoSpecialize(*func) ||
|
!grad_func.empty() || ctx.IsFeedNode(node.name()) ||
|
||||||
MarkedForXlaCompilation(node);
|
MarkedNoSpecialize(*func) || MarkedForXlaCompilation(node);
|
||||||
|
|
||||||
if (specialization_worthy && !no_specialize) {
|
if (specialization_worthy && !no_specialize) {
|
||||||
// TODO(ezhulenev): Specialize function call if input has a known shape.
|
// TODO(ezhulenev): Specialize function call if input has a known shape.
|
||||||
|
@ -111,8 +111,8 @@ TEST_F(StaticScheduleTest, BasicGraph) {
|
|||||||
std::vector<std::string> ordered_node_names =
|
std::vector<std::string> ordered_node_names =
|
||||||
GetOrderedNodeNames(completion_times);
|
GetOrderedNodeNames(completion_times);
|
||||||
EXPECT_EQ(ordered_node_names,
|
EXPECT_EQ(ordered_node_names,
|
||||||
(std::vector<std::string>{"Const/Const", "x", "Square", "Square_1",
|
(std::vector<std::string>{"Const/Const", "x", "Sign", "Sign_1",
|
||||||
"Square_2", "Square_3", "y"}));
|
"Sign_2", "Sign_3", "y"}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(StaticScheduleTest, BasicGraphWithCtrlDependencies) {
|
TEST_F(StaticScheduleTest, BasicGraphWithCtrlDependencies) {
|
||||||
@ -192,8 +192,8 @@ TEST_F(StaticScheduleTest, RequiredTimes) {
|
|||||||
std::vector<std::string> ordered_node_names =
|
std::vector<std::string> ordered_node_names =
|
||||||
GetOrderedNodeNames(required_times);
|
GetOrderedNodeNames(required_times);
|
||||||
EXPECT_EQ(ordered_node_names,
|
EXPECT_EQ(ordered_node_names,
|
||||||
(std::vector<std::string>{"Const/Const", "x", "Square", "Square_1",
|
(std::vector<std::string>{"Const/Const", "x", "Sign", "Sign_1",
|
||||||
"Square_2", "Square_3", "y"}));
|
"Sign_2", "Sign_3", "y"}));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user