Resubmit constant folding change without the 1024byte limit. It was causing tf.where to fail in tf2xla.

PiperOrigin-RevId: 337236352
Change-Id: I44b8a99c0e74f2d4814933e05149e8eab5b04aaa
This commit is contained in:
Yuanzhong Xu 2020-10-14 21:43:47 -07:00 committed by TensorFlower Gardener
parent 2ae0d560ee
commit ff2b597e36
3 changed files with 69 additions and 21 deletions

View File

@ -627,8 +627,28 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
graph_optimizer_options.inline_with_single_device_body_placer = true;
graph_optimizer_options.ignore_noinline = is_inside_mustcompile;
{
GraphShapeInfo shape_info;
InferShapes(graph.get(), /*arg_shapes=*/{},
flib_runtime_->GetFunctionLibraryDefinition(), &shape_info)
.IgnoreError();
auto node_name_index = graph->BuildNodeNameIndex();
std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
for (const auto& node_shape_info : shape_info) {
const string& node_name = node_shape_info.first;
const std::vector<InferredShape>& output_shapes = node_shape_info.second;
const auto& node_iter = node_name_index.find(node_name);
if (node_iter != node_name_index.end()) {
auto& partial_shapes = shape_map[node_name];
for (const auto& inferred_shape : output_shapes) {
partial_shapes.push_back(inferred_shape.shape);
}
}
}
graph_optimizer_options.shape_map = &shape_map;
optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
/*device=*/nullptr, &graph, graph_optimizer_options);
}
// Run shape inference on the graph and optimize the graph again.
GraphShapeInfo shape_info;

View File

@ -219,6 +219,7 @@ bool IsConstantFoldable(
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
const std::function<bool(const Node*)>& consider,
int64 max_constant_size_in_bytes,
std::unordered_map<const Node*, std::vector<Tensor>>*
shape_replacement_map) {
if (n->IsConstant()) {
@ -233,6 +234,20 @@ bool IsConstantFoldable(
if (consider && !consider(n)) {
return false;
}
if (shape_map != nullptr) {
// We can skip the node if an output is known to be oversized.
auto shape_it = shape_map->find(n->name());
if (shape_it != shape_map->end()) {
for (int64 i = 0; i < shape_it->second.size(); ++i) {
const auto& out_shape = shape_it->second[i];
if (out_shape.IsFullyDefined() &&
out_shape.num_elements() * DataTypeSize(n->output_type(i)) >
max_constant_size_in_bytes) {
return false;
}
}
}
}
if (n->IsControlFlow() || n->IsSend() || n->IsRecv()) {
return false;
}
@ -280,6 +295,7 @@ void ConsiderConstantFoldableNode(
std::unordered_map<const Node*, std::vector<Tensor>>* shape_replacement_map,
bool* internal_node_inserted) {
if (IsConstantFoldable(n, opts.shape_map, opts.consider,
opts.max_constant_size_in_bytes,
shape_replacement_map)) {
// A node is constant provided all of its non-control incoming Tensors come
// from constant nodes, or it's a shape Op with statically known inputs in

View File

@ -497,6 +497,7 @@ Status TpuCompileOpKernelCommon::OptimizeGraph(
opts.set_do_function_inlining(true);
opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding);
GraphOptimizer optimizer(opts);
{
// Performs a first function inlining pass before shape inference, since
// otherwise shape inference can't see inside functions and a comprehensive
// shape_map, including function ops, is needed to constant-propagate Shape
@ -505,18 +506,29 @@ Status TpuCompileOpKernelCommon::OptimizeGraph(
optimizer_opts.inline_multi_device_functions = true;
optimizer_opts.inline_impl_selection_group_functions = true;
optimizer_opts.inline_with_single_device_body_placer = true;
// Infer shapes for each node in the computation. Shape inference can help
// skip constant folding of large shapes.
GraphShapeInfo shape_info;
TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
metadata, arg_shapes, graph->get(), flr, &shape_info));
// Converts the GraphShapeInfo into the form needed by the constant-folding
// pass of the optimizer.
std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
metadata, arg_shapes, graph->get(), flr, &shape_info));
optimizer_opts.shape_map = &shape_map;
optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts);
}
{
// Infer shapes for each node in the computation.
GraphShapeInfo shape_info;
TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation(
metadata, arg_shapes, graph->get(), flr, &shape_info));
// Converts the GraphShapeInfo into the form needed by the constant-folding
// pass of the optimizer.
std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map);
optimizer.Optimize(flr, flr->env(), flr->device(), graph, &shape_map);
}
TF_RETURN_IF_ERROR(RewriteTensorListWithConstElement(graph->get(), fld));