From 068526d1d7bd4876ad91bd82e217812c8f0e7779 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Thu, 30 Apr 2020 10:56:59 -0700 Subject: [PATCH] Don't propagate ref type if it wasn't originally done "Dropping" ref type by way of identity op is common. PiperOrigin-RevId: 309258627 Change-Id: Ia7cd10eb7acaeae5f2db5a5e9a152a8cb39bed58 --- .../tensorflow/tests/shape_inference.mlir | 11 ++++++++ .../tensorflow/transforms/shape_inference.cc | 27 ++++++++++++++++++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 1c979b96a9a..caac814b870 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -390,4 +390,15 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { } return } + + // CHECK-LABEL: dont_update_for_ref + func @dont_update_for_ref() -> () { + // CHECK: () -> tensor<4x!tf.f32ref> + %11 = "tf.VariableV2"() {container = "", device = "", shape = #tf.shape<4>, shared_name = ""} : () -> tensor<4x!tf.f32ref> + // CHECK: (tensor<4x!tf.f32ref>) -> tensor<4xf32> + %12 = "tf.Identity"(%11) {device = ""} : (tensor<4x!tf.f32ref>) -> tensor<4xf32> + // CHECK: (tensor<4xf32>) -> tensor<4xf32> + %13 = "tf.Neg"(%12) {device = ""} : (tensor<4xf32>) -> tensor<4xf32> + return + } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 38a1464ffcc..ef49c90063e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -160,6 +160,18 @@ bool InferShapeForPassThroughOps(OperandRange pass_through_operands, Type operand_type = std::get<0>(entry).getType(); Value result = std::get<1>(entry); if (result.getType() == operand_type) continue; + // Pass through nodes may remove ref types, don't consider that as + // refinement. + // TODO(jpienaar): There could be refinement in addition to this, so + // refine this. + if (operand_type.cast() + .getElementType() + .isa() && + !result.getType() + .cast() + .getElementType() + .isa()) + continue; AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect, result.getType()); result.setType(operand_type); @@ -238,7 +250,20 @@ bool PassThroughOperandTypes(OperandRange operands, ResultRange results) { bool changed = false; for (auto entry : llvm::zip(operands, results)) { Type operand_type = std::get<0>(entry).getType(); - if (operand_type == std::get<1>(entry).getType()) continue; + Type result_type = std::get<1>(entry).getType(); + if (operand_type == result_type) continue; + // Pass through nodes may remove ref types, don't consider that as + // refinement. + // TODO(jpienaar): There could be refinement in addition to this, so + // refine this. + if (operand_type.cast() + .getElementType() + .isa() && + !result_type.cast() + .getElementType() + .isa()) + continue; + std::get<1>(entry).setType(operand_type); changed = true; }