diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc index 2ec73824f6c..50f034e8ba1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc @@ -62,101 +62,6 @@ bool GetCastCompatibleShape(llvm::ArrayRef a_shape, return true; } -// Given two types `a` and `b`, returns a refined type which is cast compatible -// with both `a` and `b` and is equal to or more precise than both of them. It -// returns empty Type if the input types are not cast compatible. -// -// The two types are considered cast compatible if they have dynamically equal -// shapes and element type. For element types that do not have subtypes, they -// must be equal. However for TensorFlow types such as Resource and Variant, -// that also have subtypes, we recursively check for subtype compatibilty for -// Resource types and assume all variant types are cast compatible. If either -// one of `a` or `b` have empty subtypes, they are considered cast compatible. -// -// The returned type is same or more precise than the input types. For example, -// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and -// tensor respectively, the returned type is tensor<2x4x?xf32>. -// -// Provides option to ignore ref types on 'a'. This is useful for TF ops that -// might allow operands to either be same as result type or be a ref type -// corresponding to it. -mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b, - bool may_ignore_ref_type_a) { - // Fast path if everything is equal. - if (a == b) return b; - - auto a_tt = a.dyn_cast(); - auto b_tt = b.dyn_cast(); - - // If only one of a or b is a tensor type, they are incompatible. - if (static_cast(a_tt) ^ static_cast(b_tt)) return nullptr; - - // For non-tensor types, we do not need to worry about shape and can return - // early. - if (!a_tt && !b_tt) { - // Remove ref types. - if (may_ignore_ref_type_a) { - if (auto ref_type = a.dyn_cast()) { - a = ref_type.RemoveRef(); - if (a == b) return a; - } - } - if (a.getTypeID() != b.getTypeID()) return nullptr; - - // If either is not a type that contain subtypes then the types are not cast - // compatible. - auto a_wst = a.dyn_cast(); - auto b_wst = b.dyn_cast(); - if (!a_wst || !b_wst) return nullptr; - - // For Variant types we are more permissive right now and accept all pairs - // of Variant types. If we are more constrainted and check compatibility of - // subtypes, we might reject valid graphs. - // TODO(prakalps): Variant doesn't have a subtype, we assign it - // one, so we should only assign it one when we know the subtype. Then we - // can be more constrained and check subtypes for cast compatibility as - // well. - if (a.isa()) return a; - - // For Resource types, we recursively check the subtypes for cast - // compatibility, if possible. Otherwise treat them as compatible. - auto a_wst_st = a_wst.GetSubtypes(); - auto b_wst_st = b_wst.GetSubtypes(); - if (a_wst_st.empty() || b_wst_st.empty()) return a; - if (a_wst_st.size() != b_wst_st.size()) return nullptr; - llvm::SmallVector refined_subtypes; - for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) { - mlir::Type refined_st = - GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes), - /*may_ignore_ref_type_a=*/false); - if (!refined_st) return nullptr; - refined_subtypes.push_back(refined_st.cast()); - } - - return mlir::TF::ResourceType::get(refined_subtypes, a.getContext()); - } - - // For tensor types, check compatibility of both element type and shape. - mlir::Type refined_element_ty = GetCastCompatibleType( - a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a); - if (!refined_element_ty) return nullptr; - - if (!a_tt.hasRank() && !b_tt.hasRank()) { - return mlir::UnrankedTensorType::get(refined_element_ty); - } - if (!a_tt.hasRank()) { - return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty); - } - if (!b_tt.hasRank()) { - return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty); - } - - llvm::SmallVector refined_shape; - if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape)) - return nullptr; - - return mlir::RankedTensorType::get(refined_shape, refined_element_ty); -} } // namespace namespace mlir { @@ -343,6 +248,102 @@ bool BroadcastCompatible(ArrayRef lhs, ArrayRef rhs) { return true; } +// Given two types `a` and `b`, returns a refined type which is cast compatible +// with both `a` and `b` and is equal to or more precise than both of them. It +// returns empty Type if the input types are not cast compatible. +// +// The two types are considered cast compatible if they have dynamically equal +// shapes and element type. For element types that do not have subtypes, they +// must be equal. However for TensorFlow types such as Resource and Variant, +// that also have subtypes, we recursively check for subtype compatibilty for +// Resource types and assume all variant types are cast compatible. If either +// one of `a` or `b` have empty subtypes, they are considered cast compatible. +// +// The returned type is same or more precise than the input types. For example, +// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and +// tensor respectively, the returned type is tensor<2x4x?xf32>. +// +// Provides option to ignore ref types on 'a'. This is useful for TF ops that +// might allow operands to either be same as result type or be a ref type +// corresponding to it. +mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b, + bool may_ignore_ref_type_a) { + // Fast path if everything is equal. + if (a == b) return b; + + auto a_tt = a.dyn_cast(); + auto b_tt = b.dyn_cast(); + + // If only one of a or b is a tensor type, they are incompatible. + if (static_cast(a_tt) ^ static_cast(b_tt)) return nullptr; + + // For non-tensor types, we do not need to worry about shape and can return + // early. + if (!a_tt && !b_tt) { + // Remove ref types. + if (may_ignore_ref_type_a) { + if (auto ref_type = a.dyn_cast()) { + a = ref_type.RemoveRef(); + if (a == b) return a; + } + } + if (a.getTypeID() != b.getTypeID()) return nullptr; + + // If either is not a type that contain subtypes then the types are not cast + // compatible. + auto a_wst = a.dyn_cast(); + auto b_wst = b.dyn_cast(); + if (!a_wst || !b_wst) return nullptr; + + // For Variant types we are more permissive right now and accept all pairs + // of Variant types. If we are more constrainted and check compatibility of + // subtypes, we might reject valid graphs. + // TODO(prakalps): Variant doesn't have a subtype, we assign it + // one, so we should only assign it one when we know the subtype. Then we + // can be more constrained and check subtypes for cast compatibility as + // well. + if (a.isa()) return a; + + // For Resource types, we recursively check the subtypes for cast + // compatibility, if possible. Otherwise treat them as compatible. + auto a_wst_st = a_wst.GetSubtypes(); + auto b_wst_st = b_wst.GetSubtypes(); + if (a_wst_st.empty() || b_wst_st.empty()) return a; + if (a_wst_st.size() != b_wst_st.size()) return nullptr; + llvm::SmallVector refined_subtypes; + for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) { + mlir::Type refined_st = + GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes), + /*may_ignore_ref_type_a=*/false); + if (!refined_st) return nullptr; + refined_subtypes.push_back(refined_st.cast()); + } + + return mlir::TF::ResourceType::get(refined_subtypes, a.getContext()); + } + + // For tensor types, check compatibility of both element type and shape. + mlir::Type refined_element_ty = GetCastCompatibleType( + a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a); + if (!refined_element_ty) return nullptr; + + if (!a_tt.hasRank() && !b_tt.hasRank()) { + return mlir::UnrankedTensorType::get(refined_element_ty); + } + if (!a_tt.hasRank()) { + return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty); + } + if (!b_tt.hasRank()) { + return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty); + } + + llvm::SmallVector refined_shape; + if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape)) + return nullptr; + + return mlir::RankedTensorType::get(refined_shape, refined_element_ty); +} + bool HasCompatibleElementTypes(Type lhs, Type rhs, bool may_ignore_ref_type_lhs) { return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index f93f6b657da..60a86f32920 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -272,6 +272,15 @@ class VariantType : public detail::TypeWithSubtypeImpl { static std::string getTypeName() { return "VariantType"; } }; +// Given two types `a` and `b`, returns a refined type which is cast compatible +// with both `a` and `b` and is equal to or more precise than both of them. It +// returns empty Type if the input types are not cast compatible. +// Provides option to ignore ref types on 'a'. This is useful for TF ops that +// might allow operands to either be same as result type or be a ref type +// corresponding to it. +mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b, + bool may_ignore_ref_type_a); + // Returns whether two arrays of Type are broadcast compatible. bool BroadcastCompatible(ArrayRef lhs, ArrayRef rhs); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir index 458e7bb28dc..011c1fe7b6d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir @@ -987,5 +987,66 @@ func @test_unsupported_resource_op() -> tensor<*xi32> { return %1 : tensor<*xi32> } +// ----- +// Test type refinement. If the resource has a single subtype, check that that +// type gets used when hoisting the read. None of the result types will change. +// CHECK-LABEL: func @type_refinement_use_subtype +func @type_refinement_use_subtype() -> tensor<*xi32> { + + // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + + // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) + // CHECK-SAME: -> tensor<4xi32> + // CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster" + // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) : (tensor<4xi32>) -> tensor<*xi32> + // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]] + // CHECK-SAME: tensor<*xi32>, tensor<*xi32> + // CHECK: {cluster_attr = "cluster_attr"} + // CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>) + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) + + %1 = "tf_device.cluster"() ( { + %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>>) -> tensor<*xi32> + %3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>) + "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>>, tensor<*xi32>) -> () + tf_device.return %3 : tensor<*xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> + + // CHECK: return %[[CLUSTER_RES]]#0 + // CHECK-SAME: tensor<*xi32> + return %1 : tensor<*xi32> +} + +// If multiple types are used across reads and writes, check that the read uses +// the most refined type. The first ReadVariable should refine the type from +// *xi32 to ?xi32 and the assign should refine it further to 4xi32. +// CHECK-LABEL: func @type_refinement_use_refined_type +func @type_refinement_use_refined_type() -> tensor<4xi32> { + + // CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp" + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + + // CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]]) + // CHECK-SAME: -> tensor<4xi32> + // CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster" + // CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) : (tensor<4xi32>) -> tensor<4xi32> + // CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]] + // CHECK-SAME: tensor<4xi32>, tensor<4xi32> + // CHECK: {cluster_attr = "cluster_attr"} + // CHECK-SAME: () -> (tensor<4xi32>, tensor<4xi32>) + // CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1) + + %1 = "tf_device.cluster"() ( { + %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>>) -> tensor + %3 = "tf.SomeComputation"(%2) : (tensor) -> (tensor<4xi32>) + "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>>, tensor<4xi32>) -> () + tf_device.return %3 : tensor<4xi32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<4xi32> + + // CHECK: return %[[CLUSTER_RES]]#0 + // CHECK-SAME: tensor<4xi32> + return %1 : tensor<4xi32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index c2826a350ab..58c18c50b89 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -147,6 +147,16 @@ bool IsResource(Value value) { return getElementTypeOrSelf(value.getType()).isa(); } +// Get the type of the data contained in a resource. Returns null if there is +// no single type in the resource. +Type GetResourceSubtype(Value value) { + auto resource_type = + getElementTypeOrSelf(value.getType()).dyn_cast(); + auto subtypes = resource_type.getSubtypes(); + if (subtypes.size() == 1) return subtypes[0]; + return nullptr; +} + // Performs store-load forwarding. This effectively removes // 1) Any resource loads after a store to that same resource is done // 2) Any resource stores except the last one. @@ -276,6 +286,17 @@ class RegionResourceHoister { result_index(-1) {} bool IsResultIndexAssigned() { return result_index != -1; } + + // Refine the resource type using the given type `type`. + void RefineType(Type type) { + if (!data_type) { + data_type = type; + } else { + data_type = TF::GetCastCompatibleType(data_type, type, + /*may_ignore_ref_type_a=*/false); + assert(data_type != nullptr && "Resource used with incompatible types"); + } + } }; llvm::MapVector resources_; llvm::SetVector written_resources_; @@ -310,6 +331,7 @@ LogicalResult RegionResourceHoister::Analyze() { for (auto resource : all_resources) { ResourceInfo info; + info.data_type = GetResourceSubtype(resource); llvm::BitVector written_regions(op_->getNumRegions()); bool unsupported_use = false; for (OpOperand& use : resource.getUses()) { @@ -341,13 +363,13 @@ LogicalResult RegionResourceHoister::Analyze() { if (read && !info.is_read) { info.is_read = true; - info.data_type = read.value().getType(); + info.RefineType(read.value().getType()); info.read_attrs = user->getAttrDictionary(); } if (write) { info.is_written = true; - info.data_type = write.value().getType(); + info.RefineType(write.value().getType()); info.write_attrs = user->getAttrDictionary(); written_regions.set(user->getParentRegion()->getRegionNumber()); }