Don't propagate ref type if it wasn't originally done

"Dropping" ref type by way of identity op is common.

PiperOrigin-RevId: 309258627
Change-Id: Ia7cd10eb7acaeae5f2db5a5e9a152a8cb39bed58
This commit is contained in:
Jacques Pienaar 2020-04-30 10:56:59 -07:00 committed by TensorFlower Gardener
parent b3fa00bf32
commit 068526d1d7
2 changed files with 37 additions and 1 deletions
tensorflow/compiler/mlir/tensorflow

View File

@ -390,4 +390,15 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
}
return
}
// CHECK-LABEL: dont_update_for_ref
func @dont_update_for_ref() -> () {
// CHECK: () -> tensor<4x!tf.f32ref>
%11 = "tf.VariableV2"() {container = "", device = "", shape = #tf.shape<4>, shared_name = ""} : () -> tensor<4x!tf.f32ref>
// CHECK: (tensor<4x!tf.f32ref>) -> tensor<4xf32>
%12 = "tf.Identity"(%11) {device = ""} : (tensor<4x!tf.f32ref>) -> tensor<4xf32>
// CHECK: (tensor<4xf32>) -> tensor<4xf32>
%13 = "tf.Neg"(%12) {device = ""} : (tensor<4xf32>) -> tensor<4xf32>
return
}
}

View File

@ -160,6 +160,18 @@ bool InferShapeForPassThroughOps(OperandRange pass_through_operands,
Type operand_type = std::get<0>(entry).getType();
Value result = std::get<1>(entry);
if (result.getType() == operand_type) continue;
// Pass through nodes may remove ref types, don't consider that as
// refinement.
// TODO(jpienaar): There could be refinement in addition to this, so
// refine this.
if (operand_type.cast<TensorType>()
.getElementType()
.isa<TF::TensorFlowRefType>() &&
!result.getType()
.cast<TensorType>()
.getElementType()
.isa<TF::TensorFlowRefType>())
continue;
AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect,
result.getType());
result.setType(operand_type);
@ -238,7 +250,20 @@ bool PassThroughOperandTypes(OperandRange operands, ResultRange results) {
bool changed = false;
for (auto entry : llvm::zip(operands, results)) {
Type operand_type = std::get<0>(entry).getType();
if (operand_type == std::get<1>(entry).getType()) continue;
Type result_type = std::get<1>(entry).getType();
if (operand_type == result_type) continue;
// Pass through nodes may remove ref types, don't consider that as
// refinement.
// TODO(jpienaar): There could be refinement in addition to this, so
// refine this.
if (operand_type.cast<TensorType>()
.getElementType()
.isa<TF::TensorFlowRefType>() &&
!result_type.cast<TensorType>()
.getElementType()
.isa<TF::TensorFlowRefType>())
continue;
std::get<1>(entry).setType(operand_type);
changed = true;
}