Revert constant folding to previous state.
PiperOrigin-RevId: 215946205
This commit is contained in:
parent
0541a277d5
commit
d016650ca7
@ -94,8 +94,9 @@ Status FunctionalizeControlFlowForFunction(
|
||||
}
|
||||
});
|
||||
const FunctionBody* body = flr->GetFunctionBody(handle);
|
||||
Graph* g = body->graph;
|
||||
|
||||
// Check if the graph has Switch or Merge node before optimizing the graph.
|
||||
// Check if the graph has Switch or Merge node.
|
||||
bool has_switch_or_merge = false;
|
||||
for (Node* n : body->graph->nodes()) {
|
||||
if (n->type_string() == "Switch" || n->type_string() == "Merge") {
|
||||
@ -108,58 +109,13 @@ Status FunctionalizeControlFlowForFunction(
|
||||
// in function body. We still need to rewrite those functions and modify
|
||||
// corresponding nodes.
|
||||
|
||||
// Call graph optimizer. The most important optimization we need is constant
|
||||
// folding, which will replace ops like Shape/BroadcastGradientArgs with
|
||||
// constant shape input. Without this optimization, those ops might become
|
||||
// dynamic input for then/else body function and XLA will complain that input
|
||||
// is not compile time constant. We enable function inlining as well, because
|
||||
// otherwise we won't be able to infer shape for any node depending on
|
||||
// function call nodes.
|
||||
if (VLOG_IS_ON(4)) {
|
||||
dump_graph::DumpGraphToFile(
|
||||
absl::StrCat("functionalize_control_flow_before_opt_", func_name),
|
||||
*body->graph, fld);
|
||||
}
|
||||
// Optimizer accepts std::unique_ptr<Graph>* as input and might change
|
||||
// underlying pointer, thus we create a new Graph and copy from body->graph.
|
||||
std::unique_ptr<Graph> optimized_graph(new Graph(fld));
|
||||
CopyGraph(*body->graph, optimized_graph.get());
|
||||
OptimizerOptions opts;
|
||||
opts.set_opt_level(OptimizerOptions::L0);
|
||||
opts.set_do_function_inlining(true);
|
||||
opts.set_do_constant_folding(true);
|
||||
GraphOptimizer optimizer(opts);
|
||||
auto cf_consider_fn = [](const Node* n) {
|
||||
// Skip SymbolicGradient op when doing constant folding.
|
||||
// Enabling SymbolicGradient op in constant folding requires
|
||||
// flr->device() to be non-null, and here we have not constructed
|
||||
// proper Device object yet (it will be constructed in XlaCompiler).
|
||||
return n->type_string() != FunctionLibraryDefinition::kGradientOp;
|
||||
};
|
||||
optimizer.Optimize(flr, flr->env(),
|
||||
/*device=*/nullptr, &optimized_graph,
|
||||
/*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr,
|
||||
cf_consider_fn);
|
||||
if (VLOG_IS_ON(4)) {
|
||||
dump_graph::DumpGraphToFile(
|
||||
absl::StrCat("functionalize_control_flow_after_opt_", func_name),
|
||||
*optimized_graph, fld);
|
||||
}
|
||||
// Some inlined functions might have Switch/Merge nodes.
|
||||
for (Node* n : optimized_graph->nodes()) {
|
||||
if (n->type_string() == "Switch" || n->type_string() == "Merge") {
|
||||
has_switch_or_merge = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If any node has associated functions, functionalize them first.
|
||||
// Gather nodes with associated functions first, because rewriting those nodes
|
||||
// might involve node deletion/addition. Avoid modifying nodes while iterating
|
||||
// it.
|
||||
std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
|
||||
nodes_to_associated_functions;
|
||||
for (auto* n : optimized_graph->nodes()) {
|
||||
for (auto* n : g->nodes()) {
|
||||
auto associated_functions = GetAssociatedFunctions(*n, flr);
|
||||
if (!associated_functions.empty()) {
|
||||
nodes_to_associated_functions.push_back({n, associated_functions});
|
||||
@ -215,7 +171,7 @@ Status FunctionalizeControlFlowForFunction(
|
||||
// pointer. That's fine because in that case, associated_functions will
|
||||
// only have one member and the loop will only run once.
|
||||
TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
|
||||
optimized_graph.get(), n, fld, associated_function, new_name));
|
||||
g, n, fld, associated_function, new_name));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -227,21 +183,21 @@ Status FunctionalizeControlFlowForFunction(
|
||||
if (VLOG_IS_ON(4)) {
|
||||
dump_graph::DumpGraphToFile(
|
||||
absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
|
||||
*optimized_graph, fld);
|
||||
*g, fld);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld));
|
||||
TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld));
|
||||
if (VLOG_IS_ON(4)) {
|
||||
dump_graph::DumpGraphToFile(
|
||||
absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
|
||||
*optimized_graph, fld);
|
||||
absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g,
|
||||
fld);
|
||||
}
|
||||
}
|
||||
|
||||
if (*modified) {
|
||||
// Add rewritten FunctionDef into library.
|
||||
FunctionDef functionalized_fdef;
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name,
|
||||
&functionalized_fdef));
|
||||
TF_RETURN_IF_ERROR(
|
||||
GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
|
||||
if (func_name == new_func_name) {
|
||||
VLOG(2) << "Replacing function " << func_name;
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -466,23 +466,23 @@ Graph* GetConstantGraph(
|
||||
bool ReplaceTensorWithConstant(
|
||||
Graph* graph, Device* partition_device, NodeAndOutput tensor,
|
||||
const Tensor& constant, const gtl::FlatSet<Node*>& control_deps,
|
||||
int64 max_constant_size_in_bytes, bool disable_memory_output_type_check,
|
||||
int64 max_constant_size_in_bytes,
|
||||
const ConstantFoldNameGenerator& generate_new_name) {
|
||||
// Be conservative when replacing a tensor with a constant, when not
|
||||
// running on CPU.
|
||||
// 1) Do not replace another constant.
|
||||
// 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
|
||||
// constraint, do not replace it.
|
||||
// 3) If the size of the constant in bytes is too large (>
|
||||
// 3) If the destination tensor is an int32 tensor, and has DEVICE_MEMORY
|
||||
// constraint, do not replace it.
|
||||
// 4) If the size of the constant in bytes is too large (>
|
||||
// max_constant_in_bytes), do not replace it. This prevents the size of the
|
||||
// Graph from growing too large.
|
||||
// 4) If the constant op created does not have a kernel implementation
|
||||
// 5) If the constant op created does not have a kernel implementation
|
||||
// for the device, do not use it.
|
||||
// TODO(keveman): Consider adding a new constant op that has a kernel
|
||||
// implementation for all types, but with HostMemory constraint on it's
|
||||
// output.
|
||||
// 5) If the constant op for the device has different output memory type
|
||||
// from the original op output memory type, do not replace it.
|
||||
if (tensor.first->IsConstant()) {
|
||||
return false;
|
||||
}
|
||||
@ -497,7 +497,8 @@ bool ReplaceTensorWithConstant(
|
||||
return false;
|
||||
}
|
||||
bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32;
|
||||
if (memory_type == HOST_MEMORY && !is_int32) {
|
||||
if ((memory_type == HOST_MEMORY && !is_int32) ||
|
||||
(memory_type == DEVICE_MEMORY && is_int32)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -535,25 +536,6 @@ bool ReplaceTensorWithConstant(
|
||||
if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
|
||||
return false;
|
||||
}
|
||||
if (!disable_memory_output_type_check) {
|
||||
if (partition_device && device_type != DEVICE_CPU) {
|
||||
MemoryType original_output_memory_type;
|
||||
if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
|
||||
&original_output_memory_type)
|
||||
.ok()) {
|
||||
return false;
|
||||
}
|
||||
MemoryType const_output_memory_type;
|
||||
if (!MemoryTypeForOutput(device_type, graph, constant_node, 0,
|
||||
&const_output_memory_type)
|
||||
.ok()) {
|
||||
return false;
|
||||
}
|
||||
if (original_output_memory_type != const_output_memory_type) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto edge : edges_to_remove) {
|
||||
graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
|
||||
graph->RemoveEdge(edge);
|
||||
@ -660,8 +642,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
|
||||
constant_control_deps[tensors_to_replace[c].first];
|
||||
if (ReplaceTensorWithConstant(
|
||||
graph, partition_device, tensors_to_replace[c], outputs[c],
|
||||
control_deps, opts.max_constant_size_in_bytes,
|
||||
opts.disable_memory_output_type_check, generate_new_name)) {
|
||||
control_deps, opts.max_constant_size_in_bytes, generate_new_name)) {
|
||||
++num_nodes_replaced;
|
||||
}
|
||||
}
|
||||
|
@ -45,10 +45,6 @@ struct ConstantFoldingOptions {
|
||||
// optimization.
|
||||
int64 max_constant_size_in_bytes = 10 * 1024 * 1024;
|
||||
|
||||
// If disable_memory_output_type_check is true, we will disable output memory
|
||||
// type check for constant node replacement.
|
||||
bool disable_memory_output_type_check = false;
|
||||
|
||||
// A generator for the name suffix of constant folded nodes. A
|
||||
// default id generator that monotonically increases is used if nullptr is
|
||||
// passed.
|
||||
|
@ -39,8 +39,7 @@ void GraphOptimizer::Optimize(
|
||||
const std::unordered_map<string, std::vector<PartialTensorShape>>*
|
||||
shape_map,
|
||||
const std::function<bool(const Node*)>& cse_consider_fn,
|
||||
const std::function<bool(const Node*)>& cf_consider_fn,
|
||||
bool cf_disable_memory_output_type_check) {
|
||||
const std::function<bool(const Node*)>& cf_consider_fn) {
|
||||
Graph* g = graph->get();
|
||||
DumpGraph("Initial", g);
|
||||
|
||||
@ -65,8 +64,6 @@ void GraphOptimizer::Optimize(
|
||||
ConstantFoldingOptions cf_opts;
|
||||
cf_opts.shape_map = shape_map;
|
||||
cf_opts.consider = cf_consider_fn;
|
||||
cf_opts.disable_memory_output_type_check =
|
||||
cf_disable_memory_output_type_check;
|
||||
if (opts_.max_folded_constant_in_bytes() > 0) {
|
||||
cf_opts.max_constant_size_in_bytes =
|
||||
opts_.max_folded_constant_in_bytes();
|
||||
|
@ -47,16 +47,13 @@ class GraphOptimizer {
|
||||
// returns true will be considered for CSE.
|
||||
// If cf_consider_fn is not null then only nodes for which cf_consider_fn
|
||||
// returns true will be considered for CF.
|
||||
// If cf_disable_memory_output_type_check is true, CF will discard output
|
||||
// memory type check for constant node replacement.
|
||||
void Optimize(
|
||||
FunctionLibraryRuntime* runtime, Env* env, Device* device,
|
||||
std::unique_ptr<Graph>* graph,
|
||||
const std::unordered_map<string, std::vector<PartialTensorShape>>*
|
||||
shape_map,
|
||||
const std::function<bool(const Node*)>& cse_consider_fn = nullptr,
|
||||
const std::function<bool(const Node*)>& cf_consider_fn = nullptr,
|
||||
bool cf_disable_memory_output_type_check = false);
|
||||
const std::function<bool(const Node*)>& cf_consider_fn = nullptr);
|
||||
|
||||
const OptimizerOptions& options() { return opts_; }
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user