Add and register MLIR bridge v1 compat pass.

This pass is similar to the MLIR bridge pass but is specific to TensorFlow V1 Graphs. In addition, this pass will only trigger on non function graphs as they are already handled by the MLIR bridge pass.

PiperOrigin-RevId: 294521035
Change-Id: If80bb613b51ff9fd8b2dad61b4357a55452a4093
This commit is contained in:
Andy Ly 2020-02-11 14:27:56 -08:00 committed by TensorFlower Gardener
parent 5d68ad19c6
commit 11d3a2d7f2
3 changed files with 54 additions and 2 deletions

View File

@ -105,7 +105,6 @@ Status MlirBridgePass::Run(const DeviceSet& device_set,
GraphImportConfig import_config;
import_config.graph_as_function = true;
import_config.control_outputs = *control_ret_node_names;
TF_ASSIGN_OR_RETURN(auto module_ref,
ConvertGraphToMlir(**graph, debug_info, *flib_def,
import_config, &context));
@ -138,4 +137,44 @@ Status MlirBridgePass::Run(const DeviceSet& device_set,
return Status::OK();
}
Status MlirBridgeV1CompatPass::Run(
const GraphOptimizationPassOptions& options) {
// Skip function graphs as MlirBridgePass will be used instead.
if (options.is_function_graph) return Status::OK();
if (!options.session_options->config.experimental().enable_mlir_bridge()) {
VLOG(1) << "Skipping MLIR Bridge V1 Compat Pass, session flag not enabled";
return Status::OK();
}
VLOG(1) << "Running MLIR Bridge V1 Compat Pass";
GraphDebugInfo debug_info;
mlir::MLIRContext context;
GraphImportConfig import_config;
import_config.upgrade_legacy = true;
TF_ASSIGN_OR_RETURN(
auto module_ref,
ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def,
import_config, &context));
AddDevicesToOp(*module_ref, options.device_set);
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_v1_compat_before_");
// Run the bridge now
TF_RETURN_IF_ERROR(mlir::TFTPU::TPUBridgeV1Compat(
*module_ref, /*enable_logging=*/VLOG_IS_ON(1)));
if (VLOG_IS_ON(1)) DumpModule(*module_ref, "mlir_bridge_v1_compat_after_");
GraphExportConfig export_config;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
ConvertMlirToGraph(*module_ref, export_config, options.graph,
options.flib_def),
"Error converting MLIR module back to graph");
return Status::OK();
}
} // namespace tensorflow

View File

@ -17,11 +17,12 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
// This pass uses MLIR to implement all the conversion steps to target XLA from
// a TensorFlow Graph. It is meant to expose a very limited set of
// a TensorFlow Function Graph. It is meant to expose a very limited set of
// functionalities during the bring-up of MLIR-based bridge.
class MlirBridgePass : public FunctionOptimizationPass {
public:
@ -31,6 +32,14 @@ class MlirBridgePass : public FunctionOptimizationPass {
bool* control_rets_updated) override;
};
// This pass uses MLIR to implement all the conversion steps to target XLA from
// a TensorFlow V1 Graph. It is meant to expose a very limited set of
// functionalities during the bring-up of MLIR-based bridge.
class MlirBridgeV1CompatPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_

View File

@ -17,10 +17,14 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
static function_optimization_registration::FunctionOptimizationPassRegistration
register_mlir_bridge_pass(std::make_unique<MlirBridgePass>());
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
MlirBridgeV1CompatPass);
} // namespace tensorflow