Resubmit constant folding change without the 1024byte limit. It was causing tf.where to fail in tf2xla.
PiperOrigin-RevId: 337236352 Change-Id: I44b8a99c0e74f2d4814933e05149e8eab5b04aaa
This commit is contained in:
parent
2ae0d560ee
commit
ff2b597e36
@ -627,8 +627,28 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
|
||||
graph_optimizer_options.inline_with_single_device_body_placer = true;
|
||||
graph_optimizer_options.ignore_noinline = is_inside_mustcompile;
|
||||
|
||||
{
|
||||
GraphShapeInfo shape_info;
|
||||
InferShapes(graph.get(), /*arg_shapes=*/{},
|
||||
flib_runtime_->GetFunctionLibraryDefinition(), &shape_info)
|
||||
.IgnoreError();
|
||||
auto node_name_index = graph->BuildNodeNameIndex();
|
||||
std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
|
||||
for (const auto& node_shape_info : shape_info) {
|
||||
const string& node_name = node_shape_info.first;
|
||||
const std::vector<InferredShape>& output_shapes = node_shape_info.second;
|
||||
const auto& node_iter = node_name_index.find(node_name);
|
||||
if (node_iter != node_name_index.end()) {
|
||||
auto& partial_shapes = shape_map[node_name];
|
||||
for (const auto& inferred_shape : output_shapes) {
|
||||
partial_shapes.push_back(inferred_shape.shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
graph_optimizer_options.shape_map = &shape_map;
|
||||
optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
|
||||
/*device=*/nullptr, &graph, graph_optimizer_options);
|
||||
}
|
||||
|
||||
// Run shape inference on the graph and optimize the graph again.
|
||||
GraphShapeInfo shape_info;
|
||||
|
@ -219,6 +219,7 @@ bool IsConstantFoldable(
|
||||
const std::unordered_map<string, std::vector<PartialTensorShape>>*
|
||||
shape_map,
|
||||
const std::function<bool(const Node*)>& consider,
|
||||
int64 max_constant_size_in_bytes,
|
||||
std::unordered_map<const Node*, std::vector<Tensor>>*
|
||||
shape_replacement_map) {
|
||||
if (n->IsConstant()) {
|
||||
@ -233,6 +234,20 @@ bool IsConstantFoldable(
|
||||
if (consider && !consider(n)) {
|
||||
return false;
|
||||
}
|
||||
if (shape_map != nullptr) {
|
||||
// We can skip the node if an output is known to be oversized.
|
||||
auto shape_it = shape_map->find(n->name());
|
||||
if (shape_it != shape_map->end()) {
|
||||
for (int64 i = 0; i < shape_it->second.size(); ++i) {
|
||||
const auto& out_shape = shape_it->second[i];
|
||||
if (out_shape.IsFullyDefined() &&
|
||||
out_shape.num_elements() * DataTypeSize(n->output_type(i)) >
|
||||
max_constant_size_in_bytes) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (n->IsControlFlow() || n->IsSend() || n->IsRecv()) {
|
||||
return false;
|
||||
}
|
||||
@ -280,6 +295,7 @@ void ConsiderConstantFoldableNode(
|
||||
std::unordered_map<const Node*, std::vector<Tensor>>* shape_replacement_map,
|
||||
bool* internal_node_inserted) {
|
||||
if (IsConstantFoldable(n, opts.shape_map, opts.consider,
|
||||
opts.max_constant_size_in_bytes,
|
||||
shape_replacement_map)) {
|
||||
// A node is constant provided all of its non-control incoming Tensors come
|
||||
// from constant nodes, or it's a shape Op with statically known inputs in
|
||||
|
@ -497,6 +497,7 @@ Status TpuCompileOpKernelCommon::OptimizeGraph(
|
||||
opts.set_do_function_inlining(true);
|
||||
opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding);
|
||||
GraphOptimizer optimizer(opts);
|
||||
{
|
||||
// Performs a first function inlining pass before shape inference, since
|
||||
// otherwise shape inference can't see inside functions and a comprehensive
|
||||
// shape_map, including function ops, is needed to constant-propagate Shape
|
||||
@ -505,18 +506,29 @@ Status TpuCompileOpKernelCommon::OptimizeGraph(
|
||||
optimizer_opts.inline_multi_device_functions = true;
|
||||
optimizer_opts.inline_impl_selection_group_functions = true;
|
||||
optimizer_opts.inline_with_single_device_body_placer = true;
|
||||
// Infer shapes for each node in the computation. Shape inference can help
|
||||
// skip constant folding of large shapes.
|
||||
GraphShapeInfo shape_info;
|
||||
TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
|
||||
metadata, arg_shapes, graph->get(), flr, &shape_info));
|
||||
// Converts the GraphShapeInfo into the form needed by the constant-folding
|
||||
// pass of the optimizer.
|
||||
std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
|
||||
TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
|
||||
metadata, arg_shapes, graph->get(), flr, &shape_info));
|
||||
optimizer_opts.shape_map = &shape_map;
|
||||
optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts);
|
||||
}
|
||||
|
||||
{
|
||||
// Infer shapes for each node in the computation.
|
||||
GraphShapeInfo shape_info;
|
||||
TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
|
||||
metadata, arg_shapes, graph->get(), flr, &shape_info));
|
||||
|
||||
// Converts the GraphShapeInfo into the form needed by the constant-folding
|
||||
// pass of the optimizer.
|
||||
std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
|
||||
ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
|
||||
optimizer.Optimize(flr, flr->env(), flr->device(), graph, &shape_map);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(RewriteTensorListWithConstElement(graph->get(), fld));
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user