Update CompileGraphToXlaHlo to populate target/control ret nodes.
This is in preparation of updating graph pruning to always prune imported function graphs. PiperOrigin-RevId: 335944889 Change-Id: I3f6156aa08384883eee6227210f8fc8f1b7cc575
This commit is contained in:
parent
d460bb83cb
commit
c0da1d4092
@ -303,8 +303,12 @@ Status XlaCompilationCache::CompileSingleOp(
|
|||||||
}
|
}
|
||||||
|
|
||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
|
std::vector<std::string> control_rets;
|
||||||
|
if (result_dtypes.empty()) {
|
||||||
|
control_rets.push_back(node_def.name());
|
||||||
|
}
|
||||||
return CompileGraphToXlaHlo(
|
return CompileGraphToXlaHlo(
|
||||||
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
|
*graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
|
||||||
options.device_type.type_string(), compile_options.use_tuple_arg,
|
options.device_type.type_string(), compile_options.use_tuple_arg,
|
||||||
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
||||||
#endif
|
#endif
|
||||||
|
@ -534,8 +534,9 @@ Status CompileGraphToXlaHlo(
|
|||||||
|
|
||||||
Status CompileGraphToXlaHlo(
|
Status CompileGraphToXlaHlo(
|
||||||
const Graph& graph, llvm::ArrayRef<XlaArgument> args,
|
const Graph& graph, llvm::ArrayRef<XlaArgument> args,
|
||||||
llvm::StringRef device_type, bool use_tuple_args,
|
llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
|
||||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
bool use_tuple_args, const FunctionLibraryDefinition& flib_def,
|
||||||
|
const GraphDebugInfo& debug_info,
|
||||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
XlaCompilationResult* compilation_result,
|
XlaCompilationResult* compilation_result,
|
||||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||||
@ -544,6 +545,7 @@ Status CompileGraphToXlaHlo(
|
|||||||
RegisterDialects(context.getDialectRegistry());
|
RegisterDialects(context.getDialectRegistry());
|
||||||
GraphImportConfig config;
|
GraphImportConfig config;
|
||||||
config.graph_as_function = true;
|
config.graph_as_function = true;
|
||||||
|
config.control_outputs = control_rets;
|
||||||
// Disable shape inference during import as some TensorFlow op fails during
|
// Disable shape inference during import as some TensorFlow op fails during
|
||||||
// shape inference with dynamic shaped operands. This in turn causes the
|
// shape inference with dynamic shaped operands. This in turn causes the
|
||||||
// import to fail. Shape inference during import is going to be removed and
|
// import to fail. Shape inference during import is going to be removed and
|
||||||
|
@ -116,11 +116,11 @@ Status CompileGraphToXlaHlo(
|
|||||||
|
|
||||||
// Compiles a TensorFlow Graph into XLA HLO, generates all accompanying metadata
|
// Compiles a TensorFlow Graph into XLA HLO, generates all accompanying metadata
|
||||||
// and stores them in CompilationResult.
|
// and stores them in CompilationResult.
|
||||||
// TODO(lyandy): Allow populating of targets/control outputs.
|
|
||||||
Status CompileGraphToXlaHlo(
|
Status CompileGraphToXlaHlo(
|
||||||
const Graph& graph, llvm::ArrayRef<XlaArgument> args,
|
const Graph& graph, llvm::ArrayRef<XlaArgument> args,
|
||||||
llvm::StringRef device_type, bool use_tuple_args,
|
llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
|
||||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
bool use_tuple_args, const FunctionLibraryDefinition& flib_def,
|
||||||
|
const GraphDebugInfo& debug_info,
|
||||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
XlaCompilationResult* compilation_result,
|
XlaCompilationResult* compilation_result,
|
||||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||||
|
@ -743,9 +743,13 @@ Status XlaCompiler::CompileFunction(
|
|||||||
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
|
if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
|
||||||
VLOG(1) << "Using MLIR bridge";
|
VLOG(1) << "Using MLIR bridge";
|
||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
|
std::vector<std::string> control_rets;
|
||||||
|
for (const auto* control_ret_node : fbody->control_ret_nodes) {
|
||||||
|
control_rets.push_back(control_ret_node->name());
|
||||||
|
}
|
||||||
TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
|
TF_RETURN_IF_ERROR(CompileGraphToXlaHlo(
|
||||||
std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
|
std::move(*graph), mlir::SpanToArrayRef<XlaCompiler::Argument>(args),
|
||||||
options_.device_type.type_string(), options.use_tuple_arg,
|
control_rets, options_.device_type.type_string(), options.use_tuple_arg,
|
||||||
*options_.flib_def, debug_info, options_.shape_representation_fn,
|
*options_.flib_def, debug_info, options_.shape_representation_fn,
|
||||||
result));
|
result));
|
||||||
} else {
|
} else {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user