Add a pipeline to run the TPU Bridge in V1 compat mode

PiperOrigin-RevId: 294368788
Change-Id: I17153394602edd147cd7ce13f41e02570e4b33c8
This commit is contained in:
Mehdi Amini 2020-02-10 21:22:43 -08:00 committed by TensorFlower Gardener
parent 901e6af539
commit 1c53dd246b
5 changed files with 98 additions and 23 deletions

View File

@ -0,0 +1,29 @@
// RUN: tf-opt %s -tf-tpu-bridge-v1 | FileCheck %s --dump-input=fail
module attributes {tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"], tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 296 : i32}} {
func @main() {
// CHECK: std.constant
// CHECK: TPUCompile
// CHECK: TPUExecute
tf_executor.graph {
%outputs, %control = tf_executor.island wraps "std.constant"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
%outputs_0, %control_1 = tf_executor.island wraps "std.constant"() {value = dense<3.000000e+00> : tensor<f32>} : () -> tensor<f32>
%control_2 = tf_executor.island wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], host_compute_core = [], name = "TPUReplicateMetadata", num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_tpu = true} : () -> ()
%outputs_3, %control_4 = tf_executor.island wraps "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "x", shape = "tfshape$dim { }"} : () -> tensor<0xf32>
%outputs_5, %control_6 = tf_executor.island wraps "tf.TPUReplicatedInput"(%outputs_3) {N = 1 : i64, T = "tfdtype$DT_FLOAT", device = "", name = "input0"} : (tensor<0xf32>) -> tensor<0xf32>
%outputs_7, %control_8 = tf_executor.island wraps "tf.Identity"(%outputs_5) {T = "tfdtype$DT_FLOAT", _tpu_input_identity = true, _tpu_replicate = "cluster", device = "", name = "replicated_input_0"} : (tensor<0xf32>) -> tensor<0xf32>
%outputs_9, %control_10 = tf_executor.island wraps "tf.Mul"(%outputs_7, %outputs) {T = "tfdtype$DT_FLOAT", _tpu_replicate = "cluster", device = "", name = "mul"} : (tensor<0xf32>, tensor<f32>) -> tensor<0xf32>
%outputs_11, %control_12 = tf_executor.island wraps "tf.Placeholder"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "y", shape = "tfshape$dim { }"} : () -> tensor<0xf32>
%outputs_13, %control_14 = tf_executor.island wraps "tf.TPUReplicatedInput"(%outputs_11) {N = 1 : i64, T = "tfdtype$DT_FLOAT", device = "", name = "input1"} : (tensor<0xf32>) -> tensor<0xf32>
%outputs_15, %control_16 = tf_executor.island wraps "tf.Identity"(%outputs_13) {T = "tfdtype$DT_FLOAT", _tpu_input_identity = true, _tpu_replicate = "cluster", device = "", name = "replicated_input_1"} : (tensor<0xf32>) -> tensor<0xf32>
%outputs_17, %control_18 = tf_executor.island wraps "tf.Mul"(%outputs_15, %outputs_0) {T = "tfdtype$DT_FLOAT", _tpu_replicate = "cluster", device = "", name = "mul_1"} : (tensor<0xf32>, tensor<f32>) -> tensor<0xf32>
%outputs_19, %control_20 = tf_executor.island wraps "tf.AddV2"(%outputs_9, %outputs_17) {T = "tfdtype$DT_FLOAT", _tpu_replicate = "cluster", device = "", name = "add"} : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
%outputs_21, %control_22 = tf_executor.island wraps "tf.Identity"(%outputs_19) {T = "tfdtype$DT_FLOAT", _tpu_output_identity = true, _tpu_replicate = "cluster", device = "/device:TPU_REPLICATED_CORE:0", name = "Identity"} : (tensor<0xf32>) -> tensor<0xf32>
%outputs_23, %control_24 = tf_executor.island wraps "tf.TPUReplicatedOutput"(%outputs_21) {T = "tfdtype$DT_FLOAT", device = "", name = "output0", num_replicas = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32>
%outputs_25, %control_26 = tf_executor.island wraps "tf.Identity"(%outputs_23) {T = "tfdtype$DT_FLOAT", device = "", name = "output_0_shard_0"} : (tensor<0xf32>) -> tensor<0xf32>
%control_27 = tf_executor.island(%control_2, %control_26) wraps "tf.NoOp"() : () -> ()
tf_executor.fetch %control_27 : !tf_executor.control
}
return
}
}

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "mlir/Transforms/Passes.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
@ -25,8 +26,39 @@ limitations under the License.
namespace mlir {
namespace TFTPU {
namespace {
void AddGraphExportLoweringPasses(OpPassManager &pm) {
pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateToIslandPass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
}
void CreateTPUBridge(OpPassManager &pm) {
tensorflow::Status RunTPUBridge(
ModuleOp module, bool enable_logging,
llvm::function_ref<void(OpPassManager &pm)> pipeline_builder) {
PassManager bridge(module.getContext());
// Add logger to bridge passmanager.
if (enable_logging)
bridge.enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>());
// Populate a passmanager with the list of passes that implement the bridge.
pipeline_builder(bridge);
// Add set of passes to lower back to graph (from tf_executor).
AddGraphExportLoweringPasses(bridge);
// Run the bridge on the module, in case of failure, the `diag_handler`
// converts MLIR errors emitted to the MLIRContext into a tensorflow::Status.
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
LogicalResult result = bridge.run(module);
(void)result;
return diag_handler.ConsumeStatus();
}
} // namespace
void CreateTPUBridgePipeline(OpPassManager &pm) {
// Run island coarsening before shape inference to allow more exact shape
// inference using constant folding within islands.
pm.addNestedPass<FuncOp>(tf_executor::CreateTFExecutorIslandCoarseningPass());
@ -55,28 +87,26 @@ void CreateTPUBridge(OpPassManager &pm) {
pm.addNestedPass<FuncOp>(CreateTPUDynamicLayoutPass());
pm.addNestedPass<FuncOp>(CreateTPUMergeVariablesWithExecutePass());
pm.addPass(CreateTPUVariableReformattingPass());
pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateToIslandPass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
}
void CreateTPUBridgePipelineV1(OpPassManager &pm) {
// For V1 compatibility, we process a module where the graph does not have
// feeds and fetched. We extract first the TPU computation in a submodule,
// where it'll be in a function with args and returned values, much more like
// a TF v2 module. We can then run the usual pipeline on this nested module.
// Afterward we inline back in the parent module and delete the nested one.
pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandCoarseningPass());
pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandOutliningPass());
OpPassManager &nested_module = pm.nest<ModuleOp>();
CreateTPUBridgePipeline(nested_module);
pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandInliningPass());
}
tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging) {
PassManager bridge(module.getContext());
// Add logger to bridge passmanager.
if (enable_logging)
bridge.enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>());
// Populate a passmanager with the list of passes that implement the bridge.
CreateTPUBridge(bridge);
// Run the bridge on the module, in case of failure, the `diag_handler`
// converts MLIR errors emitted to the MLIRContext into a tensorflow::Status.
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
LogicalResult result = bridge.run(module);
(void)result;
return diag_handler.ConsumeStatus();
return RunTPUBridge(module, enable_logging, CreateTPUBridgePipeline);
}
tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool enable_logging) {
return RunTPUBridge(module, enable_logging, CreateTPUBridgePipelineV1);
}
} // namespace TFTPU

View File

@ -27,6 +27,12 @@ namespace TFTPU {
// tensorflow::BridgeLogger.
tensorflow::Status TPUBridge(ModuleOp module, bool enable_logging);
// Run all the passes involved in transforming the graph before execution so
// that it is suitable for targeting TPUs. When enable_logging is true, enables
// tensorflow::BridgeLogger.
// This variant of `TPUBridge` is intended for TensorFlow V1 compatibility.
tensorflow::Status TPUBridgeV1Compat(ModuleOp module, bool enable_logging);
} // namespace TFTPU
namespace TF {

View File

@ -27,6 +27,13 @@ mlir::PassPipelineRegistration<> tpu_pipeline(
"tf-tpu-bridge",
"Run all the passes involved in transforming the graph before execution so "
"that it is suitable for targeting TPUs.",
mlir::TFTPU::CreateTPUBridge);
mlir::TFTPU::CreateTPUBridgePipeline);
// Registers an existing pipeline builder function.
mlir::PassPipelineRegistration<> tpu_pipeline_v1(
"tf-tpu-bridge-v1",
"Run all the passes involved in transforming a TensorFlow V1 graph before "
"execution so that it is suitable for targeting TPUs.",
mlir::TFTPU::CreateTPUBridgePipelineV1);
} // anonymous namespace

View File

@ -194,8 +194,11 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateTPUMergeVariablesWithExecutePass();
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUVariableReformattingPass();
// Populates the supplied passmanager with the passes required to run the
// bridge. NOLINTNEXTLINE - MLIR contract is pass by mutable reference.
void CreateTPUBridge(OpPassManager& pm);
void CreateTPUBridgePipeline(OpPassManager& pm);
// Populates the supplied passmanager with the passes required to run the
// bridge in V1 mode.
void CreateTPUBridgePipelineV1(OpPassManager& pm);
} // namespace TFTPU