Register MLIR bridge pass as a FunctionOptimizationPass with FunctionOptimizationPassRegistry.
This reworks the MLIR bridge pass from an GraphOptimizationPass to a FunctionOptimizationPass for the function library runtime. Control ret node names are also passed into and updated from the registry pass run (MLIR to Graph). PiperOrigin-RevId: 293703916 Change-Id: I644765f4c5294962450cacb9c748582a579989c5
This commit is contained in:
parent
36abbb6956
commit
d1577971d7
@ -625,8 +625,8 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:device_util",
|
"//tensorflow/compiler/mlir/tensorflow:device_util",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||||
"//tensorflow/core:core_cpu_lib",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:session_options",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
@ -639,6 +639,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":mlir_bridge_pass",
|
":mlir_bridge_pass",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
|
||||||
@ -86,36 +87,54 @@ static void DumpModule(mlir::ModuleOp module, llvm::StringRef file_prefix) {
|
|||||||
// and attached to a "compile" operation, whose result is fed to an "execute"
|
// and attached to a "compile" operation, whose result is fed to an "execute"
|
||||||
// operation. The kernel for these operations is responsible to lower the
|
// operation. The kernel for these operations is responsible to lower the
|
||||||
// encapsulated graph to a particular device.
|
// encapsulated graph to a particular device.
|
||||||
Status MlirBridgePass::Run(const GraphOptimizationPassOptions& options) {
|
Status MlirBridgePass::Run(const DeviceSet& device_set,
|
||||||
if (!options.session_options->config.experimental().enable_mlir_bridge()) {
|
const ConfigProto& config_proto,
|
||||||
|
std::unique_ptr<Graph>* graph,
|
||||||
|
FunctionLibraryDefinition* flib_def,
|
||||||
|
std::vector<std::string>* control_ret_node_names,
|
||||||
|
bool* control_rets_updated) {
|
||||||
|
if (!config_proto.experimental().enable_mlir_bridge()) {
|
||||||
VLOG(1) << "Skipping MLIR Bridge Pass, session flag not enabled";
|
VLOG(1) << "Skipping MLIR Bridge Pass, session flag not enabled";
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
VLOG(1) << "Running MLIR Bridge Pass";
|
||||||
|
|
||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
mlir::MLIRContext context;
|
mlir::MLIRContext context;
|
||||||
GraphImportConfig specs;
|
GraphImportConfig import_config;
|
||||||
specs.graph_as_function = true;
|
import_config.graph_as_function = true;
|
||||||
|
import_config.control_outputs = *control_ret_node_names;
|
||||||
|
|
||||||
GraphExportConfig confs;
|
TF_ASSIGN_OR_RETURN(auto module_ref,
|
||||||
confs.graph_as_function = true;
|
ConvertGraphToMlir(**graph, debug_info, *flib_def,
|
||||||
TF_ASSIGN_OR_RETURN(auto module,
|
import_config, &context));
|
||||||
ConvertGraphToMlir(**options.graph, debug_info,
|
|
||||||
*options.flib_def, specs, &context));
|
|
||||||
|
|
||||||
AddDevicesToOp(*module, options.device_set);
|
AddDevicesToOp(*module_ref, &device_set);
|
||||||
|
|
||||||
if (VLOG_IS_ON(1)) DumpModule(*module, "mlir_bridge_before_");
|
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_before_");
|
||||||
|
|
||||||
// Run the bridge now
|
// Run the bridge now
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
mlir::TFTPU::TPUBridge(*module, /*enable_logging=*/VLOG_IS_ON(1)));
|
mlir::TFTPU::TPUBridge(*module_ref, /*enable_logging=*/VLOG_IS_ON(1)));
|
||||||
|
|
||||||
if (VLOG_IS_ON(1)) DumpModule(*module, "mlir_bridge_after_");
|
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_after_");
|
||||||
|
|
||||||
|
GraphExportConfig export_config;
|
||||||
|
export_config.graph_as_function = true;
|
||||||
|
absl::flat_hash_set<Node*> control_ret_nodes;
|
||||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||||
ConvertMlirToGraph(*module, confs, options.graph, options.flib_def),
|
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
|
||||||
|
&control_ret_nodes),
|
||||||
"Error converting MLIR module back to graph");
|
"Error converting MLIR module back to graph");
|
||||||
|
|
||||||
|
control_ret_node_names->clear();
|
||||||
|
control_ret_node_names->reserve(control_ret_nodes.size());
|
||||||
|
for (const auto* node : control_ret_nodes)
|
||||||
|
control_ret_node_names->push_back(node->name());
|
||||||
|
|
||||||
|
*control_rets_updated = true;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,16 +16,19 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
|
#ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
|
||||||
#define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
|
#define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// This pass uses MLIR to implement all the conversion steps to target XLA from
|
// This pass uses MLIR to implement all the conversion steps to target XLA from
|
||||||
// a TensorFlow Graph. It is meant to expose a very limited set of
|
// a TensorFlow Graph. It is meant to expose a very limited set of
|
||||||
// functionalities during the bring-up of MLIR-based bridge.
|
// functionalities during the bring-up of MLIR-based bridge.
|
||||||
class MlirBridgePass : public GraphOptimizationPass {
|
class MlirBridgePass : public FunctionOptimizationPass {
|
||||||
public:
|
public:
|
||||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
|
||||||
|
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
|
||||||
|
std::vector<std::string>* control_ret_node_names,
|
||||||
|
bool* control_rets_updated) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -13,11 +13,14 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
|
#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
|
||||||
|
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
|
static function_optimization_registration::FunctionOptimizationPassRegistration
|
||||||
MlirBridgePass);
|
register_mlir_bridge_pass(std::make_unique<MlirBridgePass>());
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
x
Reference in New Issue
Block a user