[Grappler] Before calling Grappler from the TensorFlow runtime, check that feeds and fetches are valid and only refer to nodes that exist in the graph.
Move a few things around to fail faster and add comments to improve readability. Added a couple of missing tests in grappler_test_test.cc. PiperOrigin-RevId: 332030347 Change-Id: If2d8c1dbb0a7e700db16886c3c6c39bba2de4e0a
This commit is contained in:
parent
6904987bee
commit
ad04b28f5f
tensorflow/core
common_runtime
grappler
@ -1666,6 +1666,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
@ -634,16 +635,19 @@ Status GraphExecutionState::InitBaseGraph(std::unique_ptr<Graph>&& new_graph) {
|
||||
Status GraphExecutionState::OptimizeGraph(
|
||||
const BuildGraphOptions& options, std::unique_ptr<Graph>* optimized_graph,
|
||||
std::unique_ptr<FunctionLibraryDefinition>* optimized_flib) {
|
||||
#ifndef IS_MOBILE_PLATFORM
|
||||
#ifdef IS_MOBILE_PLATFORM
|
||||
return errors::InvalidArgument("Mobile platforms not supported");
|
||||
#else
|
||||
if (session_options_->config.graph_options().place_pruned_graph()) {
|
||||
return errors::InvalidArgument("Can't optimize a pruned graph");
|
||||
}
|
||||
|
||||
if (grappler::MetaOptimizerEnabled(session_options_->config)) {
|
||||
// Here we build the GrapplerItem before calling the optimizer.
|
||||
grappler::GrapplerItem item;
|
||||
item.id = "tf_graph";
|
||||
graph_->ToGraphDef(&item.graph);
|
||||
|
||||
// Add devices to the GrapplerItem
|
||||
// It's ok to skip invalid device annotations in Grappler.
|
||||
for (const Device* d : device_set_->devices()) {
|
||||
Status added_device = item.AddDevice(d->name());
|
||||
@ -652,11 +656,7 @@ Status GraphExecutionState::OptimizeGraph(
|
||||
VLOG(3) << "Grappler available devices: "
|
||||
<< absl::StrJoin(item.devices(), ", ");
|
||||
|
||||
// TODO(b/114748242): Add a unit test to test this bug fix.
|
||||
if (flib_def_) {
|
||||
*item.graph.mutable_library() = flib_def_->ToProto();
|
||||
}
|
||||
|
||||
// Add fetches to the GrapplerItem.
|
||||
item.fetch.insert(item.fetch.end(),
|
||||
options.callable_options.fetch().begin(),
|
||||
options.callable_options.fetch().end());
|
||||
@ -669,6 +669,8 @@ Status GraphExecutionState::OptimizeGraph(
|
||||
item.fetch.push_back(tensor_connection.from_tensor());
|
||||
}
|
||||
|
||||
// Add feeds to the GrapplerItem if we know them.
|
||||
absl::flat_hash_set<absl::string_view> node_names;
|
||||
if (!(options.callable_options.feed().empty() &&
|
||||
options.callable_options.tensor_connection().empty())) {
|
||||
std::vector<SafeTensorId> feeds;
|
||||
@ -683,7 +685,7 @@ Status GraphExecutionState::OptimizeGraph(
|
||||
|
||||
// For feeds with tensor index 0 we try to find the corresponding node in
|
||||
// the graph to infer feed data type and shape.
|
||||
std::unordered_set<std::string> feed_nodes;
|
||||
absl::flat_hash_set<absl::string_view> feed_nodes;
|
||||
|
||||
// For feeds with tensor index larger than 0, we can't infer data type or
|
||||
// shape from the graph. Currently we only support type and shape
|
||||
@ -702,7 +704,9 @@ Status GraphExecutionState::OptimizeGraph(
|
||||
|
||||
// For feeds with tensor index == 0 we try to infer data type and tensor
|
||||
// shape from the graph, by looking at the fed node attributes.
|
||||
node_names.reserve(graph_->num_nodes());
|
||||
for (const Node* node : graph_->nodes()) {
|
||||
node_names.insert(node->name());
|
||||
if (feed_nodes.find(node->name()) == feed_nodes.end()) continue;
|
||||
|
||||
// Try to get the type and shape of the feed node.
|
||||
@ -747,6 +751,39 @@ Status GraphExecutionState::OptimizeGraph(
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that the feeds and fetches are valid.
|
||||
if (node_names.empty()) {
|
||||
// Collect all node names in the graph if we didn't already.
|
||||
node_names.reserve(graph_->num_nodes());
|
||||
for (const Node* node : graph_->nodes()) {
|
||||
node_names.insert(node->name());
|
||||
}
|
||||
}
|
||||
for (const auto& feed : item.feed) {
|
||||
SafeTensorId tensor_id = ParseTensorName(feed.first);
|
||||
if (node_names.find(tensor_id.node()) == node_names.end()) {
|
||||
return errors::InvalidArgument("Invalid feed, no such node in graph: ",
|
||||
feed.first);
|
||||
}
|
||||
}
|
||||
for (const auto& fetch : item.fetch) {
|
||||
SafeTensorId tensor_id = ParseTensorName(fetch);
|
||||
if (node_names.find(tensor_id.node()) == node_names.end()) {
|
||||
return errors::InvalidArgument("Invalid fetch, no such node in graph: ",
|
||||
fetch);
|
||||
}
|
||||
}
|
||||
|
||||
// Convert Graph to GraphDef and add it to the GrapplerItem.
|
||||
graph_->ToGraphDef(&item.graph);
|
||||
// TODO(b/114748242): Add a unit test to test this bug fix.
|
||||
if (flib_def_) {
|
||||
*item.graph.mutable_library() = flib_def_->ToProto();
|
||||
}
|
||||
|
||||
// Construct a virtual cluster and find the cpu_device, which the
|
||||
// ConstantFolding optimizer will use for partial evaluation of the graph.
|
||||
grappler::VirtualCluster cluster(device_set_);
|
||||
Device* cpu_device = nullptr;
|
||||
for (const auto& device : device_set_->devices()) {
|
||||
if (device->parsed_name().id == 0 &&
|
||||
@ -755,7 +792,8 @@ Status GraphExecutionState::OptimizeGraph(
|
||||
cpu_device = device;
|
||||
}
|
||||
}
|
||||
grappler::VirtualCluster cluster(device_set_);
|
||||
|
||||
// Now we can run the MetaOptimizer on the constructed GrapplerItem.
|
||||
GraphDef new_graph;
|
||||
TF_RETURN_IF_ERROR(
|
||||
grappler::RunMetaOptimizer(std::move(item), session_options_->config,
|
||||
@ -778,9 +816,9 @@ Status GraphExecutionState::OptimizeGraph(
|
||||
TF_RETURN_IF_ERROR((*optimized_flib)->AddFunctionDef(fdef));
|
||||
}
|
||||
}
|
||||
|
||||
optimized_graph->reset(new Graph(OpRegistry::Global()));
|
||||
|
||||
// Convert the optimized GraphDef back to a Graph.
|
||||
GraphConstructorOptions opts;
|
||||
opts.allow_internal_ops = true;
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, std::move(new_graph),
|
||||
@ -796,8 +834,6 @@ Status GraphExecutionState::OptimizeGraph(
|
||||
} else {
|
||||
return errors::InvalidArgument("Meta Optimizer disabled");
|
||||
}
|
||||
#else
|
||||
return errors::InvalidArgument("Mobile platforms not supported");
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
||||
|
@ -421,8 +421,7 @@ TEST_F(ScopedAllocatorOptimizerTest, UnaryExecute) {
|
||||
SetShapes(&graph_def);
|
||||
std::vector<Tensor> outputs;
|
||||
ExecuteGraph(graph_def,
|
||||
/*output_names=*/{"r1:0", "r2:0", "scoped_allocator_1_2_Abs:0"},
|
||||
&outputs);
|
||||
/*output_names=*/{"r1:0", "r2:0"}, &outputs);
|
||||
// a + b == 2, -2, 3, 3
|
||||
// b + c == -4, -4, 3, 2
|
||||
ValidateValues(outputs, /*expected=*/{{2, 2, 3, 3}, {4, 4, 3, 2}});
|
||||
|
@ -68,27 +68,44 @@ void CompareGraphNodes(protobuf::RepeatedPtrField<NodeDef>* want,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SetAllOptimizers(RewriterConfig* cfg, RewriterConfig::Toggle value) {
|
||||
cfg->set_arithmetic_optimization(value);
|
||||
cfg->set_auto_mixed_precision(value);
|
||||
cfg->set_auto_mixed_precision_mkl(value);
|
||||
cfg->set_common_subgraph_elimination(value);
|
||||
cfg->set_constant_folding(value);
|
||||
cfg->set_debug_stripper(value);
|
||||
cfg->set_dependency_optimization(value);
|
||||
cfg->set_function_optimization(value);
|
||||
cfg->set_implementation_selector(value);
|
||||
cfg->set_layout_optimizer(value);
|
||||
cfg->set_loop_optimization(value);
|
||||
cfg->set_pin_to_host_optimization(value);
|
||||
cfg->set_remapping(value);
|
||||
cfg->set_scoped_allocator_optimization(value);
|
||||
cfg->set_shape_optimization(value);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
GrapplerTest::GrapplerTest() {
|
||||
// Turn off all the automatic optimizations to ensure that we run the graph
|
||||
// exactly as it is given to us. This ensures that we can compare the results
|
||||
// before and after manual optimization, without any of the automatic
|
||||
// optimizations interfering in the comparison.
|
||||
RewriterConfig* cfg =
|
||||
options_.config.mutable_graph_options()->mutable_rewrite_options();
|
||||
// TODO(rmlarsen): Add utility to generate config w/ all optimizers turned
|
||||
// off.
|
||||
cfg->set_arithmetic_optimization(RewriterConfig::OFF);
|
||||
cfg->set_constant_folding(RewriterConfig::OFF);
|
||||
cfg->set_debug_stripper(RewriterConfig::OFF);
|
||||
cfg->set_dependency_optimization(RewriterConfig::OFF);
|
||||
cfg->set_function_optimization(RewriterConfig::OFF);
|
||||
cfg->set_implementation_selector(RewriterConfig::OFF);
|
||||
cfg->set_layout_optimizer(RewriterConfig::OFF);
|
||||
cfg->set_loop_optimization(RewriterConfig::OFF);
|
||||
cfg->set_pin_to_host_optimization(RewriterConfig::OFF);
|
||||
cfg->set_remapping(RewriterConfig::OFF);
|
||||
// exactly as it is given to us. This ensures that we can compare the
|
||||
// results before and after manual optimization, without any of the
|
||||
// automatic optimizations interfering in the comparison.
|
||||
DisableAllOptimizers();
|
||||
}
|
||||
|
||||
void GrapplerTest::DisableAllOptimizers() {
|
||||
SetAllOptimizers(
|
||||
options_.config.mutable_graph_options()->mutable_rewrite_options(),
|
||||
RewriterConfig::OFF);
|
||||
}
|
||||
|
||||
void GrapplerTest::EnableAllOptimizers() {
|
||||
SetAllOptimizers(
|
||||
options_.config.mutable_graph_options()->mutable_rewrite_options(),
|
||||
RewriterConfig::ON);
|
||||
}
|
||||
|
||||
std::vector<Tensor> GrapplerTest::EvaluateNodes(
|
||||
|
@ -37,6 +37,9 @@ class GrapplerTest : public ::testing::Test {
|
||||
GrapplerTest();
|
||||
|
||||
protected:
|
||||
void DisableAllOptimizers();
|
||||
void EnableAllOptimizers();
|
||||
|
||||
std::vector<Tensor> EvaluateNodes(
|
||||
const GraphDef& graph, const std::vector<string>& node_names) const;
|
||||
|
||||
@ -51,6 +54,8 @@ class GrapplerTest : public ::testing::Test {
|
||||
const std::vector<std::pair<string, AttrValue>>& attributes,
|
||||
GraphDef* graph) const;
|
||||
|
||||
void DisableAllOptimizers(RewriterConfig* cfg);
|
||||
|
||||
// Checks if two graphs are equal. Both graphs must have the same set of nodes
|
||||
// with the same inputs and attributes. Nodes can be in different order.
|
||||
//
|
||||
|
@ -95,6 +95,35 @@ TEST_F(GrapplerTestTest, CountOpNodes) {
|
||||
EXPECT_EQ(0, CountOpNodes(graph, "Transpose"));
|
||||
}
|
||||
|
||||
TEST_F(GrapplerTestTest, EvaluateNodes) {
|
||||
EnableAllOptimizers();
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output a = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
|
||||
Output b = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2});
|
||||
Output mul = ops::Mul(s.WithOpName("mul"), a, b);
|
||||
GrapplerItem item;
|
||||
item.fetch = {"mul"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
auto tensors = EvaluateNodes(item.graph, item.fetch);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
EXPECT_EQ(tensors[0].flat<float>()(0), 3.0f);
|
||||
EXPECT_EQ(tensors[0].flat<float>()(1), 8.0f);
|
||||
}
|
||||
|
||||
TEST_F(GrapplerTestTest, EvaluateNodesInvalidFetch) {
|
||||
EnableAllOptimizers();
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output a = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
|
||||
Output b = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2});
|
||||
Output mul = ops::Mul(s.WithOpName("mul"), a, b);
|
||||
GrapplerItem item;
|
||||
item.fetch = {"no_such_node"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
EXPECT_DEATH(EvaluateNodes(item.graph, item.fetch),
|
||||
"Invalid argument: Tensor no_such_node:0, specified in either "
|
||||
"feed_devices or fetch_devices was not found in the Graph");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user