[TFRT:Compiler] TensorDeviceCopyConversionPass should fold the tf.Identity op iff the arg device is same as the op device
PiperOrigin-RevId: 344156865 Change-Id: I8b5d4f991f15a37cd8d843783b37f439238a292e
This commit is contained in:
parent
aa22defe05
commit
8637177e25
@ -1,16 +1,47 @@
|
||||
// RUN: tf-opt -tf-tensor-device-copy %s | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: func @fold_identity
|
||||
// CHECK-SAME: ([[arg0:%.*]]: tensor<2x2xf32>, [[arg1:%.*]]: tensor<2x2xf32>
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32}} {
|
||||
func @fold_identity(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%0 = tf_executor.graph {
|
||||
// CHECK: tf.MatMul
|
||||
%outputs, %control = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NOT: tf.Identity
|
||||
%outputs_0, %control_1 = tf_executor.island wraps "tf.Identity"(%outputs) {device = ""} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
tf_executor.fetch %outputs_0 : tensor<2x2xf32>
|
||||
}
|
||||
return %0 : tensor<2x2xf32>
|
||||
// CHECK-LABEL: func @fold_identity_test
|
||||
func @fold_identity_test(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%0 = tf_executor.graph {
|
||||
// CHECK: tf.MatMul
|
||||
%outputs, %control = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NOT: tf.Identity
|
||||
%outputs_0, %control_1 = tf_executor.island wraps "tf.Identity"(%outputs) {device = ""} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
tf_executor.fetch %outputs_0 : tensor<2x2xf32>
|
||||
}
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @keep_identity_test
|
||||
func @keep_identity_test(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%0 = tf_executor.graph {
|
||||
// CHECK: tf.MatMul
|
||||
%outputs, %control = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {device = "/device:GPU:0", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: tf.Identity
|
||||
%outputs_0, %control_1 = tf_executor.island wraps "tf.Identity"(%outputs) {device = "/device:CPU:0"} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
tf_executor.fetch %outputs_0 : tensor<2x2xf32>
|
||||
}
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func @while_loop_test(%[[ARG_0:.*]]: tensor<i32>, %[[ARG_1:.*]]: tensor<i32>, %arg2: tensor<*xf32>)
|
||||
func @while_loop_test(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<*xf32>) {
|
||||
// CHECK-NEXT: tf.WhileRegion
|
||||
%0:2 = "tf.WhileRegion"(%arg0, %arg2) ( {
|
||||
// CHECK-NEXT: bb0(%[[ARG_3:.*]]: tensor<i32>, %[[ARG_4:.*]]: tensor<*xf32>)
|
||||
^bb0(%arg3: tensor<i32>, %arg4: tensor<*xf32>):
|
||||
// CHECK-NEXT: %[[RESULT_1:.*]] = "tf.Identity"(%[[ARG_3]])
|
||||
%1 = "tf.Identity"(%arg3) : (tensor<i32>) -> tensor<i32>
|
||||
%2 = "tf.Identity"(%arg1) : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK-NEXT: %[[RESULT_2:.*]] = "tf.NotEqual"(%[[RESULT_1]], %[[ARG_1]])
|
||||
%3 = "tf.NotEqual"(%1, %2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%3) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg3: tensor<i32>, %arg4: tensor<*xf32>):
|
||||
%cst = constant dense<1> : tensor<i32>
|
||||
%1 = "tf.Sub"(%arg3, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
"tf.Yield"(%1, %arg4) : (tensor<i32>, tensor<*xf32>) -> ()
|
||||
}) {is_stateless = true} : (tensor<i32>, tensor<*xf32>) -> (tensor<i32>, tensor<*xf32>)
|
||||
return
|
||||
}
|
||||
|
@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This pass folds the tf.Identity op if the operation has the same device as
|
||||
// its operand.
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
@ -29,40 +32,44 @@ namespace mlir {
|
||||
namespace TF {
|
||||
namespace {
|
||||
|
||||
// Deletes the op and forwards the arguments.
|
||||
template <typename TF_Op>
|
||||
class PassThroughConversion : public mlir::OpConversionPattern<TF_Op> {
|
||||
public:
|
||||
explicit PassThroughConversion(MLIRContext *context)
|
||||
: mlir::OpConversionPattern<TF_Op>(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
TF_Op op, ArrayRef<mlir::Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override { // NOLINT
|
||||
// Just forward the arguments to results.
|
||||
rewriter.replaceOp(op, operands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
constexpr const char *kDeviceAttr = "device";
|
||||
constexpr const char *kTFDeviceAttr = "tf.device";
|
||||
|
||||
class TensorDeviceCopyConversionPass
|
||||
: public PassWrapper<TensorDeviceCopyConversionPass, FunctionPass> {
|
||||
public:
|
||||
void runOnFunction() override {
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
mlir::ConversionTarget target(getContext());
|
||||
FuncOp func_op = getFunction();
|
||||
StringAttr empty_string = StringAttr::get("", func_op.getContext());
|
||||
func_op.walk([&](TF::IdentityOp op) {
|
||||
StringAttr arg_device = empty_string;
|
||||
mlir::Value arg = op.getOperand();
|
||||
if (BlockArgument block_arg = arg.dyn_cast<BlockArgument>()) {
|
||||
// Skip the folding logic if the block argument is not from the function
|
||||
// arguments. This can happen when the argument is from a while loop.
|
||||
if (block_arg.getParentRegion() != &func_op.getRegion()) {
|
||||
return WalkResult::advance();
|
||||
}
|
||||
if (StringAttr attr = func_op.getArgAttrOfType<StringAttr>(
|
||||
block_arg.getArgNumber(), kTFDeviceAttr)) {
|
||||
arg_device = attr;
|
||||
}
|
||||
} else if (StringAttr attr =
|
||||
arg.getDefiningOp()->getAttrOfType<StringAttr>(
|
||||
kDeviceAttr)) {
|
||||
arg_device = attr;
|
||||
}
|
||||
|
||||
// TODO(tfrt-devs): when device placer is introduced in the lowering pass,
|
||||
// we need to check if Identity op and it's previous op are placed on the
|
||||
// same device. If not, we don't fold Identity op since it's used for tensor
|
||||
// copying between devices.
|
||||
patterns.insert<PassThroughConversion<TF::IdentityOp>,
|
||||
PassThroughConversion<TF::IdentityNOp>>(&getContext());
|
||||
StringAttr op_device = op.getAttrOfType<StringAttr>(kDeviceAttr);
|
||||
if (!op_device) op_device = empty_string;
|
||||
// Skip the folding logic if the argument's device is different from the
|
||||
// operation's device.
|
||||
if (op_device != arg_device) return WalkResult::advance();
|
||||
|
||||
if (failed(applyPartialConversion(getFunction(), target,
|
||||
std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
op.replaceAllUsesWith(op.getOperand());
|
||||
op.erase();
|
||||
return WalkResult::advance();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@ -76,7 +83,7 @@ CreateTensorDeviceCopyConversionPass() {
|
||||
static mlir::PassRegistration<TensorDeviceCopyConversionPass>
|
||||
tensor_device_copy_pass(
|
||||
"tf-tensor-device-copy",
|
||||
"Handle ops that copy tensors between devices. E.g., tf.Identity.");
|
||||
"Fold the tf.Identity op if the op has the same device as its operand");
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
Loading…
x
Reference in New Issue
Block a user