[MLIR] Fix resource data type inference to work correctly when a resource is used with multiple types

- If a resource is used in read/writes with different types (some more and some less
  refined), use the most refined type seen as the resource type (as opposed to the last
  seen type). This type is used as the result type of the hoisted reads generated by the
  pass.

PiperOrigin-RevId: 331192056
Change-Id: Ia5e443edcbf2e67294dcd756aea2064a3c811762
This commit is contained in:
Rahul Joshi 2020-09-11 11:49:49 -07:00 committed by TensorFlower Gardener
parent a08ba6a413
commit 6d56f6a3d2
4 changed files with 190 additions and 97 deletions

View File

@ -62,101 +62,6 @@ bool GetCastCompatibleShape(llvm::ArrayRef<int64_t> a_shape,
return true;
}
// Given two types `a` and `b`, returns a refined type which is cast compatible
// with both `a` and `b` and is equal to or more precise than both of them. It
// returns empty Type if the input types are not cast compatible.
//
// The two types are considered cast compatible if they have dynamically equal
// shapes and element type. For element types that do not have subtypes, they
// must be equal. However for TensorFlow types such as Resource and Variant,
// that also have subtypes, we recursively check for subtype compatibilty for
// Resource types and assume all variant types are cast compatible. If either
// one of `a` or `b` have empty subtypes, they are considered cast compatible.
//
// The returned type is same or more precise than the input types. For example,
// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and
// tensor<?x4x?xf32> respectively, the returned type is tensor<2x4x?xf32>.
//
// Provides option to ignore ref types on 'a'. This is useful for TF ops that
// might allow operands to either be same as result type or be a ref type
// corresponding to it.
mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
bool may_ignore_ref_type_a) {
// Fast path if everything is equal.
if (a == b) return b;
auto a_tt = a.dyn_cast<mlir::TensorType>();
auto b_tt = b.dyn_cast<mlir::TensorType>();
// If only one of a or b is a tensor type, they are incompatible.
if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr;
// For non-tensor types, we do not need to worry about shape and can return
// early.
if (!a_tt && !b_tt) {
// Remove ref types.
if (may_ignore_ref_type_a) {
if (auto ref_type = a.dyn_cast<mlir::TF::TensorFlowRefType>()) {
a = ref_type.RemoveRef();
if (a == b) return a;
}
}
if (a.getTypeID() != b.getTypeID()) return nullptr;
// If either is not a type that contain subtypes then the types are not cast
// compatible.
auto a_wst = a.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
auto b_wst = b.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
if (!a_wst || !b_wst) return nullptr;
// For Variant types we are more permissive right now and accept all pairs
// of Variant types. If we are more constrainted and check compatibility of
// subtypes, we might reject valid graphs.
// TODO(prakalps): Variant doesn't have a subtype, we assign it
// one, so we should only assign it one when we know the subtype. Then we
// can be more constrained and check subtypes for cast compatibility as
// well.
if (a.isa<mlir::TF::VariantType>()) return a;
// For Resource types, we recursively check the subtypes for cast
// compatibility, if possible. Otherwise treat them as compatible.
auto a_wst_st = a_wst.GetSubtypes();
auto b_wst_st = b_wst.GetSubtypes();
if (a_wst_st.empty() || b_wst_st.empty()) return a;
if (a_wst_st.size() != b_wst_st.size()) return nullptr;
llvm::SmallVector<mlir::TensorType, 4> refined_subtypes;
for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) {
mlir::Type refined_st =
GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes),
/*may_ignore_ref_type_a=*/false);
if (!refined_st) return nullptr;
refined_subtypes.push_back(refined_st.cast<mlir::TensorType>());
}
return mlir::TF::ResourceType::get(refined_subtypes, a.getContext());
}
// For tensor types, check compatibility of both element type and shape.
mlir::Type refined_element_ty = GetCastCompatibleType(
a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a);
if (!refined_element_ty) return nullptr;
if (!a_tt.hasRank() && !b_tt.hasRank()) {
return mlir::UnrankedTensorType::get(refined_element_ty);
}
if (!a_tt.hasRank()) {
return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty);
}
if (!b_tt.hasRank()) {
return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty);
}
llvm::SmallVector<int64_t, 8> refined_shape;
if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape))
return nullptr;
return mlir::RankedTensorType::get(refined_shape, refined_element_ty);
}
} // namespace
namespace mlir {
@ -343,6 +248,102 @@ bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs) {
return true;
}
// Given two types `a` and `b`, returns a refined type which is cast compatible
// with both `a` and `b` and is equal to or more precise than both of them. It
// returns empty Type if the input types are not cast compatible.
//
// The two types are considered cast compatible if they have dynamically equal
// shapes and element type. For element types that do not have subtypes, they
// must be equal. However for TensorFlow types such as Resource and Variant,
// that also have subtypes, we recursively check for subtype compatibilty for
// Resource types and assume all variant types are cast compatible. If either
// one of `a` or `b` have empty subtypes, they are considered cast compatible.
//
// The returned type is same or more precise than the input types. For example,
// if `a` and `b` are cast compatible types tensor<2x?x?xf32> and
// tensor<?x4x?xf32> respectively, the returned type is tensor<2x4x?xf32>.
//
// Provides option to ignore ref types on 'a'. This is useful for TF ops that
// might allow operands to either be same as result type or be a ref type
// corresponding to it.
mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
bool may_ignore_ref_type_a) {
// Fast path if everything is equal.
if (a == b) return b;
auto a_tt = a.dyn_cast<mlir::TensorType>();
auto b_tt = b.dyn_cast<mlir::TensorType>();
// If only one of a or b is a tensor type, they are incompatible.
if (static_cast<bool>(a_tt) ^ static_cast<bool>(b_tt)) return nullptr;
// For non-tensor types, we do not need to worry about shape and can return
// early.
if (!a_tt && !b_tt) {
// Remove ref types.
if (may_ignore_ref_type_a) {
if (auto ref_type = a.dyn_cast<mlir::TF::TensorFlowRefType>()) {
a = ref_type.RemoveRef();
if (a == b) return a;
}
}
if (a.getTypeID() != b.getTypeID()) return nullptr;
// If either is not a type that contain subtypes then the types are not cast
// compatible.
auto a_wst = a.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
auto b_wst = b.dyn_cast<mlir::TF::TensorFlowTypeWithSubtype>();
if (!a_wst || !b_wst) return nullptr;
// For Variant types we are more permissive right now and accept all pairs
// of Variant types. If we are more constrainted and check compatibility of
// subtypes, we might reject valid graphs.
// TODO(prakalps): Variant doesn't have a subtype, we assign it
// one, so we should only assign it one when we know the subtype. Then we
// can be more constrained and check subtypes for cast compatibility as
// well.
if (a.isa<mlir::TF::VariantType>()) return a;
// For Resource types, we recursively check the subtypes for cast
// compatibility, if possible. Otherwise treat them as compatible.
auto a_wst_st = a_wst.GetSubtypes();
auto b_wst_st = b_wst.GetSubtypes();
if (a_wst_st.empty() || b_wst_st.empty()) return a;
if (a_wst_st.size() != b_wst_st.size()) return nullptr;
llvm::SmallVector<mlir::TensorType, 4> refined_subtypes;
for (auto subtypes : llvm::zip(a_wst_st, b_wst_st)) {
mlir::Type refined_st =
GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes),
/*may_ignore_ref_type_a=*/false);
if (!refined_st) return nullptr;
refined_subtypes.push_back(refined_st.cast<mlir::TensorType>());
}
return mlir::TF::ResourceType::get(refined_subtypes, a.getContext());
}
// For tensor types, check compatibility of both element type and shape.
mlir::Type refined_element_ty = GetCastCompatibleType(
a_tt.getElementType(), b_tt.getElementType(), may_ignore_ref_type_a);
if (!refined_element_ty) return nullptr;
if (!a_tt.hasRank() && !b_tt.hasRank()) {
return mlir::UnrankedTensorType::get(refined_element_ty);
}
if (!a_tt.hasRank()) {
return mlir::RankedTensorType::get(b_tt.getShape(), refined_element_ty);
}
if (!b_tt.hasRank()) {
return mlir::RankedTensorType::get(a_tt.getShape(), refined_element_ty);
}
llvm::SmallVector<int64_t, 8> refined_shape;
if (!GetCastCompatibleShape(a_tt.getShape(), b_tt.getShape(), &refined_shape))
return nullptr;
return mlir::RankedTensorType::get(refined_shape, refined_element_ty);
}
bool HasCompatibleElementTypes(Type lhs, Type rhs,
bool may_ignore_ref_type_lhs) {
return GetCastCompatibleType(lhs, rhs, may_ignore_ref_type_lhs) != nullptr;

View File

@ -272,6 +272,15 @@ class VariantType : public detail::TypeWithSubtypeImpl<VariantType> {
static std::string getTypeName() { return "VariantType"; }
};
// Given two types `a` and `b`, returns a refined type which is cast compatible
// with both `a` and `b` and is equal to or more precise than both of them. It
// returns empty Type if the input types are not cast compatible.
// Provides option to ignore ref types on 'a'. This is useful for TF ops that
// might allow operands to either be same as result type or be a ref type
// corresponding to it.
mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b,
bool may_ignore_ref_type_a);
// Returns whether two arrays of Type are broadcast compatible.
bool BroadcastCompatible(ArrayRef<Type> lhs, ArrayRef<Type> rhs);

View File

@ -987,5 +987,66 @@ func @test_unsupported_resource_op() -> tensor<*xi32> {
return %1 : tensor<*xi32>
}
// -----
// Test type refinement. If the resource has a single subtype, check that that
// type gets used when hoisting the read. None of the result types will change.
// CHECK-LABEL: func @type_refinement_use_subtype
func @type_refinement_use_subtype() -> tensor<*xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<4xi32>>>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
// CHECK-SAME: -> tensor<4xi32>
// CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster"
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) : (tensor<4xi32>) -> tensor<*xi32>
// CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]]
// CHECK-SAME: tensor<*xi32>, tensor<*xi32>
// CHECK: {cluster_attr = "cluster_attr"}
// CHECK-SAME: () -> (tensor<*xi32>, tensor<*xi32>)
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1)
%1 = "tf_device.cluster"() ( {
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource<tensor<4xi32>>>) -> tensor<*xi32>
%3 = "tf.SomeComputation"(%2) : (tensor<*xi32>) -> (tensor<*xi32>)
"tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource<tensor<4xi32>>>, tensor<*xi32>) -> ()
tf_device.return %3 : tensor<*xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32>
// CHECK: return %[[CLUSTER_RES]]#0
// CHECK-SAME: tensor<*xi32>
return %1 : tensor<*xi32>
}
// If multiple types are used across reads and writes, check that the read uses
// the most refined type. The first ReadVariable should refine the type from
// *xi32 to ?xi32 and the assign should refine it further to 4xi32.
// CHECK-LABEL: func @type_refinement_use_refined_type
func @type_refinement_use_refined_type() -> tensor<4xi32> {
// CHECK: %[[RES_HANDLE:[0-9]*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<*xi32>>>
// CHECK: %[[RES_READ_VAL:[0-9]*]] = "tf.ReadVariableOp"(%[[RES_HANDLE]])
// CHECK-SAME: -> tensor<4xi32>
// CHECK: %[[CLUSTER_RES:[0-9]*]]:2 = "tf_device.cluster"
// CHECK: %[[COMPUTE_RES:[0-9]*]] = "tf.SomeComputation"(%[[RES_READ_VAL]]) : (tensor<4xi32>) -> tensor<4xi32>
// CHECK: tf_device.return %[[COMPUTE_RES]], %[[COMPUTE_RES]]
// CHECK-SAME: tensor<4xi32>, tensor<4xi32>
// CHECK: {cluster_attr = "cluster_attr"}
// CHECK-SAME: () -> (tensor<4xi32>, tensor<4xi32>)
// CHECK: "tf.AssignVariableOp"(%[[RES_HANDLE]], %[[CLUSTER_RES]]#1)
%1 = "tf_device.cluster"() ( {
%2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>) -> tensor<?xi32>
%3 = "tf.SomeComputation"(%2) : (tensor<?xi32>) -> (tensor<4xi32>)
"tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource<tensor<*xi32>>>, tensor<4xi32>) -> ()
tf_device.return %3 : tensor<4xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<4xi32>
// CHECK: return %[[CLUSTER_RES]]#0
// CHECK-SAME: tensor<4xi32>
return %1 : tensor<4xi32>
}

View File

@ -147,6 +147,16 @@ bool IsResource(Value value) {
return getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>();
}
// Get the type of the data contained in a resource. Returns null if there is
// no single type in the resource.
Type GetResourceSubtype(Value value) {
auto resource_type =
getElementTypeOrSelf(value.getType()).dyn_cast<TF::ResourceType>();
auto subtypes = resource_type.getSubtypes();
if (subtypes.size() == 1) return subtypes[0];
return nullptr;
}
// Performs store-load forwarding. This effectively removes
// 1) Any resource loads after a store to that same resource is done
// 2) Any resource stores except the last one.
@ -276,6 +286,17 @@ class RegionResourceHoister {
result_index(-1) {}
bool IsResultIndexAssigned() { return result_index != -1; }
// Refine the resource type using the given type `type`.
void RefineType(Type type) {
if (!data_type) {
data_type = type;
} else {
data_type = TF::GetCastCompatibleType(data_type, type,
/*may_ignore_ref_type_a=*/false);
assert(data_type != nullptr && "Resource used with incompatible types");
}
}
};
llvm::MapVector<Value, ResourceInfo> resources_;
llvm::SetVector<Value> written_resources_;
@ -310,6 +331,7 @@ LogicalResult RegionResourceHoister::Analyze() {
for (auto resource : all_resources) {
ResourceInfo info;
info.data_type = GetResourceSubtype(resource);
llvm::BitVector written_regions(op_->getNumRegions());
bool unsupported_use = false;
for (OpOperand& use : resource.getUses()) {
@ -341,13 +363,13 @@ LogicalResult RegionResourceHoister::Analyze() {
if (read && !info.is_read) {
info.is_read = true;
info.data_type = read.value().getType();
info.RefineType(read.value().getType());
info.read_attrs = user->getAttrDictionary();
}
if (write) {
info.is_written = true;
info.data_type = write.value().getType();
info.RefineType(write.value().getType());
info.write_attrs = user->getAttrDictionary();
written_regions.set(user->getParentRegion()->getRegionNumber());
}