Register MLIR bridge pass as a FunctionOptimizationPass with FunctionOptimizationPassRegistry.

This reworks the MLIR bridge pass from an GraphOptimizationPass to a FunctionOptimizationPass for the function library runtime. Control ret node names are also passed into and updated from the registry pass run (MLIR to Graph).

PiperOrigin-RevId: 293703916
Change-Id: I644765f4c5294962450cacb9c748582a579989c5
This commit is contained in:
Andy Ly 2020-02-06 16:27:13 -08:00 committed by TensorFlower Gardener
parent 36abbb6956
commit d1577971d7
4 changed files with 47 additions and 21 deletions

View File

@ -625,8 +625,8 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:device_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:session_options",
"//tensorflow/core:core_cpu",
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:support",
],
alwayslink = 1,
@ -639,6 +639,7 @@ cc_library(
],
deps = [
":mlir_bridge_pass",
"//tensorflow/core:core_cpu",
],
alwayslink = 1,
)

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <string>
#include "absl/container/flat_hash_set.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_os_ostream.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
@ -86,36 +87,54 @@ static void DumpModule(mlir::ModuleOp module, llvm::StringRef file_prefix) {
// and attached to a "compile" operation, whose result is fed to an "execute"
// operation. The kernel for these operations is responsible to lower the
// encapsulated graph to a particular device.
Status MlirBridgePass::Run(const GraphOptimizationPassOptions& options) {
if (!options.session_options->config.experimental().enable_mlir_bridge()) {
Status MlirBridgePass::Run(const DeviceSet& device_set,
const ConfigProto& config_proto,
std::unique_ptr<Graph>* graph,
FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) {
if (!config_proto.experimental().enable_mlir_bridge()) {
VLOG(1) << "Skipping MLIR Bridge Pass, session flag not enabled";
return Status::OK();
}
VLOG(1) << "Running MLIR Bridge Pass";
GraphDebugInfo debug_info;
mlir::MLIRContext context;
GraphImportConfig specs;
specs.graph_as_function = true;
GraphImportConfig import_config;
import_config.graph_as_function = true;
import_config.control_outputs = *control_ret_node_names;
GraphExportConfig confs;
confs.graph_as_function = true;
TF_ASSIGN_OR_RETURN(auto module,
ConvertGraphToMlir(**options.graph, debug_info,
*options.flib_def, specs, &context));
TF_ASSIGN_OR_RETURN(auto module_ref,
ConvertGraphToMlir(**graph, debug_info, *flib_def,
import_config, &context));
AddDevicesToOp(*module, options.device_set);
AddDevicesToOp(*module_ref, &device_set);
if (VLOG_IS_ON(1)) DumpModule(*module, "mlir_bridge_before_");
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_before_");
// Run the bridge now
TF_RETURN_IF_ERROR(
mlir::TFTPU::TPUBridge(*module, /*enable_logging=*/VLOG_IS_ON(1)));
mlir::TFTPU::TPUBridge(*module_ref, /*enable_logging=*/VLOG_IS_ON(1)));
if (VLOG_IS_ON(1)) DumpModule(*module, "mlir_bridge_after_");
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_after_");
GraphExportConfig export_config;
export_config.graph_as_function = true;
absl::flat_hash_set<Node*> control_ret_nodes;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
ConvertMlirToGraph(*module, confs, options.graph, options.flib_def),
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
&control_ret_nodes),
"Error converting MLIR module back to graph");
control_ret_node_names->clear();
control_ret_node_names->reserve(control_ret_nodes.size());
for (const auto* node : control_ret_nodes)
control_ret_node_names->push_back(node->name());
*control_rets_updated = true;
return Status::OK();
}

View File

@ -16,16 +16,19 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
#define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
namespace tensorflow {
// This pass uses MLIR to implement all the conversion steps to target XLA from
// a TensorFlow Graph. It is meant to expose a very limited set of
// functionalities during the bring-up of MLIR-based bridge.
class MlirBridgePass : public GraphOptimizationPass {
class MlirBridgePass : public FunctionOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
std::vector<std::string>* control_ret_node_names,
bool* control_rets_updated) override;
};
} // namespace tensorflow

View File

@ -13,11 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
namespace tensorflow {
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
MlirBridgePass);
static function_optimization_registration::FunctionOptimizationPassRegistration
register_mlir_bridge_pass(std::make_unique<MlirBridgePass>());
} // namespace tensorflow