Update CompileGraphToXlaHlo to populate target/control ret nodes.

This is in preparation of updating graph pruning to always prune imported function graphs.

PiperOrigin-RevId: 335944889
Change-Id: I3f6156aa08384883eee6227210f8fc8f1b7cc575
This commit is contained in:
Andy Ly 2020-10-07 13:47:41 -07:00 committed by TensorFlower Gardener
parent d460bb83cb
commit c0da1d4092
4 changed files with 17 additions and 7 deletions

View File

@ -303,8 +303,12 @@ Status XlaCompilationCache::CompileSingleOp(
}
GraphDebugInfo debug_info;
std::vector<std::string> control_rets;
if (result_dtypes.empty()) {
control_rets.push_back(node_def.name());
}
return CompileGraphToXlaHlo(
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
options.device_type.type_string(), compile_options.use_tuple_arg,
*options.flib_def, debug_info, options.shape_representation_fn, result);
#endif

View File

@ -534,8 +534,9 @@ Status CompileGraphToXlaHlo(
Status CompileGraphToXlaHlo(
const Graph& graph, llvm::ArrayRef<XlaArgument> args,
llvm::StringRef device_type, bool use_tuple_args,
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
bool use_tuple_args, const FunctionLibraryDefinition& flib_def,
const GraphDebugInfo& debug_info,
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
XlaCompilationResult* compilation_result,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
@ -544,6 +545,7 @@ Status CompileGraphToXlaHlo(
RegisterDialects(context.getDialectRegistry());
GraphImportConfig config;
config.graph_as_function = true;
config.control_outputs = control_rets;
// Disable shape inference during import as some TensorFlow op fails during
// shape inference with dynamic shaped operands. This in turn causes the
// import to fail. Shape inference during import is going to be removed and

View File

@ -116,11 +116,11 @@ Status CompileGraphToXlaHlo(
// Compiles a TensorFlow Graph into XLA HLO, generates all accompanying metadata
// and stores them in CompilationResult.
// TODO(lyandy): Allow populating of targets/control outputs.
Status CompileGraphToXlaHlo(
const Graph& graph, llvm::ArrayRef<XlaArgument> args,
llvm::StringRef device_type, bool use_tuple_args,
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
bool use_tuple_args, const FunctionLibraryDefinition& flib_def,
const GraphDebugInfo& debug_info,
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
XlaCompilationResult* compilation_result,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>

View File

@ -743,9 +743,13 @@ Status XlaCompiler::CompileFunction(
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
VLOG(1) << "Using MLIR bridge";
GraphDebugInfo debug_info;
std::vector<std::string> control_rets;
for (const auto* control_ret_node : fbody->control_ret_nodes) {
control_rets.push_back(control_ret_node->name());
}
TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
options_.device_type.type_string(), options.use_tuple_arg,
control_rets, options_.device_type.type_string(), options.use_tuple_arg,
*options_.flib_def, debug_info, options_.shape_representation_fn,
result));
} else {