Eliminate TPUReplicatedInput and TPUReplicatedOutput when TPURewrite is complete

These operations are markers in the graph for each input/output of the TPU cluster.
They don't have a kernel and can't be executed: they are only present to drive the
rewrite and must be eliminated.

PiperOrigin-RevId: 268133032
This commit is contained in:
Mehdi Amini 2019-09-09 19:42:50 -07:00 committed by TensorFlower Gardener
parent 6b346a7b3f
commit d153f264c5
2 changed files with 31 additions and 0 deletions

View File

@ -418,3 +418,21 @@ func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// -----
// Tests that TPUReplicatedInput and TPUReplicatedOutput operations are properly rewritten
func @main(%arg0 : tensor<0xf32>, %arg1 : tensor<0xf32>) -> tensor<0xf32> {
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%arg0, %arg1
%0 = "tf.TPUReplicatedInput"(%arg0) {N = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32>
%1 = "tf.TPUReplicatedInput"(%arg1) {N = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32>
%2 = "tf_device.launch_func"(%0, %1) {device = "", _tpu_replicate = "cluster", func = @_func} : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
%3 = "tf.TPUReplicatedOutput"(%2) {num_replicas = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32>
return %3 : tensor<0xf32>
}
func @_func(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
%0 = "tf.Const"() {value = dense<3.000000e+00> : tensor<0xf32>} : () -> tensor<0xf32>
return %0 : tensor<0xf32>
}

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
@ -22,6 +23,7 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
@ -267,6 +269,17 @@ void TPURewritePass::runOnModule() {
Rewrite(op, &builder);
});
// Eliminate TPUReplicatedInput and TPUReplicatedOutput now that the rewrite
// is complete.
getModule().walk([&](Operation* op) {
auto op_name = op->getName().getStringRef();
if (op_name != "tf.TPUReplicatedInput" &&
op_name != "tf.TPUReplicatedOutput")
return;
op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
op->erase();
});
// TODO(b/139377366): Remove functions that are no longer needed.
}