Revert constant folding to previous state.
PiperOrigin-RevId: 215946205
This commit is contained in:
parent
0541a277d5
commit
d016650ca7
@ -94,8 +94,9 @@ Status FunctionalizeControlFlowForFunction(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
const FunctionBody* body = flr->GetFunctionBody(handle);
|
const FunctionBody* body = flr->GetFunctionBody(handle);
|
||||||
|
Graph* g = body->graph;
|
||||||
|
|
||||||
// Check if the graph has Switch or Merge node before optimizing the graph.
|
// Check if the graph has Switch or Merge node.
|
||||||
bool has_switch_or_merge = false;
|
bool has_switch_or_merge = false;
|
||||||
for (Node* n : body->graph->nodes()) {
|
for (Node* n : body->graph->nodes()) {
|
||||||
if (n->type_string() == "Switch" || n->type_string() == "Merge") {
|
if (n->type_string() == "Switch" || n->type_string() == "Merge") {
|
||||||
@ -108,58 +109,13 @@ Status FunctionalizeControlFlowForFunction(
|
|||||||
// in function body. We still need to rewrite those functions and modify
|
// in function body. We still need to rewrite those functions and modify
|
||||||
// corresponding nodes.
|
// corresponding nodes.
|
||||||
|
|
||||||
// Call graph optimizer. The most important optimization we need is constant
|
|
||||||
// folding, which will replace ops like Shape/BroadcastGradientArgs with
|
|
||||||
// constant shape input. Without this optimization, those ops might become
|
|
||||||
// dynamic input for then/else body function and XLA will complain that input
|
|
||||||
// is not compile time constant. We enable function inlining as well, because
|
|
||||||
// otherwise we won't be able to infer shape for any node depending on
|
|
||||||
// function call nodes.
|
|
||||||
if (VLOG_IS_ON(4)) {
|
|
||||||
dump_graph::DumpGraphToFile(
|
|
||||||
absl::StrCat("functionalize_control_flow_before_opt_", func_name),
|
|
||||||
*body->graph, fld);
|
|
||||||
}
|
|
||||||
// Optimizer accepts std::unique_ptr<Graph>* as input and might change
|
|
||||||
// underlying pointer, thus we create a new Graph and copy from body->graph.
|
|
||||||
std::unique_ptr<Graph> optimized_graph(new Graph(fld));
|
|
||||||
CopyGraph(*body->graph, optimized_graph.get());
|
|
||||||
OptimizerOptions opts;
|
|
||||||
opts.set_opt_level(OptimizerOptions::L0);
|
|
||||||
opts.set_do_function_inlining(true);
|
|
||||||
opts.set_do_constant_folding(true);
|
|
||||||
GraphOptimizer optimizer(opts);
|
|
||||||
auto cf_consider_fn = [](const Node* n) {
|
|
||||||
// Skip SymbolicGradient op when doing constant folding.
|
|
||||||
// Enabling SymbolicGradient op in constant folding requires
|
|
||||||
// flr->device() to be non-null, and here we have not constructed
|
|
||||||
// proper Device object yet (it will be constructed in XlaCompiler).
|
|
||||||
return n->type_string() != FunctionLibraryDefinition::kGradientOp;
|
|
||||||
};
|
|
||||||
optimizer.Optimize(flr, flr->env(),
|
|
||||||
/*device=*/nullptr, &optimized_graph,
|
|
||||||
/*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr,
|
|
||||||
cf_consider_fn);
|
|
||||||
if (VLOG_IS_ON(4)) {
|
|
||||||
dump_graph::DumpGraphToFile(
|
|
||||||
absl::StrCat("functionalize_control_flow_after_opt_", func_name),
|
|
||||||
*optimized_graph, fld);
|
|
||||||
}
|
|
||||||
// Some inlined functions might have Switch/Merge nodes.
|
|
||||||
for (Node* n : optimized_graph->nodes()) {
|
|
||||||
if (n->type_string() == "Switch" || n->type_string() == "Merge") {
|
|
||||||
has_switch_or_merge = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If any node has associated functions, functionalize them first.
|
// If any node has associated functions, functionalize them first.
|
||||||
// Gather nodes with associated functions first, because rewriting those nodes
|
// Gather nodes with associated functions first, because rewriting those nodes
|
||||||
// might involve node deletion/addition. Avoid modifying nodes while iterating
|
// might involve node deletion/addition. Avoid modifying nodes while iterating
|
||||||
// it.
|
// it.
|
||||||
std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
|
std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
|
||||||
nodes_to_associated_functions;
|
nodes_to_associated_functions;
|
||||||
for (auto* n : optimized_graph->nodes()) {
|
for (auto* n : g->nodes()) {
|
||||||
auto associated_functions = GetAssociatedFunctions(*n, flr);
|
auto associated_functions = GetAssociatedFunctions(*n, flr);
|
||||||
if (!associated_functions.empty()) {
|
if (!associated_functions.empty()) {
|
||||||
nodes_to_associated_functions.push_back({n, associated_functions});
|
nodes_to_associated_functions.push_back({n, associated_functions});
|
||||||
@ -215,7 +171,7 @@ Status FunctionalizeControlFlowForFunction(
|
|||||||
// pointer. That's fine because in that case, associated_functions will
|
// pointer. That's fine because in that case, associated_functions will
|
||||||
// only have one member and the loop will only run once.
|
// only have one member and the loop will only run once.
|
||||||
TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
|
TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
|
||||||
optimized_graph.get(), n, fld, associated_function, new_name));
|
g, n, fld, associated_function, new_name));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -227,21 +183,21 @@ Status FunctionalizeControlFlowForFunction(
|
|||||||
if (VLOG_IS_ON(4)) {
|
if (VLOG_IS_ON(4)) {
|
||||||
dump_graph::DumpGraphToFile(
|
dump_graph::DumpGraphToFile(
|
||||||
absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
|
absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
|
||||||
*optimized_graph, fld);
|
*g, fld);
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld));
|
TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld));
|
||||||
if (VLOG_IS_ON(4)) {
|
if (VLOG_IS_ON(4)) {
|
||||||
dump_graph::DumpGraphToFile(
|
dump_graph::DumpGraphToFile(
|
||||||
absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
|
absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g,
|
||||||
*optimized_graph, fld);
|
fld);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (*modified) {
|
if (*modified) {
|
||||||
// Add rewritten FunctionDef into library.
|
// Add rewritten FunctionDef into library.
|
||||||
FunctionDef functionalized_fdef;
|
FunctionDef functionalized_fdef;
|
||||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name,
|
TF_RETURN_IF_ERROR(
|
||||||
&functionalized_fdef));
|
GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
|
||||||
if (func_name == new_func_name) {
|
if (func_name == new_func_name) {
|
||||||
VLOG(2) << "Replacing function " << func_name;
|
VLOG(2) << "Replacing function " << func_name;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
|
@ -466,23 +466,23 @@ Graph* GetConstantGraph(
|
|||||||
bool ReplaceTensorWithConstant(
|
bool ReplaceTensorWithConstant(
|
||||||
Graph* graph, Device* partition_device, NodeAndOutput tensor,
|
Graph* graph, Device* partition_device, NodeAndOutput tensor,
|
||||||
const Tensor& constant, const gtl::FlatSet<Node*>& control_deps,
|
const Tensor& constant, const gtl::FlatSet<Node*>& control_deps,
|
||||||
int64 max_constant_size_in_bytes, bool disable_memory_output_type_check,
|
int64 max_constant_size_in_bytes,
|
||||||
const ConstantFoldNameGenerator& generate_new_name) {
|
const ConstantFoldNameGenerator& generate_new_name) {
|
||||||
// Be conservative when replacing a tensor with a constant, when not
|
// Be conservative when replacing a tensor with a constant, when not
|
||||||
// running on CPU.
|
// running on CPU.
|
||||||
// 1) Do not replace another constant.
|
// 1) Do not replace another constant.
|
||||||
// 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
|
// 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
|
||||||
// constraint, do not replace it.
|
// constraint, do not replace it.
|
||||||
// 3) If the size of the constant in bytes is too large (>
|
// 3) If the destination tensor is an int32 tensor, and has DEVICE_MEMORY
|
||||||
|
// constraint, do not replace it.
|
||||||
|
// 4) If the size of the constant in bytes is too large (>
|
||||||
// max_constant_in_bytes), do not replace it. This prevents the size of the
|
// max_constant_in_bytes), do not replace it. This prevents the size of the
|
||||||
// Graph from growing too large.
|
// Graph from growing too large.
|
||||||
// 4) If the constant op created does not have a kernel implementation
|
// 5) If the constant op created does not have a kernel implementation
|
||||||
// for the device, do not use it.
|
// for the device, do not use it.
|
||||||
// TODO(keveman): Consider adding a new constant op that has a kernel
|
// TODO(keveman): Consider adding a new constant op that has a kernel
|
||||||
// implementation for all types, but with HostMemory constraint on it's
|
// implementation for all types, but with HostMemory constraint on it's
|
||||||
// output.
|
// output.
|
||||||
// 5) If the constant op for the device has different output memory type
|
|
||||||
// from the original op output memory type, do not replace it.
|
|
||||||
if (tensor.first->IsConstant()) {
|
if (tensor.first->IsConstant()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -497,7 +497,8 @@ bool ReplaceTensorWithConstant(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32;
|
bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32;
|
||||||
if (memory_type == HOST_MEMORY && !is_int32) {
|
if ((memory_type == HOST_MEMORY && !is_int32) ||
|
||||||
|
(memory_type == DEVICE_MEMORY && is_int32)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -535,25 +536,6 @@ bool ReplaceTensorWithConstant(
|
|||||||
if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
|
if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!disable_memory_output_type_check) {
|
|
||||||
if (partition_device && device_type != DEVICE_CPU) {
|
|
||||||
MemoryType original_output_memory_type;
|
|
||||||
if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
|
|
||||||
&original_output_memory_type)
|
|
||||||
.ok()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
MemoryType const_output_memory_type;
|
|
||||||
if (!MemoryTypeForOutput(device_type, graph, constant_node, 0,
|
|
||||||
&const_output_memory_type)
|
|
||||||
.ok()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (original_output_memory_type != const_output_memory_type) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (auto edge : edges_to_remove) {
|
for (auto edge : edges_to_remove) {
|
||||||
graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
|
graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
|
||||||
graph->RemoveEdge(edge);
|
graph->RemoveEdge(edge);
|
||||||
@ -660,8 +642,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
|
|||||||
constant_control_deps[tensors_to_replace[c].first];
|
constant_control_deps[tensors_to_replace[c].first];
|
||||||
if (ReplaceTensorWithConstant(
|
if (ReplaceTensorWithConstant(
|
||||||
graph, partition_device, tensors_to_replace[c], outputs[c],
|
graph, partition_device, tensors_to_replace[c], outputs[c],
|
||||||
control_deps, opts.max_constant_size_in_bytes,
|
control_deps, opts.max_constant_size_in_bytes, generate_new_name)) {
|
||||||
opts.disable_memory_output_type_check, generate_new_name)) {
|
|
||||||
++num_nodes_replaced;
|
++num_nodes_replaced;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -45,10 +45,6 @@ struct ConstantFoldingOptions {
|
|||||||
// optimization.
|
// optimization.
|
||||||
int64 max_constant_size_in_bytes = 10 * 1024 * 1024;
|
int64 max_constant_size_in_bytes = 10 * 1024 * 1024;
|
||||||
|
|
||||||
// If disable_memory_output_type_check is true, we will disable output memory
|
|
||||||
// type check for constant node replacement.
|
|
||||||
bool disable_memory_output_type_check = false;
|
|
||||||
|
|
||||||
// A generator for the name suffix of constant folded nodes. A
|
// A generator for the name suffix of constant folded nodes. A
|
||||||
// default id generator that monotonically increases is used if nullptr is
|
// default id generator that monotonically increases is used if nullptr is
|
||||||
// passed.
|
// passed.
|
||||||
|
@ -39,8 +39,7 @@ void GraphOptimizer::Optimize(
|
|||||||
const std::unordered_map<string, std::vector<PartialTensorShape>>*
|
const std::unordered_map<string, std::vector<PartialTensorShape>>*
|
||||||
shape_map,
|
shape_map,
|
||||||
const std::function<bool(const Node*)>& cse_consider_fn,
|
const std::function<bool(const Node*)>& cse_consider_fn,
|
||||||
const std::function<bool(const Node*)>& cf_consider_fn,
|
const std::function<bool(const Node*)>& cf_consider_fn) {
|
||||||
bool cf_disable_memory_output_type_check) {
|
|
||||||
Graph* g = graph->get();
|
Graph* g = graph->get();
|
||||||
DumpGraph("Initial", g);
|
DumpGraph("Initial", g);
|
||||||
|
|
||||||
@ -65,8 +64,6 @@ void GraphOptimizer::Optimize(
|
|||||||
ConstantFoldingOptions cf_opts;
|
ConstantFoldingOptions cf_opts;
|
||||||
cf_opts.shape_map = shape_map;
|
cf_opts.shape_map = shape_map;
|
||||||
cf_opts.consider = cf_consider_fn;
|
cf_opts.consider = cf_consider_fn;
|
||||||
cf_opts.disable_memory_output_type_check =
|
|
||||||
cf_disable_memory_output_type_check;
|
|
||||||
if (opts_.max_folded_constant_in_bytes() > 0) {
|
if (opts_.max_folded_constant_in_bytes() > 0) {
|
||||||
cf_opts.max_constant_size_in_bytes =
|
cf_opts.max_constant_size_in_bytes =
|
||||||
opts_.max_folded_constant_in_bytes();
|
opts_.max_folded_constant_in_bytes();
|
||||||
|
@ -47,16 +47,13 @@ class GraphOptimizer {
|
|||||||
// returns true will be considered for CSE.
|
// returns true will be considered for CSE.
|
||||||
// If cf_consider_fn is not null then only nodes for which cf_consider_fn
|
// If cf_consider_fn is not null then only nodes for which cf_consider_fn
|
||||||
// returns true will be considered for CF.
|
// returns true will be considered for CF.
|
||||||
// If cf_disable_memory_output_type_check is true, CF will discard output
|
|
||||||
// memory type check for constant node replacement.
|
|
||||||
void Optimize(
|
void Optimize(
|
||||||
FunctionLibraryRuntime* runtime, Env* env, Device* device,
|
FunctionLibraryRuntime* runtime, Env* env, Device* device,
|
||||||
std::unique_ptr<Graph>* graph,
|
std::unique_ptr<Graph>* graph,
|
||||||
const std::unordered_map<string, std::vector<PartialTensorShape>>*
|
const std::unordered_map<string, std::vector<PartialTensorShape>>*
|
||||||
shape_map,
|
shape_map,
|
||||||
const std::function<bool(const Node*)>& cse_consider_fn = nullptr,
|
const std::function<bool(const Node*)>& cse_consider_fn = nullptr,
|
||||||
const std::function<bool(const Node*)>& cf_consider_fn = nullptr,
|
const std::function<bool(const Node*)>& cf_consider_fn = nullptr);
|
||||||
bool cf_disable_memory_output_type_check = false);
|
|
||||||
|
|
||||||
const OptimizerOptions& options() { return opts_; }
|
const OptimizerOptions& options() { return opts_; }
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user