Add and register MLIR bridge v1 compat pass.
This pass is similar to the MLIR bridge pass but is specific to TensorFlow V1 Graphs. In addition, this pass will only trigger on non function graphs as they are already handled by the MLIR bridge pass. PiperOrigin-RevId: 294521035 Change-Id: If80bb613b51ff9fd8b2dad61b4357a55452a4093
This commit is contained in:
parent
5d68ad19c6
commit
11d3a2d7f2
@ -105,7 +105,6 @@ Status MlirBridgePass::Run(const DeviceSet& device_set,
|
||||
GraphImportConfig import_config;
|
||||
import_config.graph_as_function = true;
|
||||
import_config.control_outputs = *control_ret_node_names;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto module_ref,
|
||||
ConvertGraphToMlir(**graph, debug_info, *flib_def,
|
||||
import_config, &context));
|
||||
@ -138,4 +137,44 @@ Status MlirBridgePass::Run(const DeviceSet& device_set,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MlirBridgeV1CompatPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
// Skip function graphs as MlirBridgePass will be used instead.
|
||||
if (options.is_function_graph) return Status::OK();
|
||||
|
||||
if (!options.session_options->config.experimental().enable_mlir_bridge()) {
|
||||
VLOG(1) << "Skipping MLIR Bridge V1 Compat Pass, session flag not enabled";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VLOG(1) << "Running MLIR Bridge V1 Compat Pass";
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
mlir::MLIRContext context;
|
||||
GraphImportConfig import_config;
|
||||
import_config.upgrade_legacy = true;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module_ref,
|
||||
ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def,
|
||||
import_config, &context));
|
||||
|
||||
AddDevicesToOp(*module_ref, options.device_set);
|
||||
|
||||
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_v1_compat_before_");
|
||||
|
||||
// Run the bridge now
|
||||
TF_RETURN_IF_ERROR(mlir::TFTPU::TPUBridgeV1Compat(
|
||||
*module_ref, /*enable_logging=*/VLOG_IS_ON(1)));
|
||||
|
||||
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_v1_compat_after_");
|
||||
|
||||
GraphExportConfig export_config;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
ConvertMlirToGraph(*module_ref, export_config, options.graph,
|
||||
options.flib_def),
|
||||
"Error converting MLIR module back to graph");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -17,11 +17,12 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
|
||||
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This pass uses MLIR to implement all the conversion steps to target XLA from
|
||||
// a TensorFlow Graph. It is meant to expose a very limited set of
|
||||
// a TensorFlow Function Graph. It is meant to expose a very limited set of
|
||||
// functionalities during the bring-up of MLIR-based bridge.
|
||||
class MlirBridgePass : public FunctionOptimizationPass {
|
||||
public:
|
||||
@ -31,6 +32,14 @@ class MlirBridgePass : public FunctionOptimizationPass {
|
||||
bool* control_rets_updated) override;
|
||||
};
|
||||
|
||||
// This pass uses MLIR to implement all the conversion steps to target XLA from
|
||||
// a TensorFlow V1 Graph. It is meant to expose a very limited set of
|
||||
// functionalities during the bring-up of MLIR-based bridge.
|
||||
class MlirBridgeV1CompatPass : public GraphOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
|
||||
|
@ -17,10 +17,14 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
|
||||
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static function_optimization_registration::FunctionOptimizationPassRegistration
|
||||
register_mlir_bridge_pass(std::make_unique<MlirBridgePass>());
|
||||
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
|
||||
MlirBridgeV1CompatPass);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user