Update CompileGraphToXlaHlo to populate target/control ret nodes.
This is in preparation of updating graph pruning to always prune imported function graphs. PiperOrigin-RevId: 335944889 Change-Id: I3f6156aa08384883eee6227210f8fc8f1b7cc575
This commit is contained in:
parent
d460bb83cb
commit
c0da1d4092
tensorflow/compiler
jit
mlir/tensorflow/utils
tf2xla
@ -303,8 +303,12 @@ Status XlaCompilationCache::CompileSingleOp(
|
||||
}
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
std::vector<std::string> control_rets;
|
||||
if (result_dtypes.empty()) {
|
||||
control_rets.push_back(node_def.name());
|
||||
}
|
||||
return CompileGraphToXlaHlo(
|
||||
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
|
||||
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
|
||||
options.device_type.type_string(), compile_options.use_tuple_arg,
|
||||
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
||||
#endif
|
||||
|
@ -534,8 +534,9 @@ Status CompileGraphToXlaHlo(
|
||||
|
||||
Status CompileGraphToXlaHlo(
|
||||
const Graph& graph, llvm::ArrayRef<XlaArgument> args,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||
llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
|
||||
bool use_tuple_args, const FunctionLibraryDefinition& flib_def,
|
||||
const GraphDebugInfo& debug_info,
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompilationResult* compilation_result,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
@ -544,6 +545,7 @@ Status CompileGraphToXlaHlo(
|
||||
RegisterDialects(context.getDialectRegistry());
|
||||
GraphImportConfig config;
|
||||
config.graph_as_function = true;
|
||||
config.control_outputs = control_rets;
|
||||
// Disable shape inference during import as some TensorFlow op fails during
|
||||
// shape inference with dynamic shaped operands. This in turn causes the
|
||||
// import to fail. Shape inference during import is going to be removed and
|
||||
|
@ -116,11 +116,11 @@ Status CompileGraphToXlaHlo(
|
||||
|
||||
// Compiles a TensorFlow Graph into XLA HLO, generates all accompanying metadata
|
||||
// and stores them in CompilationResult.
|
||||
// TODO(lyandy): Allow populating of targets/control outputs.
|
||||
Status CompileGraphToXlaHlo(
|
||||
const Graph& graph, llvm::ArrayRef<XlaArgument> args,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||
llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
|
||||
bool use_tuple_args, const FunctionLibraryDefinition& flib_def,
|
||||
const GraphDebugInfo& debug_info,
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompilationResult* compilation_result,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
|
@ -743,9 +743,13 @@ Status XlaCompiler::CompileFunction(
|
||||
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
|
||||
VLOG(1) << "Using MLIR bridge";
|
||||
GraphDebugInfo debug_info;
|
||||
std::vector<std::string> control_rets;
|
||||
for (const auto* control_ret_node : fbody->control_ret_nodes) {
|
||||
control_rets.push_back(control_ret_node->name());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
|
||||
std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
|
||||
options_.device_type.type_string(), options.use_tuple_arg,
|
||||
control_rets, options_.device_type.type_string(), options.use_tuple_arg,
|
||||
*options_.flib_def, debug_info, options_.shape_representation_fn,
|
||||
result));
|
||||
} else {
|
||||
|
Loading…
Reference in New Issue
Block a user