diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc index d58b1cd74c4..cc7ce6b39dc 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc @@ -96,6 +96,9 @@ bool TensorFlowRefType::classof(Type type) { #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" >(); } +bool TensorFlowTypeWithSubtype::classof(Type type) { + return type.isa<ResourceType, VariantType>(); +} TensorFlowType TensorFlowRefType::get(Type type) { MLIRContext* ctx = type.getContext(); @@ -178,10 +181,6 @@ Type TensorFlowRefType::RemoveRef() { llvm_unreachable("unexpected tensorflow ref type kind"); } -bool TensorFlowTypeWithSubtype::classof(Type type) { - return type.isa<ResourceType, VariantType>(); -} - Type TensorFlowTypeWithSubtype::RemoveSubtypes() { MLIRContext* ctx = getContext(); if (isa<VariantType>()) return VariantType::get(ctx); @@ -189,18 +188,6 @@ Type TensorFlowTypeWithSubtype::RemoveSubtypes() { llvm_unreachable("unexpected tensorflow type with subtypes kind"); } -TensorFlowTypeWithSubtype TensorFlowTypeWithSubtype::clone( - ArrayRef<TensorType> new_subtypes) { - MLIRContext* ctx = getContext(); - if (isa<VariantType>()) - return VariantType::get(new_subtypes, ctx) - .cast<TensorFlowTypeWithSubtype>(); - if (isa<ResourceType>()) - return ResourceType::get(new_subtypes, ctx) - .cast<TensorFlowTypeWithSubtype>(); - llvm_unreachable("unexpected tensorflow type with subtypes kind"); -} - ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes() { if (auto variant_type = dyn_cast<VariantType>()) return variant_type.getSubtypes(); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index 0ba3201b95d..57ff9dce272 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -246,9 +246,6 @@ class TensorFlowTypeWithSubtype : public TensorFlowType { // Converts a TypeWithSubtype type to the same type but without its subtypes. Type RemoveSubtypes(); - // Clone the current Type with new subtypes. - TensorFlowTypeWithSubtype clone(ArrayRef<TensorType> new_subtypes); - // Returns the subtypes. ArrayRef<TensorType> GetSubtypes(); }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 9b43e3dcd17..9a2f660c74a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -94,120 +94,20 @@ bool CanBeRefined(Type type) { return !shape_type.hasStaticShape(); } -// Compute a refined type between two types `lhs` and `rhs`, the result type -// is always more refined (i.e. has more static information) than `lhs` -// This method will actually merge the information contained in the -// types, it is capable of refining: -// tensor<!tf.variant<tensor<?x8xf32>>> -// and: -// tensor<!tf.variant<tensor<10x?xf32>>> -// into: -// tensor<!tf.variant<tensor<10x8xf32>>> -// -// In case of inconsistencies (rank disagreement for example), it returns `lhs`. -Type TypeMeet(Type lhs, Type rhs) { - DCOMMENT("RefineTypeWith : " << lhs << " : " << rhs); - if (lhs == rhs) return lhs; - - auto rhs_shape_type = rhs.dyn_cast<ShapedType>(); - if (!rhs_shape_type) return lhs; - auto lhs_shape_type = lhs.cast<ShapedType>(); - if (lhs_shape_type.hasRank() && rhs_shape_type.hasRank() && - lhs_shape_type.getRank() != rhs_shape_type.getRank()) { - DCOMMENT("Unexpected rank mismatch: " << lhs << " vs " << rhs); - return lhs; - } - - SmallVector<int64_t> shape; - bool refined_shape = false; - // Build the shape of the refined type, if lhs is unranked it - // will be directly the shape of the refined type, otherwise we merged by - // taking the most specialized. This combines `10x?x?` and `?x?x8` into - // `10x?x8`. - if (!lhs_shape_type.hasRank()) { - if (rhs_shape_type.hasRank()) { - shape.append(rhs_shape_type.getShape().begin(), - rhs_shape_type.getShape().end()); - refined_shape = true; - } - } else if (rhs_shape_type.hasRank()) { - for (auto shape_elts : llvm::enumerate( - llvm::zip(lhs_shape_type.getShape(), rhs_shape_type.getShape()))) { - if (ShapedType::isDynamic(std::get<0>(shape_elts.value())) && - !ShapedType::isDynamic(std::get<1>(shape_elts.value()))) { - shape.push_back(std::get<1>(shape_elts.value())); - refined_shape = true; - DCOMMENT("-> refining shape element #" << shape_elts.index()); - } else { - DCOMMENT("-> not refining shape element #" << shape_elts.index()); - shape.push_back(std::get<0>(shape_elts.value())); - } - } - } - - // Some tensor have an element type wrapping a subtensor, like resource and - // variants. In this case we may recurse on the wrapped subtype. - // `element_type` will contain the refined inferred element type for the - // returned type. - auto lhs_element_type = lhs_shape_type.getElementType(); - auto rhs_element_type_with_subtype = - rhs_shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>(); - // Look for resource or variant element type and ensure we refine the subtype. - // We only support a single subtype at the moment, we won't handle something - // like: - // tensor<!tf.variant<tensor<10xf32>, tensor<8xf32>> - if (rhs_element_type_with_subtype && - rhs_element_type_with_subtype.GetSubtypes().size() == 1) { - auto lhs_element_type_with_subtype = - lhs_element_type.dyn_cast<TF::TensorFlowTypeWithSubtype>(); - TensorType subtype; - if (!lhs_element_type_with_subtype) { - DCOMMENT( - "Unexpected inferred `TensorFlowTypeWithSubtype` when original " - "result isn't"); - } else if (lhs_element_type_with_subtype.GetSubtypes().size() > 1) { - DCOMMENT( - "Unexpected `TensorFlowTypeWithSubtype` original type with size>1"); - } else if (lhs_element_type_with_subtype.GetSubtypes().empty()) { - subtype = rhs_element_type_with_subtype.GetSubtypes().front(); - } else { - // Recurse on the subtypes in the variant/resource. Basically if the input - // were: - // tensor<!tf.variant<tensor<?x8xf32>>> - // and: - // tensor<!tf.variant<tensor<10x8xf32>>> - // we'll try here to refine tensor<?x8xf32> with tensor<10x8xf32>. - auto refined_subtype = - TypeMeet(lhs_element_type_with_subtype.GetSubtypes().front(), - rhs_element_type_with_subtype.GetSubtypes().front()) - .cast<TensorType>(); - if (refined_subtype != - lhs_element_type_with_subtype.GetSubtypes().front()) - subtype = refined_subtype; - } - // If we managed to refine the subtype, recreate the element type itself - // (i.e. the tf.variant or tf.resource). - if (subtype) { - lhs_element_type = lhs_element_type_with_subtype.clone({subtype}); - } - } - if (refined_shape || lhs_element_type != lhs_shape_type.getElementType()) { - Type new_type; - if (!lhs_shape_type.hasRank() && !rhs_shape_type.hasRank()) - new_type = UnrankedTensorType::get(lhs_element_type); - else - new_type = lhs_shape_type.clone(shape, lhs_element_type); - DCOMMENT("Refined to: " << new_type); - return new_type; - } - DCOMMENT("No refinement " << lhs); - return lhs; -} - // Returns whether `original_type` type can be refined with // `potential_refined_type` type. bool CanRefineTypeWith(Type original_type, Type potential_refined_type) { - return original_type != TypeMeet(original_type, potential_refined_type); + if (original_type == potential_refined_type || !CanBeRefined(original_type)) + return false; + + auto shape_type = potential_refined_type.dyn_cast<ShapedType>(); + if (!shape_type) return false; + if (shape_type.hasRank()) return true; + + auto element_type_with_subtype = + shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>(); + return element_type_with_subtype && + !element_type_with_subtype.GetSubtypes().empty(); } // Returns if the shape inference pass supports an op outside the TF dialect. @@ -799,7 +699,7 @@ bool ShapeInference::InferShapeForCall(CallOpInterface call_op) { FuncOp func = dyn_cast<FuncOp>(call_op.resolveCallable()); if (!func) return false; - DCOMMENT("Infer shape for call " << func.getName()); + LLVM_DEBUG(llvm::dbgs() << "Infer shape for call " << func.getName()); Operation* op = call_op.getOperation(); bool changed = false; // Map each of the results of the call to the returned type of the @@ -808,7 +708,7 @@ bool ShapeInference::InferShapeForCall(CallOpInterface call_op) { changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) || changed; } - DCOMMENT(" - call " << func.getName() << "changed ? " << changed << "\n"); + LLVM_DEBUG(llvm::dbgs() << " changed ? " << changed << "\n"); return changed; } @@ -885,7 +785,7 @@ bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) { DCOMMENT("InferShapeForListInitOps " << op << " could not infer"); return false; } - DCOMMENT("InferShapeForListInitOps " << *op << " could be inferred " + DCOMMENT("InferShapeForListInitOps " << op << " could be inferred " << element_type); if (!element_type || !element_type.hasStaticShape()) return false; auto variant_type = VariantType::get(element_type, op->getContext());