Internal change

PiperOrigin-RevId: 361245058
Change-Id: Ife308d1661b0eed6befe067046623eb037f1b05f
This commit is contained in:
A. Unique TensorFlower 2021-03-05 16:07:37 -08:00 committed by TensorFlower Gardener
parent d0b5e4188c
commit 081dcedc8a
3 changed files with 17 additions and 133 deletions

View File

@ -96,6 +96,9 @@ bool TensorFlowRefType::classof(Type type) {
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"
>();
}
bool TensorFlowTypeWithSubtype::classof(Type type) {
return type.isa<ResourceType, VariantType>();
}
TensorFlowType TensorFlowRefType::get(Type type) {
MLIRContext* ctx = type.getContext();
@ -178,10 +181,6 @@ Type TensorFlowRefType::RemoveRef() {
llvm_unreachable("unexpected tensorflow ref type kind");
}
bool TensorFlowTypeWithSubtype::classof(Type type) {
return type.isa<ResourceType, VariantType>();
}
Type TensorFlowTypeWithSubtype::RemoveSubtypes() {
MLIRContext* ctx = getContext();
if (isa<VariantType>()) return VariantType::get(ctx);
@ -189,18 +188,6 @@ Type TensorFlowTypeWithSubtype::RemoveSubtypes() {
llvm_unreachable("unexpected tensorflow type with subtypes kind");
}
TensorFlowTypeWithSubtype TensorFlowTypeWithSubtype::clone(
ArrayRef<TensorType> new_subtypes) {
MLIRContext* ctx = getContext();
if (isa<VariantType>())
return VariantType::get(new_subtypes, ctx)
.cast<TensorFlowTypeWithSubtype>();
if (isa<ResourceType>())
return ResourceType::get(new_subtypes, ctx)
.cast<TensorFlowTypeWithSubtype>();
llvm_unreachable("unexpected tensorflow type with subtypes kind");
}
ArrayRef<TensorType> TensorFlowTypeWithSubtype::GetSubtypes() {
if (auto variant_type = dyn_cast<VariantType>())
return variant_type.getSubtypes();

View File

@ -246,9 +246,6 @@ class TensorFlowTypeWithSubtype : public TensorFlowType {
// Converts a TypeWithSubtype type to the same type but without its subtypes.
Type RemoveSubtypes();
// Clone the current Type with new subtypes.
TensorFlowTypeWithSubtype clone(ArrayRef<TensorType> new_subtypes);
// Returns the subtypes.
ArrayRef<TensorType> GetSubtypes();
};

View File

@ -94,120 +94,20 @@ bool CanBeRefined(Type type) {
return !shape_type.hasStaticShape();
}
// Compute a refined type between two types `lhs` and `rhs`, the result type
// is always more refined (i.e. has more static information) than `lhs`
// This method will actually merge the information contained in the
// types, it is capable of refining:
// tensor<!tf.variant<tensor<?x8xf32>>>
// and:
// tensor<!tf.variant<tensor<10x?xf32>>>
// into:
// tensor<!tf.variant<tensor<10x8xf32>>>
//
// In case of inconsistencies (rank disagreement for example), it returns `lhs`.
Type TypeMeet(Type lhs, Type rhs) {
DCOMMENT("RefineTypeWith : " << lhs << " : " << rhs);
if (lhs == rhs) return lhs;
auto rhs_shape_type = rhs.dyn_cast<ShapedType>();
if (!rhs_shape_type) return lhs;
auto lhs_shape_type = lhs.cast<ShapedType>();
if (lhs_shape_type.hasRank() && rhs_shape_type.hasRank() &&
lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
DCOMMENT("Unexpected rank mismatch: " << lhs << " vs " << rhs);
return lhs;
}
SmallVector<int64_t> shape;
bool refined_shape = false;
// Build the shape of the refined type, if lhs is unranked it
// will be directly the shape of the refined type, otherwise we merged by
// taking the most specialized. This combines `10x?x?` and `?x?x8` into
// `10x?x8`.
if (!lhs_shape_type.hasRank()) {
if (rhs_shape_type.hasRank()) {
shape.append(rhs_shape_type.getShape().begin(),
rhs_shape_type.getShape().end());
refined_shape = true;
}
} else if (rhs_shape_type.hasRank()) {
for (auto shape_elts : llvm::enumerate(
llvm::zip(lhs_shape_type.getShape(), rhs_shape_type.getShape()))) {
if (ShapedType::isDynamic(std::get<0>(shape_elts.value())) &&
!ShapedType::isDynamic(std::get<1>(shape_elts.value()))) {
shape.push_back(std::get<1>(shape_elts.value()));
refined_shape = true;
DCOMMENT("-> refining shape element #" << shape_elts.index());
} else {
DCOMMENT("-> not refining shape element #" << shape_elts.index());
shape.push_back(std::get<0>(shape_elts.value()));
}
}
}
// Some tensor have an element type wrapping a subtensor, like resource and
// variants. In this case we may recurse on the wrapped subtype.
// `element_type` will contain the refined inferred element type for the
// returned type.
auto lhs_element_type = lhs_shape_type.getElementType();
auto rhs_element_type_with_subtype =
rhs_shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
// Look for resource or variant element type and ensure we refine the subtype.
// We only support a single subtype at the moment, we won't handle something
// like:
// tensor<!tf.variant<tensor<10xf32>, tensor<8xf32>>
if (rhs_element_type_with_subtype &&
rhs_element_type_with_subtype.GetSubtypes().size() == 1) {
auto lhs_element_type_with_subtype =
lhs_element_type.dyn_cast<TF::TensorFlowTypeWithSubtype>();
TensorType subtype;
if (!lhs_element_type_with_subtype) {
DCOMMENT(
"Unexpected inferred `TensorFlowTypeWithSubtype` when original "
"result isn't");
} else if (lhs_element_type_with_subtype.GetSubtypes().size() > 1) {
DCOMMENT(
"Unexpected `TensorFlowTypeWithSubtype` original type with size>1");
} else if (lhs_element_type_with_subtype.GetSubtypes().empty()) {
subtype = rhs_element_type_with_subtype.GetSubtypes().front();
} else {
// Recurse on the subtypes in the variant/resource. Basically if the input
// were:
// tensor<!tf.variant<tensor<?x8xf32>>>
// and:
// tensor<!tf.variant<tensor<10x8xf32>>>
// we'll try here to refine tensor<?x8xf32> with tensor<10x8xf32>.
auto refined_subtype =
TypeMeet(lhs_element_type_with_subtype.GetSubtypes().front(),
rhs_element_type_with_subtype.GetSubtypes().front())
.cast<TensorType>();
if (refined_subtype !=
lhs_element_type_with_subtype.GetSubtypes().front())
subtype = refined_subtype;
}
// If we managed to refine the subtype, recreate the element type itself
// (i.e. the tf.variant or tf.resource).
if (subtype) {
lhs_element_type = lhs_element_type_with_subtype.clone({subtype});
}
}
if (refined_shape || lhs_element_type != lhs_shape_type.getElementType()) {
Type new_type;
if (!lhs_shape_type.hasRank() && !rhs_shape_type.hasRank())
new_type = UnrankedTensorType::get(lhs_element_type);
else
new_type = lhs_shape_type.clone(shape, lhs_element_type);
DCOMMENT("Refined to: " << new_type);
return new_type;
}
DCOMMENT("No refinement " << lhs);
return lhs;
}
// Returns whether `original_type` type can be refined with
// `potential_refined_type` type.
bool CanRefineTypeWith(Type original_type, Type potential_refined_type) {
return original_type != TypeMeet(original_type, potential_refined_type);
if (original_type == potential_refined_type || !CanBeRefined(original_type))
return false;
auto shape_type = potential_refined_type.dyn_cast<ShapedType>();
if (!shape_type) return false;
if (shape_type.hasRank()) return true;
auto element_type_with_subtype =
shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
return element_type_with_subtype &&
!element_type_with_subtype.GetSubtypes().empty();
}
// Returns if the shape inference pass supports an op outside the TF dialect.
@ -799,7 +699,7 @@ bool ShapeInference::InferShapeForCall(CallOpInterface call_op) {
FuncOp func = dyn_cast<FuncOp>(call_op.resolveCallable());
if (!func) return false;
DCOMMENT("Infer shape for call " << func.getName());
LLVM_DEBUG(llvm::dbgs() << "Infer shape for call " << func.getName());
Operation* op = call_op.getOperation();
bool changed = false;
// Map each of the results of the call to the returned type of the
@ -808,7 +708,7 @@ bool ShapeInference::InferShapeForCall(CallOpInterface call_op) {
changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
changed;
}
DCOMMENT(" - call " << func.getName() << "changed ? " << changed << "\n");
LLVM_DEBUG(llvm::dbgs() << " changed ? " << changed << "\n");
return changed;
}
@ -885,7 +785,7 @@ bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) {
DCOMMENT("InferShapeForListInitOps " << op << " could not infer");
return false;
}
DCOMMENT("InferShapeForListInitOps " << *op << " could be inferred "
DCOMMENT("InferShapeForListInitOps " << op << " could be inferred "
<< element_type);
if (!element_type || !element_type.hasStaticShape()) return false;
auto variant_type = VariantType::get(element_type, op->getContext());