Eliminate TPUReplicatedInput and TPUReplicatedOutput when TPURewrite is complete
These operations are markers in the graph for each input/output of the TPU cluster. They don't have a kernel and can't be executed: they are only present to drive the rewrite and must be eliminated. PiperOrigin-RevId: 268133032
This commit is contained in:
parent
6b346a7b3f
commit
d153f264c5
@ -418,3 +418,21 @@ func @tpu0_func(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.B"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that TPUReplicatedInput and TPUReplicatedOutput operations are properly rewritten
|
||||
|
||||
func @main(%arg0 : tensor<0xf32>, %arg1 : tensor<0xf32>) -> tensor<0xf32> {
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%arg0, %arg1
|
||||
%0 = "tf.TPUReplicatedInput"(%arg0) {N = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32>
|
||||
%1 = "tf.TPUReplicatedInput"(%arg1) {N = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32>
|
||||
%2 = "tf_device.launch_func"(%0, %1) {device = "", _tpu_replicate = "cluster", func = @_func} : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
|
||||
%3 = "tf.TPUReplicatedOutput"(%2) {num_replicas = 1 : i64} : (tensor<0xf32>) -> tensor<0xf32>
|
||||
return %3 : tensor<0xf32>
|
||||
}
|
||||
func @_func(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
|
||||
%0 = "tf.Const"() {value = dense<3.000000e+00> : tensor<0xf32>} : () -> tensor<0xf32>
|
||||
return %0 : tensor<0xf32>
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
@ -22,6 +23,7 @@ limitations under the License.
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
@ -267,6 +269,17 @@ void TPURewritePass::runOnModule() {
|
||||
Rewrite(op, &builder);
|
||||
});
|
||||
|
||||
// Eliminate TPUReplicatedInput and TPUReplicatedOutput now that the rewrite
|
||||
// is complete.
|
||||
getModule().walk([&](Operation* op) {
|
||||
auto op_name = op->getName().getStringRef();
|
||||
if (op_name != "tf.TPUReplicatedInput" &&
|
||||
op_name != "tf.TPUReplicatedOutput")
|
||||
return;
|
||||
op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
|
||||
op->erase();
|
||||
});
|
||||
|
||||
// TODO(b/139377366): Remove functions that are no longer needed.
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user