Resubmit constant folding change without the 1024byte limit. It was causing tf.where to fail in tf2xla.
PiperOrigin-RevId: 337236352 Change-Id: I44b8a99c0e74f2d4814933e05149e8eab5b04aaa
This commit is contained in:
		
							parent
							
								
									2ae0d560ee
								
							
						
					
					
						commit
						ff2b597e36
					
				| @ -627,8 +627,28 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) { | |||||||
|   graph_optimizer_options.inline_with_single_device_body_placer = true; |   graph_optimizer_options.inline_with_single_device_body_placer = true; | ||||||
|   graph_optimizer_options.ignore_noinline = is_inside_mustcompile; |   graph_optimizer_options.ignore_noinline = is_inside_mustcompile; | ||||||
| 
 | 
 | ||||||
|  |   { | ||||||
|  |     GraphShapeInfo shape_info; | ||||||
|  |     InferShapes(graph.get(), /*arg_shapes=*/{}, | ||||||
|  |                 flib_runtime_->GetFunctionLibraryDefinition(), &shape_info) | ||||||
|  |         .IgnoreError(); | ||||||
|  |     auto node_name_index = graph->BuildNodeNameIndex(); | ||||||
|  |     std::unordered_map<string, std::vector<PartialTensorShape>> shape_map; | ||||||
|  |     for (const auto& node_shape_info : shape_info) { | ||||||
|  |       const string& node_name = node_shape_info.first; | ||||||
|  |       const std::vector<InferredShape>& output_shapes = node_shape_info.second; | ||||||
|  |       const auto& node_iter = node_name_index.find(node_name); | ||||||
|  |       if (node_iter != node_name_index.end()) { | ||||||
|  |         auto& partial_shapes = shape_map[node_name]; | ||||||
|  |         for (const auto& inferred_shape : output_shapes) { | ||||||
|  |           partial_shapes.push_back(inferred_shape.shape); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     graph_optimizer_options.shape_map = &shape_map; | ||||||
|     optimizer.Optimize(flib_runtime_, flib_runtime_->env(), |     optimizer.Optimize(flib_runtime_, flib_runtime_->env(), | ||||||
|                        /*device=*/nullptr, &graph, graph_optimizer_options); |                        /*device=*/nullptr, &graph, graph_optimizer_options); | ||||||
|  |   } | ||||||
| 
 | 
 | ||||||
|   // Run shape inference on the graph and optimize the graph again.
 |   // Run shape inference on the graph and optimize the graph again.
 | ||||||
|   GraphShapeInfo shape_info; |   GraphShapeInfo shape_info; | ||||||
|  | |||||||
| @ -219,6 +219,7 @@ bool IsConstantFoldable( | |||||||
|     const std::unordered_map<string, std::vector<PartialTensorShape>>* |     const std::unordered_map<string, std::vector<PartialTensorShape>>* | ||||||
|         shape_map, |         shape_map, | ||||||
|     const std::function<bool(const Node*)>& consider, |     const std::function<bool(const Node*)>& consider, | ||||||
|  |     int64 max_constant_size_in_bytes, | ||||||
|     std::unordered_map<const Node*, std::vector<Tensor>>* |     std::unordered_map<const Node*, std::vector<Tensor>>* | ||||||
|         shape_replacement_map) { |         shape_replacement_map) { | ||||||
|   if (n->IsConstant()) { |   if (n->IsConstant()) { | ||||||
| @ -233,6 +234,20 @@ bool IsConstantFoldable( | |||||||
|   if (consider && !consider(n)) { |   if (consider && !consider(n)) { | ||||||
|     return false; |     return false; | ||||||
|   } |   } | ||||||
|  |   if (shape_map != nullptr) { | ||||||
|  |     // We can skip the node if an output is known to be oversized.
 | ||||||
|  |     auto shape_it = shape_map->find(n->name()); | ||||||
|  |     if (shape_it != shape_map->end()) { | ||||||
|  |       for (int64 i = 0; i < shape_it->second.size(); ++i) { | ||||||
|  |         const auto& out_shape = shape_it->second[i]; | ||||||
|  |         if (out_shape.IsFullyDefined() && | ||||||
|  |             out_shape.num_elements() * DataTypeSize(n->output_type(i)) > | ||||||
|  |                 max_constant_size_in_bytes) { | ||||||
|  |           return false; | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|   if (n->IsControlFlow() || n->IsSend() || n->IsRecv()) { |   if (n->IsControlFlow() || n->IsSend() || n->IsRecv()) { | ||||||
|     return false; |     return false; | ||||||
|   } |   } | ||||||
| @ -280,6 +295,7 @@ void ConsiderConstantFoldableNode( | |||||||
|     std::unordered_map<const Node*, std::vector<Tensor>>* shape_replacement_map, |     std::unordered_map<const Node*, std::vector<Tensor>>* shape_replacement_map, | ||||||
|     bool* internal_node_inserted) { |     bool* internal_node_inserted) { | ||||||
|   if (IsConstantFoldable(n, opts.shape_map, opts.consider, |   if (IsConstantFoldable(n, opts.shape_map, opts.consider, | ||||||
|  |                          opts.max_constant_size_in_bytes, | ||||||
|                          shape_replacement_map)) { |                          shape_replacement_map)) { | ||||||
|     // A node is constant provided all of its non-control incoming Tensors come
 |     // A node is constant provided all of its non-control incoming Tensors come
 | ||||||
|     // from constant nodes, or it's a shape Op with statically known inputs in
 |     // from constant nodes, or it's a shape Op with statically known inputs in
 | ||||||
|  | |||||||
| @ -497,6 +497,7 @@ Status TpuCompileOpKernelCommon::OptimizeGraph( | |||||||
|   opts.set_do_function_inlining(true); |   opts.set_do_function_inlining(true); | ||||||
|   opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding); |   opts.set_do_constant_folding(!flags->tf_xla_disable_constant_folding); | ||||||
|   GraphOptimizer optimizer(opts); |   GraphOptimizer optimizer(opts); | ||||||
|  |   { | ||||||
|     // Performs a first function inlining pass before shape inference, since
 |     // Performs a first function inlining pass before shape inference, since
 | ||||||
|     // otherwise shape inference can't see inside functions and a comprehensive
 |     // otherwise shape inference can't see inside functions and a comprehensive
 | ||||||
|     // shape_map, including function ops, is needed to constant-propagate Shape
 |     // shape_map, including function ops, is needed to constant-propagate Shape
 | ||||||
| @ -505,18 +506,29 @@ Status TpuCompileOpKernelCommon::OptimizeGraph( | |||||||
|     optimizer_opts.inline_multi_device_functions = true; |     optimizer_opts.inline_multi_device_functions = true; | ||||||
|     optimizer_opts.inline_impl_selection_group_functions = true; |     optimizer_opts.inline_impl_selection_group_functions = true; | ||||||
|     optimizer_opts.inline_with_single_device_body_placer = true; |     optimizer_opts.inline_with_single_device_body_placer = true; | ||||||
|  |     // Infer shapes for each node in the computation. Shape inference can help
 | ||||||
|  |     // skip constant folding of large shapes.
 | ||||||
|  |     GraphShapeInfo shape_info; | ||||||
|  |     TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation( | ||||||
|  |         metadata, arg_shapes, graph->get(), flr, &shape_info)); | ||||||
|  |     // Converts the GraphShapeInfo into the form needed by the constant-folding
 | ||||||
|  |     // pass of the optimizer.
 | ||||||
|  |     std::unordered_map<string, std::vector<PartialTensorShape>> shape_map; | ||||||
|  |     TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation( | ||||||
|  |         metadata, arg_shapes, graph->get(), flr, &shape_info)); | ||||||
|  |     optimizer_opts.shape_map = &shape_map; | ||||||
|     optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts); |     optimizer.Optimize(flr, flr->env(), flr->device(), graph, optimizer_opts); | ||||||
|  |   } | ||||||
| 
 | 
 | ||||||
|  |   { | ||||||
|     // Infer shapes for each node in the computation.
 |     // Infer shapes for each node in the computation.
 | ||||||
|     GraphShapeInfo shape_info; |     GraphShapeInfo shape_info; | ||||||
|     TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation( |     TF_RETURN_IF_ERROR(RunShapeInferenceOnComputation( | ||||||
|         metadata, arg_shapes, graph->get(), flr, &shape_info)); |         metadata, arg_shapes, graph->get(), flr, &shape_info)); | ||||||
| 
 |  | ||||||
|   // Converts the GraphShapeInfo into the form needed by the constant-folding
 |  | ||||||
|   // pass of the optimizer.
 |  | ||||||
|     std::unordered_map<string, std::vector<PartialTensorShape>> shape_map; |     std::unordered_map<string, std::vector<PartialTensorShape>> shape_map; | ||||||
|     ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map); |     ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map); | ||||||
|     optimizer.Optimize(flr, flr->env(), flr->device(), graph, &shape_map); |     optimizer.Optimize(flr, flr->env(), flr->device(), graph, &shape_map); | ||||||
|  |   } | ||||||
| 
 | 
 | ||||||
|   TF_RETURN_IF_ERROR(RewriteTensorListWithConstElement(graph->get(), fld)); |   TF_RETURN_IF_ERROR(RewriteTensorListWithConstElement(graph->get(), fld)); | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user