Don't propagate ref type if it wasn't originally done
"Dropping" ref type by way of identity op is common. PiperOrigin-RevId: 309258627 Change-Id: Ia7cd10eb7acaeae5f2db5a5e9a152a8cb39bed58
This commit is contained in:
parent
b3fa00bf32
commit
068526d1d7
tensorflow/compiler/mlir/tensorflow
@ -390,4 +390,15 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: dont_update_for_ref
|
||||
func @dont_update_for_ref() -> () {
|
||||
// CHECK: () -> tensor<4x!tf.f32ref>
|
||||
%11 = "tf.VariableV2"() {container = "", device = "", shape = #tf.shape<4>, shared_name = ""} : () -> tensor<4x!tf.f32ref>
|
||||
// CHECK: (tensor<4x!tf.f32ref>) -> tensor<4xf32>
|
||||
%12 = "tf.Identity"(%11) {device = ""} : (tensor<4x!tf.f32ref>) -> tensor<4xf32>
|
||||
// CHECK: (tensor<4xf32>) -> tensor<4xf32>
|
||||
%13 = "tf.Neg"(%12) {device = ""} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -160,6 +160,18 @@ bool InferShapeForPassThroughOps(OperandRange pass_through_operands,
|
||||
Type operand_type = std::get<0>(entry).getType();
|
||||
Value result = std::get<1>(entry);
|
||||
if (result.getType() == operand_type) continue;
|
||||
// Pass through nodes may remove ref types, don't consider that as
|
||||
// refinement.
|
||||
// TODO(jpienaar): There could be refinement in addition to this, so
|
||||
// refine this.
|
||||
if (operand_type.cast<TensorType>()
|
||||
.getElementType()
|
||||
.isa<TF::TensorFlowRefType>() &&
|
||||
!result.getType()
|
||||
.cast<TensorType>()
|
||||
.getElementType()
|
||||
.isa<TF::TensorFlowRefType>())
|
||||
continue;
|
||||
AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect,
|
||||
result.getType());
|
||||
result.setType(operand_type);
|
||||
@ -238,7 +250,20 @@ bool PassThroughOperandTypes(OperandRange operands, ResultRange results) {
|
||||
bool changed = false;
|
||||
for (auto entry : llvm::zip(operands, results)) {
|
||||
Type operand_type = std::get<0>(entry).getType();
|
||||
if (operand_type == std::get<1>(entry).getType()) continue;
|
||||
Type result_type = std::get<1>(entry).getType();
|
||||
if (operand_type == result_type) continue;
|
||||
// Pass through nodes may remove ref types, don't consider that as
|
||||
// refinement.
|
||||
// TODO(jpienaar): There could be refinement in addition to this, so
|
||||
// refine this.
|
||||
if (operand_type.cast<TensorType>()
|
||||
.getElementType()
|
||||
.isa<TF::TensorFlowRefType>() &&
|
||||
!result_type.cast<TensorType>()
|
||||
.getElementType()
|
||||
.isa<TF::TensorFlowRefType>())
|
||||
continue;
|
||||
|
||||
std::get<1>(entry).setType(operand_type);
|
||||
changed = true;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user