Revert constant folding to previous state.

PiperOrigin-RevId: 215946205
This commit is contained in:
Tong Shen 2018-10-05 12:17:31 -07:00 committed by TensorFlower Gardener
parent 0541a277d5
commit d016650ca7
5 changed files with 20 additions and 93 deletions

View File

@ -94,8 +94,9 @@ Status FunctionalizeControlFlowForFunction(
}
});
const FunctionBody* body = flr->GetFunctionBody(handle);
Graph* g = body->graph;
// Check if the graph has Switch or Merge node before optimizing the graph.
// Check if the graph has Switch or Merge node.
bool has_switch_or_merge = false;
for (Node* n : body->graph->nodes()) {
if (n->type_string() == "Switch" || n->type_string() == "Merge") {
@ -108,58 +109,13 @@ Status FunctionalizeControlFlowForFunction(
// in function body. We still need to rewrite those functions and modify
// corresponding nodes.
// Call graph optimizer. The most important optimization we need is constant
// folding, which will replace ops like Shape/BroadcastGradientArgs with
// constant shape input. Without this optimization, those ops might become
// dynamic input for then/else body function and XLA will complain that input
// is not compile time constant. We enable function inlining as well, because
// otherwise we won't be able to infer shape for any node depending on
// function call nodes.
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("functionalize_control_flow_before_opt_", func_name),
*body->graph, fld);
}
// Optimizer accepts std::unique_ptr<Graph>* as input and might change
// underlying pointer, thus we create a new Graph and copy from body->graph.
std::unique_ptr<Graph> optimized_graph(new Graph(fld));
CopyGraph(*body->graph, optimized_graph.get());
OptimizerOptions opts;
opts.set_opt_level(OptimizerOptions::L0);
opts.set_do_function_inlining(true);
opts.set_do_constant_folding(true);
GraphOptimizer optimizer(opts);
auto cf_consider_fn = [](const Node* n) {
// Skip SymbolicGradient op when doing constant folding.
// Enabling SymbolicGradient op in constant folding requires
// flr->device() to be non-null, and here we have not constructed
// proper Device object yet (it will be constructed in XlaCompiler).
return n->type_string() != FunctionLibraryDefinition::kGradientOp;
};
optimizer.Optimize(flr, flr->env(),
/*device=*/nullptr, &optimized_graph,
/*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr,
cf_consider_fn);
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("functionalize_control_flow_after_opt_", func_name),
*optimized_graph, fld);
}
// Some inlined functions might have Switch/Merge nodes.
for (Node* n : optimized_graph->nodes()) {
if (n->type_string() == "Switch" || n->type_string() == "Merge") {
has_switch_or_merge = true;
break;
}
}
// If any node has associated functions, functionalize them first.
// Gather nodes with associated functions first, because rewriting those nodes
// might involve node deletion/addition. Avoid modifying nodes while iterating
// it.
std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
nodes_to_associated_functions;
for (auto* n : optimized_graph->nodes()) {
for (auto* n : g->nodes()) {
auto associated_functions = GetAssociatedFunctions(*n, flr);
if (!associated_functions.empty()) {
nodes_to_associated_functions.push_back({n, associated_functions});
@ -215,7 +171,7 @@ Status FunctionalizeControlFlowForFunction(
// pointer. That's fine because in that case, associated_functions will
// only have one member and the loop will only run once.
TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
optimized_graph.get(), n, fld, associated_function, new_name));
g, n, fld, associated_function, new_name));
}
}
}
@ -227,21 +183,21 @@ Status FunctionalizeControlFlowForFunction(
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
*optimized_graph, fld);
*g, fld);
}
TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld));
TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld));
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
*optimized_graph, fld);
absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g,
fld);
}
}
if (*modified) {
// Add rewritten FunctionDef into library.
FunctionDef functionalized_fdef;
TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name,
&functionalized_fdef));
TF_RETURN_IF_ERROR(
GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
if (func_name == new_func_name) {
VLOG(2) << "Replacing function " << func_name;
TF_RETURN_IF_ERROR(

View File

@ -466,23 +466,23 @@ Graph* GetConstantGraph(
bool ReplaceTensorWithConstant(
Graph* graph, Device* partition_device, NodeAndOutput tensor,
const Tensor& constant, const gtl::FlatSet<Node*>& control_deps,
int64 max_constant_size_in_bytes, bool disable_memory_output_type_check,
int64 max_constant_size_in_bytes,
const ConstantFoldNameGenerator& generate_new_name) {
// Be conservative when replacing a tensor with a constant, when not
// running on CPU.
// 1) Do not replace another constant.
// 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
// constraint, do not replace it.
// 3) If the size of the constant in bytes is too large (>
// 3) If the destination tensor is an int32 tensor, and has DEVICE_MEMORY
// constraint, do not replace it.
// 4) If the size of the constant in bytes is too large (>
// max_constant_in_bytes), do not replace it. This prevents the size of the
// Graph from growing too large.
// 4) If the constant op created does not have a kernel implementation
// 5) If the constant op created does not have a kernel implementation
// for the device, do not use it.
// TODO(keveman): Consider adding a new constant op that has a kernel
// implementation for all types, but with HostMemory constraint on it's
// output.
// 5) If the constant op for the device has different output memory type
// from the original op output memory type, do not replace it.
if (tensor.first->IsConstant()) {
return false;
}
@ -497,7 +497,8 @@ bool ReplaceTensorWithConstant(
return false;
}
bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32;
if (memory_type == HOST_MEMORY && !is_int32) {
if ((memory_type == HOST_MEMORY && !is_int32) ||
(memory_type == DEVICE_MEMORY && is_int32)) {
return false;
}
}
@ -535,25 +536,6 @@ bool ReplaceTensorWithConstant(
if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
return false;
}
if (!disable_memory_output_type_check) {
if (partition_device && device_type != DEVICE_CPU) {
MemoryType original_output_memory_type;
if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
&original_output_memory_type)
.ok()) {
return false;
}
MemoryType const_output_memory_type;
if (!MemoryTypeForOutput(device_type, graph, constant_node, 0,
&const_output_memory_type)
.ok()) {
return false;
}
if (original_output_memory_type != const_output_memory_type) {
return false;
}
}
}
for (auto edge : edges_to_remove) {
graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
graph->RemoveEdge(edge);
@ -660,8 +642,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
constant_control_deps[tensors_to_replace[c].first];
if (ReplaceTensorWithConstant(
graph, partition_device, tensors_to_replace[c], outputs[c],
control_deps, opts.max_constant_size_in_bytes,
opts.disable_memory_output_type_check, generate_new_name)) {
control_deps, opts.max_constant_size_in_bytes, generate_new_name)) {
++num_nodes_replaced;
}
}

View File

@ -45,10 +45,6 @@ struct ConstantFoldingOptions {
// optimization.
int64 max_constant_size_in_bytes = 10 * 1024 * 1024;
// If disable_memory_output_type_check is true, we will disable output memory
// type check for constant node replacement.
bool disable_memory_output_type_check = false;
// A generator for the name suffix of constant folded nodes. A
// default id generator that monotonically increases is used if nullptr is
// passed.

View File

@ -39,8 +39,7 @@ void GraphOptimizer::Optimize(
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
const std::function<bool(const Node*)>& cse_consider_fn,
const std::function<bool(const Node*)>& cf_consider_fn,
bool cf_disable_memory_output_type_check) {
const std::function<bool(const Node*)>& cf_consider_fn) {
Graph* g = graph->get();
DumpGraph("Initial", g);
@ -65,8 +64,6 @@ void GraphOptimizer::Optimize(
ConstantFoldingOptions cf_opts;
cf_opts.shape_map = shape_map;
cf_opts.consider = cf_consider_fn;
cf_opts.disable_memory_output_type_check =
cf_disable_memory_output_type_check;
if (opts_.max_folded_constant_in_bytes() > 0) {
cf_opts.max_constant_size_in_bytes =
opts_.max_folded_constant_in_bytes();

View File

@ -47,16 +47,13 @@ class GraphOptimizer {
// returns true will be considered for CSE.
// If cf_consider_fn is not null then only nodes for which cf_consider_fn
// returns true will be considered for CF.
// If cf_disable_memory_output_type_check is true, CF will discard output
// memory type check for constant node replacement.
void Optimize(
FunctionLibraryRuntime* runtime, Env* env, Device* device,
std::unique_ptr<Graph>* graph,
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
const std::function<bool(const Node*)>& cse_consider_fn = nullptr,
const std::function<bool(const Node*)>& cf_consider_fn = nullptr,
bool cf_disable_memory_output_type_check = false);
const std::function<bool(const Node*)>& cf_consider_fn = nullptr);
const OptimizerOptions& options() { return opts_; }