[Grappler] Cleanup GrapplerItem constructors

PiperOrigin-RevId: 224226461
This commit is contained in:
Eugene Zhulenev 2018-12-05 14:55:02 -08:00 committed by TensorFlower Gardener
parent 45cfe71266
commit 3f43965a44
10 changed files with 56 additions and 53 deletions

View File

@ -30,20 +30,22 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef* graph_def) {
id = other.id;
feed = other.feed;
fetch = other.fetch;
init_ops = other.init_ops;
keep_ops = other.keep_ops;
expected_init_time = other.expected_init_time;
save_op = other.save_op;
restore_op = other.restore_op;
save_restore_loc_tensor = other.save_restore_loc_tensor;
queue_runners = other.queue_runners;
devices_ = other.devices_;
allowed_optimizations_ = other.allowed_optimizations_;
graph.Swap(graph_def);
GrapplerItem GrapplerItem::WithGraph(GraphDef&& graph_def) const {
GrapplerItem item;
item.id = id;
item.feed = feed;
item.fetch = fetch;
item.init_ops = init_ops;
item.keep_ops = keep_ops;
item.expected_init_time = expected_init_time;
item.save_op = save_op;
item.restore_op = restore_op;
item.save_restore_loc_tensor = save_restore_loc_tensor;
item.queue_runners = queue_runners;
item.devices_ = devices_;
item.allowed_optimizations_ = allowed_optimizations_;
item.graph.Swap(&graph_def);
return item;
}
std::vector<const NodeDef*> GrapplerItem::MainOpsFanin() const {

View File

@ -35,12 +35,15 @@ namespace grappler {
// nodes, and potentially a set of nodes to feed.
struct GrapplerItem {
GrapplerItem() = default;
GrapplerItem(const GrapplerItem& other, GraphDef&& graph_def)
: GrapplerItem(other, &graph_def) {}
// Swaps *graph_def with an empty GraphDef.
GrapplerItem(const GrapplerItem& other, GraphDef* graph_def);
GrapplerItem(const GrapplerItem& other) = default;
GrapplerItem(GrapplerItem&& other) = default;
GrapplerItem& operator=(const GrapplerItem& other) = default;
GrapplerItem& operator=(GrapplerItem&& other) = default;
virtual ~GrapplerItem() = default;
// Create a copy of this GrapplerItem with graph swapped with the argument.
GrapplerItem WithGraph(GraphDef&& graph) const;
string id; // A unique id for this item
// Inputs

View File

@ -3561,8 +3561,7 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
// Set up helper data structures.
nodes_to_preserve_ = item.NodesToPreserve();
fetch_nodes_known_ = !item.fetch.empty();
*optimized_graph = item.graph;
GrapplerItem optimized_item(item, optimized_graph);
GrapplerItem optimized_item(item);
optimized_graph_ = &optimized_item.graph;
node_map_.reset(new NodeMap(optimized_graph_));

View File

@ -127,7 +127,7 @@ TEST_F(ExperimentalImplementationSelectorTest, SwapImplementationEval) {
test::AsScalar<float>(4.0f));
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
const auto twice_boosted_tensor = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(twice_boosted_tensor[0],
test::AsScalar<float>(2.0f));
@ -223,7 +223,7 @@ TEST_F(ExperimentalImplementationSelectorTest, SwapImplementationWithGradient) {
test::AsScalar<float>(4.0f));
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
const auto twice_boosted_tensor = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(twice_boosted_tensor[0],
test::AsScalar<float>(2.0f));

View File

@ -108,7 +108,7 @@ TEST_F(FunctionOptimizerTest, InlineFunction_SimpleFunction) {
item.fetch = {"z"};
item.feed.emplace_back("x", pi);
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
@ -184,7 +184,7 @@ TEST_F(FunctionOptimizerTest, InlineFunction_SkipErrorsIfGraphNotModified) {
item.fetch = {"z1"};
item.feed.emplace_back("x", pi);
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
@ -284,7 +284,7 @@ TEST_F(FunctionOptimizerTest, InlineFunction_FixedTypeFunction) {
item.fetch = {"z"};
item.feed.emplace_back("x", pi);
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
@ -368,7 +368,7 @@ TEST_F(FunctionOptimizerTest, InlineFunction_FunctionWithOutputMapping) {
item.fetch = {"z"};
item.feed.emplace_back("x", pi);
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
@ -418,7 +418,7 @@ TEST_F(FunctionOptimizerTest, InlineFunction_FunctionWithInputForwarding) {
item.feed.emplace_back("x4", test::AsScalar<float>(-1.0f));
item.feed.emplace_back("x3", test::AsScalar<int>(1234));
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
test::ExpectTensorEqual<float>(tensors_expected[1], tensors[1]);
@ -549,7 +549,7 @@ TEST_F(FunctionOptimizerTest, InlineFunction_FunctionWithNestedFunctionCall) {
item.feed.emplace_back("a", test::AsScalar<float>(2.0f));
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
@ -748,7 +748,7 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionSimpleFunction) {
item.feed.emplace_back("a", pi);
item.feed.emplace_back("b", pi);
GrapplerItem optimized(item, std::move(optimized_graph));
GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
auto tensors_expected = EvaluateFetchNodes(item);
auto tensors = EvaluateFetchNodes(optimized);
ASSERT_EQ(tensors_expected.size(), 1);
@ -876,7 +876,7 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithControlDependencies) {
EXPECT_EQ(tensors_expected[0].flat<float>()(0), 4.0); // mul
EXPECT_EQ(tensors_expected[1].flat<float>()(0), 3.0); // read variable
GrapplerItem optimized(item, std::move(optimized_graph));
GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
auto tensors = EvaluateFetchNodes(optimized);
ASSERT_EQ(tensors.size(), 2);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
@ -1019,7 +1019,7 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithoutSideEffects) {
auto tensors_expected = EvaluateFetchNodes(item);
ASSERT_EQ(tensors_expected.size(), 1);
GrapplerItem optimized(item, std::move(optimized_graph));
GrapplerItem optimized = item.WithGraph(std::move(optimized_graph));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
@ -1067,7 +1067,7 @@ TEST_F(FunctionOptimizerTest, SpecializeFunctionXTimesTwo) {
item.feed.emplace_back("x", pi);
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
@ -1131,7 +1131,7 @@ TEST_F(FunctionOptimizerTest, SpecializeIndirectFunctionXTimesTwo) {
item.feed.emplace_back("x", pi);
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
@ -1194,7 +1194,7 @@ TEST_F(FunctionOptimizerTest, SpecializeFunctionPushDownConstInput) {
item.feed.emplace_back("x", pi);
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
@ -1274,7 +1274,7 @@ TEST_F(FunctionOptimizerTest, SpecializeIndirectFunctionPushDownConstInput) {
item.feed.emplace_back("x", pi);
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
@ -1390,7 +1390,7 @@ TEST_F(FunctionOptimizerTest, SpecializeFunction_OncePerUniqueContext) {
item.feed = {{"xf", pi}, {"yf", pi}, {"xi", four}, {"yi", four}};
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
@ -1499,7 +1499,7 @@ TEST_F(FunctionOptimizerTest, SpecializeFunctionForUsedOutputTensors) {
item.feed = {{"xf", pi}, {"yf", pi}};
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
ASSERT_EQ(tensors_expected.size(), tensors.size());
@ -1660,7 +1660,7 @@ TEST_F(FunctionOptimizerTest, SpecializeIndirectFunctionForUsedOutputTensors) {
item.feed = {{"xf", pi}, {"yf", pi}};
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
ASSERT_EQ(tensors_expected.size(), tensors.size());

View File

@ -1306,13 +1306,12 @@ Status RelaxAllocatorConstraints(GraphDef* optimized_graph) {
Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
*optimized_graph = item.graph;
GrapplerItem optimized_item(item);
RecomputationRewritingPass(optimization_level_,
recomputation_targets_name_scope_, optimized_graph,
item);
recomputation_targets_name_scope_,
&optimized_item.graph, item);
GrapplerItem optimized_item(item, optimized_graph);
std::unordered_set<string> skip_list;
// Bound the number of rewrite passes to avoid long processing times on graphs
// that simply won't fit in memory.

View File

@ -279,7 +279,7 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) {
EXPECT_EQ("^swap_out_e_0", new_c.input(1));
// Run the optimizer a second time to ensure it's idempotent.
GrapplerItem item_copy(item, std::move(output));
GrapplerItem item_copy = item.WithGraph(std::move(output));
status = optimizer.Optimize(cluster.get(), item_copy, &output);
TF_EXPECT_OK(status);
@ -287,7 +287,7 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) {
item.fetch = {"e"};
item.init_ops = {init.name()};
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
#endif
@ -337,7 +337,7 @@ TEST_F(MemoryOptimizerTest, SwappingHeuristics) {
#if GOOGLE_CUDA
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
for (int i = 0; i < item.fetch.size(); ++i) {
test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
@ -386,7 +386,7 @@ TEST_F(MemoryOptimizerTest, UnswappableInputs) {
#if GOOGLE_CUDA
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
#endif
@ -474,7 +474,7 @@ TEST_F(RelaxAllocatorConstraintsTest, SameDevice) {
item.fetch = {"exp"};
item.init_ops = {"variable"};
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
@ -505,7 +505,7 @@ TEST_F(RelaxAllocatorConstraintsTest, DifferentDevice) {
item.fetch = {"exp"};
item.init_ops = {"variable"};
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
#endif
@ -598,7 +598,7 @@ TEST_F(RelaxAllocatorConstraintsTest, AssignNodeInFanout) {
item.fetch = {"assign0", "assign1"};
item.init_ops = {"exp1", "variable1"};
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
for (int i = 0; i < tensors_expected.size(); ++i) {
test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);

View File

@ -440,7 +440,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
item.graph)
.ToProto();
GrapplerItem trimmed_item(item, std::move(trimmed_graph));
GrapplerItem trimmed_item = item.WithGraph(std::move(trimmed_graph));
VLOG(1) << absl::Substitute(
"Deleted $0 unreachable functions from the graph (library size = $1)",

View File

@ -396,7 +396,7 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
item.feed.emplace_back("b", test::AsScalar<int>(4));
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
@ -502,7 +502,7 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryPruneFunctionBody) {
item.feed.emplace_back("b", test::AsScalar<float>(3.123f));
auto tensors_expected = EvaluateFetchNodes(item);
GrapplerItem optimized(item, std::move(output));
GrapplerItem optimized = item.WithGraph(std::move(output));
auto tensors = EvaluateFetchNodes(optimized);
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);

View File

@ -665,7 +665,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
std::reverse(topo_sorted_graph.mutable_node()->begin(),
topo_sorted_graph.mutable_node()->end());
GrapplerItem topo_sorted_item(item, std::move(topo_sorted_graph));
GrapplerItem topo_sorted_item = item.WithGraph(std::move(topo_sorted_graph));
RemapperContext ctx(topo_sorted_item);
// Skip nodes that were invalidated by a remapper, e.g. do not process BiasAdd