Register MLIR bridge pass as a FunctionOptimizationPass with FunctionOptimizationPassRegistry.
This reworks the MLIR bridge pass from an GraphOptimizationPass to a FunctionOptimizationPass for the function library runtime. Control ret node names are also passed into and updated from the registry pass run (MLIR to Graph). PiperOrigin-RevId: 293703916 Change-Id: I644765f4c5294962450cacb9c748582a579989c5
This commit is contained in:
parent
36abbb6956
commit
d1577971d7
@ -625,8 +625,8 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:device_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:session_options",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -639,6 +639,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":mlir_bridge_pass",
|
||||
"//tensorflow/core:core_cpu",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
|
||||
@ -86,36 +87,54 @@ static void DumpModule(mlir::ModuleOp module, llvm::StringRef file_prefix) {
|
||||
// and attached to a "compile" operation, whose result is fed to an "execute"
|
||||
// operation. The kernel for these operations is responsible to lower the
|
||||
// encapsulated graph to a particular device.
|
||||
Status MlirBridgePass::Run(const GraphOptimizationPassOptions& options) {
|
||||
if (!options.session_options->config.experimental().enable_mlir_bridge()) {
|
||||
Status MlirBridgePass::Run(const DeviceSet& device_set,
|
||||
const ConfigProto& config_proto,
|
||||
std::unique_ptr<Graph>* graph,
|
||||
FunctionLibraryDefinition* flib_def,
|
||||
std::vector<std::string>* control_ret_node_names,
|
||||
bool* control_rets_updated) {
|
||||
if (!config_proto.experimental().enable_mlir_bridge()) {
|
||||
VLOG(1) << "Skipping MLIR Bridge Pass, session flag not enabled";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VLOG(1) << "Running MLIR Bridge Pass";
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
mlir::MLIRContext context;
|
||||
GraphImportConfig specs;
|
||||
specs.graph_as_function = true;
|
||||
GraphImportConfig import_config;
|
||||
import_config.graph_as_function = true;
|
||||
import_config.control_outputs = *control_ret_node_names;
|
||||
|
||||
GraphExportConfig confs;
|
||||
confs.graph_as_function = true;
|
||||
TF_ASSIGN_OR_RETURN(auto module,
|
||||
ConvertGraphToMlir(**options.graph, debug_info,
|
||||
*options.flib_def, specs, &context));
|
||||
TF_ASSIGN_OR_RETURN(auto module_ref,
|
||||
ConvertGraphToMlir(**graph, debug_info, *flib_def,
|
||||
import_config, &context));
|
||||
|
||||
AddDevicesToOp(*module, options.device_set);
|
||||
AddDevicesToOp(*module_ref, &device_set);
|
||||
|
||||
if (VLOG_IS_ON(1)) DumpModule(*module, "mlir_bridge_before_");
|
||||
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_before_");
|
||||
|
||||
// Run the bridge now
|
||||
TF_RETURN_IF_ERROR(
|
||||
mlir::TFTPU::TPUBridge(*module, /*enable_logging=*/VLOG_IS_ON(1)));
|
||||
mlir::TFTPU::TPUBridge(*module_ref, /*enable_logging=*/VLOG_IS_ON(1)));
|
||||
|
||||
if (VLOG_IS_ON(1)) DumpModule(*module, "mlir_bridge_after_");
|
||||
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_after_");
|
||||
|
||||
GraphExportConfig export_config;
|
||||
export_config.graph_as_function = true;
|
||||
absl::flat_hash_set<Node*> control_ret_nodes;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
ConvertMlirToGraph(*module, confs, options.graph, options.flib_def),
|
||||
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
|
||||
&control_ret_nodes),
|
||||
"Error converting MLIR module back to graph");
|
||||
|
||||
control_ret_node_names->clear();
|
||||
control_ret_node_names->reserve(control_ret_nodes.size());
|
||||
for (const auto* node : control_ret_nodes)
|
||||
control_ret_node_names->push_back(node->name());
|
||||
|
||||
*control_rets_updated = true;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -16,16 +16,19 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This pass uses MLIR to implement all the conversion steps to target XLA from
|
||||
// a TensorFlow Graph. It is meant to expose a very limited set of
|
||||
// functionalities during the bring-up of MLIR-based bridge.
|
||||
class MlirBridgePass : public GraphOptimizationPass {
|
||||
class MlirBridgePass : public FunctionOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
|
||||
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
|
||||
std::vector<std::string>* control_ret_node_names,
|
||||
bool* control_rets_updated) override;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -13,11 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
|
||||
MlirBridgePass);
|
||||
static function_optimization_registration::FunctionOptimizationPassRegistration
|
||||
register_mlir_bridge_pass(std::make_unique<MlirBridgePass>());
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user