Skip computing nodes with known oversized outputs in constant folding.
The threshold on tensor size was applied after the value is computed, only when replacing the old nodes. However, that could already have caused OOM in large models. Changed compilation to XLA to limit TF constant folding to 1024 bytes, since it's only used for getting the shapes, and XLA internally also has constant folding. PiperOrigin-RevId: 337221696 Change-Id: I4cdca20d28141f34b2c85120298bffb89e6df85d
This commit is contained in:
parent
4b9cd3a42b
commit
c577eb1a3d
tensorflow
compiler/tf2xla
core
@ -584,6 +584,9 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
|
||||
opts.set_do_common_subexpression_elimination(false);
|
||||
opts.set_do_function_inlining(true);
|
||||
opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding);
|
||||
// Our constant folding is mainly for helping shape inference, so we do not
|
||||
// need to fold large constant values.
|
||||
opts.set_max_folded_constant_in_bytes(1024);
|
||||
GraphOptimizer optimizer(opts);
|
||||
// Do not constant fold nodes that output DT_VARIANT type tensors.
|
||||
// XLA does not support Const nodes of Variant type since it needs
|
||||
@ -627,8 +630,28 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
|
||||
graph_optimizer_options.inline_with_single_device_body_placer = true;
|
||||
graph_optimizer_options.ignore_noinline = is_inside_mustcompile;
|
||||
|
||||
optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
|
||||
/*device=*/nullptr, &graph, graph_optimizer_options);
|
||||
{
|
||||
GraphShapeInfo shape_info;
|
||||
InferShapes(graph.get(), /*arg_shapes=*/{},
|
||||
flib_runtime_->GetFunctionLibraryDefinition(), &shape_info)
|
||||
.IgnoreError();
|
||||
auto node_name_index = graph->BuildNodeNameIndex();
|
||||
std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
|
||||
for (const auto& node_shape_info : shape_info) {
|
||||
const string& node_name = node_shape_info.first;
|
||||
const std::vector<InferredShape>& output_shapes = node_shape_info.second;
|
||||
const auto& node_iter = node_name_index.find(node_name);
|
||||
if (node_iter != node_name_index.end()) {
|
||||
auto& partial_shapes = shape_map[node_name];
|
||||
for (const auto& inferred_shape : output_shapes) {
|
||||
partial_shapes.push_back(inferred_shape.shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
graph_optimizer_options.shape_map = &shape_map;
|
||||
optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
|
||||
/*device=*/nullptr, &graph, graph_optimizer_options);
|
||||
}
|
||||
|
||||
// Run shape inference on the graph and optimize the graph again.
|
||||
GraphShapeInfo shape_info;
|
||||
|
@ -219,6 +219,7 @@ bool IsConstantFoldable(
|
||||
const std::unordered_map<string, std::vector<PartialTensorShape>>*
|
||||
shape_map,
|
||||
const std::function<bool(const Node*)>& consider,
|
||||
int64 max_constant_size_in_bytes,
|
||||
std::unordered_map<const Node*, std::vector<Tensor>>*
|
||||
shape_replacement_map) {
|
||||
if (n->IsConstant()) {
|
||||
@ -233,6 +234,20 @@ bool IsConstantFoldable(
|
||||
if (consider && !consider(n)) {
|
||||
return false;
|
||||
}
|
||||
if (shape_map != nullptr) {
|
||||
// We can skip the node if an output is known to be oversized.
|
||||
auto shape_it = shape_map->find(n->name());
|
||||
if (shape_it != shape_map->end()) {
|
||||
for (int64 i = 0; i < shape_it->second.size(); ++i) {
|
||||
const auto& out_shape = shape_it->second[i];
|
||||
if (out_shape.IsFullyDefined() &&
|
||||
out_shape.num_elements() * DataTypeSize(n->output_type(i)) >
|
||||
max_constant_size_in_bytes) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (n->IsControlFlow() || n->IsSend() || n->IsRecv()) {
|
||||
return false;
|
||||
}
|
||||
@ -280,6 +295,7 @@ void ConsiderConstantFoldableNode(
|
||||
std::unordered_map<const Node*, std::vector<Tensor>>* shape_replacement_map,
|
||||
bool* internal_node_inserted) {
|
||||
if (IsConstantFoldable(n, opts.shape_map, opts.consider,
|
||||
opts.max_constant_size_in_bytes,
|
||||
shape_replacement_map)) {
|
||||
// A node is constant provided all of its non-control incoming Tensors come
|
||||
// from constant nodes, or it's a shape Op with statically known inputs in
|
||||
|
@ -496,27 +496,42 @@ Status TpuCompileOpKernelCommon::OptimizeGraph(
|
||||
opts.set_do_common_subexpression_elimination(false);
|
||||
opts.set_do_function_inlining(true);
|
||||
opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding);
|
||||
// Our constant folding is mainly for helping shape inference, so we do not
|
||||
// need to fold large constant values.
|
||||
opts.set_max_folded_constant_in_bytes(1024);
|
||||
GraphOptimizer optimizer(opts);
|
||||
// Performs a first function inlining pass before shape inference, since
|
||||
// otherwise shape inference can't see inside functions and a comprehensive
|
||||
// shape_map, including function ops, is needed to constant-propagate Shape
|
||||
// Ops below.
|
||||
GraphOptimizer::Options optimizer_opts;
|
||||
optimizer_opts.inline_multi_device_functions = true;
|
||||
optimizer_opts.inline_impl_selection_group_functions = true;
|
||||
optimizer_opts.inline_with_single_device_body_placer = true;
|
||||
optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts);
|
||||
{
|
||||
// Performs a first function inlining pass before shape inference, since
|
||||
// otherwise shape inference can't see inside functions and a comprehensive
|
||||
// shape_map, including function ops, is needed to constant-propagate Shape
|
||||
// Ops below.
|
||||
GraphOptimizer::Options optimizer_opts;
|
||||
optimizer_opts.inline_multi_device_functions = true;
|
||||
optimizer_opts.inline_impl_selection_group_functions = true;
|
||||
optimizer_opts.inline_with_single_device_body_placer = true;
|
||||
// Infer shapes for each node in the computation. Shape inference can help
|
||||
// skip constant folding of large shapes.
|
||||
GraphShapeInfo shape_info;
|
||||
TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
|
||||
metadata, arg_shapes, graph->get(), flr, &shape_info));
|
||||
// Converts the GraphShapeInfo into the form needed by the constant-folding
|
||||
// pass of the optimizer.
|
||||
std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
|
||||
TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
|
||||
metadata, arg_shapes, graph->get(), flr, &shape_info));
|
||||
optimizer_opts.shape_map = &shape_map;
|
||||
optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts);
|
||||
}
|
||||
|
||||
// Infer shapes for each node in the computation.
|
||||
GraphShapeInfo shape_info;
|
||||
TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
|
||||
metadata, arg_shapes, graph->get(), flr, &shape_info));
|
||||
|
||||
// Converts the GraphShapeInfo into the form needed by the constant-folding
|
||||
// pass of the optimizer.
|
||||
std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
|
||||
ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
|
||||
optimizer.Optimize(flr, flr->env(), flr->device(), graph, &shape_map);
|
||||
{
|
||||
// Infer shapes for each node in the computation.
|
||||
GraphShapeInfo shape_info;
|
||||
TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
|
||||
metadata, arg_shapes, graph->get(), flr, &shape_info));
|
||||
std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
|
||||
ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
|
||||
optimizer.Optimize(flr, flr->env(), flr->device(), graph, &shape_map);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(RewriteTensorListWithConstElement(graph->get(), fld));
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user