Revert constant folding to previous state.

PiperOrigin-RevId: 215946205
This commit is contained in:
Tong Shen 2018-10-05 12:17:31 -07:00 committed by TensorFlower Gardener
parent 0541a277d5
commit d016650ca7
5 changed files with 20 additions and 93 deletions

View File

@ -94,8 +94,9 @@ Status FunctionalizeControlFlowForFunction(
} }
}); });
const FunctionBody* body = flr->GetFunctionBody(handle); const FunctionBody* body = flr->GetFunctionBody(handle);
Graph* g = body->graph;
// Check if the graph has Switch or Merge node before optimizing the graph. // Check if the graph has Switch or Merge node.
bool has_switch_or_merge = false; bool has_switch_or_merge = false;
for (Node* n : body->graph->nodes()) { for (Node* n : body->graph->nodes()) {
if (n->type_string() == "Switch" || n->type_string() == "Merge") { if (n->type_string() == "Switch" || n->type_string() == "Merge") {
@ -108,58 +109,13 @@ Status FunctionalizeControlFlowForFunction(
// in function body. We still need to rewrite those functions and modify // in function body. We still need to rewrite those functions and modify
// corresponding nodes. // corresponding nodes.
// Call graph optimizer. The most important optimization we need is constant
// folding, which will replace ops like Shape/BroadcastGradientArgs with
// constant shape input. Without this optimization, those ops might become
// dynamic input for then/else body function and XLA will complain that input
// is not compile time constant. We enable function inlining as well, because
// otherwise we won't be able to infer shape for any node depending on
// function call nodes.
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("functionalize_control_flow_before_opt_", func_name),
*body->graph, fld);
}
// Optimizer accepts std::unique_ptr<Graph>* as input and might change
// underlying pointer, thus we create a new Graph and copy from body->graph.
std::unique_ptr<Graph> optimized_graph(new Graph(fld));
CopyGraph(*body->graph, optimized_graph.get());
OptimizerOptions opts;
opts.set_opt_level(OptimizerOptions::L0);
opts.set_do_function_inlining(true);
opts.set_do_constant_folding(true);
GraphOptimizer optimizer(opts);
auto cf_consider_fn = [](const Node* n) {
// Skip SymbolicGradient op when doing constant folding.
// Enabling SymbolicGradient op in constant folding requires
// flr->device() to be non-null, and here we have not constructed
// proper Device object yet (it will be constructed in XlaCompiler).
return n->type_string() != FunctionLibraryDefinition::kGradientOp;
};
optimizer.Optimize(flr, flr->env(),
/*device=*/nullptr, &optimized_graph,
/*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr,
cf_consider_fn);
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("functionalize_control_flow_after_opt_", func_name),
*optimized_graph, fld);
}
// Some inlined functions might have Switch/Merge nodes.
for (Node* n : optimized_graph->nodes()) {
if (n->type_string() == "Switch" || n->type_string() == "Merge") {
has_switch_or_merge = true;
break;
}
}
// If any node has associated functions, functionalize them first. // If any node has associated functions, functionalize them first.
// Gather nodes with associated functions first, because rewriting those nodes // Gather nodes with associated functions first, because rewriting those nodes
// might involve node deletion/addition. Avoid modifying nodes while iterating // might involve node deletion/addition. Avoid modifying nodes while iterating
// it. // it.
std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>> std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
nodes_to_associated_functions; nodes_to_associated_functions;
for (auto* n : optimized_graph->nodes()) { for (auto* n : g->nodes()) {
auto associated_functions = GetAssociatedFunctions(*n, flr); auto associated_functions = GetAssociatedFunctions(*n, flr);
if (!associated_functions.empty()) { if (!associated_functions.empty()) {
nodes_to_associated_functions.push_back({n, associated_functions}); nodes_to_associated_functions.push_back({n, associated_functions});
@ -215,7 +171,7 @@ Status FunctionalizeControlFlowForFunction(
// pointer. That's fine because in that case, associated_functions will // pointer. That's fine because in that case, associated_functions will
// only have one member and the loop will only run once. // only have one member and the loop will only run once.
TF_RETURN_IF_ERROR(RewriteAssociatedFunction( TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
optimized_graph.get(), n, fld, associated_function, new_name)); g, n, fld, associated_function, new_name));
} }
} }
} }
@ -227,21 +183,21 @@ Status FunctionalizeControlFlowForFunction(
if (VLOG_IS_ON(4)) { if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile( dump_graph::DumpGraphToFile(
absl::StrCat("functionalize_control_flow_before_fdef_", func_name), absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
*optimized_graph, fld); *g, fld);
} }
TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld)); TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld));
if (VLOG_IS_ON(4)) { if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile( dump_graph::DumpGraphToFile(
absl::StrCat("functionalize_control_flow_after_fdef_", func_name), absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g,
*optimized_graph, fld); fld);
} }
} }
if (*modified) { if (*modified) {
// Add rewritten FunctionDef into library. // Add rewritten FunctionDef into library.
FunctionDef functionalized_fdef; FunctionDef functionalized_fdef;
TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name, TF_RETURN_IF_ERROR(
&functionalized_fdef)); GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
if (func_name == new_func_name) { if (func_name == new_func_name) {
VLOG(2) << "Replacing function " << func_name; VLOG(2) << "Replacing function " << func_name;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(

View File

@ -466,23 +466,23 @@ Graph* GetConstantGraph(
bool ReplaceTensorWithConstant( bool ReplaceTensorWithConstant(
Graph* graph, Device* partition_device, NodeAndOutput tensor, Graph* graph, Device* partition_device, NodeAndOutput tensor,
const Tensor& constant, const gtl::FlatSet<Node*>& control_deps, const Tensor& constant, const gtl::FlatSet<Node*>& control_deps,
int64 max_constant_size_in_bytes, bool disable_memory_output_type_check, int64 max_constant_size_in_bytes,
const ConstantFoldNameGenerator& generate_new_name) { const ConstantFoldNameGenerator& generate_new_name) {
// Be conservative when replacing a tensor with a constant, when not // Be conservative when replacing a tensor with a constant, when not
// running on CPU. // running on CPU.
// 1) Do not replace another constant. // 1) Do not replace another constant.
// 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY // 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
// constraint, do not replace it. // constraint, do not replace it.
// 3) If the size of the constant in bytes is too large (> // 3) If the destination tensor is an int32 tensor, and has DEVICE_MEMORY
// constraint, do not replace it.
// 4) If the size of the constant in bytes is too large (>
// max_constant_in_bytes), do not replace it. This prevents the size of the // max_constant_in_bytes), do not replace it. This prevents the size of the
// Graph from growing too large. // Graph from growing too large.
// 4) If the constant op created does not have a kernel implementation // 5) If the constant op created does not have a kernel implementation
// for the device, do not use it. // for the device, do not use it.
// TODO(keveman): Consider adding a new constant op that has a kernel // TODO(keveman): Consider adding a new constant op that has a kernel
// implementation for all types, but with HostMemory constraint on it's // implementation for all types, but with HostMemory constraint on it's
// output. // output.
// 5) If the constant op for the device has different output memory type
// from the original op output memory type, do not replace it.
if (tensor.first->IsConstant()) { if (tensor.first->IsConstant()) {
return false; return false;
} }
@ -497,7 +497,8 @@ bool ReplaceTensorWithConstant(
return false; return false;
} }
bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32; bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32;
if (memory_type == HOST_MEMORY && !is_int32) { if ((memory_type == HOST_MEMORY && !is_int32) ||
(memory_type == DEVICE_MEMORY && is_int32)) {
return false; return false;
} }
} }
@ -535,25 +536,6 @@ bool ReplaceTensorWithConstant(
if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) { if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
return false; return false;
} }
if (!disable_memory_output_type_check) {
if (partition_device && device_type != DEVICE_CPU) {
MemoryType original_output_memory_type;
if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
&original_output_memory_type)
.ok()) {
return false;
}
MemoryType const_output_memory_type;
if (!MemoryTypeForOutput(device_type, graph, constant_node, 0,
&const_output_memory_type)
.ok()) {
return false;
}
if (original_output_memory_type != const_output_memory_type) {
return false;
}
}
}
for (auto edge : edges_to_remove) { for (auto edge : edges_to_remove) {
graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input()); graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
graph->RemoveEdge(edge); graph->RemoveEdge(edge);
@ -660,8 +642,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
constant_control_deps[tensors_to_replace[c].first]; constant_control_deps[tensors_to_replace[c].first];
if (ReplaceTensorWithConstant( if (ReplaceTensorWithConstant(
graph, partition_device, tensors_to_replace[c], outputs[c], graph, partition_device, tensors_to_replace[c], outputs[c],
control_deps, opts.max_constant_size_in_bytes, control_deps, opts.max_constant_size_in_bytes, generate_new_name)) {
opts.disable_memory_output_type_check, generate_new_name)) {
++num_nodes_replaced; ++num_nodes_replaced;
} }
} }

View File

@ -45,10 +45,6 @@ struct ConstantFoldingOptions {
// optimization. // optimization.
int64 max_constant_size_in_bytes = 10 * 1024 * 1024; int64 max_constant_size_in_bytes = 10 * 1024 * 1024;
// If disable_memory_output_type_check is true, we will disable output memory
// type check for constant node replacement.
bool disable_memory_output_type_check = false;
// A generator for the name suffix of constant folded nodes. A // A generator for the name suffix of constant folded nodes. A
// default id generator that monotonically increases is used if nullptr is // default id generator that monotonically increases is used if nullptr is
// passed. // passed.

View File

@ -39,8 +39,7 @@ void GraphOptimizer::Optimize(
const std::unordered_map<string, std::vector<PartialTensorShape>>* const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map, shape_map,
const std::function<bool(const Node*)>& cse_consider_fn, const std::function<bool(const Node*)>& cse_consider_fn,
const std::function<bool(const Node*)>& cf_consider_fn, const std::function<bool(const Node*)>& cf_consider_fn) {
bool cf_disable_memory_output_type_check) {
Graph* g = graph->get(); Graph* g = graph->get();
DumpGraph("Initial", g); DumpGraph("Initial", g);
@ -65,8 +64,6 @@ void GraphOptimizer::Optimize(
ConstantFoldingOptions cf_opts; ConstantFoldingOptions cf_opts;
cf_opts.shape_map = shape_map; cf_opts.shape_map = shape_map;
cf_opts.consider = cf_consider_fn; cf_opts.consider = cf_consider_fn;
cf_opts.disable_memory_output_type_check =
cf_disable_memory_output_type_check;
if (opts_.max_folded_constant_in_bytes() > 0) { if (opts_.max_folded_constant_in_bytes() > 0) {
cf_opts.max_constant_size_in_bytes = cf_opts.max_constant_size_in_bytes =
opts_.max_folded_constant_in_bytes(); opts_.max_folded_constant_in_bytes();

View File

@ -47,16 +47,13 @@ class GraphOptimizer {
// returns true will be considered for CSE. // returns true will be considered for CSE.
// If cf_consider_fn is not null then only nodes for which cf_consider_fn // If cf_consider_fn is not null then only nodes for which cf_consider_fn
// returns true will be considered for CF. // returns true will be considered for CF.
// If cf_disable_memory_output_type_check is true, CF will discard output
// memory type check for constant node replacement.
void Optimize( void Optimize(
FunctionLibraryRuntime* runtime, Env* env, Device* device, FunctionLibraryRuntime* runtime, Env* env, Device* device,
std::unique_ptr<Graph>* graph, std::unique_ptr<Graph>* graph,
const std::unordered_map<string, std::vector<PartialTensorShape>>* const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map, shape_map,
const std::function<bool(const Node*)>& cse_consider_fn = nullptr, const std::function<bool(const Node*)>& cse_consider_fn = nullptr,
const std::function<bool(const Node*)>& cf_consider_fn = nullptr, const std::function<bool(const Node*)>& cf_consider_fn = nullptr);
bool cf_disable_memory_output_type_check = false);
const OptimizerOptions& options() { return opts_; } const OptimizerOptions& options() { return opts_; }