Internal change
PiperOrigin-RevId: 361245058 Change-Id: Ife308d1661b0eed6befe067046623eb037f1b05f
This commit is contained in:
parent
d0b5e4188c
commit
081dcedc8a
@ -96,6 +96,9 @@ bool TensorFlowRefType::classof(Type type) {
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
|
||||
>();
|
||||
}
|
||||
bool TensorFlowTypeWithSubtype::classof(Type type) {
|
||||
return type.isa<ResourceType, VariantType>();
|
||||
}
|
||||
|
||||
TensorFlowType TensorFlowRefType::get(Type type) {
|
||||
MLIRContext* ctx = type.getContext();
|
||||
@ -178,10 +181,6 @@ Type TensorFlowRefType::RemoveRef() {
|
||||
llvm_unreachable("unexpected tensorflow ref type kind");
|
||||
}
|
||||
|
||||
bool TensorFlowTypeWithSubtype::classof(Type type) {
|
||||
return type.isa<ResourceType, VariantType>();
|
||||
}
|
||||
|
||||
Type TensorFlowTypeWithSubtype::RemoveSubtypes() {
|
||||
MLIRContext* ctx = getContext();
|
||||
if (isa<VariantType>()) return VariantType::get(ctx);
|
||||
@ -189,18 +188,6 @@ Type TensorFlowTypeWithSubtype::RemoveSubtypes() {
|
||||
llvm_unreachable("unexpected tensorflow type with subtypes kind");
|
||||
}
|
||||
|
||||
TensorFlowTypeWithSubtype TensorFlowTypeWithSubtype::clone(
|
||||
ArrayRef<TensorType> new_subtypes) {
|
||||
MLIRContext* ctx = getContext();
|
||||
if (isa<VariantType>())
|
||||
return VariantType::get(new_subtypes, ctx)
|
||||
.cast<TensorFlowTypeWithSubtype>();
|
||||
if (isa<ResourceType>())
|
||||
return ResourceType::get(new_subtypes, ctx)
|
||||
.cast<TensorFlowTypeWithSubtype>();
|
||||
llvm_unreachable("unexpected tensorflow type with subtypes kind");
|
||||
}
|
||||
|
||||
ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes() {
|
||||
if (auto variant_type = dyn_cast<VariantType>())
|
||||
return variant_type.getSubtypes();
|
||||
|
@ -246,9 +246,6 @@ class TensorFlowTypeWithSubtype : public TensorFlowType {
|
||||
// Converts a TypeWithSubtype type to the same type but without its subtypes.
|
||||
Type RemoveSubtypes();
|
||||
|
||||
// Clone the current Type with new subtypes.
|
||||
TensorFlowTypeWithSubtype clone(ArrayRef<TensorType> new_subtypes);
|
||||
|
||||
// Returns the subtypes.
|
||||
ArrayRef<TensorType> GetSubtypes();
|
||||
};
|
||||
|
@ -94,120 +94,20 @@ bool CanBeRefined(Type type) {
|
||||
return !shape_type.hasStaticShape();
|
||||
}
|
||||
|
||||
// Compute a refined type between two types `lhs` and `rhs`, the result type
|
||||
// is always more refined (i.e. has more static information) than `lhs`
|
||||
// This method will actually merge the information contained in the
|
||||
// types, it is capable of refining:
|
||||
// tensor<!tf.variant<tensor<?x8xf32>>>
|
||||
// and:
|
||||
// tensor<!tf.variant<tensor<10x?xf32>>>
|
||||
// into:
|
||||
// tensor<!tf.variant<tensor<10x8xf32>>>
|
||||
//
|
||||
// In case of inconsistencies (rank disagreement for example), it returns `lhs`.
|
||||
Type TypeMeet(Type lhs, Type rhs) {
|
||||
DCOMMENT("RefineTypeWith : " << lhs << " : " << rhs);
|
||||
if (lhs == rhs) return lhs;
|
||||
|
||||
auto rhs_shape_type = rhs.dyn_cast<ShapedType>();
|
||||
if (!rhs_shape_type) return lhs;
|
||||
auto lhs_shape_type = lhs.cast<ShapedType>();
|
||||
if (lhs_shape_type.hasRank() && rhs_shape_type.hasRank() &&
|
||||
lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
|
||||
DCOMMENT("Unexpected rank mismatch: " << lhs << " vs " << rhs);
|
||||
return lhs;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> shape;
|
||||
bool refined_shape = false;
|
||||
// Build the shape of the refined type, if lhs is unranked it
|
||||
// will be directly the shape of the refined type, otherwise we merged by
|
||||
// taking the most specialized. This combines `10x?x?` and `?x?x8` into
|
||||
// `10x?x8`.
|
||||
if (!lhs_shape_type.hasRank()) {
|
||||
if (rhs_shape_type.hasRank()) {
|
||||
shape.append(rhs_shape_type.getShape().begin(),
|
||||
rhs_shape_type.getShape().end());
|
||||
refined_shape = true;
|
||||
}
|
||||
} else if (rhs_shape_type.hasRank()) {
|
||||
for (auto shape_elts : llvm::enumerate(
|
||||
llvm::zip(lhs_shape_type.getShape(), rhs_shape_type.getShape()))) {
|
||||
if (ShapedType::isDynamic(std::get<0>(shape_elts.value())) &&
|
||||
!ShapedType::isDynamic(std::get<1>(shape_elts.value()))) {
|
||||
shape.push_back(std::get<1>(shape_elts.value()));
|
||||
refined_shape = true;
|
||||
DCOMMENT("-> refining shape element #" << shape_elts.index());
|
||||
} else {
|
||||
DCOMMENT("-> not refining shape element #" << shape_elts.index());
|
||||
shape.push_back(std::get<0>(shape_elts.value()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Some tensor have an element type wrapping a subtensor, like resource and
|
||||
// variants. In this case we may recurse on the wrapped subtype.
|
||||
// `element_type` will contain the refined inferred element type for the
|
||||
// returned type.
|
||||
auto lhs_element_type = lhs_shape_type.getElementType();
|
||||
auto rhs_element_type_with_subtype =
|
||||
rhs_shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
|
||||
// Look for resource or variant element type and ensure we refine the subtype.
|
||||
// We only support a single subtype at the moment, we won't handle something
|
||||
// like:
|
||||
// tensor<!tf.variant<tensor<10xf32>, tensor<8xf32>>
|
||||
if (rhs_element_type_with_subtype &&
|
||||
rhs_element_type_with_subtype.GetSubtypes().size() == 1) {
|
||||
auto lhs_element_type_with_subtype =
|
||||
lhs_element_type.dyn_cast<TF::TensorFlowTypeWithSubtype>();
|
||||
TensorType subtype;
|
||||
if (!lhs_element_type_with_subtype) {
|
||||
DCOMMENT(
|
||||
"Unexpected inferred `TensorFlowTypeWithSubtype` when original "
|
||||
"result isn't");
|
||||
} else if (lhs_element_type_with_subtype.GetSubtypes().size() > 1) {
|
||||
DCOMMENT(
|
||||
"Unexpected `TensorFlowTypeWithSubtype` original type with size>1");
|
||||
} else if (lhs_element_type_with_subtype.GetSubtypes().empty()) {
|
||||
subtype = rhs_element_type_with_subtype.GetSubtypes().front();
|
||||
} else {
|
||||
// Recurse on the subtypes in the variant/resource. Basically if the input
|
||||
// were:
|
||||
// tensor<!tf.variant<tensor<?x8xf32>>>
|
||||
// and:
|
||||
// tensor<!tf.variant<tensor<10x8xf32>>>
|
||||
// we'll try here to refine tensor<?x8xf32> with tensor<10x8xf32>.
|
||||
auto refined_subtype =
|
||||
TypeMeet(lhs_element_type_with_subtype.GetSubtypes().front(),
|
||||
rhs_element_type_with_subtype.GetSubtypes().front())
|
||||
.cast<TensorType>();
|
||||
if (refined_subtype !=
|
||||
lhs_element_type_with_subtype.GetSubtypes().front())
|
||||
subtype = refined_subtype;
|
||||
}
|
||||
// If we managed to refine the subtype, recreate the element type itself
|
||||
// (i.e. the tf.variant or tf.resource).
|
||||
if (subtype) {
|
||||
lhs_element_type = lhs_element_type_with_subtype.clone({subtype});
|
||||
}
|
||||
}
|
||||
if (refined_shape || lhs_element_type != lhs_shape_type.getElementType()) {
|
||||
Type new_type;
|
||||
if (!lhs_shape_type.hasRank() && !rhs_shape_type.hasRank())
|
||||
new_type = UnrankedTensorType::get(lhs_element_type);
|
||||
else
|
||||
new_type = lhs_shape_type.clone(shape, lhs_element_type);
|
||||
DCOMMENT("Refined to: " << new_type);
|
||||
return new_type;
|
||||
}
|
||||
DCOMMENT("No refinement " << lhs);
|
||||
return lhs;
|
||||
}
|
||||
|
||||
// Returns whether `original_type` type can be refined with
|
||||
// `potential_refined_type` type.
|
||||
bool CanRefineTypeWith(Type original_type, Type potential_refined_type) {
|
||||
return original_type != TypeMeet(original_type, potential_refined_type);
|
||||
if (original_type == potential_refined_type || !CanBeRefined(original_type))
|
||||
return false;
|
||||
|
||||
auto shape_type = potential_refined_type.dyn_cast<ShapedType>();
|
||||
if (!shape_type) return false;
|
||||
if (shape_type.hasRank()) return true;
|
||||
|
||||
auto element_type_with_subtype =
|
||||
shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
|
||||
return element_type_with_subtype &&
|
||||
!element_type_with_subtype.GetSubtypes().empty();
|
||||
}
|
||||
|
||||
// Returns if the shape inference pass supports an op outside the TF dialect.
|
||||
@ -799,7 +699,7 @@ bool ShapeInference::InferShapeForCall(CallOpInterface call_op) {
|
||||
FuncOp func = dyn_cast<FuncOp>(call_op.resolveCallable());
|
||||
if (!func) return false;
|
||||
|
||||
DCOMMENT("Infer shape for call " << func.getName());
|
||||
LLVM_DEBUG(llvm::dbgs() << "Infer shape for call " << func.getName());
|
||||
Operation* op = call_op.getOperation();
|
||||
bool changed = false;
|
||||
// Map each of the results of the call to the returned type of the
|
||||
@ -808,7 +708,7 @@ bool ShapeInference::InferShapeForCall(CallOpInterface call_op) {
|
||||
changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
|
||||
changed;
|
||||
}
|
||||
DCOMMENT(" - call " << func.getName() << "changed ? " << changed << "\n");
|
||||
LLVM_DEBUG(llvm::dbgs() << " changed ? " << changed << "\n");
|
||||
|
||||
return changed;
|
||||
}
|
||||
@ -885,7 +785,7 @@ bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) {
|
||||
DCOMMENT("InferShapeForListInitOps " << op << " could not infer");
|
||||
return false;
|
||||
}
|
||||
DCOMMENT("InferShapeForListInitOps " << *op << " could be inferred "
|
||||
DCOMMENT("InferShapeForListInitOps " << op << " could be inferred "
|
||||
<< element_type);
|
||||
if (!element_type || !element_type.hasStaticShape()) return false;
|
||||
auto variant_type = VariantType::get(element_type, op->getContext());
|
||||
|
Loading…
Reference in New Issue
Block a user