[TFRT:Compiler] TensorDeviceCopyConversionPass should fold the tf.Identity op iff the arg device is same as the op device

PiperOrigin-RevId: 344156865
Change-Id: I8b5d4f991f15a37cd8d843783b37f439238a292e
This commit is contained in:
Dong Lin 2020-11-24 16:59:25 -08:00 committed by TensorFlower Gardener
parent aa22defe05
commit 8637177e25
2 changed files with 78 additions and 40 deletions

View File

@ -1,16 +1,47 @@
// RUN: tf-opt -tf-tensor-device-copy %s | FileCheck %s --dump-input=fail
// CHECK-LABEL: func @fold_identity
// CHECK-SAME: ([[arg0:%.*]]: tensor<2x2xf32>, [[arg1:%.*]]: tensor<2x2xf32>
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32}} {
func @fold_identity(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = tf_executor.graph {
// CHECK: tf.MatMul
%outputs, %control = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NOT: tf.Identity
%outputs_0, %control_1 = tf_executor.island wraps "tf.Identity"(%outputs) {device = ""} : (tensor<2x2xf32>) -> tensor<2x2xf32>
tf_executor.fetch %outputs_0 : tensor<2x2xf32>
}
return %0 : tensor<2x2xf32>
// CHECK-LABEL: func @fold_identity_test
func @fold_identity_test(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = tf_executor.graph {
// CHECK: tf.MatMul
%outputs, %control = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NOT: tf.Identity
%outputs_0, %control_1 = tf_executor.island wraps "tf.Identity"(%outputs) {device = ""} : (tensor<2x2xf32>) -> tensor<2x2xf32>
tf_executor.fetch %outputs_0 : tensor<2x2xf32>
}
return %0 : tensor<2x2xf32>
}
// CHECK-LABEL: func @keep_identity_test
func @keep_identity_test(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = tf_executor.graph {
// CHECK: tf.MatMul
%outputs, %control = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {device = "/device:GPU:0", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: tf.Identity
%outputs_0, %control_1 = tf_executor.island wraps "tf.Identity"(%outputs) {device = "/device:CPU:0"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
tf_executor.fetch %outputs_0 : tensor<2x2xf32>
}
return %0 : tensor<2x2xf32>
}
// CHECK: func @while_loop_test(%[[ARG_0:.*]]: tensor<i32>, %[[ARG_1:.*]]: tensor<i32>, %arg2: tensor<*xf32>)
func @while_loop_test(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<*xf32>) {
// CHECK-NEXT: tf.WhileRegion
%0:2 = "tf.WhileRegion"(%arg0, %arg2) ( {
// CHECK-NEXT: bb0(%[[ARG_3:.*]]: tensor<i32>, %[[ARG_4:.*]]: tensor<*xf32>)
^bb0(%arg3: tensor<i32>, %arg4: tensor<*xf32>):
// CHECK-NEXT: %[[RESULT_1:.*]] = "tf.Identity"(%[[ARG_3]])
%1 = "tf.Identity"(%arg3) : (tensor<i32>) -> tensor<i32>
%2 = "tf.Identity"(%arg1) : (tensor<i32>) -> tensor<i32>
// CHECK-NEXT: %[[RESULT_2:.*]] = "tf.NotEqual"(%[[RESULT_1]], %[[ARG_1]])
%3 = "tf.NotEqual"(%1, %2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%3) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<i32>, %arg4: tensor<*xf32>):
%cst = constant dense<1> : tensor<i32>
%1 = "tf.Sub"(%arg3, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"tf.Yield"(%1, %arg4) : (tensor<i32>, tensor<*xf32>) -> ()
}) {is_stateless = true} : (tensor<i32>, tensor<*xf32>) -> (tensor<i32>, tensor<*xf32>)
return
}

View File

@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This pass folds the tf.Identity op if the operation has the same device as
// its operand.
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
@ -29,40 +32,44 @@ namespace mlir {
namespace TF {
namespace {
// Deletes the op and forwards the arguments.
template <typename TF_Op>
class PassThroughConversion : public mlir::OpConversionPattern<TF_Op> {
public:
explicit PassThroughConversion(MLIRContext *context)
: mlir::OpConversionPattern<TF_Op>(context) {}
LogicalResult matchAndRewrite(
TF_Op op, ArrayRef<mlir::Value> operands,
ConversionPatternRewriter &rewriter) const override { // NOLINT
// Just forward the arguments to results.
rewriter.replaceOp(op, operands);
return success();
}
};
constexpr const char *kDeviceAttr = "device";
constexpr const char *kTFDeviceAttr = "tf.device";
class TensorDeviceCopyConversionPass
: public PassWrapper<TensorDeviceCopyConversionPass, FunctionPass> {
public:
void runOnFunction() override {
mlir::OwningRewritePatternList patterns;
mlir::ConversionTarget target(getContext());
FuncOp func_op = getFunction();
StringAttr empty_string = StringAttr::get("", func_op.getContext());
func_op.walk([&](TF::IdentityOp op) {
StringAttr arg_device = empty_string;
mlir::Value arg = op.getOperand();
if (BlockArgument block_arg = arg.dyn_cast<BlockArgument>()) {
// Skip the folding logic if the block argument is not from the function
// arguments. This can happen when the argument is from a while loop.
if (block_arg.getParentRegion() != &func_op.getRegion()) {
return WalkResult::advance();
}
if (StringAttr attr = func_op.getArgAttrOfType<StringAttr>(
block_arg.getArgNumber(), kTFDeviceAttr)) {
arg_device = attr;
}
} else if (StringAttr attr =
arg.getDefiningOp()->getAttrOfType<StringAttr>(
kDeviceAttr)) {
arg_device = attr;
}
// TODO(tfrt-devs): when device placer is introduced in the lowering pass,
// we need to check if Identity op and it's previous op are placed on the
// same device. If not, we don't fold Identity op since it's used for tensor
// copying between devices.
patterns.insert<PassThroughConversion<TF::IdentityOp>,
PassThroughConversion<TF::IdentityNOp>>(&getContext());
StringAttr op_device = op.getAttrOfType<StringAttr>(kDeviceAttr);
if (!op_device) op_device = empty_string;
// Skip the folding logic if the argument's device is different from the
// operation's device.
if (op_device != arg_device) return WalkResult::advance();
if (failed(applyPartialConversion(getFunction(), target,
std::move(patterns)))) {
signalPassFailure();
}
op.replaceAllUsesWith(op.getOperand());
op.erase();
return WalkResult::advance();
});
}
};
@ -76,7 +83,7 @@ CreateTensorDeviceCopyConversionPass() {
static mlir::PassRegistration<TensorDeviceCopyConversionPass>
tensor_device_copy_pass(
"tf-tensor-device-copy",
"Handle ops that copy tensors between devices. E.g., tf.Identity.");
"Fold the tf.Identity op if the op has the same device as its operand");
} // namespace TF
} // namespace mlir