[MLIR][NFC] Adopt variadic isa<> on MLIR Types and Attributes

- Also adopt variadic llvm::isa<> in more places

PiperOrigin-RevId: 320206113
Change-Id: Ia03a1503f699fb6be6dff02e90b6630d6d894b19
This commit is contained in:
Rahul Joshi 2020-07-08 09:46:21 -07:00 committed by TensorFlower Gardener
parent 4e1aa305a1
commit e2f8269f17
15 changed files with 24 additions and 35 deletions

View File

@ -443,8 +443,7 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) { if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
TF_ASSIGN_OR_RETURN(value, TF_ASSIGN_OR_RETURN(value,
ConvertFloatBuffer(shaped_type, float_type, buffer)); ConvertFloatBuffer(shaped_type, float_type, buffer));
} else if (elem_type.isa<mlir::IntegerType>() || } else if (elem_type.isa<mlir::IntegerType, QuantizedType>()) {
elem_type.isa<QuantizedType>()) {
TF_ASSIGN_OR_RETURN(value, TF_ASSIGN_OR_RETURN(value,
ConvertIntBuffer(shaped_type, elem_type, buffer)); ConvertIntBuffer(shaped_type, elem_type, buffer));
} else if (elem_type.isa<mlir::TF::StringType>()) { } else if (elem_type.isa<mlir::TF::StringType>()) {
@ -456,8 +455,7 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
refs.push_back({ref.data(), ref.size()}); refs.push_back({ref.data(), ref.size()});
value = mlir::DenseStringElementsAttr::get(shaped_type, refs); value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
} else if (elem_type.isa<mlir::ComplexType>() || } else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
elem_type.isa<mlir::TF::TensorFlowType>()) {
auto dialect = elem_type.getContext()->getRegisteredDialect("tf"); auto dialect = elem_type.getContext()->getRegisteredDialect("tf");
tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer); tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
std::string mangled = tensorflow::mangling_util::MangleTensor(repr); std::string mangled = tensorflow::mangling_util::MangleTensor(repr);

View File

@ -694,8 +694,7 @@ void QuantizationDriver::SetupAllStates() {
fn_.walk([&](Operation *op) { fn_.walk([&](Operation *op) {
if (op->isKnownTerminator() || if (op->isKnownTerminator() ||
op->hasTrait<OpTrait::quant::NoQuantizableResult>() || op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
llvm::isa<quant::DequantizeCastOp>(op) || llvm::isa<quant::DequantizeCastOp, quant::QuantizeCastOp>(op))
llvm::isa<quant::QuantizeCastOp>(op))
return; return;
work_list_.push_back(op); work_list_.push_back(op);

View File

@ -386,8 +386,7 @@ struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
Operation* def = pre_quantized.getDefiningOp(); Operation* def = pre_quantized.getDefiningOp();
if (!def) return failure(); if (!def) return failure();
if (llvm::isa<FixedOutputRangeInterface>(def) || if (llvm::isa<FixedOutputRangeInterface, SameScalesOpInterface>(def) ||
llvm::isa<SameScalesOpInterface>(def) ||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) { def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
return failure(); return failure();
} }

View File

@ -560,7 +560,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
return failure(); return failure();
ShapedType filter_type = filter_cst.getType(); ShapedType filter_type = filter_cst.getType();
if (llvm::isa<AddOp>(binary_op) || llvm::isa<SubOp>(binary_op)) { if (llvm::isa<AddOp, SubOp>(binary_op)) {
auto padding = fc_op.template getAttrOfType<StringAttr>("padding"); auto padding = fc_op.template getAttrOfType<StringAttr>("padding");
if (padding && padding.getValue() != "VALID") return failure(); if (padding && padding.getValue() != "VALID") return failure();
@ -606,7 +606,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
rewriter.create<ConstOp>(fc_op.getLoc(), new_bias_type, new_bias); rewriter.create<ConstOp>(fc_op.getLoc(), new_bias_type, new_bias);
fc_op.setOperand(0, binary_op->getOperand(0)); fc_op.setOperand(0, binary_op->getOperand(0));
fc_op.setOperand(2, new_bias_op); fc_op.setOperand(2, new_bias_op);
} else if (llvm::isa<MulOp>(binary_op) || llvm::isa<DivOp>(binary_op)) { } else if (llvm::isa<MulOp, DivOp>(binary_op)) {
// The fusion of mul/div is actually applying the following // The fusion of mul/div is actually applying the following
// transformation: // transformation:
// w * (x ' c) + b => (w ' c) x + b // w * (x ' c) + b => (w ' c) x + b

View File

@ -168,8 +168,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
var_handle.resource(), var_handle.resource(),
GetOrCreateIdForVarHandle(var_handle, &next_unique_id, GetOrCreateIdForVarHandle(var_handle, &next_unique_id,
&var_handle_name_id_map)); &var_handle_name_id_map));
} else if (llvm::isa<TF::IdentityNOp>(op) || } else if (llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op)) {
llvm::isa<TF::IdentityOp>(op)) {
for (auto operand_and_result : for (auto operand_and_result :
llvm::zip(op->getOperands(), op->getResults())) { llvm::zip(op->getOperands(), op->getResults())) {
forward_input_to_output(std::get<0>(operand_and_result), forward_input_to_output(std::get<0>(operand_and_result),
@ -333,7 +332,7 @@ bool OpIsDeclaration(Operation* op,
const ResourceAliasAnalysis& alias_analysis) { const ResourceAliasAnalysis& alias_analysis) {
// TODO(yuanzx): Add other types of resources. // TODO(yuanzx): Add other types of resources.
return llvm::isa<TF::VarHandleOp>(op) || return llvm::isa<TF::VarHandleOp>(op) ||
((llvm::isa<TF::IdentityNOp>(op) || llvm::isa<TF::IdentityOp>(op)) && (llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op) &&
!FindAccessedResources(op, alias_analysis).empty()); !FindAccessedResources(op, alias_analysis).empty());
} }

View File

@ -71,7 +71,7 @@ struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
// Allow inlining into tf.island regions if the incoming region has a single // Allow inlining into tf.island regions if the incoming region has a single
// block. // block.
return llvm::isa<tf_executor::IslandOp>(dest->getParentOp()) && return llvm::isa<tf_executor::IslandOp>(dest->getParentOp()) &&
std::next(src->begin()) == src->end(); llvm::hasSingleElement(*src);
} }
}; };

View File

@ -1168,8 +1168,7 @@ void ConstOp::build(OpBuilder &builder, OperationState &result,
ShapedType type; ShapedType type;
if (auto elem_attr = value.dyn_cast<ElementsAttr>()) { if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
return ConstOp::build(builder, result, elem_attr); return ConstOp::build(builder, result, elem_attr);
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() || } else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
value.isa<IntegerAttr>()) {
// All TensorFlow types must be tensor types. In the build() method, // All TensorFlow types must be tensor types. In the build() method,
// we want to provide more flexibility by allowing attributes of scalar // we want to provide more flexibility by allowing attributes of scalar
// types. But we need to wrap it up with ElementsAttr to construct // types. But we need to wrap it up with ElementsAttr to construct

View File

@ -356,7 +356,7 @@ LogicalResult VerifyExportedFunc(FuncOp func) {
LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute( LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute(
Operation *op, NamedAttribute named_attr) { Operation *op, NamedAttribute named_attr) {
if (named_attr.first == "tf_saved_model.exported_names") { if (named_attr.first == "tf_saved_model.exported_names") {
if (!isa<FuncOp>(op) && !isa<GlobalTensorOp>(op)) { if (!isa<FuncOp, GlobalTensorOp>(op)) {
return op->emitError() << "'tf_saved_model.exported_names' must be on a " return op->emitError() << "'tf_saved_model.exported_names' must be on a "
"'func' or 'tf_saved_model.global_tensor' op"; "'func' or 'tf_saved_model.global_tensor' op";
} }

View File

@ -90,8 +90,7 @@ class TensorFlowType : public Type {
// Returns true if the specified type is a valid TensorFlow element type. // Returns true if the specified type is a valid TensorFlow element type.
static inline bool IsValidTFElementType(Type type) { static inline bool IsValidTFElementType(Type type) {
return type.isa<ComplexType>() || type.isa<FloatType>() || return type.isa<ComplexType, FloatType, IntegerType, TensorFlowType>();
type.isa<IntegerType>() || type.isa<TensorFlowType>();
} }
// Returns true if this is a valid TensorFlow tensor type. // Returns true if this is a valid TensorFlow tensor type.

View File

@ -375,7 +375,7 @@ LogicalResult FindResourceArgUseInfo(
info.data_type = assign.value().getType(); info.data_type = assign.value().getType();
continue; continue;
} }
if (isa<TF::StackPushV2Op>(user) || isa<TF::StackPopV2Op>(user)) { if (isa<TF::StackPushV2Op, TF::StackPopV2Op>(user)) {
// Stacks will be handled by a separate pass. // Stacks will be handled by a separate pass.
do_not_touch = true; do_not_touch = true;
break; break;

View File

@ -205,9 +205,9 @@ GetSubtypes(Type type) {
// Returns whether type can be further refined. // Returns whether type can be further refined.
bool CanBeRefined(Type type) { bool CanBeRefined(Type type) {
auto shape_type = type.dyn_cast<ShapedType>(); auto shape_type = type.dyn_cast<ShapedType>();
return shape_type && (!shape_type.hasStaticShape() || return shape_type &&
shape_type.getElementType().isa<TF::ResourceType>() || (!shape_type.hasStaticShape() ||
shape_type.getElementType().isa<TF::VariantType>()); shape_type.getElementType().isa<TF::ResourceType, TF::VariantType>());
} }
// Infers the shape from a (Stateful)PartionedCall operation by looking up the // Infers the shape from a (Stateful)PartionedCall operation by looking up the
@ -712,8 +712,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
// The shape function of these ops sometimes does not propagate subtypes // The shape function of these ops sometimes does not propagate subtypes
// (handle shapes) for resource and variant types. We use a simple passthrough // (handle shapes) for resource and variant types. We use a simple passthrough
// to make sure they are preserved in the output. // to make sure they are preserved in the output.
if (isa<TF::IdentityOp>(op) || isa<TF::IdentityNOp>(op) || if (isa<TF::IdentityOp, TF::IdentityNOp, TF::ZerosLikeOp, TF::WhileOp>(op)) {
isa<TF::ZerosLikeOp>(op) || isa<TF::WhileOp>(op)) {
return RefineTypeForPassThroughOperands(op, op->getOperands(), return RefineTypeForPassThroughOperands(op, op->getOperands(),
op->getResults()); op->getResults());
} }
@ -729,7 +728,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
// Handle call operations by looking up callee and infering return shape as // Handle call operations by looking up callee and infering return shape as
// needed. // needed.
if (isa<PartitionedCallOp>(op) || isa<StatefulPartitionedCallOp>(op)) if (isa<PartitionedCallOp, StatefulPartitionedCallOp>(op))
return InferShapeForCall(op); return InferShapeForCall(op);
// tf.Cast are only inferred if they have at least one user in the TF dialect // tf.Cast are only inferred if they have at least one user in the TF dialect
@ -889,8 +888,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
}; };
auto new_element_type = shaped_type.getElementType(); auto new_element_type = shaped_type.getElementType();
// Populate the handle shapes for a resource/variant. // Populate the handle shapes for a resource/variant.
if (new_element_type.isa<TF::ResourceType>() || if (new_element_type.isa<TF::ResourceType, TF::VariantType>()) {
new_element_type.isa<TF::VariantType>()) {
auto handle_shapes_types = c.output_handle_shapes_and_types(output); auto handle_shapes_types = c.output_handle_shapes_and_types(output);
if (handle_shapes_types) { if (handle_shapes_types) {
SmallVector<TensorType, 1> subtypes; SmallVector<TensorType, 1> subtypes;

View File

@ -488,7 +488,7 @@ LogicalResult DecomposeStackOpsInternal(
llvm::StringMap<PartitionedCallStackOpsInfo>* llvm::StringMap<PartitionedCallStackOpsInfo>*
decomposed_partitioned_call_callees) { decomposed_partitioned_call_callees) {
for (auto& op : llvm::make_early_inc_range(block->getOperations())) { for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) { if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
// Removes identity nodes in the block. The device computation does not // Removes identity nodes in the block. The device computation does not
// need such nodes to carry information. // need such nodes to carry information.
op.replaceAllUsesWith(op.getOperands()); op.replaceAllUsesWith(op.getOperands());

View File

@ -809,7 +809,7 @@ LogicalResult DecomposeTensorArrayOps(
llvm::StringMap<PartitionedCallTensorArrayOpsInfo>* llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
decomposed_partitioned_call_callees) { decomposed_partitioned_call_callees) {
for (auto& op : llvm::make_early_inc_range(block->getOperations())) { for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) { if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
op.replaceAllUsesWith(op.getOperands()); op.replaceAllUsesWith(op.getOperands());
op.erase(); op.erase();
} else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) { } else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) {

View File

@ -495,8 +495,7 @@ void TPUClusterFormation::runOnFunction() {
// Remove TPUReplicatedInput and TPUReplicatedOutput nodes. // Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
auto remove_result = getFunction().walk([&](Operation* op) { auto remove_result = getFunction().walk([&](Operation* op) {
if (!llvm::isa<TF::TPUReplicatedInputOp>(op) && if (!llvm::isa<TF::TPUReplicatedInputOp, TF::TPUReplicatedOutputOp>(op))
!llvm::isa<TF::TPUReplicatedOutputOp>(op))
return WalkResult::advance(); return WalkResult::advance();
// Forward operand to result. When `num_replicas` attribute is 1, no // Forward operand to result. When `num_replicas` attribute is 1, no

View File

@ -576,9 +576,8 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
// Adds nodes for operations. // Adds nodes for operations.
for (Operation& inst : graph_op.GetBody()) { for (Operation& inst : graph_op.GetBody()) {
for (auto type : inst.getResultTypes()) for (auto type : inst.getResultTypes())
if (!type.isa<mlir::TensorType>() && if (!type.isa<mlir::TensorType, mlir::tf_executor::ControlType,
!type.isa<mlir::tf_executor::ControlType>() && mlir::tf_executor::TokenType>())
!type.isa<mlir::tf_executor::TokenType>())
return errors::InvalidArgument( return errors::InvalidArgument(
"Values must be of tensor type, TensorFlow control type, or " "Values must be of tensor type, TensorFlow control type, or "
"TensorFlow token type. Found ", "TensorFlow token type. Found ",