[MLIR][NFC] Adopt variadic isa<> on MLIR Types and Attributes
- Also adopt variadic llvm::isa<> in more places PiperOrigin-RevId: 320206113 Change-Id: Ia03a1503f699fb6be6dff02e90b6630d6d894b19
This commit is contained in:
parent
4e1aa305a1
commit
e2f8269f17
@ -443,8 +443,7 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
|
|||||||
if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
|
if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
|
||||||
TF_ASSIGN_OR_RETURN(value,
|
TF_ASSIGN_OR_RETURN(value,
|
||||||
ConvertFloatBuffer(shaped_type, float_type, buffer));
|
ConvertFloatBuffer(shaped_type, float_type, buffer));
|
||||||
} else if (elem_type.isa<mlir::IntegerType>() ||
|
} else if (elem_type.isa<mlir::IntegerType, QuantizedType>()) {
|
||||||
elem_type.isa<QuantizedType>()) {
|
|
||||||
TF_ASSIGN_OR_RETURN(value,
|
TF_ASSIGN_OR_RETURN(value,
|
||||||
ConvertIntBuffer(shaped_type, elem_type, buffer));
|
ConvertIntBuffer(shaped_type, elem_type, buffer));
|
||||||
} else if (elem_type.isa<mlir::TF::StringType>()) {
|
} else if (elem_type.isa<mlir::TF::StringType>()) {
|
||||||
@ -456,8 +455,7 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
|
|||||||
refs.push_back({ref.data(), ref.size()});
|
refs.push_back({ref.data(), ref.size()});
|
||||||
|
|
||||||
value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
|
value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
|
||||||
} else if (elem_type.isa<mlir::ComplexType>() ||
|
} else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
|
||||||
elem_type.isa<mlir::TF::TensorFlowType>()) {
|
|
||||||
auto dialect = elem_type.getContext()->getRegisteredDialect("tf");
|
auto dialect = elem_type.getContext()->getRegisteredDialect("tf");
|
||||||
tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
|
tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
|
||||||
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
|
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
|
||||||
|
@ -694,8 +694,7 @@ void QuantizationDriver::SetupAllStates() {
|
|||||||
fn_.walk([&](Operation *op) {
|
fn_.walk([&](Operation *op) {
|
||||||
if (op->isKnownTerminator() ||
|
if (op->isKnownTerminator() ||
|
||||||
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
|
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
|
||||||
llvm::isa<quant::DequantizeCastOp>(op) ||
|
llvm::isa<quant::DequantizeCastOp, quant::QuantizeCastOp>(op))
|
||||||
llvm::isa<quant::QuantizeCastOp>(op))
|
|
||||||
return;
|
return;
|
||||||
work_list_.push_back(op);
|
work_list_.push_back(op);
|
||||||
|
|
||||||
|
@ -386,8 +386,7 @@ struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
|
|||||||
|
|
||||||
Operation* def = pre_quantized.getDefiningOp();
|
Operation* def = pre_quantized.getDefiningOp();
|
||||||
if (!def) return failure();
|
if (!def) return failure();
|
||||||
if (llvm::isa<FixedOutputRangeInterface>(def) ||
|
if (llvm::isa<FixedOutputRangeInterface, SameScalesOpInterface>(def) ||
|
||||||
llvm::isa<SameScalesOpInterface>(def) ||
|
|
||||||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
|
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
@ -560,7 +560,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
|||||||
return failure();
|
return failure();
|
||||||
ShapedType filter_type = filter_cst.getType();
|
ShapedType filter_type = filter_cst.getType();
|
||||||
|
|
||||||
if (llvm::isa<AddOp>(binary_op) || llvm::isa<SubOp>(binary_op)) {
|
if (llvm::isa<AddOp, SubOp>(binary_op)) {
|
||||||
auto padding = fc_op.template getAttrOfType<StringAttr>("padding");
|
auto padding = fc_op.template getAttrOfType<StringAttr>("padding");
|
||||||
if (padding && padding.getValue() != "VALID") return failure();
|
if (padding && padding.getValue() != "VALID") return failure();
|
||||||
|
|
||||||
@ -606,7 +606,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
|||||||
rewriter.create<ConstOp>(fc_op.getLoc(), new_bias_type, new_bias);
|
rewriter.create<ConstOp>(fc_op.getLoc(), new_bias_type, new_bias);
|
||||||
fc_op.setOperand(0, binary_op->getOperand(0));
|
fc_op.setOperand(0, binary_op->getOperand(0));
|
||||||
fc_op.setOperand(2, new_bias_op);
|
fc_op.setOperand(2, new_bias_op);
|
||||||
} else if (llvm::isa<MulOp>(binary_op) || llvm::isa<DivOp>(binary_op)) {
|
} else if (llvm::isa<MulOp, DivOp>(binary_op)) {
|
||||||
// The fusion of mul/div is actually applying the following
|
// The fusion of mul/div is actually applying the following
|
||||||
// transformation:
|
// transformation:
|
||||||
// w * (x ' c) + b => (w ' c) x + b
|
// w * (x ' c) + b => (w ' c) x + b
|
||||||
|
@ -168,8 +168,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
|
|||||||
var_handle.resource(),
|
var_handle.resource(),
|
||||||
GetOrCreateIdForVarHandle(var_handle, &next_unique_id,
|
GetOrCreateIdForVarHandle(var_handle, &next_unique_id,
|
||||||
&var_handle_name_id_map));
|
&var_handle_name_id_map));
|
||||||
} else if (llvm::isa<TF::IdentityNOp>(op) ||
|
} else if (llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op)) {
|
||||||
llvm::isa<TF::IdentityOp>(op)) {
|
|
||||||
for (auto operand_and_result :
|
for (auto operand_and_result :
|
||||||
llvm::zip(op->getOperands(), op->getResults())) {
|
llvm::zip(op->getOperands(), op->getResults())) {
|
||||||
forward_input_to_output(std::get<0>(operand_and_result),
|
forward_input_to_output(std::get<0>(operand_and_result),
|
||||||
@ -333,7 +332,7 @@ bool OpIsDeclaration(Operation* op,
|
|||||||
const ResourceAliasAnalysis& alias_analysis) {
|
const ResourceAliasAnalysis& alias_analysis) {
|
||||||
// TODO(yuanzx): Add other types of resources.
|
// TODO(yuanzx): Add other types of resources.
|
||||||
return llvm::isa<TF::VarHandleOp>(op) ||
|
return llvm::isa<TF::VarHandleOp>(op) ||
|
||||||
((llvm::isa<TF::IdentityNOp>(op) || llvm::isa<TF::IdentityOp>(op)) &&
|
(llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op) &&
|
||||||
!FindAccessedResources(op, alias_analysis).empty());
|
!FindAccessedResources(op, alias_analysis).empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
|
|||||||
// Allow inlining into tf.island regions if the incoming region has a single
|
// Allow inlining into tf.island regions if the incoming region has a single
|
||||||
// block.
|
// block.
|
||||||
return llvm::isa<tf_executor::IslandOp>(dest->getParentOp()) &&
|
return llvm::isa<tf_executor::IslandOp>(dest->getParentOp()) &&
|
||||||
std::next(src->begin()) == src->end();
|
llvm::hasSingleElement(*src);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1168,8 +1168,7 @@ void ConstOp::build(OpBuilder &builder, OperationState &result,
|
|||||||
ShapedType type;
|
ShapedType type;
|
||||||
if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
|
if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
|
||||||
return ConstOp::build(builder, result, elem_attr);
|
return ConstOp::build(builder, result, elem_attr);
|
||||||
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
|
} else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
|
||||||
value.isa<IntegerAttr>()) {
|
|
||||||
// All TensorFlow types must be tensor types. In the build() method,
|
// All TensorFlow types must be tensor types. In the build() method,
|
||||||
// we want to provide more flexibility by allowing attributes of scalar
|
// we want to provide more flexibility by allowing attributes of scalar
|
||||||
// types. But we need to wrap it up with ElementsAttr to construct
|
// types. But we need to wrap it up with ElementsAttr to construct
|
||||||
|
@ -356,7 +356,7 @@ LogicalResult VerifyExportedFunc(FuncOp func) {
|
|||||||
LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute(
|
LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute(
|
||||||
Operation *op, NamedAttribute named_attr) {
|
Operation *op, NamedAttribute named_attr) {
|
||||||
if (named_attr.first == "tf_saved_model.exported_names") {
|
if (named_attr.first == "tf_saved_model.exported_names") {
|
||||||
if (!isa<FuncOp>(op) && !isa<GlobalTensorOp>(op)) {
|
if (!isa<FuncOp, GlobalTensorOp>(op)) {
|
||||||
return op->emitError() << "'tf_saved_model.exported_names' must be on a "
|
return op->emitError() << "'tf_saved_model.exported_names' must be on a "
|
||||||
"'func' or 'tf_saved_model.global_tensor' op";
|
"'func' or 'tf_saved_model.global_tensor' op";
|
||||||
}
|
}
|
||||||
|
@ -90,8 +90,7 @@ class TensorFlowType : public Type {
|
|||||||
|
|
||||||
// Returns true if the specified type is a valid TensorFlow element type.
|
// Returns true if the specified type is a valid TensorFlow element type.
|
||||||
static inline bool IsValidTFElementType(Type type) {
|
static inline bool IsValidTFElementType(Type type) {
|
||||||
return type.isa<ComplexType>() || type.isa<FloatType>() ||
|
return type.isa<ComplexType, FloatType, IntegerType, TensorFlowType>();
|
||||||
type.isa<IntegerType>() || type.isa<TensorFlowType>();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if this is a valid TensorFlow tensor type.
|
// Returns true if this is a valid TensorFlow tensor type.
|
||||||
|
@ -375,7 +375,7 @@ LogicalResult FindResourceArgUseInfo(
|
|||||||
info.data_type = assign.value().getType();
|
info.data_type = assign.value().getType();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (isa<TF::StackPushV2Op>(user) || isa<TF::StackPopV2Op>(user)) {
|
if (isa<TF::StackPushV2Op, TF::StackPopV2Op>(user)) {
|
||||||
// Stacks will be handled by a separate pass.
|
// Stacks will be handled by a separate pass.
|
||||||
do_not_touch = true;
|
do_not_touch = true;
|
||||||
break;
|
break;
|
||||||
|
@ -205,9 +205,9 @@ GetSubtypes(Type type) {
|
|||||||
// Returns whether type can be further refined.
|
// Returns whether type can be further refined.
|
||||||
bool CanBeRefined(Type type) {
|
bool CanBeRefined(Type type) {
|
||||||
auto shape_type = type.dyn_cast<ShapedType>();
|
auto shape_type = type.dyn_cast<ShapedType>();
|
||||||
return shape_type && (!shape_type.hasStaticShape() ||
|
return shape_type &&
|
||||||
shape_type.getElementType().isa<TF::ResourceType>() ||
|
(!shape_type.hasStaticShape() ||
|
||||||
shape_type.getElementType().isa<TF::VariantType>());
|
shape_type.getElementType().isa<TF::ResourceType, TF::VariantType>());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Infers the shape from a (Stateful)PartionedCall operation by looking up the
|
// Infers the shape from a (Stateful)PartionedCall operation by looking up the
|
||||||
@ -712,8 +712,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
|||||||
// The shape function of these ops sometimes does not propagate subtypes
|
// The shape function of these ops sometimes does not propagate subtypes
|
||||||
// (handle shapes) for resource and variant types. We use a simple passthrough
|
// (handle shapes) for resource and variant types. We use a simple passthrough
|
||||||
// to make sure they are preserved in the output.
|
// to make sure they are preserved in the output.
|
||||||
if (isa<TF::IdentityOp>(op) || isa<TF::IdentityNOp>(op) ||
|
if (isa<TF::IdentityOp, TF::IdentityNOp, TF::ZerosLikeOp, TF::WhileOp>(op)) {
|
||||||
isa<TF::ZerosLikeOp>(op) || isa<TF::WhileOp>(op)) {
|
|
||||||
return RefineTypeForPassThroughOperands(op, op->getOperands(),
|
return RefineTypeForPassThroughOperands(op, op->getOperands(),
|
||||||
op->getResults());
|
op->getResults());
|
||||||
}
|
}
|
||||||
@ -729,7 +728,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
|||||||
|
|
||||||
// Handle call operations by looking up callee and infering return shape as
|
// Handle call operations by looking up callee and infering return shape as
|
||||||
// needed.
|
// needed.
|
||||||
if (isa<PartitionedCallOp>(op) || isa<StatefulPartitionedCallOp>(op))
|
if (isa<PartitionedCallOp, StatefulPartitionedCallOp>(op))
|
||||||
return InferShapeForCall(op);
|
return InferShapeForCall(op);
|
||||||
|
|
||||||
// tf.Cast are only inferred if they have at least one user in the TF dialect
|
// tf.Cast are only inferred if they have at least one user in the TF dialect
|
||||||
@ -889,8 +888,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
|||||||
};
|
};
|
||||||
auto new_element_type = shaped_type.getElementType();
|
auto new_element_type = shaped_type.getElementType();
|
||||||
// Populate the handle shapes for a resource/variant.
|
// Populate the handle shapes for a resource/variant.
|
||||||
if (new_element_type.isa<TF::ResourceType>() ||
|
if (new_element_type.isa<TF::ResourceType, TF::VariantType>()) {
|
||||||
new_element_type.isa<TF::VariantType>()) {
|
|
||||||
auto handle_shapes_types = c.output_handle_shapes_and_types(output);
|
auto handle_shapes_types = c.output_handle_shapes_and_types(output);
|
||||||
if (handle_shapes_types) {
|
if (handle_shapes_types) {
|
||||||
SmallVector<TensorType, 1> subtypes;
|
SmallVector<TensorType, 1> subtypes;
|
||||||
|
@ -488,7 +488,7 @@ LogicalResult DecomposeStackOpsInternal(
|
|||||||
llvm::StringMap<PartitionedCallStackOpsInfo>*
|
llvm::StringMap<PartitionedCallStackOpsInfo>*
|
||||||
decomposed_partitioned_call_callees) {
|
decomposed_partitioned_call_callees) {
|
||||||
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
|
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
|
||||||
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
|
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
|
||||||
// Removes identity nodes in the block. The device computation does not
|
// Removes identity nodes in the block. The device computation does not
|
||||||
// need such nodes to carry information.
|
// need such nodes to carry information.
|
||||||
op.replaceAllUsesWith(op.getOperands());
|
op.replaceAllUsesWith(op.getOperands());
|
||||||
|
@ -809,7 +809,7 @@ LogicalResult DecomposeTensorArrayOps(
|
|||||||
llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
|
llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
|
||||||
decomposed_partitioned_call_callees) {
|
decomposed_partitioned_call_callees) {
|
||||||
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
|
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
|
||||||
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
|
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
|
||||||
op.replaceAllUsesWith(op.getOperands());
|
op.replaceAllUsesWith(op.getOperands());
|
||||||
op.erase();
|
op.erase();
|
||||||
} else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) {
|
} else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) {
|
||||||
|
@ -495,8 +495,7 @@ void TPUClusterFormation::runOnFunction() {
|
|||||||
|
|
||||||
// Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
|
// Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
|
||||||
auto remove_result = getFunction().walk([&](Operation* op) {
|
auto remove_result = getFunction().walk([&](Operation* op) {
|
||||||
if (!llvm::isa<TF::TPUReplicatedInputOp>(op) &&
|
if (!llvm::isa<TF::TPUReplicatedInputOp, TF::TPUReplicatedOutputOp>(op))
|
||||||
!llvm::isa<TF::TPUReplicatedOutputOp>(op))
|
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
|
|
||||||
// Forward operand to result. When `num_replicas` attribute is 1, no
|
// Forward operand to result. When `num_replicas` attribute is 1, no
|
||||||
|
@ -576,9 +576,8 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
|||||||
// Adds nodes for operations.
|
// Adds nodes for operations.
|
||||||
for (Operation& inst : graph_op.GetBody()) {
|
for (Operation& inst : graph_op.GetBody()) {
|
||||||
for (auto type : inst.getResultTypes())
|
for (auto type : inst.getResultTypes())
|
||||||
if (!type.isa<mlir::TensorType>() &&
|
if (!type.isa<mlir::TensorType, mlir::tf_executor::ControlType,
|
||||||
!type.isa<mlir::tf_executor::ControlType>() &&
|
mlir::tf_executor::TokenType>())
|
||||||
!type.isa<mlir::tf_executor::TokenType>())
|
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Values must be of tensor type, TensorFlow control type, or "
|
"Values must be of tensor type, TensorFlow control type, or "
|
||||||
"TensorFlow token type. Found ",
|
"TensorFlow token type. Found ",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user