NFC: Remove usages of Value::operator* and Value::operator-> now that Value is properly value-typed.

These were temporary methods used to simplify the transition.

PiperOrigin-RevId: 287902391
Change-Id: I249fc049e4f1b762a23939b8e8b79110103df4aa
This commit is contained in:
River Riddle 2020-01-02 15:10:44 -08:00 committed by TensorFlower Gardener
parent b4fd6a5963
commit 043abbdf86
89 changed files with 758 additions and 772 deletions

View File

@ -217,7 +217,7 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
// min/max stats is just for comments, so ignore it. // min/max stats is just for comments, so ignore it.
if (!tensor.quantization || IsQuantized(tensor)) return nullptr; if (!tensor.quantization || IsQuantized(tensor)) return nullptr;
// If the result isn't float and unquantizable, the min/max is ignored. // If the result isn't float and unquantizable, the min/max is ignored.
if (!res->getType() if (!res.getType()
.cast<mlir::ShapedType>() .cast<mlir::ShapedType>()
.getElementType() .getElementType()
.isa<mlir::FloatType>()) { .isa<mlir::FloatType>()) {

View File

@ -233,17 +233,17 @@ static bool IsConst(Operation* op) {
template <typename T> template <typename T>
static bool HasValidTFLiteType(Value value, T& error_handler) { static bool HasValidTFLiteType(Value value, T& error_handler) {
// None type is allowed to represent unspecified operands. // None type is allowed to represent unspecified operands.
if (value->getType().isa<NoneType>()) return true; if (value.getType().isa<NoneType>()) return true;
auto type = value->getType().dyn_cast<TensorType>(); auto type = value.getType().dyn_cast<TensorType>();
if (!type) { if (!type) {
if (auto op = value->getDefiningOp()) { if (auto op = value.getDefiningOp()) {
error_handler.emitError() error_handler.emitError()
<< '\'' << op << "' should produce value of tensor type instead of " << '\'' << op << "' should produce value of tensor type instead of "
<< value->getType(); << value.getType();
return false; return false;
} }
error_handler.emitError("expected tensor type, got ") << value->getType(); error_handler.emitError("expected tensor type, got ") << value.getType();
return false; return false;
} }
@ -282,7 +282,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
for (auto arg : bb.getArguments()) { for (auto arg : bb.getArguments()) {
if (!HasValidTFLiteType(arg, fn)) if (!HasValidTFLiteType(arg, fn))
return fn.emitError("invalid TFLite type: ") << arg->getType(), false; return fn.emitError("invalid TFLite type: ") << arg.getType(), false;
} }
// Verify that all operations except the terminator have exactly one // Verify that all operations except the terminator have exactly one
@ -292,7 +292,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
for (auto result : inst.getResults()) { for (auto result : inst.getResults()) {
if (!HasValidTFLiteType(result, inst)) if (!HasValidTFLiteType(result, inst))
return fn.emitError("invalid TFLite type: ") << result->getType(), return fn.emitError("invalid TFLite type: ") << result.getType(),
false; false;
} }
} }
@ -504,7 +504,7 @@ Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor( Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
Value value, const std::string& name, unsigned buffer_idx) { Value value, const std::string& name, unsigned buffer_idx) {
auto type = value->getType().cast<TensorType>(); auto type = value.getType().cast<TensorType>();
// TFLite requires tensor shape only for the inputs and constants. // TFLite requires tensor shape only for the inputs and constants.
// However, we output all known shapes for better round-tripping // However, we output all known shapes for better round-tripping
@ -516,7 +516,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range)) if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range))
return mlir::emitError( return mlir::emitError(
value->getLoc(), value.getLoc(),
"result shape dimensions out of 32 bit int type range"); "result shape dimensions out of 32 bit int type range");
return mlir::success(); return mlir::success();
@ -528,7 +528,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
if (mlir::failed(check_shape(shape_ref))) return llvm::None; if (mlir::failed(check_shape(shape_ref))) return llvm::None;
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end()); shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
} else if (auto* inst = value->getDefiningOp()) { } else if (auto* inst = value.getDefiningOp()) {
if (IsConst(inst)) { if (IsConst(inst)) {
// Const op can have a result of dynamic shaped type (e.g. due to constant // Const op can have a result of dynamic shaped type (e.g. due to constant
// folding), but we can still derive the shape of a constant tensor for // folding), but we can still derive the shape of a constant tensor for
@ -571,7 +571,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
// marked as a stateful. If so, set the tensor's is_variable as true // marked as a stateful. If so, set the tensor's is_variable as true
// This is v1 ref variable semantics in the TFLite runtime. // This is v1 ref variable semantics in the TFLite runtime.
bool is_variable = false; bool is_variable = false;
for (auto& use : value->getUses()) { for (auto& use : value.getUses()) {
is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber()); is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber());
if (is_variable) { if (is_variable) {
break; break;
@ -923,7 +923,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
// on failure. // on failure.
auto build_tensor_and_buffer = [&](Value value, const std::string& name) { auto build_tensor_and_buffer = [&](Value value, const std::string& name) {
// NoneType represents optional and may be skipped here. // NoneType represents optional and may be skipped here.
if (value->getType().isa<NoneType>()) { if (value.getType().isa<NoneType>()) {
return true; return true;
} }
@ -936,7 +936,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
// make the Buffer empty apart from setting the buffer_idx=0 in the Tensor. // make the Buffer empty apart from setting the buffer_idx=0 in the Tensor.
// This does not seem to affect runtime behavior for RNN/LSTM, but would be // This does not seem to affect runtime behavior for RNN/LSTM, but would be
// good for reducing memory footprint. // good for reducing memory footprint.
if (auto* inst = value->getDefiningOp()) { if (auto* inst = value.getDefiningOp()) {
auto buffer_or = BuildBuffer(inst); auto buffer_or = BuildBuffer(inst);
if (!buffer_or) return false; if (!buffer_or) return false;
buffers_.push_back(*buffer_or); buffers_.push_back(*buffer_or);
@ -976,7 +976,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
std::vector<int32_t> operands; std::vector<int32_t> operands;
operands.reserve(inst.getNumOperands()); operands.reserve(inst.getNumOperands());
for (auto operand : inst.getOperands()) { for (auto operand : inst.getOperands()) {
if (operand->getType().isa<NoneType>()) if (operand.getType().isa<NoneType>())
operands.push_back(kTfLiteOptionalTensor); operands.push_back(kTfLiteOptionalTensor);
else else
operands.push_back(tensor_index_map.lookup(operand)); operands.push_back(tensor_index_map.lookup(operand));

View File

@ -304,11 +304,11 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs, void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs,
Value rhs) { Value rhs) {
auto result_type = auto result_type =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType()); OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
if (!result_type) if (!result_type)
emitError(result.location) emitError(result.location)
<< "non-broadcastable operands: " << lhs->getType() << " and " << "non-broadcastable operands: " << lhs.getType() << " and "
<< rhs->getType(); << rhs.getType();
result.addOperands({lhs, rhs}); result.addOperands({lhs, rhs});
// Comparison binary ops always return i1 tensor. // Comparison binary ops always return i1 tensor.
if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) { if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) {
@ -324,12 +324,12 @@ void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result,
Value lhs, Value rhs, Value lhs, Value rhs,
StringAttr fused_activation_function) { StringAttr fused_activation_function) {
auto result_type = auto result_type =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType()); OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
if (!result_type) if (!result_type)
emitError(result.location) emitError(result.location)
<< "non-broadcastable operands: " << lhs->getType() << " and " << "non-broadcastable operands: " << lhs.getType() << " and "
<< rhs->getType(); << rhs.getType();
result.addOperands({lhs, rhs}); result.addOperands({lhs, rhs});
result.addAttribute("fused_activation_function", fused_activation_function); result.addAttribute("fused_activation_function", fused_activation_function);
@ -358,7 +358,7 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
namespace { namespace {
int64_t GetConcatenationOpAxis(ConcatenationOp op) { int64_t GetConcatenationOpAxis(ConcatenationOp op) {
auto output_type = op.output()->getType().cast<RankedTensorType>(); auto output_type = op.output().getType().cast<RankedTensorType>();
int64_t axis = op.axis().getSExtValue(); int64_t axis = op.axis().getSExtValue();
if (axis < 0) axis += output_type.getRank(); if (axis < 0) axis += output_type.getRank();
return axis; return axis;
@ -452,7 +452,7 @@ LogicalResult VerifyConcatenationOpTypes(Operation *op,
} }
LogicalResult Verify(ConcatenationOp op) { LogicalResult Verify(ConcatenationOp op) {
auto output_type = op.output()->getType().dyn_cast<RankedTensorType>(); auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
// If the output type is unranked, there is nothing else to be verified. // If the output type is unranked, there is nothing else to be verified.
if (!output_type) return success(); if (!output_type) return success();
@ -463,7 +463,7 @@ LogicalResult Verify(ConcatenationOp op) {
SmallVector<TensorType, 4> operand_types; SmallVector<TensorType, 4> operand_types;
for (Value operand : op.values()) for (Value operand : op.values())
operand_types.push_back(operand->getType().cast<TensorType>()); operand_types.push_back(operand.getType().cast<TensorType>());
return VerifyConcatenationOpTypes(op.getOperation(), output_type, return VerifyConcatenationOpTypes(op.getOperation(), output_type,
operand_types, axis); operand_types, axis);
@ -520,7 +520,7 @@ DenseElementsAttr ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,
OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
if (fused_activation_function() == "NONE") { if (fused_activation_function() == "NONE") {
if (auto output_type = output()->getType().dyn_cast<RankedTensorType>()) { if (auto output_type = output().getType().dyn_cast<RankedTensorType>()) {
const int64_t axis = GetConcatenationOpAxis(*this); const int64_t axis = GetConcatenationOpAxis(*this);
if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis)) if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis))
return ConstFoldConcatenateOpDense(operands, output_type, axis); return ConstFoldConcatenateOpDense(operands, output_type, axis);
@ -530,7 +530,7 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
// Remove all empty values. // Remove all empty values.
SmallVector<Value, 4> non_empty_values; SmallVector<Value, 4> non_empty_values;
for (Value value : this->values()) { for (Value value : this->values()) {
const auto shaped_type = value->getType().cast<ShapedType>(); const auto shaped_type = value.getType().cast<ShapedType>();
if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) { if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) {
continue; continue;
} }
@ -559,8 +559,8 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult Verify(FullyConnectedOp op) { LogicalResult Verify(FullyConnectedOp op) {
ShapedType input_type = op.input()->getType().cast<ShapedType>(); ShapedType input_type = op.input().getType().cast<ShapedType>();
ShapedType filter_type = op.filter()->getType().cast<ShapedType>(); ShapedType filter_type = op.filter().getType().cast<ShapedType>();
if (filter_type.hasRank() && filter_type.getRank() != 2) { if (filter_type.hasRank() && filter_type.getRank() != 2) {
return op.emitOpError("expect 2d filter, got ") << filter_type; return op.emitOpError("expect 2d filter, got ") << filter_type;
} }
@ -582,7 +582,7 @@ LogicalResult Verify(FullyConnectedOp op) {
// format. // format.
if (op.weights_format() == "DEFAULT") { if (op.weights_format() == "DEFAULT") {
ShapedType output_type = ShapedType output_type =
(*op.output().begin())->getType().cast<ShapedType>(); (*op.output().begin()).getType().cast<ShapedType>();
if (!output_type.hasStaticShape()) { if (!output_type.hasStaticShape()) {
return mlir::success(); return mlir::success();
} }
@ -610,8 +610,8 @@ LogicalResult Verify(FullyConnectedOp op) {
static void BuildGatherOp(Builder *builder, OperationState &result, static void BuildGatherOp(Builder *builder, OperationState &result,
Value params, Value indices, IntegerAttr axis) { Value params, Value indices, IntegerAttr axis) {
auto params_type = params->getType().cast<TensorType>(); auto params_type = params.getType().cast<TensorType>();
auto indices_type = indices->getType().cast<TensorType>(); auto indices_type = indices.getType().cast<TensorType>();
// If params/indices is unranked, then output is unranked. // If params/indices is unranked, then output is unranked.
if (!params_type.hasRank() || !indices_type.hasRank()) if (!params_type.hasRank() || !indices_type.hasRank())
@ -705,7 +705,7 @@ static LogicalResult Verify(PackOp op) {
return op.emitOpError("input count should match 'values_count' attribute"); return op.emitOpError("input count should match 'values_count' attribute");
Value operand0 = op.getOperand(0); Value operand0 = op.getOperand(0);
auto input_type = operand0->getType().cast<ShapedType>(); auto input_type = operand0.getType().cast<ShapedType>();
// Check axis bounds. // Check axis bounds.
if (input_type.hasRank()) { if (input_type.hasRank()) {
@ -718,7 +718,7 @@ static LogicalResult Verify(PackOp op) {
// Make sure all inputs have the same shape and element type. // Make sure all inputs have the same shape and element type.
// TODO(rahulsp): Simplify once b/135032064 is fixed. // TODO(rahulsp): Simplify once b/135032064 is fixed.
for (Value operand : op.getOperands()) { for (Value operand : op.getOperands()) {
auto other_type = operand->getType().cast<ShapedType>(); auto other_type = operand.getType().cast<ShapedType>();
if (input_type != other_type) if (input_type != other_type)
return op.emitOpError("operands should be of the same type. got ") return op.emitOpError("operands should be of the same type. got ")
<< input_type << ", " << other_type; << input_type << ", " << other_type;
@ -732,9 +732,9 @@ static LogicalResult Verify(PackOp op) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult Verify(PReluOp op) { static LogicalResult Verify(PReluOp op) {
auto input_type = op.input()->getType().cast<ShapedType>(); auto input_type = op.input().getType().cast<ShapedType>();
auto alpha_type = op.alpha()->getType().cast<ShapedType>(); auto alpha_type = op.alpha().getType().cast<ShapedType>();
auto output_type = op.output()->getType().cast<ShapedType>(); auto output_type = op.output().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) { if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) {
if (input_type.getRank() != alpha_type.getRank() + 1) { if (input_type.getRank() != alpha_type.getRank() + 1) {
@ -783,13 +783,13 @@ struct RemoveAdjacentReshape : public RewritePattern {
PatternMatchResult match(Operation *op) const override { PatternMatchResult match(Operation *op) const override {
auto thisOp = cast<ReshapeOp>(op); auto thisOp = cast<ReshapeOp>(op);
auto prevOp = thisOp.getOperand(0)->getDefiningOp(); auto prevOp = thisOp.getOperand(0).getDefiningOp();
return isa_and_nonnull<ReshapeOp>(prevOp) ? matchSuccess() : matchFailure(); return isa_and_nonnull<ReshapeOp>(prevOp) ? matchSuccess() : matchFailure();
} }
void rewrite(Operation *op, PatternRewriter &rewriter) const override { void rewrite(Operation *op, PatternRewriter &rewriter) const override {
auto thisOp = cast<ReshapeOp>(op); auto thisOp = cast<ReshapeOp>(op);
auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0)->getDefiningOp()); auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0).getDefiningOp());
// Replace // Replace
// %1 = "tfl.reshape"(%0, %shape0) // %1 = "tfl.reshape"(%0, %shape0)
@ -807,7 +807,7 @@ struct RemoveAdjacentReshape : public RewritePattern {
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
// Remove identity reshape with both static result and input shape. // Remove identity reshape with both static result and input shape.
auto result_type = getType().cast<ShapedType>(); auto result_type = getType().cast<ShapedType>();
auto input_type = getOperand(0)->getType().cast<ShapedType>(); auto input_type = getOperand(0).getType().cast<ShapedType>();
if (result_type.hasStaticShape() && result_type == input_type) { if (result_type.hasStaticShape() && result_type == input_type) {
return getOperand(0); return getOperand(0);
} }
@ -865,7 +865,7 @@ struct RemoveRedundantUnpackPack : public RewritePattern {
PatternMatchResult matchAndRewrite(Operation *op, PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
TFL::PackOp pack_op = cast<TFL::PackOp>(op); TFL::PackOp pack_op = cast<TFL::PackOp>(op);
Operation *first_input = pack_op.getOperand(0)->getDefiningOp(); Operation *first_input = pack_op.getOperand(0).getDefiningOp();
if (!first_input) return matchFailure(); if (!first_input) return matchFailure();
auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input); auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input);
if (!input_unpack_op) return matchFailure(); if (!input_unpack_op) return matchFailure();
@ -905,9 +905,9 @@ void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult Verify(SliceOp op) { static LogicalResult Verify(SliceOp op) {
auto input_type = op.input()->getType().cast<ShapedType>(); auto input_type = op.input().getType().cast<ShapedType>();
auto begin_type = op.begin()->getType().cast<ShapedType>(); auto begin_type = op.begin().getType().cast<ShapedType>();
auto size_type = op.size()->getType().cast<ShapedType>(); auto size_type = op.size().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && begin_type.hasStaticShape() && if (input_type.hasStaticShape() && begin_type.hasStaticShape() &&
size_type.hasStaticShape()) { size_type.hasStaticShape()) {
if (input_type.getRank() != begin_type.getNumElements()) { if (input_type.getRank() != begin_type.getNumElements()) {
@ -995,7 +995,7 @@ static void BuildTopKOp(Builder *builder, OperationState &result, Value input,
// TODO(jpienaar): This should use a helper function. // TODO(jpienaar): This should use a helper function.
const_k = cst.getValue<IntegerAttr>({}).getValue().getSExtValue(); const_k = cst.getValue<IntegerAttr>({}).getValue().getSExtValue();
auto val_type = input->getType().cast<TensorType>(); auto val_type = input.getType().cast<TensorType>();
// If value is unranked, then so is results. // If value is unranked, then so is results.
if (!val_type.hasRank()) if (!val_type.hasRank())
return TFL::TopKV2Op::build( return TFL::TopKV2Op::build(
@ -1035,7 +1035,7 @@ struct DropFakeQuant : public RewritePattern {
// If all the users of this op have valid "minmax" attributes, it is matched // If all the users of this op have valid "minmax" attributes, it is matched
// and can be removed. // and can be removed.
auto fakeQuantOp = cast<FakeQuantOp>(op); auto fakeQuantOp = cast<FakeQuantOp>(op);
for (auto *operand : fakeQuantOp.getResult()->getUsers()) for (auto *operand : fakeQuantOp.getResult().getUsers())
if (!HasValidMinMaxAttribute(operand)) return matchFailure(); if (!HasValidMinMaxAttribute(operand)) return matchFailure();
return matchSuccess(); return matchSuccess();
@ -1102,7 +1102,7 @@ static LogicalResult VerifySplitOpOutputTypes(
for (int64_t i = 0; i < num_splits; ++i) { for (int64_t i = 0; i < num_splits; ++i) {
auto expected_output_type = get_expected_output_type(i); auto expected_output_type = get_expected_output_type(i);
Value output = op->getResult(i); Value output = op->getResult(i);
auto output_type = output->getType().dyn_cast<RankedTensorType>(); auto output_type = output.getType().dyn_cast<RankedTensorType>();
if (!output_type || output_type != expected_output_type) if (!output_type || output_type != expected_output_type)
return op->emitOpError() return op->emitOpError()
<< "output #" << i << " should be " << expected_output_type; << "output #" << i << " should be " << expected_output_type;
@ -1121,7 +1121,7 @@ static LogicalResult Verify(SplitOp op) {
if (!split_dim_opt) return success(); if (!split_dim_opt) return success();
// If 'input' is not a ranked tensor, there are no other checks. // If 'input' is not a ranked tensor, there are no other checks.
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>(); auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
if (!input_type) return success(); if (!input_type) return success();
int64_t split_dim = split_dim_opt.getValue(); int64_t split_dim = split_dim_opt.getValue();
@ -1157,7 +1157,7 @@ static LogicalResult Verify(SplitVOp op) {
if (!split_dim_opt) return success(); if (!split_dim_opt) return success();
// If 'input' is not a ranked tensor, there are no other checks. // If 'input' is not a ranked tensor, there are no other checks.
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>(); auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
if (!input_type) return success(); if (!input_type) return success();
int64_t split_dim = split_dim_opt.getValue(); int64_t split_dim = split_dim_opt.getValue();
@ -1177,8 +1177,7 @@ static LogicalResult Verify(SplitVOp op) {
return success(); return success();
if (size_splits_attr.getNumElements() != num_splits) { if (size_splits_attr.getNumElements() != num_splits) {
auto size_splits_type = auto size_splits_type = op.size_splits().getType().cast<RankedTensorType>();
op.size_splits()->getType().cast<RankedTensorType>();
RankedTensorType expected_size_splits_type = RankedTensorType expected_size_splits_type =
RankedTensorType::get({num_splits}, size_splits_type.getElementType()); RankedTensorType::get({num_splits}, size_splits_type.getElementType());
return op.emitOpError("'size_splits' should be ") return op.emitOpError("'size_splits' should be ")
@ -1414,7 +1413,7 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
} }
// Also fold if `input` has a known rank. // Also fold if `input` has a known rank.
auto input_type = input()->getType().cast<ShapedType>(); auto input_type = input().getType().cast<ShapedType>();
// Do not fold if rank is zero because the TFLite converter doesn't // Do not fold if rank is zero because the TFLite converter doesn't
// distinguish between unranked input and scalar input due to b/138865275. // distinguish between unranked input and scalar input due to b/138865275.
// TODO(b/138865275): Remove `input_type.getRank() != 0` in the following // TODO(b/138865275): Remove `input_type.getRank() != 0` in the following
@ -1445,18 +1444,18 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
static void BuildSelectV2Op(Builder *builder, OperationState &result, static void BuildSelectV2Op(Builder *builder, OperationState &result,
Value cond, Value x, Value y) { Value cond, Value x, Value y) {
auto operand_type = auto operand_type =
OpTrait::util::getBroadcastedType(x->getType(), y->getType()); OpTrait::util::getBroadcastedType(x.getType(), y.getType());
if (!operand_type) if (!operand_type)
emitError(result.location) << "non-broadcastable operands: " << x->getType() emitError(result.location) << "non-broadcastable operands: " << x.getType()
<< " and " << y->getType(); << " and " << y.getType();
bool has_static_cond_shape = false; bool has_static_cond_shape = false;
bool has_static_operand_shape = false; bool has_static_operand_shape = false;
ArrayRef<int64_t> cond_shape; ArrayRef<int64_t> cond_shape;
ArrayRef<int64_t> operand_shape; ArrayRef<int64_t> operand_shape;
if (auto shaped_type = cond->getType().dyn_cast<ShapedType>()) { if (auto shaped_type = cond.getType().dyn_cast<ShapedType>()) {
if (shaped_type.hasStaticShape()) { if (shaped_type.hasStaticShape()) {
has_static_cond_shape = true; has_static_cond_shape = true;
cond_shape = shaped_type.getShape(); cond_shape = shaped_type.getShape();
@ -1474,12 +1473,12 @@ static void BuildSelectV2Op(Builder *builder, OperationState &result,
!OpTrait::util::getBroadcastedShape(cond_shape, operand_shape, !OpTrait::util::getBroadcastedShape(cond_shape, operand_shape,
broadcastedShape)) { broadcastedShape)) {
emitError(result.location) << "non-broadcastable operands: " << operand_type emitError(result.location) << "non-broadcastable operands: " << operand_type
<< " and " << cond->getType(); << " and " << cond.getType();
} }
result.addOperands({cond, x, y}); result.addOperands({cond, x, y});
auto elementType = x->getType().dyn_cast<ShapedType>().getElementType(); auto elementType = x.getType().dyn_cast<ShapedType>().getElementType();
if (has_static_cond_shape && has_static_operand_shape) { if (has_static_cond_shape && has_static_operand_shape) {
result.types.push_back( result.types.push_back(
RankedTensorType::get(broadcastedShape, elementType)); RankedTensorType::get(broadcastedShape, elementType));
@ -1571,9 +1570,8 @@ OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult Verify(TransposeConvOp op) { static LogicalResult Verify(TransposeConvOp op) {
ShapedType output_type = op.output()->getType().cast<ShapedType>(); ShapedType output_type = op.output().getType().cast<ShapedType>();
ShapedType output_shape_type = ShapedType output_shape_type = op.output_shape().getType().cast<ShapedType>();
op.output_shape()->getType().cast<ShapedType>();
if (output_type.hasRank() && output_shape_type.hasStaticShape()) { if (output_type.hasRank() && output_shape_type.hasStaticShape()) {
if (output_type.getRank() != output_shape_type.getDimSize(0)) { if (output_type.getRank() != output_shape_type.getDimSize(0)) {
return op.emitOpError(llvm::formatv( return op.emitOpError(llvm::formatv(
@ -1679,9 +1677,9 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
} }
static LogicalResult Verify(TransposeOp op) { static LogicalResult Verify(TransposeOp op) {
auto input_type = op.x()->getType().cast<ShapedType>(); auto input_type = op.x().getType().cast<ShapedType>();
auto perm_type = op.perm()->getType().cast<ShapedType>(); auto perm_type = op.perm().getType().cast<ShapedType>();
auto output_type = op.y()->getType().cast<ShapedType>(); auto output_type = op.y().getType().cast<ShapedType>();
if (input_type.hasStaticShape() && perm_type.hasStaticShape()) { if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
if (perm_type.getNumElements() != input_type.getRank()) { if (perm_type.getNumElements() != input_type.getRank()) {
return op.emitOpError( return op.emitOpError(

View File

@ -135,7 +135,7 @@ def TFL_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TFL_Int32Or64]>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class TFL_OperandIsUnrankedPred<int n> : class TFL_OperandIsUnrankedPred<int n> :
CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">; CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
// TODO: Some of these could be generalized and/or moved to more general // TODO: Some of these could be generalized and/or moved to more general
// location. // location.
@ -144,38 +144,38 @@ class TFL_OperandHasRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D", PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[TFL_OperandIsUnrankedPred<n>, Or<[TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n # CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() == " # m>]>>; ").getType().cast<ShapedType>().getRank() == " # m>]>>;
// Returns true if the n-th operand is ranked and has rank dim. // Returns true if the n-th operand is ranked and has rank dim.
class TFL_OperandHasKnownRank<int n, int dim> : And<[ class TFL_OperandHasKnownRank<int n, int dim> : And<[
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">, CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() == " CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() == "
# dim>]>; # dim>]>;
// True if operand n is ranked and has a rank > dim. // True if operand n is ranked and has a rank > dim.
class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[ class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">, CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() > " CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() > "
# dim>]>; # dim>]>;
class TFL_OperandDimEquals<int n, int dim, int size> : And<[ class TFL_OperandDimEquals<int n, int dim, int size> : And<[
TFL_OperandIsRankedAndHasDimPred<n, dim>, TFL_OperandIsRankedAndHasDimPred<n, dim>,
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>()" CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>()"
".getShape()[" # dim # " ] == " # size>]>; ".getShape()[" # dim # " ] == " # size>]>;
// Returns true if the n-th operand has unknown rank or at least rank m. // Returns true if the n-th operand has unknown rank or at least rank m.
class TFL_OperandHasAtleastRank<int n, int m> : class TFL_OperandHasAtleastRank<int n, int m> :
PredOpTrait<"operand " # n # " is " # m # "-D", PredOpTrait<"operand " # n # " is " # m # "-D",
Or<[CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">, Or<[CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">,
CPred<"$_op.getOperand(" # n # CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() >= " # m>]>>; ").getType().cast<ShapedType>().getRank() >= " # m>]>>;
class TFL_OperandRankEquals1DimOfOperand<int x, int y> : class TFL_OperandRankEquals1DimOfOperand<int x, int y> :
PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size", PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size",
CPred<"$_op.getOperand(" # x # CPred<"$_op.getOperand(" # x #
")->getType().cast<ShapedType>().getRank() == " ").getType().cast<ShapedType>().getRank() == "
"$_op.getOperand(" # y # "$_op.getOperand(" # y #
")->getType().cast<ShapedType>().getShape()[0]">>; ").getType().cast<ShapedType>().getShape()[0]">>;
class TFL_Operand0DOr1ElementTensor<int x> : class TFL_Operand0DOr1ElementTensor<int x> :
PredOpTrait<"operand #" # x # " is an 0-d tensor or 1-d tensor w/ 1 element", PredOpTrait<"operand #" # x # " is an 0-d tensor or 1-d tensor w/ 1 element",
@ -195,7 +195,7 @@ class TFL_OperandHasRankLessThan<int n, int m> :
PredOpTrait<"operand " # n # " is maximum " # m # "-D", PredOpTrait<"operand " # n # " is maximum " # m # "-D",
Or<[TFL_OperandIsUnrankedPred<n>, Or<[TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n # CPred<"$_op.getOperand(" # n #
")->getType().cast<ShapedType>().getRank() <= " # m>]>>; ").getType().cast<ShapedType>().getRank() <= " # m>]>>;
// This is a quantization-aware version of TCresVTEtIsSameAsOp // This is a quantization-aware version of TCresVTEtIsSameAsOp
class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[ class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
@ -227,7 +227,7 @@ def TFL_BroadcastableBinaryBuilder : OpBuilder<
"Builder *builder, OperationState &result, Value lhs, Value rhs", "Builder *builder, OperationState &result, Value lhs, Value rhs",
[{ [{
auto resultType = auto resultType =
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType()); OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
if (!resultType) if (!resultType)
mlir::emitError(result.location, "non-broadcastable operands"); mlir::emitError(result.location, "non-broadcastable operands");
result.addOperands({lhs, rhs}); result.addOperands({lhs, rhs});
@ -471,7 +471,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
let hasOptions = 1; let hasOptions = 1;
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{ DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType(). return getResult().getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 : cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32; tflite::TensorType_INT32;
}]>; }]>;
@ -500,7 +500,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
let hasOptions = 1; let hasOptions = 1;
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{ DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType(). return getResult().getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 : cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32; tflite::TensorType_INT32;
}]>; }]>;
@ -1996,7 +1996,7 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> {
let results = (outs AnyTensor:$output); let results = (outs AnyTensor:$output);
DerivedTypeAttr out_type = DerivedTypeAttr<[{ DerivedTypeAttr out_type = DerivedTypeAttr<[{
return getResult()->getType().cast<TensorType>().getElementType(); return getResult().getType().cast<TensorType>().getElementType();
}]>; }]>;
let hasOptions = 1; let hasOptions = 1;
@ -2083,7 +2083,7 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
let builders = [OpBuilder<"Builder *builder, OperationState &result, " let builders = [OpBuilder<"Builder *builder, OperationState &result, "
"Value condition, Value x, Value y", "Value condition, Value x, Value y",
[{ [{
auto resultType = x->getType(); auto resultType = x.getType();
result.addOperands({condition, x, y}); result.addOperands({condition, x, y});
result.types.push_back(resultType); result.types.push_back(resultType);
}]>]; }]>];
@ -2733,7 +2733,7 @@ in the unique output `y`. In other words:
); );
DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{ DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{
return getResult(1)->getType().cast<TensorType>().getElementType(). return getResult(1).getType().cast<TensorType>().getElementType().
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 : cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
tflite::TensorType_INT32; tflite::TensorType_INT32;
}]>; }]>;

View File

@ -78,8 +78,8 @@ class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
bool IsQuantizableResult(Operation *op, int index) { bool IsQuantizableResult(Operation *op, int index) {
if (index < 0 || index >= op->getNumResults()) return false; if (index < 0 || index >= op->getNumResults()) return false;
Value res = op->getResult(index); Value res = op->getResult(index);
return res->getType().isa<ShapedType>() && return res.getType().isa<ShapedType>() &&
res->getType().cast<ShapedType>().getElementType().isa<FloatType>(); res.getType().cast<ShapedType>().getElementType().isa<FloatType>();
} }
// A method to retrieve the name for the given op. // A method to retrieve the name for the given op.
@ -123,7 +123,7 @@ void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value res,
IntegerAttr axis) { IntegerAttr axis) {
auto stats_op = b.create<quant::StatisticsOp>(b.getUnknownLoc(), res, auto stats_op = b.create<quant::StatisticsOp>(b.getUnknownLoc(), res,
layer_stats, axis_stats, axis); layer_stats, axis_stats, axis);
res->replaceAllUsesWith(stats_op); res.replaceAllUsesWith(stats_op);
stats_op.getOperation()->replaceUsesOfWith(stats_op, res); stats_op.getOperation()->replaceUsesOfWith(stats_op, res);
} }

View File

@ -146,14 +146,14 @@ class QuantizationDriver {
// Adds all the users of index-th result of op to the work list. // Adds all the users of index-th result of op to the work list.
void AddUserToList(Operation *op, int index) { void AddUserToList(Operation *op, int index) {
for (auto *user : op->getResult(index)->getUsers()) { for (auto *user : op->getResult(index).getUsers()) {
work_list_.push_back(user); work_list_.push_back(user);
} }
} }
// Adds the defining op of index-th operand of op to the work list. // Adds the defining op of index-th operand of op to the work list.
void AddOperandToList(Operation *op, int index) { void AddOperandToList(Operation *op, int index) {
if (auto *inst = op->getOperand(index)->getDefiningOp()) { if (auto *inst = op->getOperand(index).getDefiningOp()) {
work_list_.push_back(inst); work_list_.push_back(inst);
} }
} }
@ -248,7 +248,7 @@ class QuantizationDriver {
return; return;
} }
QuantParams params = QuantParams params =
quant::QuantizedType::getQuantizedElementType(in->getType()); quant::QuantizedType::getQuantizedElementType(in.getType());
bool immutable = !EmptyParams(params); bool immutable = !EmptyParams(params);
int next_state_index = states_.size(); int next_state_index = states_.size();
states_.push_back({params, immutable}); states_.push_back({params, immutable});
@ -338,7 +338,7 @@ bool QuantizationDriver::IsQuantized(Operation *op) {
int QuantizationDriver::InitializeState(Operation *op, int index, Value val, int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
bool as_result) { bool as_result) {
QuantParams params = QuantParams params =
quant::QuantizedType::getQuantizedElementType(val->getType()); quant::QuantizedType::getQuantizedElementType(val.getType());
bool immutable = !EmptyParams(params); bool immutable = !EmptyParams(params);
int next_state_index = states_.size(); int next_state_index = states_.size();
states_.push_back({params, immutable}); states_.push_back({params, immutable});
@ -447,13 +447,13 @@ void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
} }
void QuantizationDriver::QuantizeArg(BlockArgument arg, QuantParams params) { void QuantizationDriver::QuantizeArg(BlockArgument arg, QuantParams params) {
builder_.setInsertionPointToStart(arg->getOwner()); builder_.setInsertionPointToStart(arg.getOwner());
QuantizeValue(arg, params, builder_.getUnknownLoc()); QuantizeValue(arg, params, builder_.getUnknownLoc());
} }
void QuantizationDriver::QuantizeValue(Value value, QuantParams params, void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
Location loc) { Location loc) {
Type expressed_type = value->getType(); Type expressed_type = value.getType();
Type new_type = params.castFromExpressedType(expressed_type); Type new_type = params.castFromExpressedType(expressed_type);
// This value isn't an expressed type (float), skip. // This value isn't an expressed type (float), skip.
if (!new_type) return; if (!new_type) return;
@ -465,7 +465,7 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
quantize.output()); quantize.output());
// `original_result` has a use to `quantize`, so this will replace that use // `original_result` has a use to `quantize`, so this will replace that use
// by the result of `dequantize`. Remember to reset that use afterwards // by the result of `dequantize`. Remember to reset that use afterwards
value->replaceAllUsesWith(dequantize); value.replaceAllUsesWith(dequantize);
quantize.getOperation()->replaceUsesOfWith(dequantize, value); quantize.getOperation()->replaceUsesOfWith(dequantize, value);
} }
@ -475,7 +475,7 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
builder_.setInsertionPointAfter(op); builder_.setInsertionPointAfter(op);
Value value = op->getResult(index); Value value = op->getResult(index);
if (state->pos == RequantizeState::ON_OUTPUT) { if (state->pos == RequantizeState::ON_OUTPUT) {
Operation *user = value->getUses().begin().getUser(); Operation *user = value.getUses().begin().getUser();
if (llvm::isa<TFL::QuantizeOp>(user)) { if (llvm::isa<TFL::QuantizeOp>(user)) {
// The requantize op is inserted between `quantize` and `dequantize` ops. // The requantize op is inserted between `quantize` and `dequantize` ops.
value = user->getResult(0); value = user->getResult(0);
@ -488,12 +488,12 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
void QuantizationDriver::RequantizeArg(BlockArgument arg, void QuantizationDriver::RequantizeArg(BlockArgument arg,
RequantizeState *state) { RequantizeState *state) {
Value value = arg; Value value = arg;
builder_.setInsertionPointToStart(arg->getOwner()); builder_.setInsertionPointToStart(arg.getOwner());
if (value->hasOneUse()) { if (value.hasOneUse()) {
auto user = value->use_begin().getUser(); auto user = value.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) { if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
value = q.output(); value = q.output();
builder_.setInsertionPoint(arg->getOwner(), ++Block::iterator(user)); builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user));
} }
} }
RequantizeValue(value, state, builder_.getUnknownLoc()); RequantizeValue(value, state, builder_.getUnknownLoc());
@ -503,13 +503,13 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
Location loc) { Location loc) {
Type new_type; Type new_type;
if (state->pos == RequantizeState::ON_INPUT) { if (state->pos == RequantizeState::ON_INPUT) {
Type expressed_type = value->getType(); Type expressed_type = value.getType();
// The value needs to be requantized. A Quantize op will be created to use // The value needs to be requantized. A Quantize op will be created to use
// it as the operand and replace its uses. // it as the operand and replace its uses.
new_type = state->params.castFromExpressedType(expressed_type); new_type = state->params.castFromExpressedType(expressed_type);
} else { } else {
Type expressed_type = Type expressed_type =
quant::QuantizedType::castToExpressedType(value->getType()); quant::QuantizedType::castToExpressedType(value.getType());
if (!expressed_type) return; if (!expressed_type) return;
// The value needs to be requantized. A Quantize op will be created to use // The value needs to be requantized. A Quantize op will be created to use
@ -522,7 +522,7 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
TypeAttr type_attr = TypeAttr::get(new_type); TypeAttr type_attr = TypeAttr::get(new_type);
auto requantize_op = auto requantize_op =
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr); builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
value->replaceAllUsesWith(requantize_op); value.replaceAllUsesWith(requantize_op);
requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value); requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
} }
@ -603,7 +603,7 @@ void QuantizationDriver::PreprocessConstantOps() {
Value value = cst.getResult(); Value value = cst.getResult();
SmallVector<std::pair<Operation *, int>, 4> bias_users; SmallVector<std::pair<Operation *, int>, 4> bias_users;
bool used_as_weight = false; bool used_as_weight = false;
for (auto &use : value->getUses()) { for (auto &use : value.getUses()) {
auto spec = GetQuantSpec(use.getOwner()); auto spec = GetQuantSpec(use.getOwner());
auto biases = spec->biases_params; auto biases = spec->biases_params;
Operation *user = use.getOwner(); Operation *user = use.getOwner();
@ -649,8 +649,8 @@ void QuantizationDriver::SetupAllStates() {
args_.push_back(arg); args_.push_back(arg);
Value value = arg; Value value = arg;
// If the argument is quantized, it should only has one user. // If the argument is quantized, it should only has one user.
if (arg->hasOneUse()) { if (arg.hasOneUse()) {
auto user = value->use_begin().getUser(); auto user = value.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) { if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
value = q.output(); value = q.output();
} }
@ -666,7 +666,7 @@ void QuantizationDriver::SetupAllStates() {
for (int i = 0, e = op->getNumOperands(); i != e; ++i) { for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
auto operand = op->getOperand(i); auto operand = op->getOperand(i);
if (auto *inst = operand->getDefiningOp()) { if (auto *inst = operand.getDefiningOp()) {
// If the operand comes from a tfl.dequantize op, we use the quantized // If the operand comes from a tfl.dequantize op, we use the quantized
// input of this tfl.dequantize op to set the state. // input of this tfl.dequantize op to set the state.
if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) { if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
@ -681,8 +681,8 @@ void QuantizationDriver::SetupAllStates() {
// If the result has been quantized, it should only be used by a // If the result has been quantized, it should only be used by a
// tfl.quantize op. For this case, we uses the quantized result to // tfl.quantize op. For this case, we uses the quantized result to
// create the state and mark it immutable. // create the state and mark it immutable.
if (result->hasOneUse()) { if (result.hasOneUse()) {
auto user = result->use_begin().getUser(); auto user = result.use_begin().getUser();
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) { if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
result = q.output(); result = q.output();
} }

View File

@ -70,7 +70,7 @@ class FixedResultUniformScale {
QuantizedType GetResultQuantizedType(int index) { QuantizedType GetResultQuantizedType(int index) {
auto op = this->getOperation(); auto op = this->getOperation();
auto result_type = auto result_type =
op->getResult(index)->getType().template cast<TensorType>(); op->getResult(index).getType().template cast<TensorType>();
Builder builder(op->getContext()); Builder builder(op->getContext());
IntegerType storage_type = builder.getIntegerType(BitWidth); IntegerType storage_type = builder.getIntegerType(BitWidth);
const double scale = static_cast<double>(ScaleMantissa) * const double scale = static_cast<double>(ScaleMantissa) *

View File

@ -367,7 +367,7 @@ ElementsAttr Quantize(Attribute real_value, Type tensor_type) {
static bool PreferResultScale(Operation* op) { static bool PreferResultScale(Operation* op) {
int float_operands = 0; int float_operands = 0;
for (auto operand : op->getOperands()) { for (auto operand : op->getOperands()) {
if (auto operand_type = operand->getType().dyn_cast<ShapedType>()) { if (auto operand_type = operand.getType().dyn_cast<ShapedType>()) {
if (operand_type.getElementType().isa<FloatType>()) { if (operand_type.getElementType().isa<FloatType>()) {
if (float_operands++ > 1) return true; if (float_operands++ > 1) return true;
} }
@ -400,22 +400,22 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
quant::StatisticsOp stats_op = all_stats_ops.back(); quant::StatisticsOp stats_op = all_stats_ops.back();
all_stats_ops.pop_back(); all_stats_ops.pop_back();
if (auto def = stats_op.arg()->getDefiningOp()) { if (auto def = stats_op.arg().getDefiningOp()) {
if (IsStatsRedundant(def, op_quant_spec_getter)) { if (IsStatsRedundant(def, op_quant_spec_getter)) {
redundant_stats_ops.insert(stats_op); redundant_stats_ops.insert(stats_op);
} }
} }
for (auto user : stats_op.getResult()->getUsers()) { for (auto user : stats_op.getResult().getUsers()) {
// We don't propagate this parameter down if it has multiple operands. // We don't propagate this parameter down if it has multiple operands.
// We want to use the result parameter scales instead. // We want to use the result parameter scales instead.
if (user->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() && if (user->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
!PreferResultScale(user)) { !PreferResultScale(user)) {
for (Value res : user->getResults()) { for (Value res : user->getResults()) {
if (res->hasOneUse()) { if (res.hasOneUse()) {
if (auto next_stats = llvm::dyn_cast<quant::StatisticsOp>( if (auto next_stats = llvm::dyn_cast<quant::StatisticsOp>(
*res->getUsers().begin())) { *res.getUsers().begin())) {
// quantization parameters can be propagated to next_stats // quantization parameters can be propagated to next_stats
redundant_stats_ops.insert(next_stats); redundant_stats_ops.insert(next_stats);
// add next_stats to the work list so propagation can // add next_stats to the work list so propagation can
@ -440,12 +440,12 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
quant::StatisticsOp stats_op = all_stats_ops.back(); quant::StatisticsOp stats_op = all_stats_ops.back();
all_stats_ops.pop_back(); all_stats_ops.pop_back();
if (auto def = stats_op.arg()->getDefiningOp()) { if (auto def = stats_op.arg().getDefiningOp()) {
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() && if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
PreferResultScale(def)) { PreferResultScale(def)) {
for (auto input : def->getOperands()) { for (auto input : def->getOperands()) {
if (auto next_stats = llvm::dyn_cast_or_null<quant::StatisticsOp>( if (auto next_stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(
input->getDefiningOp())) { input.getDefiningOp())) {
redundant_stats_ops.insert(next_stats); redundant_stats_ops.insert(next_stats);
all_stats_ops.push_back(next_stats); all_stats_ops.push_back(next_stats);
} }
@ -458,7 +458,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
for (auto it : redundant_stats_ops) { for (auto it : redundant_stats_ops) {
if (!llvm::isa<quant::StatisticsOp>(it)) return true; if (!llvm::isa<quant::StatisticsOp>(it)) return true;
auto stats_op = llvm::cast<quant::StatisticsOp>(it); auto stats_op = llvm::cast<quant::StatisticsOp>(it);
stats_op.getResult()->replaceAllUsesWith(stats_op.arg()); stats_op.getResult().replaceAllUsesWith(stats_op.arg());
stats_op.erase(); stats_op.erase();
} }

View File

@ -116,7 +116,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg(), auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg(),
TypeAttr::get(result_type)); TypeAttr::get(result_type));
auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q); auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
op.getResult()->replaceAllUsesWith(dq); op.getResult().replaceAllUsesWith(dq);
q.getOperation()->replaceUsesOfWith(dq, op.arg()); q.getOperation()->replaceUsesOfWith(dq, op.arg());
op.erase(); op.erase();
@ -162,7 +162,7 @@ struct QuantizationPattern : public RewritePattern {
return matchFailure(); return matchFailure();
} }
Value quantized_value = op->getResult(0); Value quantized_value = op->getResult(0);
for (Operation* quantized_op : quantized_value->getUsers()) { for (Operation* quantized_op : quantized_value.getUsers()) {
// If it is requantize op, we shouldn't rewrite this op. // If it is requantize op, we shouldn't rewrite this op.
if (llvm::isa<Q>(quantized_op) || llvm::isa<DQ>(quantized_op)) { if (llvm::isa<Q>(quantized_op) || llvm::isa<DQ>(quantized_op)) {
return matchFailure(); return matchFailure();
@ -179,14 +179,14 @@ struct QuantizationPattern : public RewritePattern {
SmallVector<Value, 4> inputs; SmallVector<Value, 4> inputs;
inputs.reserve(quantized_op->getNumOperands()); inputs.reserve(quantized_op->getNumOperands());
for (auto operand : quantized_op->getOperands()) { for (auto operand : quantized_op->getOperands()) {
Type operand_type = operand->getType(); Type operand_type = operand.getType();
if (operand_type.isa<NoneType>()) { if (operand_type.isa<NoneType>()) {
inputs.push_back(operand); inputs.push_back(operand);
continue; continue;
} }
auto ele_type = operand->getType().cast<TensorType>().getElementType(); auto ele_type = operand.getType().cast<TensorType>().getElementType();
if (auto op_inst = dyn_cast_or_null<DQ>(operand->getDefiningOp())) { if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
inputs.push_back(op_inst.input()); inputs.push_back(op_inst.input());
} else if (ele_type.isa<IntegerType>()) { } else if (ele_type.isa<IntegerType>()) {
// If the operand is an integer tensor, then it doesn't require the // If the operand is an integer tensor, then it doesn't require the
@ -207,7 +207,7 @@ struct QuantizationPattern : public RewritePattern {
for (auto enumerated_result : for (auto enumerated_result :
llvm::enumerate(quantized_op->getResults())) { llvm::enumerate(quantized_op->getResults())) {
Value result = enumerated_result.value(); Value result = enumerated_result.value();
Type result_type = result->getType(); Type result_type = result.getType();
// Add this to the test coverage once we create test ops with none type // Add this to the test coverage once we create test ops with none type
// results. // results.
if (result_type.isa<NoneType>()) { if (result_type.isa<NoneType>()) {
@ -216,20 +216,20 @@ struct QuantizationPattern : public RewritePattern {
continue; continue;
} }
Type result_ele_type = Type result_ele_type =
result->getType().cast<TensorType>().getElementType(); result.getType().cast<TensorType>().getElementType();
// If the user is the Quantize op, it must be the only user. // If the user is the Quantize op, it must be the only user.
if (result->hasOneUse() && llvm::isa<Q>(*result->user_begin())) { if (result.hasOneUse() && llvm::isa<Q>(*result.user_begin())) {
auto user = llvm::cast<Q>(*result->user_begin()); auto user = llvm::cast<Q>(*result.user_begin());
outputs_replaced.insert({user.output(), enumerated_result.index()}); outputs_replaced.insert({user.output(), enumerated_result.index()});
output_types.push_back(user.getType()); output_types.push_back(user.getType());
} else if (result_ele_type.template isa<IntegerType>()) { } else if (result_ele_type.template isa<IntegerType>()) {
// If the result is an integer tensor, then it doesn't require the // If the result is an integer tensor, then it doesn't require the
// D op in the pattern. // D op in the pattern.
outputs_replaced.insert({result, enumerated_result.index()}); outputs_replaced.insert({result, enumerated_result.index()});
output_types.push_back(result->getType()); output_types.push_back(result.getType());
} else if (static_cast<const ConcretTy*>(this)->AllowHybridResult()) { } else if (static_cast<const ConcretTy*>(this)->AllowHybridResult()) {
outputs_replaced.insert({result, enumerated_result.index()}); outputs_replaced.insert({result, enumerated_result.index()});
output_types.push_back(result->getType()); output_types.push_back(result.getType());
} else { } else {
return matchFailure(); return matchFailure();
} }
@ -241,7 +241,7 @@ struct QuantizationPattern : public RewritePattern {
output_types, quantized_op->getAttrs()); output_types, quantized_op->getAttrs());
Operation* new_op = rewriter.createOperation(new_state); Operation* new_op = rewriter.createOperation(new_state);
for (auto output : outputs_replaced) { for (auto output : outputs_replaced) {
output.getFirst()->replaceAllUsesWith( output.getFirst().replaceAllUsesWith(
new_op->getResult(output.getSecond())); new_op->getResult(output.getSecond()));
} }
@ -252,7 +252,7 @@ struct QuantizationPattern : public RewritePattern {
// For constant operands, the floating-point constant is duplicated in // For constant operands, the floating-point constant is duplicated in
// case it is quantized. // case it is quantized.
for (int i = 0, e = new_op->getNumOperands(); i != e; ++i) { for (int i = 0, e = new_op->getNumOperands(); i != e; ++i) {
auto def = new_op->getOperand(i)->getDefiningOp(); auto def = new_op->getOperand(i).getDefiningOp();
if (auto q = llvm::dyn_cast_or_null<Q>(def)) { if (auto q = llvm::dyn_cast_or_null<Q>(def)) {
DenseFPElementsAttr attr; DenseFPElementsAttr attr;
if (!matchPattern(q.input(), m_Constant(&attr))) { if (!matchPattern(q.input(), m_Constant(&attr))) {
@ -265,7 +265,7 @@ struct QuantizationPattern : public RewritePattern {
for (int i = 0, e = new_op->getNumResults(); i != e; ++i) { for (int i = 0, e = new_op->getNumResults(); i != e; ++i) {
if (!quantized_op->getResult(i) if (!quantized_op->getResult(i)
->getType() .getType()
.cast<ShapedType>() .cast<ShapedType>()
.getElementType() .getElementType()
.isa<FloatType>()) { .isa<FloatType>()) {
@ -283,13 +283,13 @@ struct QuantizationPattern : public RewritePattern {
// Find the Dequantize/Dequantize users of the new op results, and // Find the Dequantize/Dequantize users of the new op results, and
// replace the usage. Then all the floating-point ops are connected. // replace the usage. Then all the floating-point ops are connected.
// N.B. the return op will use this floating-point result. // N.B. the return op will use this floating-point result.
for (auto user : new_op->getResult(i)->getUsers()) { for (auto user : new_op->getResult(i).getUsers()) {
// Skip the Requantize op, and we know it has a single user. // Skip the Requantize op, and we know it has a single user.
if (llvm::isa<Q>(user)) { if (llvm::isa<Q>(user)) {
user = *user->getResult(0)->getUsers().begin(); user = *user->getResult(0).getUsers().begin();
} }
if (auto dequantize = llvm::dyn_cast<DQ>(user)) { if (auto dequantize = llvm::dyn_cast<DQ>(user)) {
dequantize.getResult()->replaceAllUsesWith( dequantize.getResult().replaceAllUsesWith(
quantized_op->getResult(i)); quantized_op->getResult(i));
} }
} }
@ -316,7 +316,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
PatternMatchResult matchAndRewrite(Q op, PatternMatchResult matchAndRewrite(Q op,
PatternRewriter& rewriter) const override { PatternRewriter& rewriter) const override {
Type output_type = op.output()->getType(); Type output_type = op.output().getType();
auto qtype = QType::getQuantizedElementType(output_type); auto qtype = QType::getQuantizedElementType(output_type);
if (!qtype || qtype.isSigned()) return this->matchFailure(); if (!qtype || qtype.isSigned()) return this->matchFailure();

View File

@ -103,7 +103,7 @@ static int PrintFunctionResultMapping(const std::string &result,
i = 0; i = 0;
for (auto output : *subgraph->outputs()) { for (auto output : *subgraph->outputs()) {
print_buffer(*subgraph, i, output, [&](int i) { print_buffer(*subgraph, i, output, [&](int i) {
return terminator ? terminator->getOperand(i)->getLoc() : unknown_loc; return terminator ? terminator->getOperand(i).getLoc() : unknown_loc;
}); });
} }
} }

View File

@ -205,7 +205,7 @@ struct OphintCompositeOp {
Operation* current_identity_op = operand.ops.begin()->second; Operation* current_identity_op = operand.ops.begin()->second;
Value input = current_identity_op->getOperand(0); Value input = current_identity_op->getOperand(0);
RankedTensorType input_type = RankedTensorType input_type =
input->getType().cast<RankedTensorType>(); input.getType().cast<RankedTensorType>();
// The Reshape will be {1, (original_shape)} // The Reshape will be {1, (original_shape)}
SmallVector<int64_t, 4> reshape_op_shape; SmallVector<int64_t, 4> reshape_op_shape;
reshape_op_shape.push_back(1); reshape_op_shape.push_back(1);
@ -242,13 +242,13 @@ struct OphintCompositeOp {
} }
// Find the first op that consumes the last value of the aggregated // Find the first op that consumes the last value of the aggregated
// inputs. // inputs.
Operation* first_use = *(packed_input_consumers.back()->user_begin()); Operation* first_use = *(packed_input_consumers.back().user_begin());
// The pack reshape will be {N, (original_shape)} // The pack reshape will be {N, (original_shape)}
SmallVector<int64_t, 4> pack_shape; SmallVector<int64_t, 4> pack_shape;
pack_shape.push_back(pack_input_operands.size()); pack_shape.push_back(pack_input_operands.size());
RankedTensorType type = operand.ops.at(0) RankedTensorType type = operand.ops.at(0)
->getResult(0) ->getResult(0)
->getType() .getType()
.cast<RankedTensorType>(); .cast<RankedTensorType>();
for (const auto& dim : type.getShape()) { for (const auto& dim : type.getShape()) {
pack_shape.push_back(dim); pack_shape.push_back(dim);
@ -290,7 +290,7 @@ struct OphintCompositeOp {
const int output_numer = operand.ops.size(); const int output_numer = operand.ops.size();
Value first_output = operand.ops.at(0)->getOperand(0); Value first_output = operand.ops.at(0)->getOperand(0);
RankedTensorType first_output_type = RankedTensorType first_output_type =
first_output->getType().cast<RankedTensorType>(); first_output.getType().cast<RankedTensorType>();
// The aggregated output shape will be {N, original_shape}. // The aggregated output shape will be {N, original_shape}.
SmallVector<int64_t, 4> shape; SmallVector<int64_t, 4> shape;
shape.push_back(output_numer); shape.push_back(output_numer);
@ -302,10 +302,10 @@ struct OphintCompositeOp {
} else if (operand.aggregation == kStrategyLast) { } else if (operand.aggregation == kStrategyLast) {
Value last_output = Value last_output =
operand.ops.at(operand.ops.size() - 1)->getOperand(0); operand.ops.at(operand.ops.size() - 1)->getOperand(0);
aggregated_output_types[kv.first] = last_output->getType(); aggregated_output_types[kv.first] = last_output.getType();
} else { } else {
Value first_output = operand.ops.at(0)->getOperand(0); Value first_output = operand.ops.at(0)->getOperand(0);
aggregated_output_types[kv.first] = first_output->getType(); aggregated_output_types[kv.first] = first_output.getType();
} }
} }
return aggregated_output_types; return aggregated_output_types;
@ -329,7 +329,7 @@ struct OphintCompositeOp {
Operation* first_output = operand.ops.at(0); Operation* first_output = operand.ops.at(0);
Location insert_loc = first_output->getLoc(); Location insert_loc = first_output->getLoc();
SmallVector<Type, 4> unpack_output_types( SmallVector<Type, 4> unpack_output_types(
output_number, first_output->getOperand(0)->getType()); output_number, first_output->getOperand(0).getType());
builder->setInsertionPoint(first_output); builder->setInsertionPoint(first_output);
Operation* unpack_op = builder->create<TFL::UnpackOp>( Operation* unpack_op = builder->create<TFL::UnpackOp>(
@ -404,7 +404,7 @@ void PreprocessTopoSortGraph(
// should only count as one. // should only count as one.
llvm::DenseSet<Operation*> input_ops; llvm::DenseSet<Operation*> input_ops;
for (int i = 0; i < op.getNumOperands(); ++i) { for (int i = 0; i < op.getNumOperands(); ++i) {
Operation* input_op = op.getOperand(i)->getDefiningOp(); Operation* input_op = op.getOperand(i).getDefiningOp();
if (input_op) input_ops.insert(input_op); if (input_op) input_ops.insert(input_op);
} }
if (input_ops.empty()) { if (input_ops.empty()) {
@ -515,7 +515,7 @@ Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
SmallVector<int, 4> input_indexes; SmallVector<int, 4> input_indexes;
for (const auto& kv : inputs) { for (const auto& kv : inputs) {
Value input = kv.second; Value input = kv.second;
input_types.push_back(input->getType()); input_types.push_back(input.getType());
input_values.push_back(input); input_values.push_back(input);
input_indexes.push_back(kv.first); input_indexes.push_back(kv.first);
} }
@ -589,7 +589,7 @@ llvm::DenseSet<Operation*> BfsForReachableOps(ArrayRef<Operation*> input_ops) {
std::queue<Operation*> ops_queue; std::queue<Operation*> ops_queue;
for (auto& input_op : input_ops) { for (auto& input_op : input_ops) {
for (Value value : input_op->getOperands()) { for (Value value : input_op->getOperands()) {
Operation* op = value->getDefiningOp(); Operation* op = value.getDefiningOp();
if (op != nullptr) ops_queue.push(op); if (op != nullptr) ops_queue.push(op);
} }
} }
@ -599,7 +599,7 @@ llvm::DenseSet<Operation*> BfsForReachableOps(ArrayRef<Operation*> input_ops) {
ops_queue.pop(); ops_queue.pop();
reachable_ops.insert(current_op); reachable_ops.insert(current_op);
for (Value value : current_op->getOperands()) { for (Value value : current_op->getOperands()) {
Operation* upstream_op = value->getDefiningOp(); Operation* upstream_op = value.getDefiningOp();
// Not visited, put it into the queue. // Not visited, put it into the queue.
if (upstream_op != nullptr && if (upstream_op != nullptr &&
!llvm::is_contained(reachable_ops, upstream_op)) { !llvm::is_contained(reachable_ops, upstream_op)) {
@ -642,7 +642,7 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
aggregated_inputs, aggregated_output_types, builder, module_op); aggregated_inputs, aggregated_output_types, builder, module_op);
for (const auto& kv : aggregated_inputs) { for (const auto& kv : aggregated_inputs) {
Operation* op = kv.second->getDefiningOp(); Operation* op = kv.second.getDefiningOp();
if (op == nullptr) return failure(); if (op == nullptr) return failure();
op->moveBefore(fused_op); op->moveBefore(fused_op);
} }

View File

@ -103,7 +103,7 @@ LogicalResult BuildUnidirectionalSequenceRnnOp(FuncOp composite_func_op,
Value hidden_state = call_op.getOperand(4); Value hidden_state = call_op.getOperand(4);
// Build Output. // Build Output.
auto output_type = call_op.getResult(0)->getType(); auto output_type = call_op.getResult(0).getType();
// Currently, ophinted RNN only supports time_major = True. // Currently, ophinted RNN only supports time_major = True.
const bool time_major = true; const bool time_major = true;
@ -170,11 +170,11 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op,
for (int i = 0; i < call_op.getNumResults() - 1; ++i) { for (int i = 0; i < call_op.getNumResults() - 1; ++i) {
// This one should not be used. // This one should not be used.
Value unused_output = call_op.getResult(i); Value unused_output = call_op.getResult(i);
if (!unused_output->use_empty()) return failure(); if (!unused_output.use_empty()) return failure();
} }
} }
output_types.push_back( output_types.push_back(
call_op.getResult(call_op.getNumResults() - 1)->getType()); call_op.getResult(call_op.getNumResults() - 1).getType());
// Prepare attributes. // Prepare attributes.
SmallVector<NamedAttribute, 4> attributes; SmallVector<NamedAttribute, 4> attributes;
@ -207,10 +207,10 @@ LogicalResult ConvertTfLiteFusedOpIfAvailable(StringRef func_name,
composite_func_op, call_op, builder, &fused_op); composite_func_op, call_op, builder, &fused_op);
if (failed(build_fused_op_result)) return build_fused_op_result; if (failed(build_fused_op_result)) return build_fused_op_result;
Value call_output = call_op.getResult(call_op.getNumResults() - 1); Value call_output = call_op.getResult(call_op.getNumResults() - 1);
if (call_output->getType() != fused_op->getResult(0)->getType()) { if (call_output.getType() != fused_op->getResult(0).getType()) {
return failure(); return failure();
} }
call_output->replaceAllUsesWith(fused_op->getResult(0)); call_output.replaceAllUsesWith(fused_op->getResult(0));
} else { // If we support more fused op, we should add the conversion here. } else { // If we support more fused op, we should add the conversion here.
return failure(); return failure();
} }

View File

@ -39,7 +39,7 @@ def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
// Use the tensor type information from $0 and convert min $1, max $2 and // Use the tensor type information from $0 and convert min $1, max $2 and
// numBits $3 and narrowRange $4 to a QuantizedType. // numBits $3 and narrowRange $4 to a QuantizedType.
def ConvertToQuantTypeFromAttrs : NativeCodeCall< def ConvertToQuantTypeFromAttrs : NativeCodeCall<
"GetQuantizedTypeAttr($_builder, $0->getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">; "GetQuantizedTypeAttr($_builder, $0.getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">;
// Converts an integer attribute $0 to 32-bit with builder. // Converts an integer attribute $0 to 32-bit with builder.
def convertIntAttrTo32Bit : NativeCodeCall< def convertIntAttrTo32Bit : NativeCodeCall<
@ -50,7 +50,7 @@ def ExtractSingleElementAsInteger : NativeCodeCall<
"ExtractSingleElementAsInteger($_self.cast<ElementsAttr>())">; "ExtractSingleElementAsInteger($_self.cast<ElementsAttr>())">;
// Checks whether the given operation has static shapes and same shapes of all inputs. // Checks whether the given operation has static shapes and same shapes of all inputs.
def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0->getDefiningOp())">; def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">;
def HasSameStaticShapes : Constraint<HasSameStaticShapesPred, "op must have static same input shapes">; def HasSameStaticShapes : Constraint<HasSameStaticShapesPred, "op must have static same input shapes">;
def HasNotSameStaticShapes : Constraint<Neg<HasSameStaticShapesPred>, "op must have not static same input shapes">; def HasNotSameStaticShapes : Constraint<Neg<HasSameStaticShapesPred>, "op must have not static same input shapes">;

View File

@ -72,7 +72,7 @@ bool HasSameStaticShapes(Operation* op) {
int index = 0; int index = 0;
ArrayRef<int64_t> shape; ArrayRef<int64_t> shape;
for (Value value : values) { for (Value value : values) {
auto shaped_type = value->getType().dyn_cast<ShapedType>(); auto shaped_type = value.getType().dyn_cast<ShapedType>();
if (!shaped_type && !shaped_type.hasStaticShape()) { if (!shaped_type && !shaped_type.hasStaticShape()) {
return false; return false;
} }
@ -122,7 +122,7 @@ PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
auto tf_concat_op = cast<TF::ConcatOp>(op); auto tf_concat_op = cast<TF::ConcatOp>(op);
auto values = tf_concat_op.values(); auto values = tf_concat_op.values();
auto output_type = tf_concat_op.output()->getType(); auto output_type = tf_concat_op.output().getType();
// Extract axis attribute from constant concat_dims tensor // Extract axis attribute from constant concat_dims tensor
ElementsAttr axis; ElementsAttr axis;
if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis))) if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis)))
@ -141,7 +141,7 @@ PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite(
auto tf_concat_op = cast<TF::ConcatV2Op>(op); auto tf_concat_op = cast<TF::ConcatV2Op>(op);
auto values = tf_concat_op.values(); auto values = tf_concat_op.values();
auto output_type = tf_concat_op.output()->getType(); auto output_type = tf_concat_op.output().getType();
// Extract axis attribute from constant axis tensor // Extract axis attribute from constant axis tensor
ElementsAttr axis; ElementsAttr axis;
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis)))
@ -167,7 +167,7 @@ PatternMatchResult ConvertTFMatMulOp::matchAndRewrite(
if (tf_matmul_op.transpose_a()) return matchFailure(); if (tf_matmul_op.transpose_a()) return matchFailure();
if (!tf_matmul_op.transpose_b()) return matchFailure(); if (!tf_matmul_op.transpose_b()) return matchFailure();
Type output_type = tf_matmul_op.getResult()->getType(); Type output_type = tf_matmul_op.getResult().getType();
// TODO(jpienaar): Follow up post shuffle discussion. // TODO(jpienaar): Follow up post shuffle discussion.
auto no_input = rewriter.create<ConstantOp>( auto no_input = rewriter.create<ConstantOp>(
op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
@ -184,7 +184,7 @@ PatternMatchResult ConvertTFPackOp::matchAndRewrite(
auto tf_pack_op = cast<TF::PackOp>(op); auto tf_pack_op = cast<TF::PackOp>(op);
SmallVector<Value, 4> values(tf_pack_op.values()); SmallVector<Value, 4> values(tf_pack_op.values());
auto output_type = tf_pack_op.output()->getType(); auto output_type = tf_pack_op.output().getType();
auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N()); auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N());
// Axis can be negative. // Axis can be negative.
auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis().getSExtValue()); auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis().getSExtValue());
@ -201,7 +201,7 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
auto input = tf_reshape_op.tensor(); auto input = tf_reshape_op.tensor();
auto shape = tf_reshape_op.shape(); auto shape = tf_reshape_op.shape();
ShapedType shape_type = shape->getType().cast<ShapedType>(); ShapedType shape_type = shape.getType().cast<ShapedType>();
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast. // The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
if (!shape_type.getElementType().isInteger(32)) { if (!shape_type.getElementType().isInteger(32)) {
auto new_shape = shape_type.getShape(); auto new_shape = shape_type.getShape();
@ -213,7 +213,7 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
rewriter.getBoolAttr(false)) rewriter.getBoolAttr(false))
.y(); .y();
} }
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output()->getType(), rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output().getType(),
input, shape); input, shape);
return matchSuccess(); return matchSuccess();
} }
@ -222,7 +222,7 @@ PatternMatchResult ConvertTFSplitOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const { Operation* op, PatternRewriter& rewriter) const {
auto tf_split_op = cast<TF::SplitOp>(op); auto tf_split_op = cast<TF::SplitOp>(op);
auto output_types = functional::map([](Value v) { return v->getType(); }, auto output_types = functional::map([](Value v) { return v.getType(); },
tf_split_op.output()); tf_split_op.output());
// Number of splits cannot be negative. // Number of splits cannot be negative.
auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split()); auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split());
@ -237,7 +237,7 @@ PatternMatchResult ConvertTFSplitVOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const { Operation* op, PatternRewriter& rewriter) const {
auto tf_splitv_op = cast<TF::SplitVOp>(op); auto tf_splitv_op = cast<TF::SplitVOp>(op);
auto output_types = functional::map([](Value v) { return v->getType(); }, auto output_types = functional::map([](Value v) { return v.getType(); },
tf_splitv_op.output()); tf_splitv_op.output());
// Number of splits cannot be negative. // Number of splits cannot be negative.
auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split()); auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split());
@ -254,7 +254,7 @@ Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
DenseIntElementsAttr dense_elem_attr; DenseIntElementsAttr dense_elem_attr;
SmallVector<int32_t, 8> padded_val; SmallVector<int32_t, 8> padded_val;
auto ranked_attr_type = attribute->getType().dyn_cast<RankedTensorType>(); auto ranked_attr_type = attribute.getType().dyn_cast<RankedTensorType>();
if (!ranked_attr_type || if (!ranked_attr_type ||
!matchPattern(attribute, m_Constant(&dense_elem_attr))) { !matchPattern(attribute, m_Constant(&dense_elem_attr))) {
// If the input attribute is neither ranked type nor constant, we // If the input attribute is neither ranked type nor constant, we
@ -280,14 +280,14 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const { Operation* op, PatternRewriter& rewriter) const {
auto tf_strided_slice_op = cast<TF::StridedSliceOp>(op); auto tf_strided_slice_op = cast<TF::StridedSliceOp>(op);
auto ranked_input_type = auto ranked_input_type =
tf_strided_slice_op.input()->getType().dyn_cast<RankedTensorType>(); tf_strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
if (!ranked_input_type) { if (!ranked_input_type) {
// If input is not a ranked tensor, we can't deduce the padding dimensions // If input is not a ranked tensor, we can't deduce the padding dimensions
// from it, so we just do a plain conversion here. // from it, so we just do a plain conversion here.
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>( rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
op, tf_strided_slice_op.output()->getType(), op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
tf_strided_slice_op.input(), tf_strided_slice_op.begin(), tf_strided_slice_op.begin(), tf_strided_slice_op.end(),
tf_strided_slice_op.end(), tf_strided_slice_op.strides(), tf_strided_slice_op.strides(),
rewriter.getI32IntegerAttr( rewriter.getI32IntegerAttr(
tf_strided_slice_op.begin_mask().getSExtValue()), tf_strided_slice_op.begin_mask().getSExtValue()),
rewriter.getI32IntegerAttr( rewriter.getI32IntegerAttr(
@ -318,7 +318,7 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
Value padded_strides = PadStridedSliceAttributeArray( Value padded_strides = PadStridedSliceAttributeArray(
op, rewriter, tf_strided_slice_op.strides(), strides_pad_val, nullptr); op, rewriter, tf_strided_slice_op.strides(), strides_pad_val, nullptr);
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>( rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
op, tf_strided_slice_op.output()->getType(), tf_strided_slice_op.input(), op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
padded_begin, padded_end, padded_strides, padded_begin, padded_end, padded_strides,
rewriter.getI32IntegerAttr(begin_mask), rewriter.getI32IntegerAttr(begin_mask),
rewriter.getI32IntegerAttr(end_mask), rewriter.getI32IntegerAttr(end_mask),
@ -336,7 +336,7 @@ PatternMatchResult ConvertTFUnpackOp::matchAndRewrite(
auto tf_unpack_op = cast<TF::UnpackOp>(op); auto tf_unpack_op = cast<TF::UnpackOp>(op);
auto input = tf_unpack_op.value(); auto input = tf_unpack_op.value();
auto output_types = functional::map([](Value v) { return v->getType(); }, auto output_types = functional::map([](Value v) { return v.getType(); },
tf_unpack_op.output()); tf_unpack_op.output());
auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num()); auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num());
// Axis can be negative. // Axis can be negative.
@ -360,7 +360,7 @@ bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) {
if (tf_matrix_diag_v2_or_v3_op.getNumOperands() != 5) return false; if (tf_matrix_diag_v2_or_v3_op.getNumOperands() != 5) return false;
auto input = tf_matrix_diag_v2_or_v3_op.diagonal(); auto input = tf_matrix_diag_v2_or_v3_op.diagonal();
auto output_type = tf_matrix_diag_v2_or_v3_op.output()->getType(); auto output_type = tf_matrix_diag_v2_or_v3_op.output().getType();
// Extract k constant tensor and check value = 0. // Extract k constant tensor and check value = 0.
ElementsAttr k; ElementsAttr k;
@ -500,7 +500,7 @@ PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite(
auto status_or_const_op = CreateConstOpWithSingleValue( auto status_or_const_op = CreateConstOpWithSingleValue(
&rewriter, op->getLoc(), &rewriter, op->getLoc(),
tf_reciprocal_op.x()->getType().cast<ShapedType>(), 1); tf_reciprocal_op.x().getType().cast<ShapedType>(), 1);
if (!status_or_const_op.ok()) { if (!status_or_const_op.ok()) {
return matchFailure(); return matchFailure();
} }

View File

@ -71,7 +71,7 @@ struct LoadQuantizationRecipe : public FunctionPass<LoadQuantizationRecipe> {
void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) { void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
Type expressed_type = Type expressed_type =
lstm.input()->getType().cast<ShapedType>().getElementType(); lstm.input().getType().cast<ShapedType>().getElementType();
Type int8_storage_type = builder->getIntegerType(8); Type int8_storage_type = builder->getIntegerType(8);
Type int16_storage_type = builder->getIntegerType(16); Type int16_storage_type = builder->getIntegerType(16);
auto flag = quant::QuantizationFlags::FlagValue::Signed; auto flag = quant::QuantizationFlags::FlagValue::Signed;
@ -88,8 +88,8 @@ void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
auto any_int16 = quant::AnyQuantizedType::get( auto any_int16 = quant::AnyQuantizedType::get(
flag, int16_storage_type, expressed_type, int16_min, int16_max); flag, int16_storage_type, expressed_type, int16_min, int16_max);
int8 = any_int8.castFromExpressedType(lstm.input()->getType()); int8 = any_int8.castFromExpressedType(lstm.input().getType());
int16 = any_int16.castFromExpressedType(lstm.input()->getType()); int16 = any_int16.castFromExpressedType(lstm.input().getType());
} }
Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value in, Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value in,

View File

@ -196,13 +196,13 @@ struct ConvertTensorListSetItem : public ConversionPattern {
// Calculate `index` + 1, which is used to generate the start position for // Calculate `index` + 1, which is used to generate the start position for
// the second slice op. // the second slice op.
auto suffix_start = auto suffix_start =
rewriter.create<TF::AddOp>(loc, index->getType(), index, rewriter.create<TF::AddOp>(loc, index.getType(), index,
CreateI32SplatConst(loc, &rewriter, {}, 1)); CreateI32SplatConst(loc, &rewriter, {}, 1));
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>( auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero); loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero);
// Create two slice ops. // Create two slice ops.
Type element_type = input->getType().cast<TensorType>().getElementType(); Type element_type = input.getType().cast<TensorType>().getElementType();
UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type); UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1); Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
TF::SliceOp slice1 = TF::SliceOp slice1 =
@ -225,7 +225,7 @@ struct ConvertTensorListSetItem : public ConversionPattern {
// Concatenate three parts together to generate the final result. // Concatenate three parts together to generate the final result.
rewriter.replaceOpWithNewOp<TF::ConcatOp>( rewriter.replaceOpWithNewOp<TF::ConcatOp>(
op, input->getType(), scalar_zero, op, input.getType(), scalar_zero,
ArrayRef<Value>({slice1, expanded_item, slice2})); ArrayRef<Value>({slice1, expanded_item, slice2}));
return matchSuccess(); return matchSuccess();
} }
@ -264,7 +264,7 @@ struct ConvertTensorListInitOp : public ConversionPattern {
} }
Value element_shape = operands[0]; Value element_shape = operands[0];
Type shape_dtype = getElementTypeOrSelf(element_shape->getType()); Type shape_dtype = getElementTypeOrSelf(element_shape.getType());
DenseIntElementsAttr dense_elem_attr; DenseIntElementsAttr dense_elem_attr;
if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) { if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
@ -297,11 +297,10 @@ struct ConvertTensorListInitOp : public ConversionPattern {
new_element_shape_values.push_back(dim_value); new_element_shape_values.push_back(dim_value);
} }
auto attr = auto attr = DenseIntElementsAttr::get(
DenseIntElementsAttr::get(element_shape->getType().cast<ShapedType>(), element_shape.getType().cast<ShapedType>(), new_element_shape_values);
new_element_shape_values);
auto new_element_shape = rewriter.create<ConstantOp>( auto new_element_shape = rewriter.create<ConstantOp>(
op.getLoc(), element_shape->getType(), attr); op.getLoc(), element_shape.getType(), attr);
element_shape = new_element_shape; element_shape = new_element_shape;
} }
@ -355,7 +354,7 @@ struct ConvertTensorListReserve
Value GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value> operands, Value GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value> operands,
PatternRewriter *rewriter) const override { PatternRewriter *rewriter) const override {
Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0); Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
Type shape_dtype = getElementTypeOrSelf(op.element_shape()->getType()); Type shape_dtype = getElementTypeOrSelf(op.element_shape().getType());
Value num_elements = operands[1]; Value num_elements = operands[1];
return rewriter->create<TF::ExpandDimsOp>( return rewriter->create<TF::ExpandDimsOp>(
op.getLoc(), RankedTensorType::get({1}, shape_dtype), num_elements, op.getLoc(), RankedTensorType::get({1}, shape_dtype), num_elements,
@ -392,14 +391,14 @@ struct ConvertTensorListPushBack : public ConversionPattern {
// Expand the shape of the item so that it will have rank same as the input // Expand the shape of the item so that it will have rank same as the input
// tensor and it is compatible for the Concat Op. // tensor and it is compatible for the Concat Op.
Type expanded_item_type = Type expanded_item_type =
PrependLeadingDimIfRanked(1, item->getType(), &rewriter); PrependLeadingDimIfRanked(1, item.getType(), &rewriter);
Value scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0); Value scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0);
auto expanded_item = rewriter.create<TF::ExpandDimsOp>( auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), expanded_item_type, item, scalar_zero); op->getLoc(), expanded_item_type, item, scalar_zero);
Type elem_type = getElementTypeOrSelf(item); Type elem_type = getElementTypeOrSelf(item);
auto handle_dtype = auto handle_dtype =
getElementTypeOrSelf(push_back_op.output_handle()->getType()) getElementTypeOrSelf(push_back_op.output_handle().getType())
.cast<TF::VariantType>(); .cast<TF::VariantType>();
Type result_type = Type result_type =
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter); GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
@ -446,7 +445,7 @@ struct ConvertTensorListResize : public ConversionPattern {
// Infer result type of this op based on TF's shape inference result. // Infer result type of this op based on TF's shape inference result.
Type elem_type = getElementTypeOrSelf(input_handle); Type elem_type = getElementTypeOrSelf(input_handle);
auto handle_dtype = auto handle_dtype =
getElementTypeOrSelf(resize_op.output_handle()->getType()) getElementTypeOrSelf(resize_op.output_handle().getType())
.cast<TF::VariantType>(); .cast<TF::VariantType>();
Type result_type = Type result_type =
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter); GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
@ -463,8 +462,8 @@ struct ConvertTensorListResize : public ConversionPattern {
auto input_shape = rewriter.create<TF::ShapeOp>( auto input_shape = rewriter.create<TF::ShapeOp>(
loc, RankedTensorType::get({-1}, shape_dtype), input_handle); loc, RankedTensorType::get({-1}, shape_dtype), input_handle);
Type branch_args_type[] = {input_handle->getType(), input_shape.getType(), Type branch_args_type[] = {input_handle.getType(), input_shape.getType(),
size_diff.getType(), size->getType()}; size_diff.getType(), size.getType()};
Type branch_result_type[] = {result_type}; Type branch_result_type[] = {result_type};
auto func_type = FunctionType::get(branch_args_type, branch_result_type, auto func_type = FunctionType::get(branch_args_type, branch_result_type,
rewriter.getContext()); rewriter.getContext());
@ -524,7 +523,7 @@ struct ConvertTensorListResize : public ConversionPattern {
loc, RankedTensorType::get({-1}, shape_dtype), input_shape, slice_start, loc, RankedTensorType::get({-1}, shape_dtype), input_shape, slice_start,
slice_size); slice_size);
auto extended_part = rewriter->create<TF::TensorListReserveOp>( auto extended_part = rewriter->create<TF::TensorListReserveOp>(
loc, resize_op.output_handle()->getType(), elem_shape, size_diff); loc, resize_op.output_handle().getType(), elem_shape, size_diff);
// `ConcatOp` expects non-variant-typed input. Insert a // `ConcatOp` expects non-variant-typed input. Insert a
// `TensorListStackOp` here to convert type from variant to non-variant. // `TensorListStackOp` here to convert type from variant to non-variant.
// Note that we are using the same `result_type` for both the // Note that we are using the same `result_type` for both the
@ -627,7 +626,7 @@ struct ConvertTensorListStack : public ConversionPattern {
// trivial Reshape op (that doesn't actually change the input's shape) and // trivial Reshape op (that doesn't actually change the input's shape) and
// also populate the shape info to the op result. The shape of the // also populate the shape info to the op result. The shape of the
// tensorlist is inferred from `num_elements` and `element_shape`. // tensorlist is inferred from `num_elements` and `element_shape`.
auto ranked_type = element_shape->getType().dyn_cast<RankedTensorType>(); auto ranked_type = element_shape.getType().dyn_cast<RankedTensorType>();
DenseIntElementsAttr dense_elem_attr; DenseIntElementsAttr dense_elem_attr;
if ((ranked_type && ranked_type.getRank() == 0) || if ((ranked_type && ranked_type.getRank() == 0) ||
!matchPattern(element_shape, m_Constant(&dense_elem_attr))) { !matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
@ -659,7 +658,7 @@ struct ConvertIdentity : public ConversionPattern {
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::IdentityOp>(operation); auto op = llvm::cast<TF::IdentityOp>(operation);
Value input = operands[0]; Value input = operands[0];
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input->getType(), operands, rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands,
op.getAttrs()); op.getAttrs());
return matchSuccess(); return matchSuccess();
} }
@ -687,7 +686,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
Type arg_type = func_type.getInput(i); Type arg_type = func_type.getInput(i);
if (getElementTypeOrSelf(arg_type).isa<TF::VariantType>()) { if (getElementTypeOrSelf(arg_type).isa<TF::VariantType>()) {
arg_type = UnrankedTensorType::get( arg_type = UnrankedTensorType::get(
getElementTypeOrSelf(op.getOperand(i)->getType())); getElementTypeOrSelf(op.getOperand(i).getType()));
} }
updated_argument_types.push_back(arg_type); updated_argument_types.push_back(arg_type);
} }
@ -703,7 +702,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
// from the corresponding input operand. This is correct because while // from the corresponding input operand. This is correct because while
// body's inputs and results have the same type. // body's inputs and results have the same type.
result_type = UnrankedTensorType::get( result_type = UnrankedTensorType::get(
getElementTypeOrSelf(op.getOperand(i)->getType())); getElementTypeOrSelf(op.getOperand(i).getType()));
} }
updated_result_types.push_back(result_type); updated_result_types.push_back(result_type);
} }
@ -717,7 +716,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
// Change the argument type for the first block. // Change the argument type for the first block.
Block &body_first_bb = func.front(); Block &body_first_bb = func.front();
for (int i = 0; i < body_first_bb.getNumArguments(); ++i) { for (int i = 0; i < body_first_bb.getNumArguments(); ++i) {
body_first_bb.getArgument(i)->setType(updated_argument_types[i]); body_first_bb.getArgument(i).setType(updated_argument_types[i]);
} }
} }
return success(); return success();
@ -735,12 +734,12 @@ struct ConvertWhile : public ConversionPattern {
llvm::SmallVector<Type, 8> result_types; llvm::SmallVector<Type, 8> result_types;
result_types.reserve(op.getNumOperands()); result_types.reserve(op.getNumOperands());
for (int i = 0, e = operands.size(); i != e; ++i) { for (int i = 0, e = operands.size(); i != e; ++i) {
Type result_ty = op.getResult(i)->getType(); Type result_ty = op.getResult(i).getType();
// If we notice the result type is a DT_VARIANT, we change the // If we notice the result type is a DT_VARIANT, we change the
// corresponding result type to unranked tensor type. // corresponding result type to unranked tensor type.
if (getElementTypeOrSelf(result_ty).isa<TF::VariantType>()) { if (getElementTypeOrSelf(result_ty).isa<TF::VariantType>()) {
Type element_ty = getElementTypeOrSelf(operands[i]->getType()); Type element_ty = getElementTypeOrSelf(operands[i].getType());
result_ty = UnrankedTensorType::get(element_ty); result_ty = UnrankedTensorType::get(element_ty);
} }
result_types.push_back(result_ty); result_types.push_back(result_ty);

View File

@ -51,15 +51,15 @@ namespace TFL {
namespace { namespace {
bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) { bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
if (sq_op->getType().cast<ShapedType>().getRank() - 1 == if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
*axis.getValues<int>().begin() || *axis.getValues<int>().begin() ||
*axis.getValues<int>().begin() == -1) { *axis.getValues<int>().begin() == -1) {
return true; return true;
} }
if (sq_op->getType().cast<ShapedType>().getRank() != axis.getNumElements()) { if (sq_op.getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
return false; return false;
} }
auto shape = sq_op->getType().cast<ShapedType>(); auto shape = sq_op.getType().cast<ShapedType>();
SmallVector<int, 4> elems{axis.getValues<int>().begin(), SmallVector<int, 4> elems{axis.getValues<int>().begin(),
axis.getValues<int>().end()}; axis.getValues<int>().end()};
for (int i = 0; i < shape.getRank(); ++i) { for (int i = 0; i < shape.getRank(); ++i) {
@ -143,7 +143,7 @@ ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
// Returns shape of a ranked tensor. // Returns shape of a ranked tensor.
// Precondition: output_val's is ranked tensor. // Precondition: output_val's is ranked tensor.
DenseElementsAttr GetShape(Value output_val) { DenseElementsAttr GetShape(Value output_val) {
auto output_type = output_val->getType().cast<RankedTensorType>(); auto output_type = output_val.getType().cast<RankedTensorType>();
auto shape_vector = output_type.getShape(); auto shape_vector = output_type.getShape();
std::vector<int32_t> shape(shape_vector.size()); std::vector<int32_t> shape(shape_vector.size());
for (int i = 0; i < shape_vector.size(); ++i) { for (int i = 0; i < shape_vector.size(); ++i) {
@ -152,7 +152,7 @@ DenseElementsAttr GetShape(Value output_val) {
return mlir::DenseElementsAttr::get( return mlir::DenseElementsAttr::get(
RankedTensorType::get( RankedTensorType::get(
{static_cast<int>(shape.size())}, {static_cast<int>(shape.size())},
mlir::IntegerType::get(32, output_val->getContext())), mlir::IntegerType::get(32, output_val.getContext())),
llvm::makeArrayRef(shape)); llvm::makeArrayRef(shape));
} }
@ -173,13 +173,13 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
// Fully Connected. // Fully Connected.
auto fc_op = auto fc_op =
dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs()->getDefiningOp()); dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs().getDefiningOp());
if (!fc_op) return matchFailure(); if (!fc_op) return matchFailure();
Value filter = fc_op.filter(); Value filter = fc_op.filter();
Value bias = fc_op.bias(); Value bias = fc_op.bias();
ElementsAttr bias_value; ElementsAttr bias_value;
const bool is_none_bias = bias->getType().isa<NoneType>(); const bool is_none_bias = bias.getType().isa<NoneType>();
if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value))) if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
return matchFailure(); return matchFailure();
if (fc_op.fused_activation_function() != "NONE") return matchFailure(); if (fc_op.fused_activation_function() != "NONE") return matchFailure();
@ -213,7 +213,7 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
PatternMatchResult matchAndRewrite(TFL::ReluOp relu_op, PatternMatchResult matchAndRewrite(TFL::ReluOp relu_op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Operation *input = relu_op.getOperand()->getDefiningOp(); Operation *input = relu_op.getOperand().getDefiningOp();
if (!isa_and_nonnull<FullyConnectedOp>(input)) return matchFailure(); if (!isa_and_nonnull<FullyConnectedOp>(input)) return matchFailure();
auto fully_connected_op = cast<FullyConnectedOp>(input); auto fully_connected_op = cast<FullyConnectedOp>(input);
if (fully_connected_op.fused_activation_function() != "NONE") if (fully_connected_op.fused_activation_function() != "NONE")
@ -247,13 +247,13 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
// Fully Connected. // Fully Connected.
auto fc_op = auto fc_op =
dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs()->getDefiningOp()); dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs().getDefiningOp());
if (!fc_op) return matchFailure(); if (!fc_op) return matchFailure();
Value filter = fc_op.filter(); Value filter = fc_op.filter();
Value bias = fc_op.bias(); Value bias = fc_op.bias();
ElementsAttr cst_tmp; ElementsAttr cst_tmp;
if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure(); if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure();
if (!bias->getType().isa<NoneType>() && if (!bias.getType().isa<NoneType>() &&
!matchPattern(bias, m_Constant(&cst_tmp))) !matchPattern(bias, m_Constant(&cst_tmp)))
return matchFailure(); return matchFailure();
if (fc_op.fused_activation_function().equals("None")) return matchFailure(); if (fc_op.fused_activation_function().equals("None")) return matchFailure();
@ -262,7 +262,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
// filter input. We only support broadcasting the operand along the depth // filter input. We only support broadcasting the operand along the depth
// dimension, when the operand's depth is 1. // dimension, when the operand's depth is 1.
Value new_const_val = constant_val; Value new_const_val = constant_val;
if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter->getType())) { if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter.getType())) {
auto original_shape = cst.getType().getShape(); auto original_shape = cst.getType().getShape();
llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(), llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(),
original_shape.end()); original_shape.end());
@ -270,7 +270,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
auto new_cst = cst.reshape(RankedTensorType::get( auto new_cst = cst.reshape(RankedTensorType::get(
normalized_shape, cst.getType().getElementType())); normalized_shape, cst.getType().getElementType()));
Type new_type = new_cst.getType(); Type new_type = new_cst.getType();
if (!IsBroadcastableElementsAttrAndType(new_type, filter->getType())) { if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
return matchFailure(); return matchFailure();
} }
auto new_op = auto new_op =
@ -285,7 +285,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
auto new_filter = auto new_filter =
rewriter.create<TF::MulOp>(loc, filter, new_const_val).z(); rewriter.create<TF::MulOp>(loc, filter, new_const_val).z();
// If bias isn't None, it needs to be multiplied as well. // If bias isn't None, it needs to be multiplied as well.
if (!bias->getType().isa<NoneType>()) { if (!bias.getType().isa<NoneType>()) {
bias = rewriter.create<TF::MulOp>(loc, bias, constant_val).z(); bias = rewriter.create<TF::MulOp>(loc, bias, constant_val).z();
} }
@ -311,7 +311,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
PatternMatchResult matchAndRewrite(AffineOpType fc_op, PatternMatchResult matchAndRewrite(AffineOpType fc_op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
// Binary op. // Binary op.
Operation *binary_op = fc_op.input()->getDefiningOp(); Operation *binary_op = fc_op.input().getDefiningOp();
if (!binary_op || binary_op->getNumOperands() != 2) if (!binary_op || binary_op->getNumOperands() != 2)
return this->matchFailure(); return this->matchFailure();
// We only handle the cases the RHS is a scalar. // We only handle the cases the RHS is a scalar.
@ -330,15 +330,15 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
DenseFPElementsAttr filter_cst, bias_cst; DenseFPElementsAttr filter_cst, bias_cst;
if (!matchPattern(filter, m_Constant(&filter_cst))) { if (!matchPattern(filter, m_Constant(&filter_cst))) {
// The filter maybe quantized, then we should set it to the real constant. // The filter maybe quantized, then we should set it to the real constant.
auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter->getDefiningOp()); auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter.getDefiningOp());
if (!dq) return this->matchFailure(); if (!dq) return this->matchFailure();
auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input()->getDefiningOp()); auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input().getDefiningOp());
if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) { if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) {
return this->matchFailure(); return this->matchFailure();
} }
filter = q.input(); filter = q.input();
} }
if (!bias->getType().isa<NoneType>() && if (!bias.getType().isa<NoneType>() &&
!matchPattern(bias, m_Constant(&bias_cst))) !matchPattern(bias, m_Constant(&bias_cst)))
return this->matchFailure(); return this->matchFailure();
ShapedType filter_type = filter_cst.getType(); ShapedType filter_type = filter_cst.getType();
@ -362,7 +362,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
// The new bias should be a 1-D tensor with length equals to the bias // The new bias should be a 1-D tensor with length equals to the bias
// dimension of the weight. // dimension of the weight.
SmallVector<APFloat, 4> new_bias_values; SmallVector<APFloat, 4> new_bias_values;
if (bias->getType().isa<NoneType>()) { // none bias, a list of zeros if (bias.getType().isa<NoneType>()) { // none bias, a list of zeros
new_bias_values.resize(bias_size, APFloat(0.0)); new_bias_values.resize(bias_size, APFloat(0.0));
} else if (bias_cst.getNumElements() == 1) { // scalar bias, broadcast it } else if (bias_cst.getNumElements() == 1) { // scalar bias, broadcast it
new_bias_values.resize(bias_size, *bias_cst.float_value_begin()); new_bias_values.resize(bias_size, *bias_cst.float_value_begin());
@ -401,12 +401,12 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
// We recreate the constant op in case it is shared by the other ops. This // We recreate the constant op in case it is shared by the other ops. This
// might increase the model size. // might increase the model size.
auto new_filter_op = rewriter.create<ConstOp>( auto new_filter_op = rewriter.create<ConstOp>(
fc_op.getLoc(), filter->getType(), new_filter); fc_op.getLoc(), filter.getType(), new_filter);
fc_op.setOperand(0, binary_op->getOperand(0)); fc_op.setOperand(0, binary_op->getOperand(0));
if (fc_op.filter() != filter) { if (fc_op.filter() != filter) {
// This filter goes through quantize and dequantize ops. Then we just // This filter goes through quantize and dequantize ops. Then we just
// need to update the weight to the quantize op. // need to update the weight to the quantize op.
filter->replaceAllUsesWith(new_filter_op); filter.replaceAllUsesWith(new_filter_op);
} else { } else {
// This filter doesn't go through quantize and dequantize ops, Then // This filter doesn't go through quantize and dequantize ops, Then
// we update the weight of the affine op directly. // we update the weight of the affine op directly.

View File

@ -55,7 +55,7 @@ foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>; defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
// Checks if the value has only one user. // Checks if the value has only one user.
def HasOneUse : Constraint<CPred<"$0->hasOneUse()">>; def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
// If we see a binary op (add, sub) op adding a constant value to a convolution // If we see a binary op (add, sub) op adding a constant value to a convolution
// op with constant bias, we can fuse the binary op into the convolution op by // op with constant bias, we can fuse the binary op into the convolution op by
@ -161,7 +161,7 @@ def EqualOperands : Constraint<CPred<"$0 == $1">>;
// Checks if the operand has rank == n // Checks if the operand has rank == n
class OperandHasRank<int n> : Constraint< class OperandHasRank<int n> : Constraint<
CPred<"$0->getType().cast<ShapedType>().getRank() == " # n>>; CPred<"$0.getType().cast<ShapedType>().getRank() == " # n>>;
// Matching HardSwish // Matching HardSwish
def : Pat< def : Pat<
@ -256,7 +256,7 @@ foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]]
in defm : L2NormalizePatterns<L2NormalizePairs[0], L2NormalizePairs[1]>; in defm : L2NormalizePatterns<L2NormalizePairs[0], L2NormalizePairs[1]>;
def AreBroadcastableTypes : Constraint<CPred< def AreBroadcastableTypes : Constraint<CPred<
"TFL::IsBroadcastableElementsAttrAndType($0->getType(), $1->getType())">>; "TFL::IsBroadcastableElementsAttrAndType($0.getType(), $1.getType())">>;
// Pattern for skipping Tile if it is mainly for broadcasting and the // Pattern for skipping Tile if it is mainly for broadcasting and the
// Op is already supporting broadcasting. // Op is already supporting broadcasting.

View File

@ -71,29 +71,29 @@ void RemoveQuantizationAdaptorOps(FuncOp func) {
auto remove_quantize_op = [&](QuantizeOp quantize_op) { auto remove_quantize_op = [&](QuantizeOp quantize_op) {
auto quantize_output = quantize_op.output(); auto quantize_output = quantize_op.output();
auto quantize_type = quantize_output->getType(); auto quantize_type = quantize_output.getType();
input_types.push_back(quantize_type); input_types.push_back(quantize_type);
auto new_arg = bb.addArgument(quantize_type); auto new_arg = bb.addArgument(quantize_type);
quantize_output->replaceAllUsesWith(new_arg); quantize_output.replaceAllUsesWith(new_arg);
quantize_op.erase(); quantize_op.erase();
arg->dropAllUses(); arg.dropAllUses();
bb.eraseArgument(0); bb.eraseArgument(0);
}; };
// This is looking for a pattern: arg -> tfl.quantize // This is looking for a pattern: arg -> tfl.quantize
if (arg->hasOneUse() && llvm::isa<QuantizeOp>(*arg->user_begin())) { if (arg.hasOneUse() && llvm::isa<QuantizeOp>(*arg.user_begin())) {
auto quantize_op = llvm::cast<QuantizeOp>(*arg->user_begin()); auto quantize_op = llvm::cast<QuantizeOp>(*arg.user_begin());
remove_quantize_op(quantize_op); remove_quantize_op(quantize_op);
continue; continue;
} }
// Make a copy of current argument and append it to the end of the list if // Make a copy of current argument and append it to the end of the list if
// the pattern isn't found. // the pattern isn't found.
Type arg_type = arg->getType(); Type arg_type = arg.getType();
input_types.push_back(arg_type); input_types.push_back(arg_type);
auto new_arg = bb.addArgument(arg_type); auto new_arg = bb.addArgument(arg_type);
arg->replaceAllUsesWith(new_arg); arg.replaceAllUsesWith(new_arg);
arg->dropAllUses(); arg.dropAllUses();
bb.eraseArgument(0); bb.eraseArgument(0);
} }
@ -103,15 +103,15 @@ void RemoveQuantizationAdaptorOps(FuncOp func) {
output_types.reserve(num_return_operands); output_types.reserve(num_return_operands);
for (int i = 0; i != num_return_operands; ++i) { for (int i = 0; i != num_return_operands; ++i) {
auto returned_value = terminator->getOperand(i); auto returned_value = terminator->getOperand(i);
Operation* returned_op = returned_value->getDefiningOp(); Operation* returned_op = returned_value.getDefiningOp();
if (returned_op && llvm::isa<DequantizeOp>(returned_op)) { if (returned_op && llvm::isa<DequantizeOp>(returned_op)) {
auto dequantize_op = llvm::cast<DequantizeOp>(returned_op); auto dequantize_op = llvm::cast<DequantizeOp>(returned_op);
Value dequantized_result = dequantize_op.input(); Value dequantized_result = dequantize_op.input();
output_types.push_back(dequantized_result->getType()); output_types.push_back(dequantized_result.getType());
terminator->setOperand(i, dequantized_result); terminator->setOperand(i, dequantized_result);
returned_op->erase(); returned_op->erase();
} else { } else {
output_types.push_back(returned_value->getType()); output_types.push_back(returned_value.getType());
} }
} }
auto new_func_type = builder.getFunctionType(input_types, output_types); auto new_func_type = builder.getFunctionType(input_types, output_types);

View File

@ -135,10 +135,10 @@ def : Pat<(TF_ReshapeOp
// Casts result type of $1 to a quantized type by using the quantization // Casts result type of $1 to a quantized type by using the quantization
// parameters from the type in $0. // parameters from the type in $0.
class UpdateShapeWithAxis<int i> : NativeCodeCall< class UpdateShapeWithAxis<int i> : NativeCodeCall<
"CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1->getType(), " # i # ")">; "CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1.getType(), " # i # ")">;
class UsedBy<string op> : Constraint< class UsedBy<string op> : Constraint<
CPred<"llvm::isa<mlir::TFL::" # op # "Op>(*$0->getUsers().begin())">>; CPred<"llvm::isa<mlir::TFL::" # op # "Op>(*$0.getUsers().begin())">>;
// When the op is passing-through, the output types of the quantized ops need // When the op is passing-through, the output types of the quantized ops need
// to be updated as well. Since the quantize op manages its own type by the // to be updated as well. Since the quantize op manages its own type by the

View File

@ -153,7 +153,7 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
params); params);
auto dq_op = auto dq_op =
builder.create<TFL::DequantizeOp>(loc, input_type, q_op.output()); builder.create<TFL::DequantizeOp>(loc, input_type, q_op.output());
arg->replaceAllUsesWith(dq_op.output()); arg.replaceAllUsesWith(dq_op.output());
q_op.setOperand(arg); q_op.setOperand(arg);
} }
} }
@ -161,8 +161,8 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
for (int i = 0, e = func.getNumArguments(); i != e; ++i) { for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
BlockArgument arg = func.getArgument(i); BlockArgument arg = func.getArgument(i);
auto* arg_block = arg->getOwner(); auto* arg_block = arg.getOwner();
add_quantize_op(arg->getLoc(), arg->getType(), arg_block, add_quantize_op(arg.getLoc(), arg.getType(), arg_block,
std::next(arg_block->begin(), i), arg, i); std::next(arg_block->begin(), i), arg, i);
} }

View File

@ -115,7 +115,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
// We don't want to insert quantize/dequantize if the quantize op exists. // We don't want to insert quantize/dequantize if the quantize op exists.
auto res = tf_op.outputs(); auto res = tf_op.outputs();
if (!res->hasOneUse() || isa<QuantizeOp>(*res->user_begin())) if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin()))
return this->matchFailure(); return this->matchFailure();
// Extract the min/max constant values from the operands. We also consider // Extract the min/max constant values from the operands. We also consider
@ -123,9 +123,9 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
// constants and the tf.FakeQuantWithMinMaxVarsOp. // constants and the tf.FakeQuantWithMinMaxVarsOp.
Value min = tf_op.min(), max = tf_op.max(); Value min = tf_op.min(), max = tf_op.max();
DenseFPElementsAttr min_value, max_value; DenseFPElementsAttr min_value, max_value;
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min->getDefiningOp())) if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp()))
min = id1.input(); min = id1.input();
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max->getDefiningOp())) if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp()))
max = id2.input(); max = id2.input();
if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure(); if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure(); if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
@ -133,7 +133,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
int quant_dim = -1; int quant_dim = -1;
if (PerAxis) { if (PerAxis) {
// This is a special case that the quant_dim is the last dimensions. // This is a special case that the quant_dim is the last dimensions.
quant_dim = res->getType().template cast<ShapedType>().getRank() - 1; quant_dim = res.getType().template cast<ShapedType>().getRank() - 1;
} }
// Use the min/max from the operands and the num_bits and narrow_range // Use the min/max from the operands and the num_bits and narrow_range
// attribute to create the quantization parameter for the new quantize op. // attribute to create the quantization parameter for the new quantize op.
@ -155,7 +155,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
tf_op.getLoc(), qtype.getValue(), value, qtype); tf_op.getLoc(), qtype.getValue(), value, qtype);
auto dequantize = rewriter.create<TFL::DequantizeOp>( auto dequantize = rewriter.create<TFL::DequantizeOp>(
tf_op.getLoc(), res_type, quantize.output()); tf_op.getLoc(), res_type, quantize.output());
value->replaceAllUsesWith(dequantize); value.replaceAllUsesWith(dequantize);
quantize.getOperation()->replaceUsesOfWith(dequantize, value); quantize.getOperation()->replaceUsesOfWith(dequantize, value);
return this->matchSuccess(); return this->matchSuccess();
@ -240,7 +240,7 @@ struct ConvertTFConvOp : public RewritePattern {
// that we can extract info from the shape (e.g., for constructing bias // that we can extract info from the shape (e.g., for constructing bias
// tensor, for setting depth_multiplier attribute, etc.). // tensor, for setting depth_multiplier attribute, etc.).
auto filter_type = auto filter_type =
tf_op.filter()->getType().template dyn_cast<RankedTensorType>(); tf_op.filter().getType().template dyn_cast<RankedTensorType>();
if (filter_type && filter_type.getRank() == 4) if (filter_type && filter_type.getRank() == 4)
return matchSuccess(std::move(state)); return matchSuccess(std::move(state));
@ -262,7 +262,7 @@ struct ConvertTFConvOp : public RewritePattern {
// Get a splat zero tensor with the expected dimension for the bias tensor // Get a splat zero tensor with the expected dimension for the bias tensor
auto filter = tf_op.filter(); auto filter = tf_op.filter();
auto filter_type = filter->getType().template cast<RankedTensorType>(); auto filter_type = filter.getType().template cast<RankedTensorType>();
auto elem_type = filter_type.getElementType(); auto elem_type = filter_type.getElementType();
auto bias_dim = static_cast<const ConcreteType *>(this)->getBiasDim( auto bias_dim = static_cast<const ConcreteType *>(this)->getBiasDim(
filter_type.getShape()); filter_type.getShape());
@ -323,7 +323,7 @@ class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr); auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr);
// Create tensor type for the transpose result. // Create tensor type for the transpose result.
auto filter_type = filter->getType().cast<RankedTensorType>(); auto filter_type = filter.getType().cast<RankedTensorType>();
auto result_shape = functional::map( auto result_shape = functional::map(
[filter_type](int64_t dim) { return filter_type.getDimSize(dim); }, [filter_type](int64_t dim) { return filter_type.getDimSize(dim); },
perm); perm);
@ -356,7 +356,7 @@ class ConvertTFDepthwiseConv2dNative
// have a corresponding 'depth_multiplier' attribute; the multiplier is the // have a corresponding 'depth_multiplier' attribute; the multiplier is the
// fourth dimension in the 4-D filter tensor. We query the multiplier from // fourth dimension in the 4-D filter tensor. We query the multiplier from
// tf.DepthwiseConv2dNative and set it as the attribute value accordingly. // tf.DepthwiseConv2dNative and set it as the attribute value accordingly.
auto multiplier = filter->getType().cast<RankedTensorType>().getDimSize(3); auto multiplier = filter.getType().cast<RankedTensorType>().getDimSize(3);
filter = legalizeFilter(rewriter, loc, filter); filter = legalizeFilter(rewriter, loc, filter);
return rewriter.create<TFL::DepthwiseConv2DOp>( return rewriter.create<TFL::DepthwiseConv2DOp>(
@ -380,7 +380,7 @@ class ConvertTFDepthwiseConv2dNative
/// RankedTensorType. /// RankedTensorType.
Value legalizeFilter(PatternRewriter &rewriter, Location loc, Value legalizeFilter(PatternRewriter &rewriter, Location loc,
Value filter) const { Value filter) const {
auto filter_type = filter->getType().cast<RankedTensorType>(); auto filter_type = filter.getType().cast<RankedTensorType>();
auto filterShape = filter_type.getShape(); auto filterShape = filter_type.getShape();
SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1], SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1],
filterShape[2] * filterShape[3]}; filterShape[2] * filterShape[3]};
@ -432,11 +432,11 @@ struct ConvertTFStridedSlice : public RewritePattern {
// Insert a new reshape op. // Insert a new reshape op.
Value original_input = strided_slice_op.input(); Value original_input = strided_slice_op.input();
RankedTensorType original_input_type = RankedTensorType original_input_type =
original_input->getType().cast<RankedTensorType>(); original_input.getType().cast<RankedTensorType>();
const ArrayRef<int64_t> &original_input_shape = const ArrayRef<int64_t> &original_input_shape =
original_input_type.getShape(); original_input_type.getShape();
RankedTensorType begin_type = RankedTensorType begin_type =
strided_slice_op.begin()->getType().cast<RankedTensorType>(); strided_slice_op.begin().getType().cast<RankedTensorType>();
const int dim_size = begin_type.getShape()[0]; const int dim_size = begin_type.getShape()[0];
SmallVector<int64_t, 4> new_shape; SmallVector<int64_t, 4> new_shape;
int mask = 1; int mask = 1;

View File

@ -83,7 +83,7 @@ LogicalResult DuplicateValueIfNeeded(Operation* op,
// We can only clone the constant op at this point. // We can only clone the constant op at this point.
// Since all ops have been legalized to tflite ops, so we only care about // Since all ops have been legalized to tflite ops, so we only care about
// ConstOp or QConstOp or mlir constant op/ // ConstOp or QConstOp or mlir constant op/
Operation* input_op = operand->getDefiningOp(); Operation* input_op = operand.getDefiningOp();
if (input_op == nullptr) return failure(); if (input_op == nullptr) return failure();
Attribute attr; Attribute attr;

View File

@ -83,7 +83,7 @@ TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
template <typename BatchMatMulOpType> template <typename BatchMatMulOpType>
std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput( std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
Value value, int batch_size, Location loc, PatternRewriter& rewriter) { Value value, int batch_size, Location loc, PatternRewriter& rewriter) {
RankedTensorType tensorType = value->getType().cast<RankedTensorType>(); RankedTensorType tensorType = value.getType().cast<RankedTensorType>();
Type element_type = tensorType.getElementType(); Type element_type = tensorType.getElementType();
int rank = tensorType.getShape().size(); int rank = tensorType.getShape().size();
@ -127,7 +127,7 @@ std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
template <typename BatchMatMulOpType> template <typename BatchMatMulOpType>
TF::TransposeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createTransposeOp( TF::TransposeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createTransposeOp(
Value value, Location loc, PatternRewriter& rewriter) { Value value, Location loc, PatternRewriter& rewriter) {
auto value_type = value->getType().cast<RankedTensorType>(); auto value_type = value.getType().cast<RankedTensorType>();
auto shape = value_type.getShape(); auto shape = value_type.getShape();
int dims = shape.size(); int dims = shape.size();
@ -197,17 +197,17 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
Value input_lhs = op.x(); Value input_lhs = op.x();
Value input_rhs = op.y(); Value input_rhs = op.y();
if (!input_lhs->getType().isa<RankedTensorType>()) { if (!input_lhs.getType().isa<RankedTensorType>()) {
// LHS must be a ranked tensor type // LHS must be a ranked tensor type
return this->matchFailure(); return this->matchFailure();
} }
if (!input_rhs->getType().isa<RankedTensorType>()) { if (!input_rhs.getType().isa<RankedTensorType>()) {
// RHS must be a ranked tensor type // RHS must be a ranked tensor type
return this->matchFailure(); return this->matchFailure();
} }
auto lhs_type = input_lhs->getType().cast<RankedTensorType>(); auto lhs_type = input_lhs.getType().cast<RankedTensorType>();
auto rhs_type = input_rhs->getType().cast<RankedTensorType>(); auto rhs_type = input_rhs.getType().cast<RankedTensorType>();
auto element_type = lhs_type.getElementType(); auto element_type = lhs_type.getElementType();
@ -233,7 +233,7 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
if (op.adj_x()) { if (op.adj_x()) {
input_lhs = createTransposeOp(input_lhs, loc, rewriter); input_lhs = createTransposeOp(input_lhs, loc, rewriter);
lhs_type = input_lhs->getType().cast<RankedTensorType>(); lhs_type = input_lhs.getType().cast<RankedTensorType>();
lhs_shape = lhs_type.getShape(); lhs_shape = lhs_type.getShape();
} }
@ -241,7 +241,7 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
if (op.adj_y()) { if (op.adj_y()) {
input_rhs = createTransposeOp(input_rhs, loc, rewriter); input_rhs = createTransposeOp(input_rhs, loc, rewriter);
rhs_type = input_rhs->getType().cast<RankedTensorType>(); rhs_type = input_rhs.getType().cast<RankedTensorType>();
rhs_shape = rhs_type.getShape(); rhs_shape = rhs_type.getShape();
} }

View File

@ -88,7 +88,7 @@ Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
} }
ArrayRef<int64_t> GetRankedTensorShape(Value value) { ArrayRef<int64_t> GetRankedTensorShape(Value value) {
return value->getType().cast<RankedTensorType>().getShape(); return value.getType().cast<RankedTensorType>().getShape();
} }
Value SliceRankedTensor(OpBuilder* builder, Value input, Value SliceRankedTensor(OpBuilder* builder, Value input,
@ -120,7 +120,7 @@ Value SliceRankedTensor(OpBuilder* builder, Value input,
location, location,
RankedTensorType::get( RankedTensorType::get(
size_values, size_values,
input->getType().cast<RankedTensorType>().getElementType()), input.getType().cast<RankedTensorType>().getElementType()),
input, slice_i2c_begin, slice_i2c_size); input, slice_i2c_begin, slice_i2c_size);
} }
@ -327,8 +327,7 @@ void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() {
SmallVector<int64_t, 2> output_shape{1, -1}; SmallVector<int64_t, 2> output_shape{1, -1};
auto input_types = fused_func_op_.getType().getInputs(); auto input_types = fused_func_op_.getType().getInputs();
auto output_type = mlir::RankedTensorType::get( auto output_type = mlir::RankedTensorType::get(
output_shape, output_shape, input_.getType().cast<RankedTensorType>().getElementType());
input_->getType().cast<RankedTensorType>().getElementType());
fused_func_op_.setType(mlir::FunctionType::get(input_types, output_type, fused_func_op_.setType(mlir::FunctionType::get(input_types, output_type,
fused_func_op_.getContext())); fused_func_op_.getContext()));
} }
@ -351,8 +350,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
// Create the fused LSTM op. // Create the fused LSTM op.
SmallVector<int64_t, 2> output_shape = {1, n_output_}; SmallVector<int64_t, 2> output_shape = {1, n_output_};
auto result_type = mlir::RankedTensorType::get( auto result_type = mlir::RankedTensorType::get(
output_shape, output_shape, input_.getType().cast<RankedTensorType>().getElementType());
input_->getType().cast<RankedTensorType>().getElementType());
lstm_ = builder_.create<mlir::TFL::LSTMOp>( lstm_ = builder_.create<mlir::TFL::LSTMOp>(
fused_func_op_.getLoc(), result_type, input_, input2input_, input2forget_, fused_func_op_.getLoc(), result_type, input_, input2input_, input2forget_,
input2cell_, input2output_, rec2input_, rec2forget_, rec2cell_, input2cell_, input2output_, rec2input_, rec2forget_, rec2cell_,
@ -371,7 +369,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
SmallVector<int64_t, 2> func_output_shape = {1, -1}; SmallVector<int64_t, 2> func_output_shape = {1, -1};
auto func_result_type = mlir::RankedTensorType::get( auto func_result_type = mlir::RankedTensorType::get(
func_output_shape, func_output_shape,
input_->getType().cast<RankedTensorType>().getElementType()); input_.getType().cast<RankedTensorType>().getElementType());
auto tensor_cast = builder_.create<mlir::TensorCastOp>( auto tensor_cast = builder_.create<mlir::TensorCastOp>(
fused_func_op_.getLoc(), lstm_.getResult(), func_result_type); fused_func_op_.getLoc(), lstm_.getResult(), func_result_type);
@ -426,7 +424,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() {
bias_ = fused_func_op_.getArgument(2); bias_ = fused_func_op_.getArgument(2);
weight_ = fused_func_op_.getArgument(1); weight_ = fused_func_op_.getArgument(1);
weight_type_ = weight_->getType().cast<RankedTensorType>(); weight_type_ = weight_.getType().cast<RankedTensorType>();
if (weight_type_.getRank() != 2) { if (weight_type_.getRank() != 2) {
return fused_func_op_.emitError() << "The weight tensor was not of rank 2"; return fused_func_op_.emitError() << "The weight tensor was not of rank 2";
@ -440,7 +438,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() {
n_cell_ = weight_type_.getDimSize(1) / num_gates_; n_cell_ = weight_type_.getDimSize(1) / num_gates_;
projection_ = fused_func_op_.getArgument(3); projection_ = fused_func_op_.getArgument(3);
projection_type_ = projection_->getType().cast<RankedTensorType>(); projection_type_ = projection_.getType().cast<RankedTensorType>();
if (projection_type_.getRank() != 2) { if (projection_type_.getRank() != 2) {
n_output_ = n_cell_; n_output_ = n_cell_;
} else { } else {
@ -467,8 +465,7 @@ LogicalResult ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::Initialize() {
} }
layer_norm_scale_ = fused_func_op_.getArgument(4); layer_norm_scale_ = fused_func_op_.getArgument(4);
layer_norm_scale_type_ = layer_norm_scale_type_ = layer_norm_scale_.getType().cast<RankedTensorType>();
layer_norm_scale_->getType().cast<RankedTensorType>();
if (layer_norm_scale_type_.getRank() != 1) { if (layer_norm_scale_type_.getRank() != 1) {
return fused_func_op_.emitError() return fused_func_op_.emitError()
<< "The layer_norm_scale tensor was not of rank 1"; << "The layer_norm_scale tensor was not of rank 1";

View File

@ -128,22 +128,20 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) {
auto transpose_op = fused_lstm_func_.getBody().front().begin(); auto transpose_op = fused_lstm_func_.getBody().front().begin();
transpose_op++; transpose_op++;
EXPECT_EQ(transpose_op->getOperand(0)
->getType()
.cast<RankedTensorType>()
.getDimSize(0),
3);
EXPECT_EQ(transpose_op->getOperand(0)
->getType()
.cast<RankedTensorType>()
.getDimSize(1),
12);
EXPECT_EQ( EXPECT_EQ(
transpose_op->getResult(0)->getType().cast<RankedTensorType>().getDimSize( transpose_op->getOperand(0).getType().cast<RankedTensorType>().getDimSize(
0),
3);
EXPECT_EQ(
transpose_op->getOperand(0).getType().cast<RankedTensorType>().getDimSize(
1),
12);
EXPECT_EQ(
transpose_op->getResult(0).getType().cast<RankedTensorType>().getDimSize(
0), 0),
12); 12);
EXPECT_EQ( EXPECT_EQ(
transpose_op->getResult(0)->getType().cast<RankedTensorType>().getDimSize( transpose_op->getResult(0).getType().cast<RankedTensorType>().getDimSize(
1), 1),
3); 3);
@ -156,12 +154,12 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) {
EXPECT_EQ(it->getNumOperands(), 24); EXPECT_EQ(it->getNumOperands(), 24);
EXPECT_EQ(it->getNumResults(), 1); EXPECT_EQ(it->getNumResults(), 1);
// cifg = false, so input2input is not None. // cifg = false, so input2input is not None.
EXPECT_FALSE(it->getOperand(1)->getType().isa<NoneType>()); EXPECT_FALSE(it->getOperand(1).getType().isa<NoneType>());
// input layer norm is None // input layer norm is None
EXPECT_TRUE(it->getOperand(20)->getType().isa<NoneType>()); EXPECT_TRUE(it->getOperand(20).getType().isa<NoneType>());
// proj_bias is F32 // proj_bias is F32
EXPECT_TRUE(it->getOperand(17) EXPECT_TRUE(it->getOperand(17)
->getType() .getType()
.cast<RankedTensorType>() .cast<RankedTensorType>()
.getElementType() .getElementType()
.isF32()); .isF32());
@ -169,7 +167,7 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) {
// output gate bias is 0 since it is out of bounds of the bias tensor, so // output gate bias is 0 since it is out of bounds of the bias tensor, so
// we set its value as a const tensor of specified size and value 0. // we set its value as a const tensor of specified size and value 0.
EXPECT_TRUE( EXPECT_TRUE(
mlir::cast<mlir::ConstantOp>(it->getOpOperand(15).get()->getDefiningOp()) mlir::cast<mlir::ConstantOp>(it->getOpOperand(15).get().getDefiningOp())
.getValue() .getValue()
.cast<ElementsAttr>() .cast<ElementsAttr>()
.getValue<FloatAttr>(0) .getValue<FloatAttr>(0)
@ -209,7 +207,7 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimpleToFusedLSTMCoupleInputForget) {
EXPECT_EQ(it->getNumOperands(), 24); EXPECT_EQ(it->getNumOperands(), 24);
EXPECT_EQ(it->getNumResults(), 1); EXPECT_EQ(it->getNumResults(), 1);
// cifg = true, so input2input is None. // cifg = true, so input2input is None.
EXPECT_TRUE(it->getOperand(1)->getType().isa<NoneType>()); EXPECT_TRUE(it->getOperand(1).getType().isa<NoneType>());
} }
TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) { TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) {
@ -235,15 +233,15 @@ TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) {
EXPECT_EQ(it->getNumOperands(), 24); EXPECT_EQ(it->getNumOperands(), 24);
EXPECT_EQ(it->getNumResults(), 1); EXPECT_EQ(it->getNumResults(), 1);
// cifg = false, so input2input is not None. // cifg = false, so input2input is not None.
EXPECT_FALSE(it->getOperand(1)->getType().isa<NoneType>()); EXPECT_FALSE(it->getOperand(1).getType().isa<NoneType>());
// input layer norm // input layer norm
EXPECT_FALSE(it->getOperand(20)->getType().isa<NoneType>()); EXPECT_FALSE(it->getOperand(20).getType().isa<NoneType>());
EXPECT_EQ( EXPECT_EQ(
it->getOperand(20)->getType().cast<RankedTensorType>().getShape().size(), it->getOperand(20).getType().cast<RankedTensorType>().getShape().size(),
1); 1);
EXPECT_EQ( EXPECT_EQ(it->getOperand(20).getType().cast<RankedTensorType>().getDimSize(0),
it->getOperand(20)->getType().cast<RankedTensorType>().getDimSize(0), 3); 3);
EXPECT_EQ(fused_ln_lstm_func_.getType().getNumResults(), 1); EXPECT_EQ(fused_ln_lstm_func_.getType().getNumResults(), 1);
auto output_types = fused_ln_lstm_func_.getType().getResults(); auto output_types = fused_ln_lstm_func_.getType().getResults();

View File

@ -52,7 +52,7 @@ bool TFIntListIsAllOnes(const ArrayAttr &attr);
// Returns true iff the given value is a float tensor. // Returns true iff the given value is a float tensor.
// is "DT_FLOAT". // is "DT_FLOAT".
inline bool TFTypeIsFloatTensor(Value value) { inline bool TFTypeIsFloatTensor(Value value) {
auto tensorType = value->getType().dyn_cast<TensorType>(); auto tensorType = value.getType().dyn_cast<TensorType>();
if (!tensorType) return false; if (!tensorType) return false;
return tensorType.getElementType().isa<FloatType>(); return tensorType.getElementType().isa<FloatType>();
} }

View File

@ -149,17 +149,17 @@ std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) {
return op->getName().getStringRef(); return op->getName().getStringRef();
} }
auto val = op_or_val.dyn_cast<mlir::Value>(); auto val = op_or_val.dyn_cast<mlir::Value>();
auto name_from_loc = GetNameFromLoc(val->getLoc()); auto name_from_loc = GetNameFromLoc(val.getLoc());
if (!name_from_loc.empty()) return name_from_loc; if (!name_from_loc.empty()) return name_from_loc;
// If the location is none of the expected types, then simply use name // If the location is none of the expected types, then simply use name
// generated using the op type. Follow TF convention and append the result // generated using the op type. Follow TF convention and append the result
// index unless 0. // index unless 0.
if (auto result = val->dyn_cast<mlir::OpResult>()) { if (auto result = val.dyn_cast<mlir::OpResult>()) {
if (result->getResultNumber() > 0) if (result.getResultNumber() > 0)
return llvm::formatv("{0}:{1}", return llvm::formatv("{0}:{1}",
result->getOwner()->getName().getStringRef(), result.getOwner()->getName().getStringRef(),
result->getResultNumber()); result.getResultNumber());
return result->getOwner()->getName().getStringRef(); return result.getOwner()->getName().getStringRef();
} }
return ""; return "";
} }

View File

@ -84,17 +84,17 @@ int64_t FindPassthroughArgumentForReturnValue(int64_t return_index,
FuncOp func_op) { FuncOp func_op) {
auto value = auto value =
func_op.getBody().front().getTerminator()->getOperand(return_index); func_op.getBody().front().getTerminator()->getOperand(return_index);
assert(mlir::getElementTypeOrSelf(value->getType()).isa<TF::ResourceType>()); assert(mlir::getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>());
int64_t arg_index = -1; int64_t arg_index = -1;
auto try_parse_arg_index = [&arg_index](Value v) { auto try_parse_arg_index = [&arg_index](Value v) {
auto resource_arg = v->dyn_cast<BlockArgument>(); auto resource_arg = v.dyn_cast<BlockArgument>();
if (resource_arg) arg_index = resource_arg->getArgNumber(); if (resource_arg) arg_index = resource_arg.getArgNumber();
return arg_index; return arg_index;
}; };
while (try_parse_arg_index(value) == -1) { while (try_parse_arg_index(value) == -1) {
auto op = value->getDefiningOp(); auto op = value.getDefiningOp();
assert(op); assert(op);
int64_t res_num = value->cast<OpResult>()->getResultNumber(); int64_t res_num = value.cast<OpResult>().getResultNumber();
if (auto graph = llvm::dyn_cast<tf_executor::GraphOp>(op)) { if (auto graph = llvm::dyn_cast<tf_executor::GraphOp>(op)) {
value = graph.GetFetch().getOperand(res_num); value = graph.GetFetch().getOperand(res_num);
} else if (auto island = llvm::dyn_cast<tf_executor::IslandOp>(op)) { } else if (auto island = llvm::dyn_cast<tf_executor::IslandOp>(op)) {
@ -126,13 +126,13 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
// Before having that, we assume function arguments do not alias each other. // Before having that, we assume function arguments do not alias each other.
int64_t next_unique_id = 0; int64_t next_unique_id = 0;
for (auto arg : func_op.getArguments()) { for (auto arg : func_op.getArguments()) {
if (!mlir::getElementTypeOrSelf(arg->getType()).isa<TF::ResourceType>()) if (!mlir::getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>())
continue; continue;
resource_value_to_ids_[arg].insert(next_unique_id++); resource_value_to_ids_[arg].insert(next_unique_id++);
} }
llvm::StringMap<int64_t> var_handle_name_id_map; llvm::StringMap<int64_t> var_handle_name_id_map;
auto forward_input_to_output = [&](Value operand, Value result) { auto forward_input_to_output = [&](Value operand, Value result) {
if (!mlir::getElementTypeOrSelf(result->getType()).isa<TF::ResourceType>()) if (!mlir::getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>())
return; return;
auto& result_ids = resource_value_to_ids_[result]; auto& result_ids = resource_value_to_ids_[result];
auto operand_it = resource_value_to_ids_.find(operand); auto operand_it = resource_value_to_ids_.find(operand);
@ -161,8 +161,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
// analysis. Inside that block, we can still treat its block arguments as // analysis. Inside that block, we can still treat its block arguments as
// different resources. // different resources.
for (auto arg : replicate.GetBody().getArguments()) { for (auto arg : replicate.GetBody().getArguments()) {
if (mlir::getElementTypeOrSelf(arg->getType()) if (mlir::getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) {
.isa<TF::ResourceType>()) {
resource_value_to_ids_[arg].insert(next_unique_id++); resource_value_to_ids_[arg].insert(next_unique_id++);
} }
} }
@ -171,7 +170,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
// If a result is a passthrough of the body input, use the corresponding // If a result is a passthrough of the body input, use the corresponding
// operand's resource IDs. // operand's resource IDs.
for (auto result : llvm::enumerate(while_op.getResults())) { for (auto result : llvm::enumerate(while_op.getResults())) {
if (!mlir::getElementTypeOrSelf(result.value()->getType()) if (!mlir::getElementTypeOrSelf(result.value().getType())
.isa<TF::ResourceType>()) { .isa<TF::ResourceType>()) {
continue; continue;
} }
@ -192,7 +191,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
// If a result is a passthrough of both branches' inputs, merge the // If a result is a passthrough of both branches' inputs, merge the
// resource IDs of corresponding operands for the two inputs. // resource IDs of corresponding operands for the two inputs.
for (auto result : llvm::enumerate(if_op.getResults())) { for (auto result : llvm::enumerate(if_op.getResults())) {
if (!mlir::getElementTypeOrSelf(result.value()->getType()) if (!mlir::getElementTypeOrSelf(result.value().getType())
.isa<TF::ResourceType>()) { .isa<TF::ResourceType>()) {
continue; continue;
} }
@ -211,7 +210,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
} }
} else { } else {
for (auto result : op->getResults()) { for (auto result : op->getResults()) {
if (!mlir::getElementTypeOrSelf(result->getType()) if (!mlir::getElementTypeOrSelf(result.getType())
.isa<TF::ResourceType>()) .isa<TF::ResourceType>())
continue; continue;
resource_value_to_ids_[result].insert(kUnknownResourceId); resource_value_to_ids_[result].insert(kUnknownResourceId);
@ -253,14 +252,14 @@ llvm::SmallDenseSet<int64_t, 8> FindAccessedResources(
llvm::SmallDenseSet<int64_t, 8> resources; llvm::SmallDenseSet<int64_t, 8> resources;
for (auto operand : op->getOperands()) { for (auto operand : op->getOperands()) {
if (!mlir::getElementTypeOrSelf(operand->getType()).isa<TF::ResourceType>()) if (!mlir::getElementTypeOrSelf(operand.getType()).isa<TF::ResourceType>())
continue; continue;
if (alias_analysis.IsUnknownResource(operand)) return UnknownResourceSet(); if (alias_analysis.IsUnknownResource(operand)) return UnknownResourceSet();
const auto& ids = alias_analysis.GetResourceUniqueIds(operand); const auto& ids = alias_analysis.GetResourceUniqueIds(operand);
resources.insert(ids.begin(), ids.end()); resources.insert(ids.begin(), ids.end());
} }
for (auto result : op->getResults()) { for (auto result : op->getResults()) {
if (!mlir::getElementTypeOrSelf(result->getType()).isa<TF::ResourceType>()) if (!mlir::getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>())
continue; continue;
if (alias_analysis.IsUnknownResource(result)) return UnknownResourceSet(); if (alias_analysis.IsUnknownResource(result)) return UnknownResourceSet();
const auto& ids = alias_analysis.GetResourceUniqueIds(result); const auto& ids = alias_analysis.GetResourceUniqueIds(result);

View File

@ -184,11 +184,11 @@ void Print(ReplicateOp op, OpAsmPrinter* p) {
*p << '('; *p << '(';
Block& block = op.body().front(); Block& block = op.body().front();
interleaveComma(block.getArguments(), *p, [&](BlockArgument arg) { interleaveComma(block.getArguments(), *p, [&](BlockArgument arg) {
const int block_arg_num = arg->getArgNumber(); const int block_arg_num = arg.getArgNumber();
*p << '['; *p << '[';
p->printOperands(std::next(op.operand_begin(), block_arg_num * n), p->printOperands(std::next(op.operand_begin(), block_arg_num * n),
std::next(op.operand_begin(), (block_arg_num + 1) * n)); std::next(op.operand_begin(), (block_arg_num + 1) * n));
*p << "] as " << *arg << ": " << arg->getType(); *p << "] as " << arg << ": " << arg.getType();
}); });
*p << ')'; *p << ')';
} }
@ -229,13 +229,13 @@ LogicalResult Verify(ReplicateOp op) {
// Check replicated input types match block argument types. // Check replicated input types match block argument types.
for (auto block_arg : block.getArguments()) { for (auto block_arg : block.getArguments()) {
Type block_arg_type = block_arg->getType(); Type block_arg_type = block_arg.getType();
for (int i = n * block_arg->getArgNumber(), e = i + n; i < e; ++i) for (int i = n * block_arg.getArgNumber(), e = i + n; i < e; ++i)
if (failed(VerifyCompatibleTypes(block_arg_type, if (failed(VerifyCompatibleTypes(block_arg_type,
op.getOperand(i)->getType()))) op.getOperand(i).getType())))
return op.emitOpError() return op.emitOpError()
<< "incompatible types for operand " << i << "incompatible types for operand " << i
<< " and block argument " << block_arg->getArgNumber(); << " and block argument " << block_arg.getArgNumber();
} }
Operation& terminator = block.back(); Operation& terminator = block.back();
@ -282,7 +282,7 @@ void BuildReplicateOp(
DCHECK_EQ(llvm::size(replicated_input.first), n); DCHECK_EQ(llvm::size(replicated_input.first), n);
for (auto input : replicated_input.first) { for (auto input : replicated_input.first) {
DCHECK(succeeded( DCHECK(succeeded(
VerifyCompatibleTypes(input->getType(), replicated_input.second))); VerifyCompatibleTypes(input.getType(), replicated_input.second)));
state->addOperands(input); state->addOperands(input);
} }
block.addArgument(replicated_input.second); block.addArgument(replicated_input.second);

View File

@ -167,7 +167,7 @@ namespace {
LogicalResult VerifyControlOperandsAfterAllData(Operation *op) { LogicalResult VerifyControlOperandsAfterAllData(Operation *op) {
bool found_control = false; bool found_control = false;
for (int operand_idx : llvm::seq<int>(0, op->getNumOperands())) { for (int operand_idx : llvm::seq<int>(0, op->getNumOperands())) {
if (op->getOperand(operand_idx)->getType().isa<ControlType>()) { if (op->getOperand(operand_idx).getType().isa<ControlType>()) {
found_control = true; found_control = true;
continue; continue;
} }
@ -218,7 +218,7 @@ LogicalResult Verify(GraphOp graph) {
for (int i : llvm::seq<int>(0, fetch.getNumOperands())) { for (int i : llvm::seq<int>(0, fetch.getNumOperands())) {
Value operand = fetch.getOperand(i); Value operand = fetch.getOperand(i);
// Break out of the loop at the first control operand encountered. // Break out of the loop at the first control operand encountered.
if (operand->getType().isa<ControlType>()) { if (operand.getType().isa<ControlType>()) {
if (i != graph.getNumResults()) if (i != graph.getNumResults())
return fetch.emitOpError() return fetch.emitOpError()
<< "operand #" << i << "operand #" << i
@ -228,7 +228,7 @@ LogicalResult Verify(GraphOp graph) {
if (i >= graph.getNumResults()) if (i >= graph.getNumResults())
return fetch.emitOpError() return fetch.emitOpError()
<< "operand #" << i << " does not have a graph results to bind"; << "operand #" << i << " does not have a graph results to bind";
if (graph.getResult(i)->getType() != operand->getType()) if (graph.getResult(i).getType() != operand.getType())
return fetch.emitOpError() return fetch.emitOpError()
<< "operand #" << i << " type mismatch graph results"; << "operand #" << i << " type mismatch graph results";
} }
@ -331,8 +331,8 @@ LogicalResult Verify(IslandOp island) {
<< "has " << yield.getNumOperands() << "has " << yield.getNumOperands()
<< " operand, but island returns " << result_count; << " operand, but island returns " << result_count;
for (int operand_idx : llvm::seq<int>(0, yield.getNumOperands())) { for (int operand_idx : llvm::seq<int>(0, yield.getNumOperands())) {
if (island.getResult(operand_idx)->getType() != if (island.getResult(operand_idx).getType() !=
yield.getOperand(operand_idx)->getType()) yield.getOperand(operand_idx).getType())
return yield.emitOpError() return yield.emitOpError()
<< "operand #" << operand_idx << " type mismatch island results"; << "operand #" << operand_idx << " type mismatch island results";
} }
@ -340,7 +340,7 @@ LogicalResult Verify(IslandOp island) {
// Check that there aren't any control results other than the last one. // Check that there aren't any control results other than the last one.
Type control_type = ControlType::get(island.getContext()); Type control_type = ControlType::get(island.getContext());
for (int operand_idx : llvm::seq<int>(0, island.getNumResults() - 1)) { for (int operand_idx : llvm::seq<int>(0, island.getNumResults() - 1)) {
if (island.getResult(operand_idx)->getType() == control_type) if (island.getResult(operand_idx).getType() == control_type)
return yield.emitOpError() return yield.emitOpError()
<< "unexpected control type for operand #" << operand_idx; << "unexpected control type for operand #" << operand_idx;
} }
@ -503,12 +503,12 @@ ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) {
void Print(SwitchOp switch_op, OpAsmPrinter &p) { void Print(SwitchOp switch_op, OpAsmPrinter &p) {
p << switch_op.getOperationName() << ' '; p << switch_op.getOperationName() << ' ';
p.printOperands(switch_op.getOperands()); p.printOperands(switch_op.getOperands());
Type data_operand_ty = switch_op.data()->getType(); Type data_operand_ty = switch_op.data().getType();
// If the types aren't perfectly matching, print the functional type syntax // If the types aren't perfectly matching, print the functional type syntax
// else print the shorter single type. // else print the shorter single type.
p << " : "; p << " : ";
if (switch_op.trueOutput()->getType() != data_operand_ty || if (switch_op.trueOutput().getType() != data_operand_ty ||
switch_op.falseOutput()->getType() != data_operand_ty) { switch_op.falseOutput().getType() != data_operand_ty) {
p.printFunctionalType(switch_op.getOperation()); p.printFunctionalType(switch_op.getOperation());
} else { } else {
p << switch_op.getType(0); p << switch_op.getType(0);
@ -535,12 +535,12 @@ LogicalResult Verify(SwitchNOp switchn) {
<< "expect `num_outs` (" << num_outs.getInt() << ") results but got " << "expect `num_outs` (" << num_outs.getInt() << ") results but got "
<< (switchn.getNumResults() - 1); << (switchn.getNumResults() - 1);
auto operand0_type = switchn.getOperand(0)->getType(); auto operand0_type = switchn.getOperand(0).getType();
for (Value result : switchn.outputs()) for (Value result : switchn.outputs())
if (operand0_type != result->getType()) if (operand0_type != result.getType())
return switchn.emitOpError() return switchn.emitOpError()
<< "type mismatch between data operand and result: " << "type mismatch between data operand and result: "
<< operand0_type << " vs " << result->getType(); << operand0_type << " vs " << result.getType();
return success(); return success();
} }
@ -616,12 +616,12 @@ LogicalResult Verify(MergeOp merge) {
if (!merge.getNumOperands()) if (!merge.getNumOperands())
return merge.emitOpError() << "expects at least one operand"; return merge.emitOpError() << "expects at least one operand";
Type data_type = merge.getOperand(0)->getType(); Type data_type = merge.getOperand(0).getType();
if (data_type.isa<ControlType>()) if (data_type.isa<ControlType>())
return merge.emitOpError() << "expects a non-control input"; return merge.emitOpError() << "expects a non-control input";
// Check that each operand can be individually broadcasted to the output type. // Check that each operand can be individually broadcasted to the output type.
Type output_type = merge.output()->getType(); Type output_type = merge.output().getType();
TensorType output_tensor_ty = output_type.dyn_cast<TensorType>(); TensorType output_tensor_ty = output_type.dyn_cast<TensorType>();
if (!output_tensor_ty) { if (!output_tensor_ty) {
return merge.emitOpError() return merge.emitOpError()
@ -666,7 +666,7 @@ void Print(MergeOp merge, OpAsmPrinter &p) {
bool use_short_form = true; bool use_short_form = true;
int num_data_operands = 0; int num_data_operands = 0;
Type output_type = merge.output()->getType(); Type output_type = merge.output().getType();
for (Type operand_type : merge.getOperandTypes()) { for (Type operand_type : merge.getOperandTypes()) {
if (operand_type.isa<ControlType>()) break; if (operand_type.isa<ControlType>()) break;
num_data_operands++; num_data_operands++;
@ -750,7 +750,7 @@ void Print(EnterOp enter, OpAsmPrinter &p) {
// If the types aren't perfectly matching, print the functional type syntax // If the types aren't perfectly matching, print the functional type syntax
// else print the shorter single type. // else print the shorter single type.
p << " : "; p << " : ";
if (enter.data()->getType() != enter.output()->getType()) { if (enter.data().getType() != enter.output().getType()) {
p.printFunctionalType(enter.getOperation()); p.printFunctionalType(enter.getOperation());
} else { } else {
p << enter.getType(0); p << enter.getType(0);
@ -825,9 +825,9 @@ namespace {
LogicalResult Verify(NextIterationSourceOp source) { LogicalResult Verify(NextIterationSourceOp source) {
Value token = source.token(); Value token = source.token();
if (!token->hasOneUse()) if (!token.hasOneUse())
return source.emitOpError() << "expects a single user for produced token"; return source.emitOpError() << "expects a single user for produced token";
if (!isa<NextIterationSinkOp>(*token->user_begin())) if (!isa<NextIterationSinkOp>(*token.user_begin()))
return source.emitOpError() << "token should be consumed by a sink op"; return source.emitOpError() << "token should be consumed by a sink op";
return success(); return success();
} }
@ -859,7 +859,7 @@ namespace {
LogicalResult Verify(NextIterationSinkOp sink) { LogicalResult Verify(NextIterationSinkOp sink) {
Value token = sink.token(); Value token = sink.token();
Operation *definingOp = token->getDefiningOp(); Operation *definingOp = token.getDefiningOp();
if (!definingOp) if (!definingOp)
return sink.emitOpError() << "expects a token directly produced by a " return sink.emitOpError() << "expects a token directly produced by a "
"tf_executor.NextIteration.Source op: "; "tf_executor.NextIteration.Source op: ";
@ -867,11 +867,11 @@ LogicalResult Verify(NextIterationSinkOp sink) {
if (!source) if (!source)
return sink.emitOpError() << "expects a token produced by a " return sink.emitOpError() << "expects a token produced by a "
"tf_executor.NextIteration.Source op: "; "tf_executor.NextIteration.Source op: ";
if (source.output()->getType() != sink.input()->getType()) if (source.output().getType() != sink.input().getType())
return sink.emitOpError() return sink.emitOpError()
<< "input type " << sink.input()->getType() << "input type " << sink.input().getType()
<< " mismatch the tf_executor.NextIteration.Source output type: " << " mismatch the tf_executor.NextIteration.Source output type: "
<< source.output()->getType(); << source.output().getType();
return success(); return success();
} }
@ -880,7 +880,7 @@ void Print(NextIterationSinkOp next_iteration, OpAsmPrinter &p) {
p.printOperand(next_iteration.getOperand(0)); p.printOperand(next_iteration.getOperand(0));
p << "] "; p << "] ";
p.printOperands(llvm::drop_begin(next_iteration.getOperands(), 1)); p.printOperands(llvm::drop_begin(next_iteration.getOperands(), 1));
p << " : " << next_iteration.getOperand(1)->getType(); p << " : " << next_iteration.getOperand(1).getType();
p.printOptionalAttrDict(next_iteration.getAttrs()); p.printOptionalAttrDict(next_iteration.getAttrs());
} }
@ -980,11 +980,11 @@ void Print(LoopCondOp loop_cond, OpAsmPrinter &p) {
p.printOperands(loop_cond.getOperands()); p.printOperands(loop_cond.getOperands());
// If the types aren't matching (broadcast), print the functional type syntax. // If the types aren't matching (broadcast), print the functional type syntax.
if (loop_cond.input()->getType() != loop_cond.output()->getType()) { if (loop_cond.input().getType() != loop_cond.output().getType()) {
p << " : "; p << " : ";
p.printFunctionalType(loop_cond.getOperation()); p.printFunctionalType(loop_cond.getOperation());
} else { } else {
p << " : " << loop_cond.input()->getType(); p << " : " << loop_cond.input().getType();
} }
p.printOptionalAttrDict(loop_cond.getAttrs()); p.printOptionalAttrDict(loop_cond.getAttrs());
@ -1090,15 +1090,15 @@ struct HoistInnerOpsSingleIslandGraph : public OpRewritePattern<GraphOp> {
llvm::SmallVector<Value, 8> new_rets; llvm::SmallVector<Value, 8> new_rets;
for (Value operand : fetch_op.fetches()) { for (Value operand : fetch_op.fetches()) {
// Control results should not be propagated out. // Control results should not be propagated out.
if (operand->getType().isa<ControlType>()) break; if (operand.getType().isa<ControlType>()) break;
if (operand->getDefiningOp() != island_op) { if (operand.getDefiningOp() != island_op) {
// Operand is not from island, simply propagate it out. // Operand is not from island, simply propagate it out.
new_rets.push_back(operand); new_rets.push_back(operand);
} else { } else {
// Lookup yield operand in island for inner op result. // Lookup yield operand in island for inner op result.
auto result = operand->cast<OpResult>(); auto result = operand.cast<OpResult>();
new_rets.push_back(yield_op.getOperand(result->getResultNumber())); new_rets.push_back(yield_op.getOperand(result.getResultNumber()));
} }
} }
@ -1138,7 +1138,7 @@ struct DropEmptyIslandNoOperandNoDataResult
!HasSingleOpInBlock<YieldOp>(&op.GetBody())) !HasSingleOpInBlock<YieldOp>(&op.GetBody()))
return matchFailure(); return matchFailure();
for (auto &use : llvm::make_early_inc_range(op.control()->getUses())) for (auto &use : llvm::make_early_inc_range(op.control().getUses()))
use.getOwner()->eraseOperand(use.getOperandNumber()); use.getOwner()->eraseOperand(use.getOperandNumber());
rewriter.eraseOp(op); rewriter.eraseOp(op);
@ -1158,7 +1158,7 @@ struct DropEmptyIslandNoOperandOneDataResult
PatternMatchResult matchAndRewrite(IslandOp op, PatternMatchResult matchAndRewrite(IslandOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
if (op.getNumOperands() != 0 || op.getNumResults() != 2 || if (op.getNumOperands() != 0 || op.getNumResults() != 2 ||
!op.control()->use_empty() || !op.control().use_empty() ||
!HasSingleOpInBlock<YieldOp>(&op.GetBody())) !HasSingleOpInBlock<YieldOp>(&op.GetBody()))
return matchFailure(); return matchFailure();
@ -1193,7 +1193,7 @@ struct DropEmptyControlTrigger : public OpRewritePattern<ControlTriggerOp> {
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
if (op.getNumOperands() != 0) return matchFailure(); if (op.getNumOperands() != 0) return matchFailure();
for (auto &use : llvm::make_early_inc_range(op.control()->getUses())) for (auto &use : llvm::make_early_inc_range(op.control().getUses()))
use.getOwner()->eraseOperand(use.getOperandNumber()); use.getOwner()->eraseOperand(use.getOperandNumber());
rewriter.eraseOp(op); rewriter.eraseOp(op);

View File

@ -460,7 +460,7 @@ def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source",
let extraClassDeclaration = [{ let extraClassDeclaration = [{
NextIterationSinkOp GetSink() { NextIterationSinkOp GetSink() {
return cast<NextIterationSinkOp>(*token()->user_begin()); return cast<NextIterationSinkOp>(*token().user_begin());
} }
}]; }];

View File

@ -302,7 +302,7 @@ class WithBroadcastableBinOpBuilder {
"Builder *builder, OperationState &result, Value x, Value y", "Builder *builder, OperationState &result, Value x, Value y",
[{ [{
auto resultType = auto resultType =
OpTrait::util::getBroadcastedType(x->getType(), y->getType()); OpTrait::util::getBroadcastedType(x.getType(), y.getType());
if (!resultType) if (!resultType)
mlir::emitError(result.location, "non-broadcastable operands"); mlir::emitError(result.location, "non-broadcastable operands");
return build(builder, result, resultType, x, y); return build(builder, result, resultType, x, y);
@ -317,14 +317,14 @@ class WithBroadcastableCmpOpBuilder {
"Builder *builder, OperationState &result, Value x, Value y", "Builder *builder, OperationState &result, Value x, Value y",
[{ [{
Type resultType; Type resultType;
if (x->getType().isa<UnrankedTensorType>() || if (x.getType().isa<UnrankedTensorType>() ||
y->getType().isa<UnrankedTensorType>()) { y.getType().isa<UnrankedTensorType>()) {
resultType = UnrankedTensorType::get(builder->getI1Type()); resultType = UnrankedTensorType::get(builder->getI1Type());
} else { } else {
SmallVector<int64_t, 4> resultShape; SmallVector<int64_t, 4> resultShape;
if (!OpTrait::util::getBroadcastedShape( if (!OpTrait::util::getBroadcastedShape(
x->getType().cast<ShapedType>().getShape(), x.getType().cast<ShapedType>().getShape(),
y->getType().cast<ShapedType>().getShape(), resultShape)) { y.getType().cast<ShapedType>().getShape(), resultShape)) {
mlir::emitError(result.location, mlir::emitError(result.location,
"operands have no broadcastable shapes"); "operands have no broadcastable shapes");
} }

View File

@ -77,7 +77,7 @@ static RankedTensorType GetRankedTensorTypeForOperand(Value operand) {
if (matchPattern(operand, m_Constant(&attr))) { if (matchPattern(operand, m_Constant(&attr))) {
return attr.getType().dyn_cast<RankedTensorType>(); return attr.getType().dyn_cast<RankedTensorType>();
} }
return operand->getType().dyn_cast<RankedTensorType>(); return operand.getType().dyn_cast<RankedTensorType>();
} }
// Returns true if the given `value` is of ranked float tensor type with the // Returns true if the given `value` is of ranked float tensor type with the
@ -161,7 +161,7 @@ static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x, static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x,
Value y, BoolAttr incompatible_shape_error) { Value y, BoolAttr incompatible_shape_error) {
auto result_type = auto result_type =
OpTrait::util::getBroadcastedType(x->getType(), y->getType()); OpTrait::util::getBroadcastedType(x.getType(), y.getType());
if (!result_type) { if (!result_type) {
if (incompatible_shape_error.getValue()) { if (incompatible_shape_error.getValue()) {
mlir::emitError(loc, "non-broadcastable operands"); mlir::emitError(loc, "non-broadcastable operands");
@ -187,7 +187,7 @@ static int64_t GetDimForAxis(int64_t axis, int64_t rank) {
// inference functions. // inference functions.
static Type InferReductionOpType(Value input, Value reduction_indices, static Type InferReductionOpType(Value input, Value reduction_indices,
BoolAttr keep_dims, Builder *builder) { BoolAttr keep_dims, Builder *builder) {
Type input_ty = input->getType(); Type input_ty = input.getType();
Type element_ty = getElementTypeOrSelf(input_ty); Type element_ty = getElementTypeOrSelf(input_ty);
// Output type is unranked if input type is not ranked. // Output type is unranked if input type is not ranked.
@ -330,12 +330,12 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
// Verifies an reduction op's `input` and reduction `dims`. // Verifies an reduction op's `input` and reduction `dims`.
static LogicalResult VerifyReductionInputAndDims(Value input, Value dims, static LogicalResult VerifyReductionInputAndDims(Value input, Value dims,
Location loc) { Location loc) {
auto dims_type = dims->getType().dyn_cast<RankedTensorType>(); auto dims_type = dims.getType().dyn_cast<RankedTensorType>();
if (!dims_type) return success(); if (!dims_type) return success();
if (dims_type.getRank() > 1) if (dims_type.getRank() > 1)
return emitError(loc, "dimensions can only be 0D or 1D tensor"); return emitError(loc, "dimensions can only be 0D or 1D tensor");
auto input_type = input->getType().dyn_cast<RankedTensorType>(); auto input_type = input.getType().dyn_cast<RankedTensorType>();
if (!input_type) return success(); if (!input_type) return success();
int64_t rank = input_type.getRank(); int64_t rank = input_type.getRank();
@ -441,9 +441,8 @@ static LogicalResult Verify(BiasAddOp op) {
if (!IsOfRankOrUnranked(op.bias(), 1)) if (!IsOfRankOrUnranked(op.bias(), 1))
return op.emitOpError("requires bias operand to have rank exactly one"); return op.emitOpError("requires bias operand to have rank exactly one");
RankedTensorType value_ty = RankedTensorType value_ty = op.value().getType().dyn_cast<RankedTensorType>();
op.value()->getType().dyn_cast<RankedTensorType>(); RankedTensorType bias_ty = op.bias().getType().dyn_cast<RankedTensorType>();
RankedTensorType bias_ty = op.bias()->getType().dyn_cast<RankedTensorType>();
if (!bias_ty || !value_ty) return success(); if (!bias_ty || !value_ty) return success();
// TODO(hinsu): Leverage tensor_format.h utility in TensorFlow to compute // TODO(hinsu): Leverage tensor_format.h utility in TensorFlow to compute
@ -552,7 +551,7 @@ static LogicalResult Verify(ConcatOffsetOp op) {
<< "requires sizes of shapes and offsets to be the same, got sizes " << "requires sizes of shapes and offsets to be the same, got sizes "
<< op.shape().size() << " and " << op.offset().size(); << op.shape().size() << " and " << op.offset().size();
auto ranked_dim = op.concat_dim()->getType().dyn_cast<RankedTensorType>(); auto ranked_dim = op.concat_dim().getType().dyn_cast<RankedTensorType>();
if (ranked_dim && ranked_dim.getRank() != 0) if (ranked_dim && ranked_dim.getRank() != 0)
return op.emitOpError() return op.emitOpError()
<< "requires concat_dim to be a scalar, got tensor of rank " << "requires concat_dim to be a scalar, got tensor of rank "
@ -565,11 +564,11 @@ static LogicalResult Verify(ConcatOffsetOp op) {
Value offset = std::get<1>(shape_offset_idx.value()); Value offset = std::get<1>(shape_offset_idx.value());
const size_t idx = shape_offset_idx.index(); const size_t idx = shape_offset_idx.index();
if (failed(verifyCompatibleShape(shape->getType(), offset->getType()))) if (failed(verifyCompatibleShape(shape.getType(), offset.getType())))
return op.emitOpError() << "requires operand and result " << idx return op.emitOpError() << "requires operand and result " << idx
<< " to have compatible shapes"; << " to have compatible shapes";
auto ranked_shape = shape->getType().dyn_cast<RankedTensorType>(); auto ranked_shape = shape.getType().dyn_cast<RankedTensorType>();
if (!ranked_shape) continue; if (!ranked_shape) continue;
if (ranked_shape.getRank() != 1) if (ranked_shape.getRank() != 1)
@ -786,7 +785,7 @@ static LogicalResult Verify(OpT op) {
} }
int64_t input_channels = -1; int64_t input_channels = -1;
if (auto ty = op.input()->getType().template dyn_cast<RankedTensorType>()) { if (auto ty = op.input().getType().template dyn_cast<RankedTensorType>()) {
std::string data_format = op.data_format().str(); std::string data_format = op.data_format().str();
tensorflow::TensorFormat format; tensorflow::TensorFormat format;
auto is_valid = FormatFromString(data_format, &format); auto is_valid = FormatFromString(data_format, &format);
@ -796,7 +795,7 @@ static LogicalResult Verify(OpT op) {
} }
int64_t filter_channels = -1; int64_t filter_channels = -1;
if (auto ty = op.filter()->getType().template dyn_cast<RankedTensorType>()) { if (auto ty = op.filter().getType().template dyn_cast<RankedTensorType>()) {
int idx = tensorflow::GetFilterTensorInputChannelsDimIndex( int idx = tensorflow::GetFilterTensorInputChannelsDimIndex(
num_dims, tensorflow::FORMAT_HWIO); num_dims, tensorflow::FORMAT_HWIO);
filter_channels = ty.getDimSize(idx); filter_channels = ty.getDimSize(idx);
@ -876,8 +875,8 @@ static LogicalResult Verify(DynamicStitchOp op) {
} }
Value data = std::get<1>(it); Value data = std::get<1>(it);
RankedTensorType index_ty = index->getType().dyn_cast<RankedTensorType>(); RankedTensorType index_ty = index.getType().dyn_cast<RankedTensorType>();
RankedTensorType data_ty = data->getType().dyn_cast<RankedTensorType>(); RankedTensorType data_ty = data.getType().dyn_cast<RankedTensorType>();
if (!index_ty || !data_ty) continue; if (!index_ty || !data_ty) continue;
int64_t index_rank = index_ty.getRank(); int64_t index_rank = index_ty.getRank();
@ -993,10 +992,10 @@ void EqualOp::build(Builder *builder, OperationState &result, Value x, Value y,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
Type InferExpandDimsOpType(Value input, Value dim) { Type InferExpandDimsOpType(Value input, Value dim) {
Type element_ty = input->getType().cast<TensorType>().getElementType(); Type element_ty = input.getType().cast<TensorType>().getElementType();
auto unranked_ty = UnrankedTensorType::get(element_ty); auto unranked_ty = UnrankedTensorType::get(element_ty);
auto input_ty = input->getType().dyn_cast<RankedTensorType>(); auto input_ty = input.getType().dyn_cast<RankedTensorType>();
if (!input_ty) return unranked_ty; if (!input_ty) return unranked_ty;
DenseIntElementsAttr dim_attr; DenseIntElementsAttr dim_attr;
@ -1076,14 +1075,14 @@ static LogicalResult Verify(FakeQuantWithMinMaxVarsPerChannelOp op) {
Value inputs = op.inputs(); Value inputs = op.inputs();
if (!HasRankAtLeast(inputs, 1) || if (!HasRankAtLeast(inputs, 1) ||
inputs->getType().isa<UnrankedTensorType>()) { inputs.getType().isa<UnrankedTensorType>()) {
return op.emitError("requires inputs to be at least 1d float tensor"); return op.emitError("requires inputs to be at least 1d float tensor");
} }
auto inputsType = inputs->getType().cast<ShapedType>(); auto inputsType = inputs.getType().cast<ShapedType>();
int depth = inputsType.getDimSize(inputsType.getRank() - 1); int depth = inputsType.getDimSize(inputsType.getRank() - 1);
if (op.min()->getType().cast<ShapedType>().getDimSize(0) != depth || if (op.min().getType().cast<ShapedType>().getDimSize(0) != depth ||
op.max()->getType().cast<ShapedType>().getDimSize(0) != depth) { op.max().getType().cast<ShapedType>().getDimSize(0) != depth) {
return op.emitOpError( return op.emitOpError(
"requires min and max to have same size as last dimension of inputs"); "requires min and max to have same size as last dimension of inputs");
} }
@ -1139,7 +1138,7 @@ static LogicalResult Verify(FusedBatchNormOp op) {
static LogicalResult Verify(GatherV2Op op) { static LogicalResult Verify(GatherV2Op op) {
int64_t batch_dims = op.batch_dims().getSExtValue(); int64_t batch_dims = op.batch_dims().getSExtValue();
if (auto ty = op.indices()->getType().dyn_cast<RankedTensorType>()) { if (auto ty = op.indices().getType().dyn_cast<RankedTensorType>()) {
int64_t rank = ty.getRank(); int64_t rank = ty.getRank();
if (batch_dims > rank || batch_dims < -rank) if (batch_dims > rank || batch_dims < -rank)
return op.emitOpError() return op.emitOpError()
@ -1154,7 +1153,7 @@ static LogicalResult Verify(GatherV2Op op) {
DenseIntElementsAttr axis_attr; DenseIntElementsAttr axis_attr;
if (matchPattern(op.axis(), m_Constant(&axis_attr))) { if (matchPattern(op.axis(), m_Constant(&axis_attr))) {
int64_t axis = (*axis_attr.begin()).getSExtValue(); int64_t axis = (*axis_attr.begin()).getSExtValue();
if (auto ty = op.params()->getType().dyn_cast<RankedTensorType>()) { if (auto ty = op.params().getType().dyn_cast<RankedTensorType>()) {
int64_t rank = ty.getRank(); int64_t rank = ty.getRank();
if (axis >= rank || axis < -rank) if (axis >= rank || axis < -rank)
return op.emitOpError() << "axis (" << axis << ") must be in range [" return op.emitOpError() << "axis (" << axis << ") must be in range ["
@ -1197,7 +1196,7 @@ static LogicalResult Verify(IfOp op) {
" inputs"); " inputs");
for (unsigned i = 0; i < expectedNumInputs; ++i) { for (unsigned i = 0; i < expectedNumInputs; ++i) {
auto operandType = op.getOperand(i + 1)->getType().cast<TensorType>(); auto operandType = op.getOperand(i + 1).getType().cast<TensorType>();
auto thenInputType = thenFuncType.getInput(i).cast<TensorType>(); auto thenInputType = thenFuncType.getInput(i).cast<TensorType>();
if (!AreCastCompatible(operandType, thenInputType)) if (!AreCastCompatible(operandType, thenInputType))
return op.emitError( return op.emitError(
@ -1228,7 +1227,7 @@ static LogicalResult Verify(IfOp op) {
" results"); " results");
for (unsigned i = 0; i < expectedNumResults; ++i) { for (unsigned i = 0; i < expectedNumResults; ++i) {
auto resultType = op.getResult(i)->getType().cast<TensorType>(); auto resultType = op.getResult(i).getType().cast<TensorType>();
auto thenResultType = thenFuncType.getResult(i).cast<TensorType>(); auto thenResultType = thenFuncType.getResult(i).cast<TensorType>();
if (!AreCastCompatible(thenResultType, resultType)) if (!AreCastCompatible(thenResultType, resultType))
return op.emitError( return op.emitError(
@ -1364,7 +1363,7 @@ void NotEqualOp::build(Builder *builder, OperationState &result, Value x,
static LogicalResult Verify(OneHotOp op) { static LogicalResult Verify(OneHotOp op) {
int64_t axis = op.axis().getSExtValue(); int64_t axis = op.axis().getSExtValue();
auto indices_ty = op.indices()->getType().dyn_cast<RankedTensorType>(); auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
if (indices_ty && if (indices_ty &&
!(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) { !(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) {
return op.emitOpError() return op.emitOpError()
@ -1403,11 +1402,11 @@ static LogicalResult Verify(OneHotOp op) {
static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value, static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value,
Value off_value, IntegerAttr axis) { Value off_value, IntegerAttr axis) {
int64_t axis_val = axis.getInt(); int64_t axis_val = axis.getInt();
Type element_ty = on_value->getType().cast<TensorType>().getElementType(); Type element_ty = on_value.getType().cast<TensorType>().getElementType();
auto unranked_ty = UnrankedTensorType::get(element_ty); auto unranked_ty = UnrankedTensorType::get(element_ty);
if (axis_val < -1) return unranked_ty; if (axis_val < -1) return unranked_ty;
auto indices_ty = indices->getType().dyn_cast<RankedTensorType>(); auto indices_ty = indices.getType().dyn_cast<RankedTensorType>();
if (!indices_ty) return unranked_ty; if (!indices_ty) return unranked_ty;
auto shape = llvm::to_vector<2>(indices_ty.getShape()); auto shape = llvm::to_vector<2>(indices_ty.getShape());
@ -1446,7 +1445,7 @@ static LogicalResult Verify(PackOp op) {
int64_t inputs_rank = -1; int64_t inputs_rank = -1;
for (Value value : values) { for (Value value : values) {
if (auto ty = value->getType().dyn_cast<RankedTensorType>()) { if (auto ty = value.getType().dyn_cast<RankedTensorType>()) {
// Exit early as input types are verified to be compatible so all ranked // Exit early as input types are verified to be compatible so all ranked
// tensors have the same rank. // tensors have the same rank.
inputs_rank = ty.getRank(); inputs_rank = ty.getRank();
@ -1548,8 +1547,8 @@ static LogicalResult Verify(RandomUniformOp op) {
void RangeOp::build(Builder *builder, OperationState &result, Value start, void RangeOp::build(Builder *builder, OperationState &result, Value start,
Value limit, Value delta) { Value limit, Value delta) {
assert(start->getType() == limit->getType()); assert(start.getType() == limit.getType());
assert(start->getType() == delta->getType()); assert(start.getType() == delta.getType());
DenseIntElementsAttr start_val; DenseIntElementsAttr start_val;
DenseIntElementsAttr limit_val; DenseIntElementsAttr limit_val;
DenseIntElementsAttr delta_val; DenseIntElementsAttr delta_val;
@ -1563,13 +1562,13 @@ void RangeOp::build(Builder *builder, OperationState &result, Value start,
builder, result, builder, result,
RankedTensorType::get( RankedTensorType::get(
size.getSExtValue(), size.getSExtValue(),
start->getType().cast<TensorType>().getElementType()), start.getType().cast<TensorType>().getElementType()),
start, limit, delta); start, limit, delta);
} }
return RangeOp::build( return RangeOp::build(
builder, result, builder, result,
RankedTensorType::get( RankedTensorType::get(
{-1}, start->getType().cast<TensorType>().getElementType()), {-1}, start.getType().cast<TensorType>().getElementType()),
start, limit, delta); start, limit, delta);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1598,17 +1597,17 @@ void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
// TODO(b/128020684): Verify the rank of the output and change to use // TODO(b/128020684): Verify the rank of the output and change to use
// m_Constant. // m_Constant.
static LogicalResult Verify(ReshapeOp op) { static LogicalResult Verify(ReshapeOp op) {
auto shapeType = op.shape()->getType().cast<TensorType>(); auto shapeType = op.shape().getType().cast<TensorType>();
if (!shapeType.hasRank()) return success(); if (!shapeType.hasRank()) return success();
if (shapeType.getRank() != 1) if (shapeType.getRank() != 1)
return op.emitOpError("shape must be 1D tensor"); return op.emitOpError("shape must be 1D tensor");
auto rankByShape = shapeType.getShape()[0]; auto rankByShape = shapeType.getShape()[0];
auto typeOfTensor = op.tensor()->getType().cast<TensorType>(); auto typeOfTensor = op.tensor().getType().cast<TensorType>();
// No compile time verification for unknown sized shape. // No compile time verification for unknown sized shape.
if (rankByShape == -1 || !typeOfTensor.hasStaticShape()) return success(); if (rankByShape == -1 || !typeOfTensor.hasStaticShape()) return success();
// Check values if constant shape. No compiling time verification for // Check values if constant shape. No compiling time verification for
// non-constant shape. // non-constant shape.
auto *shapeOp = op.shape()->getDefiningOp(); auto *shapeOp = op.shape().getDefiningOp();
if (!shapeOp) return success(); if (!shapeOp) return success();
Attribute shapeCst; Attribute shapeCst;
if (auto shapeStdOp = dyn_cast<ConstantOp>(shapeOp)) { if (auto shapeStdOp = dyn_cast<ConstantOp>(shapeOp)) {
@ -1662,7 +1661,7 @@ static LogicalResult Verify(ReshapeOp op) {
void ReshapeOp::build(Builder *builder, OperationState &result, Value tensor, void ReshapeOp::build(Builder *builder, OperationState &result, Value tensor,
Value shape) { Value shape) {
auto ttype = tensor->getType().cast<ShapedType>(); auto ttype = tensor.getType().cast<ShapedType>();
auto etype = ttype.getElementType(); auto etype = ttype.getElementType();
auto unranked = [builder, etype, &result, shape, tensor]() { auto unranked = [builder, etype, &result, shape, tensor]() {
@ -1723,14 +1722,14 @@ void ReshapeOp::build(Builder *builder, OperationState &result, Value tensor,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static Type InferSelectV2OpType(Value condition, Value e, Value t) { static Type InferSelectV2OpType(Value condition, Value e, Value t) {
Type element_ty = e->getType().cast<TensorType>().getElementType(); Type element_ty = e.getType().cast<TensorType>().getElementType();
auto unranked_ty = UnrankedTensorType::get(element_ty); auto unranked_ty = UnrankedTensorType::get(element_ty);
Type broadcasted_ty = Type broadcasted_ty =
OpTrait::util::getBroadcastedType(e->getType(), t->getType()); OpTrait::util::getBroadcastedType(e.getType(), t.getType());
if (!broadcasted_ty) return unranked_ty; if (!broadcasted_ty) return unranked_ty;
auto cond_ranked_ty = condition->getType().dyn_cast<RankedTensorType>(); auto cond_ranked_ty = condition.getType().dyn_cast<RankedTensorType>();
auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast<RankedTensorType>(); auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast<RankedTensorType>();
if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty; if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty;
@ -1791,7 +1790,7 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
} // anonymous namespace } // anonymous namespace
static LogicalResult Verify(ShapeOp op) { static LogicalResult Verify(ShapeOp op) {
return VerifyShapeOperandAndResult(op, op.input()->getType(), op.getType()); return VerifyShapeOperandAndResult(op, op.input().getType(), op.getType());
} }
// Converts shape of the given type to attribute if it is of ranked tensor type. // Converts shape of the given type to attribute if it is of ranked tensor type.
@ -1816,12 +1815,12 @@ static Attribute ConvertShapeToAttr(Type input_ty, int out_width) {
OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
int width = int width =
getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth(); getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
return ConvertShapeToAttr(getOperand()->getType(), width); return ConvertShapeToAttr(getOperand().getType(), width);
} }
void ShapeOp::build(Builder *builder, OperationState &result, Value input, void ShapeOp::build(Builder *builder, OperationState &result, Value input,
BoolAttr use32Bit) { BoolAttr use32Bit) {
auto rankedTensorType = input->getType().dyn_cast<RankedTensorType>(); auto rankedTensorType = input.getType().dyn_cast<RankedTensorType>();
int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1; int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1;
auto out_type = use32Bit.getValue() ? builder->getIntegerType(32) auto out_type = use32Bit.getValue() ? builder->getIntegerType(32)
: builder->getIntegerType(64); : builder->getIntegerType(64);
@ -1846,7 +1845,7 @@ static LogicalResult Verify(ShapeNOp op) {
for (auto i : llvm::seq<uint64_t>(0, num_tensors)) { for (auto i : llvm::seq<uint64_t>(0, num_tensors)) {
auto verification = VerifyShapeOperandAndResult( auto verification = VerifyShapeOperandAndResult(
op, op.getOperand(i)->getType(), op.getResult(i)->getType(), i); op, op.getOperand(i).getType(), op.getResult(i).getType(), i);
if (failed(verification)) return verification; if (failed(verification)) return verification;
} }
@ -1919,7 +1918,7 @@ static LogicalResult Verify(SliceOp op) {
" same number of elements"; " same number of elements";
} }
auto input_ty = op.input()->getType().dyn_cast<RankedTensorType>(); auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) { if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) {
return op.emitOpError() << "requires number of elements in begin and size" return op.emitOpError() << "requires number of elements in begin and size"
"are equal to input rank"; "are equal to input rank";
@ -1973,7 +1972,7 @@ static LogicalResult Verify(SoftmaxOp op) {
// //
static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) { static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) {
auto broadcasted_ty = OpTrait::util::getBroadcastedType( auto broadcasted_ty = OpTrait::util::getBroadcastedType(
op.features()->getType(), op.labels()->getType()) op.features().getType(), op.labels().getType())
.dyn_cast_or_null<ShapedType>(); .dyn_cast_or_null<ShapedType>();
if (!broadcasted_ty || if (!broadcasted_ty ||
(broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2)) (broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2))
@ -1994,8 +1993,8 @@ static LogicalResult Verify(SparseSoftmaxCrossEntropyWithLogitsOp op) {
if (!IsOfRankOrUnranked(op.labels(), 1)) { if (!IsOfRankOrUnranked(op.labels(), 1)) {
return op.emitOpError("requires labels operand of rank one"); return op.emitOpError("requires labels operand of rank one");
} }
auto features_ty = op.features()->getType().dyn_cast<RankedTensorType>(); auto features_ty = op.features().getType().dyn_cast<RankedTensorType>();
auto labels_ty = op.labels()->getType().dyn_cast<RankedTensorType>(); auto labels_ty = op.labels().getType().dyn_cast<RankedTensorType>();
if (features_ty && labels_ty) { if (features_ty && labels_ty) {
int64_t features_batches = features_ty.getDimSize(0); int64_t features_batches = features_ty.getDimSize(0);
int64_t labels_batches = labels_ty.getDimSize(0); int64_t labels_batches = labels_ty.getDimSize(0);
@ -2020,7 +2019,7 @@ LogicalResult VerifySplitInputAndSplitDim(Op op, Optional<int64_t> *dim_index) {
*dim_index = llvm::None; *dim_index = llvm::None;
Value split_dim = op.split_dim(); Value split_dim = op.split_dim();
if (auto split_dim_type = split_dim->getType().dyn_cast<RankedTensorType>()) if (auto split_dim_type = split_dim.getType().dyn_cast<RankedTensorType>())
if (split_dim_type.getRank() != 0) if (split_dim_type.getRank() != 0)
return op.emitOpError( return op.emitOpError(
"split dimension should be an integer scalar tensor"); "split dimension should be an integer scalar tensor");
@ -2028,7 +2027,7 @@ LogicalResult VerifySplitInputAndSplitDim(Op op, Optional<int64_t> *dim_index) {
// We can perform further verification if the input tensor to be split has // We can perform further verification if the input tensor to be split has
// known rank and the split dimension tensor is a constant. // known rank and the split dimension tensor is a constant.
auto input_type = op.value()->getType().template dyn_cast<RankedTensorType>(); auto input_type = op.value().getType().template dyn_cast<RankedTensorType>();
if (!input_type) return success(); if (!input_type) return success();
int64_t input_rank = input_type.getRank(); int64_t input_rank = input_type.getRank();
@ -2057,7 +2056,7 @@ static LogicalResult Verify(SplitOp op) {
if (!dim_index) return success(); if (!dim_index) return success();
int64_t input_dim_size = int64_t input_dim_size =
op.value()->getType().cast<RankedTensorType>().getDimSize(*dim_index); op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
if (input_dim_size == ShapedType::kDynamicSize) return success(); if (input_dim_size == ShapedType::kDynamicSize) return success();
if (input_dim_size % op.getNumResults() != 0) if (input_dim_size % op.getNumResults() != 0)
@ -2073,7 +2072,7 @@ static LogicalResult Verify(SplitOp op) {
static LogicalResult Verify(SplitVOp op) { static LogicalResult Verify(SplitVOp op) {
auto split_sizes_type = auto split_sizes_type =
op.size_splits()->getType().dyn_cast<RankedTensorType>(); op.size_splits().getType().dyn_cast<RankedTensorType>();
if (!split_sizes_type) return success(); if (!split_sizes_type) return success();
if (split_sizes_type.getRank() != 1 || if (split_sizes_type.getRank() != 1 ||
@ -2086,7 +2085,7 @@ static LogicalResult Verify(SplitVOp op) {
if (!dim_index) return success(); if (!dim_index) return success();
int64_t input_dim_size = int64_t input_dim_size =
op.value()->getType().cast<RankedTensorType>().getDimSize(*dim_index); op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
if (input_dim_size == ShapedType::kDynamicSize) return success(); if (input_dim_size == ShapedType::kDynamicSize) return success();
// If split sizes come from a constant, they must sum to the dimension size // If split sizes come from a constant, they must sum to the dimension size
@ -2178,7 +2177,7 @@ static LogicalResult VerifyStridedSliceBase(OpTy op) {
int64_t expected_size = -1; int64_t expected_size = -1;
for (Value val : {op.begin(), op.end(), op.strides()}) { for (Value val : {op.begin(), op.end(), op.strides()}) {
auto operand_ty = val->getType().dyn_cast<ShapedType>(); auto operand_ty = val.getType().dyn_cast<ShapedType>();
if (!operand_ty || !operand_ty.hasStaticShape()) { if (!operand_ty || !operand_ty.hasStaticShape()) {
// TensorFlow constant ops may have non-static shape because the shape is // TensorFlow constant ops may have non-static shape because the shape is
// not propagated during constant folding. If the defining op for this // not propagated during constant folding. If the defining op for this
@ -2336,7 +2335,7 @@ bool StridedSliceOp::GetSlicedBoundRanges(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult Verify(StridedSliceGradOp op) { static LogicalResult Verify(StridedSliceGradOp op) {
auto shape_type = op.shape()->getType().dyn_cast<RankedTensorType>(); auto shape_type = op.shape().getType().dyn_cast<RankedTensorType>();
if (shape_type && shape_type.getRank() != 1) if (shape_type && shape_type.getRank() != 1)
return op.emitOpError("'shape' operand must be 1D tensor, but got ") return op.emitOpError("'shape' operand must be 1D tensor, but got ")
<< shape_type.getRank() << "D tensor"; << shape_type.getRank() << "D tensor";
@ -2433,8 +2432,8 @@ static LogicalResult Verify(TensorScatterUpdateOp op) {
return op.emitOpError( return op.emitOpError(
"requires updates operand to have at least 1 dimension"); "requires updates operand to have at least 1 dimension");
auto tensor_ty = op.tensor()->getType().dyn_cast<RankedTensorType>(); auto tensor_ty = op.tensor().getType().dyn_cast<RankedTensorType>();
auto indices_ty = op.indices()->getType().dyn_cast<RankedTensorType>(); auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
if (!tensor_ty || !indices_ty) return success(); if (!tensor_ty || !indices_ty) return success();
int64_t num_index_dims = indices_ty.getShape().back(); int64_t num_index_dims = indices_ty.getShape().back();
@ -2478,7 +2477,7 @@ static LogicalResult Verify(TransposeOp op) {
// TODO(jpienaar): perm could be optional too. // TODO(jpienaar): perm could be optional too.
void TransposeOp::build(Builder *builder, OperationState &result, Value x, void TransposeOp::build(Builder *builder, OperationState &result, Value x,
Value perm) { Value perm) {
auto x_type = x->getType().cast<TensorType>(); auto x_type = x.getType().cast<TensorType>();
// If value is unranked, then so is results. // If value is unranked, then so is results.
if (!x_type.hasRank()) if (!x_type.hasRank())
return TransposeOp::build(builder, result, return TransposeOp::build(builder, result,
@ -2509,7 +2508,7 @@ void TransposeOp::build(Builder *builder, OperationState &result, Value x,
} }
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) { OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
auto const_perm = dyn_cast_or_null<TF::ConstOp>(perm()->getDefiningOp()); auto const_perm = dyn_cast_or_null<TF::ConstOp>(perm().getDefiningOp());
if (!const_perm) { if (!const_perm) {
return {}; return {};
@ -2541,7 +2540,7 @@ void TruncateDivOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult Verify(UnpackOp op) { static LogicalResult Verify(UnpackOp op) {
auto value_type = op.value()->getType().dyn_cast<RankedTensorType>(); auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
if (!value_type) return success(); if (!value_type) return success();
int64_t value_rank = value_type.getRank(); int64_t value_rank = value_type.getRank();
@ -2569,9 +2568,9 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
if (!HasRankAtMost(op.num_segments(), 0)) if (!HasRankAtMost(op.num_segments(), 0))
return op.emitOpError("number of segments should be a 0-D tensor"); return op.emitOpError("number of segments should be a 0-D tensor");
auto data_type = op.data()->getType().template dyn_cast<RankedTensorType>(); auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
auto segment_ids_type = auto segment_ids_type =
op.segment_ids()->getType().template dyn_cast<RankedTensorType>(); op.segment_ids().getType().template dyn_cast<RankedTensorType>();
if (data_type && segment_ids_type) { if (data_type && segment_ids_type) {
if (data_type.getRank() < segment_ids_type.getRank()) if (data_type.getRank() < segment_ids_type.getRank())
return op.emitOpError( return op.emitOpError(
@ -2609,7 +2608,7 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
static LogicalResult Verify(VariableShapeOp op) { static LogicalResult Verify(VariableShapeOp op) {
auto resource_operand_type = op.input() auto resource_operand_type = op.input()
->getType() .getType()
.cast<TensorType>() .cast<TensorType>()
.getElementType() .getElementType()
.cast<TF::ResourceType>(); .cast<TF::ResourceType>();
@ -2763,7 +2762,7 @@ struct TFInlinerInterface : public DialectInlinerInterface {
Operation *materializeCallConversion(OpBuilder &builder, Value input, Operation *materializeCallConversion(OpBuilder &builder, Value input,
Type result_type, Type result_type,
Location conversion_loc) const final { Location conversion_loc) const final {
if (!result_type.isa<TensorType>() || !input->getType().isa<TensorType>()) if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
return nullptr; return nullptr;
return builder.create<TF::CastOp>(conversion_loc, result_type, input, return builder.create<TF::CastOp>(conversion_loc, result_type, input,
/*truncate=*/builder.getBoolAttr(false)); /*truncate=*/builder.getBoolAttr(false));

View File

@ -57,7 +57,7 @@ class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
// Returns data type of the result handle. Returned type contains type of // Returns data type of the result handle. Returned type contains type of
// the TensorList element as a subtype. // the TensorList element as a subtype.
VariantType handle_dtype() { VariantType handle_dtype() {
return getElementTypeOrSelf(handle()->getType()).cast<TF::VariantType>(); return getElementTypeOrSelf(handle().getType()).cast<TF::VariantType>();
} }
}]; }];
} }

View File

@ -47,7 +47,7 @@ class OperandsSameAsResultsTypeOrRef
LogicalResult shapeMatch = impl::verifySameOperandsAndResultShape(op); LogicalResult shapeMatch = impl::verifySameOperandsAndResultShape(op);
if (failed(shapeMatch)) return shapeMatch; if (failed(shapeMatch)) return shapeMatch;
auto type = getElementTypeOrSelf(op->getResult(0)->getType()); auto type = getElementTypeOrSelf(op->getResult(0).getType());
// Verify that the first result type is same as the rest of the results. // Verify that the first result type is same as the rest of the results.
// We skip the comparison against itself. // We skip the comparison against itself.

View File

@ -23,7 +23,7 @@ def SingleResultAndOperandHaveSameElementType : Constraint<
CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>; CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>;
def SingleResultAndOperandHaveSameType : Constraint< def SingleResultAndOperandHaveSameType : Constraint<
CPred<"$0->getType() == $1->getType()">>; CPred<"$0.getType() == $1.getType()">>;
def IsRank2Tensor : Type<HasAnyRankOfPred<[2]>, "Rank 2 tensor">; def IsRank2Tensor : Type<HasAnyRankOfPred<[2]>, "Rank 2 tensor">;

View File

@ -70,9 +70,9 @@ StringRef GetDevice(Operation* op) {
bool CanMergeIntoCluster(const Cluster& c, Operation* to_merge) { bool CanMergeIntoCluster(const Cluster& c, Operation* to_merge) {
return llvm::all_of(to_merge->getOperands(), [&](Value operand) { return llvm::all_of(to_merge->getOperands(), [&](Value operand) {
// Block arguments. // Block arguments.
if (operand->isa<BlockArgument>()) return true; if (operand.isa<BlockArgument>()) return true;
Operation* defining_op = operand->getDefiningOp(); Operation* defining_op = operand.getDefiningOp();
// Operand produced by other islands. // Operand produced by other islands.
if (defining_op->getBlock() != c.ops.front()->getBlock()) return true; if (defining_op->getBlock() != c.ops.front()->getBlock()) return true;
@ -100,7 +100,7 @@ void ReplaceLiveOutExternalUses(llvm::ArrayRef<Value> live_outs,
Region* launch_op_region = &launch_op.body(); Region* launch_op_region = &launch_op.body();
for (const auto& p : llvm::zip(live_outs, launch_op.getResults())) { for (const auto& p : llvm::zip(live_outs, launch_op.getResults())) {
Value from = std::get<0>(p); Value from = std::get<0>(p);
for (auto& use : from->getUses()) { for (auto& use : from.getUses()) {
if (launch_op_region->isAncestor(use.getOwner()->getParentRegion())) if (launch_op_region->isAncestor(use.getOwner()->getParentRegion()))
continue; continue;
use.set(std::get<1>(p)); use.set(std::get<1>(p));
@ -116,7 +116,7 @@ void GetLiveOuts(Region* region, llvm::SmallVectorImpl<Value>* live_outs) {
for (Value v : op.getResults()) { for (Value v : op.getResults()) {
// A value is live-out if any of its users are not inside value producer's // A value is live-out if any of its users are not inside value producer's
// region. // region.
bool is_live_out = llvm::any_of(v->getUsers(), [&](Operation* user) { bool is_live_out = llvm::any_of(v.getUsers(), [&](Operation* user) {
return !region->isAncestor(user->getParentRegion()); return !region->isAncestor(user->getParentRegion());
}); });
@ -158,7 +158,7 @@ void BuildLaunchForCluster(const Cluster& c, OpBuilder* builder) {
llvm::SmallVector<Type, 4> live_out_types; llvm::SmallVector<Type, 4> live_out_types;
live_out_types.reserve(live_outs.size()); live_out_types.reserve(live_outs.size());
for (Value v : live_outs) { for (Value v : live_outs) {
live_out_types.emplace_back(v->getType()); live_out_types.emplace_back(v.getType());
} }
tf_device::LaunchOp launch_op = builder->create<tf_device::LaunchOp>( tf_device::LaunchOp launch_op = builder->create<tf_device::LaunchOp>(

View File

@ -56,7 +56,7 @@ FuncOp BuildFunction(StringRef device, llvm::ArrayRef<Value> live_ins,
OpBuilder* builder) { OpBuilder* builder) {
llvm::SmallVector<Type, 4> operand_types; llvm::SmallVector<Type, 4> operand_types;
operand_types.reserve(live_ins.size()); operand_types.reserve(live_ins.size());
for (Value v : live_ins) operand_types.emplace_back(v->getType()); for (Value v : live_ins) operand_types.emplace_back(v.getType());
llvm::SmallVector<Type, 4> result_types(launch_op.getResultTypes()); llvm::SmallVector<Type, 4> result_types(launch_op.getResultTypes());

View File

@ -25,7 +25,7 @@ def CreateTFReadVariableOp: NativeCodeCall<
"$_builder.create<TF::ReadVariableOp>(" "$_builder.create<TF::ReadVariableOp>("
" $0.getLoc()," " $0.getLoc(),"
" UnrankedTensorType::get(" " UnrankedTensorType::get("
" $1->getType().cast<TensorType>().getElementType())," " $1.getType().cast<TensorType>().getElementType()),"
" $2)" " $2)"
>; >;

View File

@ -71,7 +71,7 @@ llvm::Optional<IslandOp> GetOperandCandidateToMergeWith(IslandOp island) {
// Check island control operands. // Check island control operands.
for (Value input : island.controlInputs()) { for (Value input : island.controlInputs()) {
Operation* def = input->getDefiningOp(); Operation* def = input.getDefiningOp();
DCHECK_EQ(def->getParentOp(), graph_op); DCHECK_EQ(def->getParentOp(), graph_op);
if (!candidate || candidate->isBeforeInBlock(def)) candidate = def; if (!candidate || candidate->isBeforeInBlock(def)) candidate = def;
} }
@ -79,7 +79,7 @@ llvm::Optional<IslandOp> GetOperandCandidateToMergeWith(IslandOp island) {
// Check island data operands. // Check island data operands.
island.walk([graph_op, &candidate](Operation* op) { island.walk([graph_op, &candidate](Operation* op) {
for (Value input : op->getOperands()) { for (Value input : op->getOperands()) {
Operation* def = input->getDefiningOp(); Operation* def = input.getDefiningOp();
if (!def || def->getParentOp() != graph_op) continue; if (!def || def->getParentOp() != graph_op) continue;
if (!candidate || candidate->isBeforeInBlock(def)) candidate = def; if (!candidate || candidate->isBeforeInBlock(def)) candidate = def;
} }
@ -99,7 +99,7 @@ llvm::Optional<IslandOp> GetResultCandidateToMergeWith(IslandOp island) {
Operation* candidate = nullptr; Operation* candidate = nullptr;
// Check island control results. // Check island control results.
for (Operation* user : island.control()->getUsers()) { for (Operation* user : island.control().getUsers()) {
DCHECK_EQ(user->getParentOp(), graph_op); DCHECK_EQ(user->getParentOp(), graph_op);
if (!candidate || user->isBeforeInBlock(candidate)) candidate = user; if (!candidate || user->isBeforeInBlock(candidate)) candidate = user;
} }
@ -107,7 +107,7 @@ llvm::Optional<IslandOp> GetResultCandidateToMergeWith(IslandOp island) {
// Check island data results. // Check island data results.
Block& graph_body = llvm::cast<GraphOp>(graph_op).GetBody(); Block& graph_body = llvm::cast<GraphOp>(graph_op).GetBody();
for (Value result : island.outputs()) { for (Value result : island.outputs()) {
for (Operation* user : result->getUsers()) { for (Operation* user : result.getUsers()) {
Operation* def = graph_body.findAncestorOpInBlock(*user); Operation* def = graph_body.findAncestorOpInBlock(*user);
DCHECK_NE(def, nullptr); DCHECK_NE(def, nullptr);
if (!candidate || def->isBeforeInBlock(candidate)) candidate = def; if (!candidate || def->isBeforeInBlock(candidate)) candidate = def;
@ -147,7 +147,7 @@ llvm::SmallVector<IslandResult, 8> GetNewIslandResultsAndForwardResults(
bool result_captured = false; bool result_captured = false;
Value inner_op_result = std::get<0>(ret_vals); Value inner_op_result = std::get<0>(ret_vals);
Value island_result = std::get<1>(ret_vals); Value island_result = std::get<1>(ret_vals);
for (auto& use : llvm::make_early_inc_range(island_result->getUses())) { for (auto& use : llvm::make_early_inc_range(island_result.getUses())) {
if (child_body.findAncestorOpInBlock(*use.getOwner())) { if (child_body.findAncestorOpInBlock(*use.getOwner())) {
// Forward result from inner op. // Forward result from inner op.
use.set(inner_op_result); use.set(inner_op_result);
@ -162,7 +162,7 @@ llvm::SmallVector<IslandResult, 8> GetNewIslandResultsAndForwardResults(
llvm::zip(child.GetYield().getOperands(), child.outputs())) { llvm::zip(child.GetYield().getOperands(), child.outputs())) {
Value inner_op_result = std::get<0>(ret_vals); Value inner_op_result = std::get<0>(ret_vals);
Value island_result = std::get<1>(ret_vals); Value island_result = std::get<1>(ret_vals);
if (!island_result->use_empty()) { if (!island_result.use_empty()) {
results.emplace_back(inner_op_result, island_result); results.emplace_back(inner_op_result, island_result);
} }
} }
@ -178,7 +178,7 @@ IslandOp CreateNewIsland(IslandOp parent, IslandOp child,
// Collect types from results. // Collect types from results.
llvm::SmallVector<Type, 8> result_types; llvm::SmallVector<Type, 8> result_types;
for (const auto& result : results) for (const auto& result : results)
result_types.push_back(result.inner_op_result->getType()); result_types.push_back(result.inner_op_result.getType());
// IslandOps always have a control result. // IslandOps always have a control result.
result_types.push_back(ControlType::get(parent.getContext())); result_types.push_back(ControlType::get(parent.getContext()));
@ -201,7 +201,7 @@ YieldOp CreateNewIslandYieldOp(IslandOp new_island,
const auto& old_result = std::get<0>(ret_vals); const auto& old_result = std::get<0>(ret_vals);
// Replace original island result with new island result. // Replace original island result with new island result.
old_result.island_result->replaceAllUsesWith(std::get<1>(ret_vals)); old_result.island_result.replaceAllUsesWith(std::get<1>(ret_vals));
// Add associated inner op result to operands of the YieldOp. // Add associated inner op result to operands of the YieldOp.
yield_operands.push_back(old_result.inner_op_result); yield_operands.push_back(old_result.inner_op_result);
@ -249,8 +249,8 @@ void MergeIslands(IslandOp parent, IslandOp child, IslandType insert_position) {
MoveInnerOpsToNewIsland(parent, child, new_yield_op.getOperation()); MoveInnerOpsToNewIsland(parent, child, new_yield_op.getOperation());
// Update control inputs to point to the new merged island. // Update control inputs to point to the new merged island.
child.control()->replaceAllUsesWith(new_island.control()); child.control().replaceAllUsesWith(new_island.control());
parent.control()->replaceAllUsesWith(new_island.control()); parent.control().replaceAllUsesWith(new_island.control());
// Remove merged islands. // Remove merged islands.
child.erase(); child.erase();
@ -291,11 +291,11 @@ void InsertDummyIslandForFetch(FetchOp fetch) {
llvm::SmallVector<Type, 4> data_types; llvm::SmallVector<Type, 4> data_types;
llvm::SmallVector<Value, 4> control_fetches; llvm::SmallVector<Value, 4> control_fetches;
for (auto value : fetch.fetches()) { for (auto value : fetch.fetches()) {
if (value->getType().isa<ControlType>()) { if (value.getType().isa<ControlType>()) {
control_fetches.push_back(value); control_fetches.push_back(value);
} else { } else {
data_fetches.push_back(value); data_fetches.push_back(value);
data_types.push_back(value->getType()); data_types.push_back(value.getType());
} }
} }
auto island = OpBuilder(fetch).create<IslandOp>( auto island = OpBuilder(fetch).create<IslandOp>(

View File

@ -66,12 +66,12 @@ class SwitchFoldPass : public mlir::FunctionPass<SwitchFoldPass> {
// Returns the defining op for a value looking through islands. // Returns the defining op for a value looking through islands.
static Operation* GetDefiningOp(Value val) { static Operation* GetDefiningOp(Value val) {
Operation* op = val->getDefiningOp(); Operation* op = val.getDefiningOp();
auto island_op = dyn_cast<tf_executor::IslandOp>(op); auto island_op = dyn_cast<tf_executor::IslandOp>(op);
if (!island_op) return op; if (!island_op) return op;
auto yield_op = island_op.GetYield(); auto yield_op = island_op.GetYield();
auto index = val->cast<mlir::OpResult>()->getResultNumber(); auto index = val.cast<mlir::OpResult>().getResultNumber();
return yield_op.getOperand(index)->getDefiningOp(); return yield_op.getOperand(index).getDefiningOp();
} }
// Returns either the value or input to an IdentityOp. // Returns either the value or input to an IdentityOp.
@ -114,7 +114,7 @@ class DeadQueue {
// feeding into the Merge then we could have a null value here. // feeding into the Merge then we could have a null value here.
count = 0; count = 0;
for (auto operand : op->getOperands()) { for (auto operand : op->getOperands()) {
if (operand && !operand->getType().isa<tf_executor::ControlType>()) if (operand && !operand.getType().isa<tf_executor::ControlType>())
++count; ++count;
} }
} }
@ -125,8 +125,8 @@ class DeadQueue {
// Enqueue users of a value. // Enqueue users of a value.
void EnqueueUsers(Value val) { void EnqueueUsers(Value val) {
for (auto user : val->getUsers()) { for (auto user : val.getUsers()) {
Enqueue(user, val->getType().isa<tf_executor::ControlType>()); Enqueue(user, val.getType().isa<tf_executor::ControlType>());
} }
} }
@ -189,7 +189,7 @@ static void MatchSwitchFoldOps(tf_executor::SwitchOp switch_op,
bool taken = pred.getSplatValue<bool>(); bool taken = pred.getSplatValue<bool>();
Value dead = taken ? switch_op.falseOutput() : switch_op.trueOutput(); Value dead = taken ? switch_op.falseOutput() : switch_op.trueOutput();
Value live = !taken ? switch_op.falseOutput() : switch_op.trueOutput(); Value live = !taken ? switch_op.falseOutput() : switch_op.trueOutput();
live->replaceAllUsesWith(switch_op.data()); live.replaceAllUsesWith(switch_op.data());
queue->EnqueueUsers(dead); queue->EnqueueUsers(dead);
// Delete switch op. // Delete switch op.
@ -218,7 +218,7 @@ static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) {
Value operand = e.value(); Value operand = e.value();
if (!operand) continue; if (!operand) continue;
// Skip control operands. // Skip control operands.
if (operand->getType().isa<tf_executor::ControlType>()) break; if (operand.getType().isa<tf_executor::ControlType>()) break;
if (val != nullptr) { if (val != nullptr) {
return merge->emitOpError("multiple valid inputs post switch folding"); return merge->emitOpError("multiple valid inputs post switch folding");
} }
@ -226,26 +226,26 @@ static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) {
index = e.index(); index = e.index();
} }
assert(val != nullptr && "merge node should have been deleted"); assert(val != nullptr && "merge node should have been deleted");
merge_op.output()->replaceAllUsesWith(val); merge_op.output().replaceAllUsesWith(val);
// Build and insert value_index only if needed. // Build and insert value_index only if needed.
if (!merge_op.value_index()->use_empty()) { if (!merge_op.value_index().use_empty()) {
merge_op.value_index()->replaceAllUsesWith( merge_op.value_index().replaceAllUsesWith(
build_index(merge->getLoc(), index)); build_index(merge->getLoc(), index));
} }
// Propagate control dependencies if used. // Propagate control dependencies if used.
if (!merge_op.control()->use_empty()) { if (!merge_op.control().use_empty()) {
// Change control dependencies from the merge to being on the parent of // Change control dependencies from the merge to being on the parent of
// the value being propagated. // the value being propagated.
auto def_op = val->getDefiningOp(); auto def_op = val.getDefiningOp();
#ifndef NDEBUG #ifndef NDEBUG
auto exec_dialect = auto exec_dialect =
function.getContext()->getRegisteredDialect("tf_executor"); function.getContext()->getRegisteredDialect("tf_executor");
assert(def_op->getDialect() == exec_dialect && assert(def_op->getDialect() == exec_dialect &&
"unable to forward control dependencies"); "unable to forward control dependencies");
#endif #endif
merge_op.control()->replaceAllUsesWith( merge_op.control().replaceAllUsesWith(
def_op->getResult(def_op->getNumResults() - 1)); def_op->getResult(def_op->getNumResults() - 1));
} }

View File

@ -53,7 +53,7 @@ static Value LowerCondition(Location loc, Value value, OpBuilder* builder) {
// FIXME: This is almost all wrong, but is a placeholder to unblock the one // FIXME: This is almost all wrong, but is a placeholder to unblock the one
// testcases, later patches will build on this once I build the right infra to // testcases, later patches will build on this once I build the right infra to
// support it. // support it.
TensorType type = value->getType().cast<TensorType>(); TensorType type = value.getType().cast<TensorType>();
if (!type.hasRank() || type.getRank() != 0 || if (!type.hasRank() || type.getRank() != 0 ||
!type.getElementType().isInteger(1)) { !type.getElementType().isInteger(1)) {
return emitError(loc, "only supports zero-D bool tensors now"), nullptr; return emitError(loc, "only supports zero-D bool tensors now"), nullptr;
@ -79,7 +79,7 @@ static Operation* CallFn(Location loc, const std::function<Value(int)>& get_arg,
for (int i = 0; i < num_operands; ++i) { for (int i = 0; i < num_operands; ++i) {
Value val = get_arg(i); Value val = get_arg(i);
Type expected = fn_type.getInput(i); Type expected = fn_type.getInput(i);
if (val->getType() != expected) { if (val.getType() != expected) {
val = val =
builder->create<TF::CastOp>(loc, expected, val, builder->create<TF::CastOp>(loc, expected, val,
/*Truncate=*/builder->getBoolAttr(false)); /*Truncate=*/builder->getBoolAttr(false));
@ -102,8 +102,8 @@ static llvm::SmallVector<Value, 4> PrepareValsForJump(
result.reserve(num_vals); result.reserve(num_vals);
for (int i = 0; i < num_vals; ++i) { for (int i = 0; i < num_vals; ++i) {
Value val = get_val(i); Value val = get_val(i);
Type expected = block->getArgument(i)->getType(); Type expected = block->getArgument(i).getType();
if (val->getType() != expected) { if (val.getType() != expected) {
val = val =
builder->create<TF::CastOp>(loc, expected, val, builder->create<TF::CastOp>(loc, expected, val,
/*Truncate=*/builder->getBoolAttr(false)); /*Truncate=*/builder->getBoolAttr(false));
@ -137,12 +137,12 @@ static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
Value arg = block->getArgument(i); Value arg = block->getArgument(i);
Value result = op->getResult(i); Value result = op->getResult(i);
if (arg->getType() != result->getType()) { if (arg.getType() != result.getType()) {
arg = arg =
builder->create<TF::CastOp>(loc, result->getType(), arg, builder->create<TF::CastOp>(loc, result.getType(), arg,
/*Truncate=*/builder->getBoolAttr(false)); /*Truncate=*/builder->getBoolAttr(false));
} }
result->replaceAllUsesWith(arg); result.replaceAllUsesWith(arg);
} }
} }
@ -174,7 +174,7 @@ static LogicalResult LowerIfOp(IfOp op) {
// Add the block arguments to the merge point, and replace all uses of the // Add the block arguments to the merge point, and replace all uses of the
// original operation results with them. // original operation results with them.
for (Value value : op_inst->getResults()) for (Value value : op_inst->getResults())
merge_block->addArgument(value->getType()); merge_block->addArgument(value.getType());
ReplaceOpResultWithBlockArgs(loc, op_inst, merge_block, &builder); ReplaceOpResultWithBlockArgs(loc, op_inst, merge_block, &builder);
// Get arguments to the branches after dropping the condition which is the // Get arguments to the branches after dropping the condition which is the

View File

@ -39,7 +39,7 @@ void PruneGraph(GraphOp graph) {
// Visit an op's operands if it is output of an Operation in same graph. // Visit an op's operands if it is output of an Operation in same graph.
auto visit_op = [&](Operation* op) { auto visit_op = [&](Operation* op) {
for (Value operand : op->getOperands()) { for (Value operand : op->getOperands()) {
Operation* def = operand->getDefiningOp(); Operation* def = operand.getDefiningOp();
if (def && def->getParentOp() == graph && if (def && def->getParentOp() == graph &&
reachable_ops.insert(def).second) { reachable_ops.insert(def).second) {
// Op has not been visited, add to queue to visit later. // Op has not been visited, add to queue to visit later.

View File

@ -55,7 +55,7 @@ void InlineGlobalTensorsPass::runOnModule() {
// Replace the arg with a tf.Const op in the function body. // Replace the arg with a tf.Const op in the function body.
auto const_op = builder.create<TF::ConstOp>(global_tensor.getLoc(), auto const_op = builder.create<TF::ConstOp>(global_tensor.getLoc(),
global_tensor.value()); global_tensor.value());
func.getArgument(i)->replaceAllUsesWith(const_op.getResult()); func.getArgument(i).replaceAllUsesWith(const_op.getResult());
args_to_erase.push_back(i); args_to_erase.push_back(i);
} }
func.eraseArguments(args_to_erase); func.eraseArguments(args_to_erase);

View File

@ -196,7 +196,7 @@ class LowerDynamicStitchOp : public OpRewritePattern<TF::DynamicStitchOp> {
if (!matchPattern(index, m_Constant(&index_attr))) return matchFailure(); if (!matchPattern(index, m_Constant(&index_attr))) return matchFailure();
indices.push_back(index_attr); indices.push_back(index_attr);
RankedTensorType data_ty = data->getType().dyn_cast<RankedTensorType>(); RankedTensorType data_ty = data.getType().dyn_cast<RankedTensorType>();
if (!data_ty || !data_ty.hasStaticShape()) return matchFailure(); if (!data_ty || !data_ty.hasStaticShape()) return matchFailure();
} }
@ -270,7 +270,7 @@ class LowerPackOp : public OpRewritePattern<TF::PackOp> {
// If input type is different than the previous input type, infer the // If input type is different than the previous input type, infer the
// output type. Otherwise, use the already inferred output type from the // output type. Otherwise, use the already inferred output type from the
// previous iteration. // previous iteration.
Type input_ty = input->getType(); Type input_ty = input.getType();
if (input_ty != prev_input_ty) { if (input_ty != prev_input_ty) {
inferred_ty = InferExpandDimsType(input_ty, axis, &rewriter); inferred_ty = InferExpandDimsType(input_ty, axis, &rewriter);
prev_input_ty = input_ty; prev_input_ty = input_ty;

View File

@ -37,7 +37,7 @@ class GetI64ScalarElementsAttr<int value> :
def GetBiasAddGradReductionIndices : NativeCodeCall< def GetBiasAddGradReductionIndices : NativeCodeCall<
"GetBiasAddGradReductionIndices(" "GetBiasAddGradReductionIndices("
"$0->getType().cast<RankedTensorType>().getRank(), $1, &$_builder)">; "$0.getType().cast<RankedTensorType>().getRank(), $1, &$_builder)">;
def LowerBiasAddGradOp : def LowerBiasAddGradOp :
Pat<(TF_BiasAddGradOp AnyRankedTensor:$out_backprop, $data_format), Pat<(TF_BiasAddGradOp AnyRankedTensor:$out_backprop, $data_format),
@ -82,12 +82,12 @@ def LowerSoftmaxCrossEntropyWithLogitsOp : Pattern<
// dimension should be known. // dimension should be known.
class GetDimSizeOfType<int dim> : NativeCodeCall< class GetDimSizeOfType<int dim> : NativeCodeCall<
"GetScalarOfType(getElementTypeOrSelf($1), " "GetScalarOfType(getElementTypeOrSelf($1), "
"$0->getType().cast<RankedTensorType>().getDimSize(" # dim # "))">; "$0.getType().cast<RankedTensorType>().getDimSize(" # dim # "))">;
// Same as the above with i32 element type. // Same as the above with i32 element type.
class GetDimSizeAsI32<int dim> : NativeCodeCall< class GetDimSizeAsI32<int dim> : NativeCodeCall<
"GetScalarOfType($_builder.getIntegerType(32), " "GetScalarOfType($_builder.getIntegerType(32), "
"$0->getType().cast<RankedTensorType>().getDimSize(" # dim # "))">; "$0.getType().cast<RankedTensorType>().getDimSize(" # dim # "))">;
// Sparse version of SoftmaxCrossEntropyWithLogits is lowered to dense by // Sparse version of SoftmaxCrossEntropyWithLogits is lowered to dense by
// expanding the sparse labels using: // expanding the sparse labels using:
@ -160,7 +160,7 @@ def LowerFillOp : Pat<(TF_FillOp $dims, $value),
def GetAllAxes : NativeCodeCall< def GetAllAxes : NativeCodeCall<
"GetI64ElementsAttrForSeq(" "GetI64ElementsAttrForSeq("
"0, $0->getType().cast<RankedTensorType>().getRank(), &$_builder)">; "0, $0.getType().cast<RankedTensorType>().getRank(), &$_builder)">;
// L2Loss is lowered using the formula, // L2Loss is lowered using the formula,
// L2Loss(input) = Sum(input * input) / 2 // L2Loss(input) = Sum(input * input) / 2
@ -220,7 +220,7 @@ def LowerTanhGradOp :
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def CreateTFShapeOp : NativeCodeCall< def CreateTFShapeOp : NativeCodeCall<
"$_builder.create<TF::ShapeOp>($0->getLoc(), $1, $2)">; "$_builder.create<TF::ShapeOp>($0.getLoc(), $1, $2)">;
// TODO(hinsu): Support inputs of TensorList types. // TODO(hinsu): Support inputs of TensorList types.
def LowerZerosLikeOp : def LowerZerosLikeOp :

View File

@ -79,7 +79,7 @@ void MaterializePassthroughOpPass::runOnFunction() {
Block &block = body.front(); Block &block = body.front();
for (const auto &arg_mapping : for (const auto &arg_mapping :
llvm::zip(block.getArguments(), op->getOperands())) { llvm::zip(block.getArguments(), op->getOperands())) {
std::get<0>(arg_mapping)->replaceAllUsesWith(std::get<1>(arg_mapping)); std::get<0>(arg_mapping).replaceAllUsesWith(std::get<1>(arg_mapping));
} }
op->getBlock()->getOperations().splice(op->getIterator(), op->getBlock()->getOperations().splice(op->getIterator(),
block.getOperations(), block.begin(), block.getOperations(), block.begin(),
@ -87,7 +87,7 @@ void MaterializePassthroughOpPass::runOnFunction() {
Operation &return_op = block.front(); Operation &return_op = block.front();
for (auto ret_mapping : for (auto ret_mapping :
llvm::zip(op->getResults(), return_op.getOperands())) { llvm::zip(op->getResults(), return_op.getOperands())) {
std::get<0>(ret_mapping)->replaceAllUsesWith(std::get<1>(ret_mapping)); std::get<0>(ret_mapping).replaceAllUsesWith(std::get<1>(ret_mapping));
} }
op->erase(); op->erase();
}); });

View File

@ -21,7 +21,7 @@ def BroadcastableElements :
Constraint<CPred<"TFL::IsBroadcastableElementsAttrs($0, $1)">>; Constraint<CPred<"TFL::IsBroadcastableElementsAttrs($0, $1)">>;
def F32ElementsAttr : ElementsAttrBase< def F32ElementsAttr : ElementsAttrBase<
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">; CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
def DefinedByConv2D : Constraint<CPred<"llvm::isa_and_nonnull<mlir::TF::Conv2DOp>($0->getDefiningOp())">>; def DefinedByConv2D : Constraint<CPred<"llvm::isa_and_nonnull<mlir::TF::Conv2DOp>($0.getDefiningOp())">>;
// If we see a Conv2D op followed by Mul, then multiply the filter // If we see a Conv2D op followed by Mul, then multiply the filter
// with the value in Mul. // with the value in Mul.

View File

@ -54,7 +54,7 @@ bool IsReadOnlyVariableOp(Operation* op) { return isa<TF::ReadVariableOp>(op); }
void RewriteReadOnlyVariableOpToTensorOp(Operation* op, Value tensor_value) { void RewriteReadOnlyVariableOpToTensorOp(Operation* op, Value tensor_value) {
auto read_variable = cast<TF::ReadVariableOp>(op); auto read_variable = cast<TF::ReadVariableOp>(op);
read_variable.value()->replaceAllUsesWith(tensor_value); read_variable.value().replaceAllUsesWith(tensor_value);
} }
bool IsFreezable(GlobalTensorOp global_tensor, bool IsFreezable(GlobalTensorOp global_tensor,
@ -74,7 +74,7 @@ bool IsFreezable(GlobalTensorOp global_tensor,
// or control flow, we fail to prove it is freezable even though we could. // or control flow, we fail to prove it is freezable even though we could.
for (auto& global_tensor_use : global_tensor_uses) { for (auto& global_tensor_use : global_tensor_uses) {
auto arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index); auto arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index);
for (auto user : arg->getUsers()) { for (auto user : arg.getUsers()) {
if (!IsReadOnlyVariableOp(user)) { if (!IsReadOnlyVariableOp(user)) {
return false; return false;
} }
@ -130,12 +130,12 @@ void FreezeGlobalTensors(ModuleOp module,
auto func = global_tensor_use.func; auto func = global_tensor_use.func;
auto arg_index = global_tensor_use.arg_index; auto arg_index = global_tensor_use.arg_index;
Value arg = func.getArgument(arg_index); Value arg = func.getArgument(arg_index);
for (Operation* user : llvm::make_early_inc_range(arg->getUsers())) { for (Operation* user : llvm::make_early_inc_range(arg.getUsers())) {
RewriteReadOnlyVariableOpToTensorOp(user, arg); RewriteReadOnlyVariableOpToTensorOp(user, arg);
user->erase(); user->erase();
} }
Type new_type = global_tensor.value().Attribute::getType(); Type new_type = global_tensor.value().Attribute::getType();
arg->setType(new_type); arg.setType(new_type);
auto old_ftype = func.getType(); auto old_ftype = func.getType();
auto input_types = old_ftype.getInputs().vec(); auto input_types = old_ftype.getInputs().vec();
input_types[arg_index] = new_type; input_types[arg_index] = new_type;
@ -168,7 +168,7 @@ void EraseUnusedBoundInputs(ModuleOp module) {
SmallVector<unsigned, 4> args_to_erase; SmallVector<unsigned, 4> args_to_erase;
for (int i = 0, e = func.getNumArguments(); i < e; i++) { for (int i = 0, e = func.getNumArguments(); i < e; i++) {
if (func.getArgAttr(i, "tf_saved_model.bound_input") && if (func.getArgAttr(i, "tf_saved_model.bound_input") &&
func.getArgument(i)->use_empty()) { func.getArgument(i).use_empty()) {
args_to_erase.push_back(i); args_to_erase.push_back(i);
} }
} }

View File

@ -100,7 +100,7 @@ void RaiseTFControlFlow::rewriteOps() {
// aren't necessary any more since the order within a block encodes the // aren't necessary any more since the order within a block encodes the
// same information. // same information.
for (auto &operand : op.getOpOperands()) { for (auto &operand : op.getOpOperands()) {
if (!operand.get()->getType().isa<TFControlType>()) if (!operand.get().getType().isa<TFControlType>())
result.operands.push_back(operand.get()); result.operands.push_back(operand.get());
// Drop all operands from the old operation, eliminating any // Drop all operands from the old operation, eliminating any
@ -111,13 +111,13 @@ void RaiseTFControlFlow::rewriteOps() {
// Add a result type for each non-control result we find. // Add a result type for each non-control result we find.
bool sawControlResult = false; bool sawControlResult = false;
for (auto opResult : op.getResults()) { for (auto opResult : op.getResults()) {
if (opResult->getType().isa<TFControlType>()) { if (opResult.getType().isa<TFControlType>()) {
sawControlResult = true; sawControlResult = true;
} else { } else {
// We assume all control inputs are at the end of the result list. // We assume all control inputs are at the end of the result list.
assert(!sawControlResult && "all control results must be last"); assert(!sawControlResult && "all control results must be last");
(void)sawControlResult; (void)sawControlResult;
result.types.push_back(opResult->getType()); result.types.push_back(opResult.getType());
} }
} }
@ -129,7 +129,7 @@ void RaiseTFControlFlow::rewriteOps() {
// We know that all the control results are last, so we can just rewrite // We know that all the control results are last, so we can just rewrite
// the first results. // the first results.
for (unsigned i = 0, e = result.types.size(); i != e; ++i) for (unsigned i = 0, e = result.types.size(); i != e; ++i)
op.getResult(i)->replaceAllUsesWith(replacement->getResult(i)); op.getResult(i).replaceAllUsesWith(replacement->getResult(i));
} }
} }

View File

@ -74,16 +74,16 @@ void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas,
Value input = shape_op.input(); Value input = shape_op.input();
// If ShapeOp operand is replicate tensor block argument, replace with the // If ShapeOp operand is replicate tensor block argument, replace with the
// associated first replica operand. // associated first replica operand.
if (auto block_arg = input->dyn_cast<BlockArgument>()) { if (auto block_arg = input.dyn_cast<BlockArgument>()) {
if (block_arg->getOwner() != replicate_block) return; if (block_arg.getOwner() != replicate_block) return;
shape_op.setOperand( shape_op.setOperand(
replicate_op.getOperand(num_replicas * block_arg->getArgNumber())); replicate_op.getOperand(num_replicas * block_arg.getArgNumber()));
return; return;
} }
Operation* input_def = input->getDefiningOp(); Operation* input_def = input.getDefiningOp();
// If ShapeOp operand is a ReadVariableOp result where the ReadVariableOp // If ShapeOp operand is a ReadVariableOp result where the ReadVariableOp
// operand is a replicate resource block argument, replace ShapeOp with // operand is a replicate resource block argument, replace ShapeOp with
@ -96,13 +96,13 @@ void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas,
// shape has not changed in replicate prior to read. Currently after both // shape has not changed in replicate prior to read. Currently after both
// ResourceOpLiftingPass and TPURewritePass, there should not be any updates // ResourceOpLiftingPass and TPURewritePass, there should not be any updates
// to resources prior to their respective ReadVariableOp. // to resources prior to their respective ReadVariableOp.
if (auto block_arg = read_var_op.resource()->dyn_cast<BlockArgument>()) { if (auto block_arg = read_var_op.resource().dyn_cast<BlockArgument>()) {
if (block_arg->getOwner() != replicate_block) return; if (block_arg.getOwner() != replicate_block) return;
OpBuilder builder(shape_op); OpBuilder builder(shape_op);
auto new_shape_op = builder.create<TF::VariableShapeOp>( auto new_shape_op = builder.create<TF::VariableShapeOp>(
shape_op.getLoc(), shape_op.getType(), shape_op.getLoc(), shape_op.getType(),
replicate_op.getOperand(num_replicas * block_arg->getArgNumber())); replicate_op.getOperand(num_replicas * block_arg.getArgNumber()));
shape_op.replaceAllUsesWith(new_shape_op.getOperation()); shape_op.replaceAllUsesWith(new_shape_op.getOperation());
shape_op.erase(); shape_op.erase();
} }
@ -112,7 +112,7 @@ void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas,
bool IsOpReplicateInvariant(Region* replicate_region, Operation* op) { bool IsOpReplicateInvariant(Region* replicate_region, Operation* op) {
auto result = op->walk([&](Operation* inner_op) { auto result = op->walk([&](Operation* inner_op) {
for (Value operand : inner_op->getOperands()) { for (Value operand : inner_op->getOperands()) {
Region* parent_region = operand->getParentRegion(); Region* parent_region = operand.getParentRegion();
if (!parent_region || !parent_region->isProperAncestor(replicate_region)) if (!parent_region || !parent_region->isProperAncestor(replicate_region))
return WalkResult::interrupt(); return WalkResult::interrupt();
} }

View File

@ -83,7 +83,7 @@ llvm::SmallVector<tf_executor::IslandOp, 8> ExpandReplicateIntoReplicas(
mapping.clear(); mapping.clear();
for (auto& block_arg : replicate_op.GetBody().getArguments()) for (auto& block_arg : replicate_op.GetBody().getArguments())
mapping.map(block_arg, replicate_op.getOperand( mapping.map(block_arg, replicate_op.getOperand(
block_arg->getArgNumber() * num_replicas + i)); block_arg.getArgNumber() * num_replicas + i));
// Copy over replicate region into replica island. // Copy over replicate region into replica island.
replicate_op.body().cloneInto(&replica.body(), mapping); replicate_op.body().cloneInto(&replica.body(), mapping);

View File

@ -127,16 +127,16 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
OpBuilder builder(func_op); OpBuilder builder(func_op);
// Function arguments. // Function arguments.
for (auto arg : func_op.getArguments()) { for (auto arg : func_op.getArguments()) {
if (!mlir::getElementTypeOrSelf(arg->getType()).isa<TF::ResourceType>()) { if (!mlir::getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) {
continue; continue;
} }
auto device_attr = func_op.getArgAttrOfType<mlir::StringAttr>( auto device_attr = func_op.getArgAttrOfType<mlir::StringAttr>(
arg->getArgNumber(), kFuncDeviceAttr); arg.getArgNumber(), kFuncDeviceAttr);
if (!device_attr || device_attr.getValue() == "") { if (!device_attr || device_attr.getValue() == "") {
// If device_attr does not exist, try to construct it from any recorded // If device_attr does not exist, try to construct it from any recorded
// assignment. // assignment.
if (auto device = result->DeviceForResource(arg)) { if (auto device = result->DeviceForResource(arg)) {
func_op.setArgAttr(arg->getArgNumber(), kFuncDeviceAttr, func_op.setArgAttr(arg.getArgNumber(), kFuncDeviceAttr,
builder.getStringAttr(*device)); builder.getStringAttr(*device));
} }
continue; continue;
@ -160,7 +160,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
} }
if (auto identity = llvm::dyn_cast<TF::IdentityOp>(op)) { if (auto identity = llvm::dyn_cast<TF::IdentityOp>(op)) {
// Try to construct IdentityOp's attribute from recorded assignment. // Try to construct IdentityOp's attribute from recorded assignment.
if (!mlir::getElementTypeOrSelf(identity.output()->getType()) if (!mlir::getElementTypeOrSelf(identity.output().getType())
.isa<TF::ResourceType>()) { .isa<TF::ResourceType>()) {
return WalkResult::advance(); return WalkResult::advance();
} }
@ -176,7 +176,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
// Propagate and record output device assignment for other ops based on // Propagate and record output device assignment for other ops based on
// existing recording. E.g., IdentityN. // existing recording. E.g., IdentityN.
for (auto output : op->getResults()) { for (auto output : op->getResults()) {
if (!mlir::getElementTypeOrSelf(output->getType()) if (!mlir::getElementTypeOrSelf(output.getType())
.isa<TF::ResourceType>()) { .isa<TF::ResourceType>()) {
continue; continue;
} }
@ -212,7 +212,7 @@ void ResourceDeviceInference::runOnModule() {
for (auto operand_and_argument : for (auto operand_and_argument :
llvm::zip(caller_operands, callee.getArguments())) { llvm::zip(caller_operands, callee.getArguments())) {
if (!mlir::getElementTypeOrSelf( if (!mlir::getElementTypeOrSelf(
std::get<0>(operand_and_argument)->getType()) std::get<0>(operand_and_argument).getType())
.isa<TF::ResourceType>()) { .isa<TF::ResourceType>()) {
continue; continue;
} }

View File

@ -100,7 +100,7 @@ void ForwardStoreToLoad(tf_device::LaunchOp launch_op) {
// Use stored value in last_store to replace all uses of current resource // Use stored value in last_store to replace all uses of current resource
// load's result, then erase this resource load. // load's result, then erase this resource load.
read_variable_op.value()->replaceAllUsesWith(last_store.value()); read_variable_op.value().replaceAllUsesWith(last_store.value());
read_variable_op.erase(); read_variable_op.erase();
continue; continue;
} }
@ -130,7 +130,7 @@ void HoistResourceLoads(tf_device::LaunchOp launch_op) {
Value resource = read_variable_op.resource(); Value resource = read_variable_op.resource();
// Skip resources created inside of launch_op. // Skip resources created inside of launch_op.
if (resource->getParentRegion() == &launch_op.body()) continue; if (resource.getParentRegion() == &launch_op.body()) continue;
auto p = resource_to_read_ops.insert({resource, read_variable_op}); auto p = resource_to_read_ops.insert({resource, read_variable_op});
if (p.second) { if (p.second) {
@ -167,7 +167,7 @@ bool AppendResourceStoreValueToReturn(tf_device::LaunchOp launch_op) {
if (!resource) continue; if (!resource) continue;
// Skip resources created inside of launch_op. // Skip resources created inside of launch_op.
if (resource->getParentRegion() == &launch_op.body()) continue; if (resource.getParentRegion() == &launch_op.body()) continue;
// TODO(ycao): Prevent same value from being returned multiple times. // TODO(ycao): Prevent same value from being returned multiple times.
// TODO(ycao): Do not return resource store value if it is defined outside // TODO(ycao): Do not return resource store value if it is defined outside
@ -207,7 +207,7 @@ void SinkResourceStores(tf_device::LaunchOp launch_op, OpBuilder* builder) {
// Replace uses of old launch_op results with those of new_launch_op. // Replace uses of old launch_op results with those of new_launch_op.
for (auto p : llvm::zip(launch_op.getResults(), new_launch_op.getResults())) { for (auto p : llvm::zip(launch_op.getResults(), new_launch_op.getResults())) {
std::get<0>(p)->replaceAllUsesWith(std::get<1>(p)); std::get<0>(p).replaceAllUsesWith(std::get<1>(p));
} }
// Create a mapping from operands of new_return_op operands to new_launch_op // Create a mapping from operands of new_return_op operands to new_launch_op

View File

@ -68,24 +68,24 @@ Optional<llvm::SmallVector<mlir::Type, 4>> InferShapeForFunctionReturnType(
// Manually fold tf.Cast that precedes the return instruction and only differs // Manually fold tf.Cast that precedes the return instruction and only differs
// in shape refinement level. // in shape refinement level.
for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) { for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) {
Operation* arg_defining_op = arg_op.get()->getDefiningOp(); Operation* arg_defining_op = arg_op.get().getDefiningOp();
if (auto cast_op = dyn_cast_or_null<CastOp>(arg_defining_op)) { if (auto cast_op = dyn_cast_or_null<CastOp>(arg_defining_op)) {
// Shape inference should not change the element type. // Shape inference should not change the element type.
if (cast_op.SrcT() != cast_op.DstT()) continue; if (cast_op.SrcT() != cast_op.DstT()) continue;
// We only refine the result shape if the result a dynamic shape, the // We only refine the result shape if the result a dynamic shape, the
// input has static shape, and the two shapes are compatible. // input has static shape, and the two shapes are compatible.
auto has_static_shape = [](const Value value) { auto has_static_shape = [](const Value value) {
auto shaped_type = value->getType().dyn_cast<ShapedType>(); auto shaped_type = value.getType().dyn_cast<ShapedType>();
return shaped_type && shaped_type.hasStaticShape(); return shaped_type && shaped_type.hasStaticShape();
}; };
Value input = cast_op.x(); Value input = cast_op.x();
Value result = cast_op.y(); Value result = cast_op.y();
if (!has_static_shape(input) || has_static_shape(result) || if (!has_static_shape(input) || has_static_shape(result) ||
failed(verifyCompatibleShape(input->getType(), result->getType()))) failed(verifyCompatibleShape(input.getType(), result.getType())))
continue; continue;
arg_op.set(cast_op.x()); arg_op.set(cast_op.x());
if (cast_op.y()->use_empty()) cast_op.erase(); if (cast_op.y().use_empty()) cast_op.erase();
} }
} }
@ -111,7 +111,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
// This is necessary to avoid reprocessing the tf.Cast that are inserted at // This is necessary to avoid reprocessing the tf.Cast that are inserted at
// the end of this function. // the end of this function.
if (isa<CastOp>(op) && if (isa<CastOp>(op) &&
llvm::all_of(op->getResult(0)->getUsers(), [&](Operation* user) { llvm::all_of(op->getResult(0).getUsers(), [&](Operation* user) {
return user->getDialect() != tf_dialect; return user->getDialect() != tf_dialect;
})) { })) {
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for tf.Cast with no TF " LLVM_DEBUG(llvm::dbgs() << "Skipping inference for tf.Cast with no TF "
@ -178,7 +178,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
} }
} }
Type operand_type = operand->getType(); Type operand_type = operand.getType();
if (auto ranked_type = operand_type.dyn_cast<RankedTensorType>()) { if (auto ranked_type = operand_type.dyn_cast<RankedTensorType>()) {
// Convert the MLIR shape indices (int64_t) to TensorFlow indices (int64). // Convert the MLIR shape indices (int64_t) to TensorFlow indices (int64).
ArrayRef<int64_t> shape = ranked_type.getShape(); ArrayRef<int64_t> shape = ranked_type.getShape();
@ -215,7 +215,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
for (int output : llvm::seq<int>(0, c.num_outputs())) { for (int output : llvm::seq<int>(0, c.num_outputs())) {
// Skip already statically shaped results. // Skip already statically shaped results.
Value result = op->getResult(output); Value result = op->getResult(output);
auto shaped_type = result->getType().dyn_cast<ShapedType>(); auto shaped_type = result.getType().dyn_cast<ShapedType>();
if (!shaped_type || shaped_type.hasStaticShape()) continue; if (!shaped_type || shaped_type.hasStaticShape()) continue;
tensorflow::shape_inference::ShapeHandle shape_handle = c.output(output); tensorflow::shape_inference::ShapeHandle shape_handle = c.output(output);
@ -235,18 +235,18 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
auto get_cast_op = [&]() { auto get_cast_op = [&]() {
if (!cast_op) if (!cast_op)
cast_op = cast_op =
builder.create<TF::CastOp>(op->getLoc(), result->getType(), result, builder.create<TF::CastOp>(op->getLoc(), result.getType(), result,
/*truncate=*/builder.getBoolAttr(false)); /*truncate=*/builder.getBoolAttr(false));
return cast_op; return cast_op;
}; };
for (OpOperand& use : llvm::make_early_inc_range(result->getUses())) { for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) {
if (use.getOwner()->getDialect() != tf_dialect) use.set(get_cast_op()); if (use.getOwner()->getDialect() != tf_dialect) use.set(get_cast_op());
} }
if (result->getType() == new_type) continue; if (result.getType() == new_type) continue;
// Finally we inferred the shape and replace the type for this result. // Finally we inferred the shape and replace the type for this result.
result->setType(new_type); result.setType(new_type);
changed = true; changed = true;
} }
if (changed) if (changed)
@ -284,7 +284,7 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
func.getContext())); func.getContext()));
for (auto arg_and_idx : llvm::enumerate(func.getArguments())) { for (auto arg_and_idx : llvm::enumerate(func.getArguments())) {
arg_and_idx.value()->setType(input_types[arg_and_idx.index()]); arg_and_idx.value().setType(input_types[arg_and_idx.index()]);
} }
auto res = auto res =
@ -307,7 +307,7 @@ LogicalResult PropagateShapeToIfWhileOpFunctions(
llvm::SmallVector<Type, 4> input_types; llvm::SmallVector<Type, 4> input_types;
input_types.reserve(std::distance(op.input().begin(), op.input().end())); input_types.reserve(std::distance(op.input().begin(), op.input().end()));
for (Value v : op.input()) { for (Value v : op.input()) {
input_types.push_back(v->getType()); input_types.push_back(v.getType());
} }
ModuleOp module = op.template getParentOfType<ModuleOp>(); ModuleOp module = op.template getParentOfType<ModuleOp>();
@ -414,7 +414,7 @@ LogicalResult InferShapeForFunction(FuncOp func,
auto new_arg_type = mlir::RankedTensorType::get(shape, element_type); auto new_arg_type = mlir::RankedTensorType::get(shape, element_type);
if (new_arg_type != func_type.getInput(i)) { if (new_arg_type != func_type.getInput(i)) {
// If the new type is more detailed, trigger shape inference. // If the new type is more detailed, trigger shape inference.
func.getArgument(i)->setType(new_arg_type); func.getArgument(i).setType(new_arg_type);
needs_refinement = true; needs_refinement = true;
} }
new_arg_types.push_back(new_arg_type); new_arg_types.push_back(new_arg_type);

View File

@ -52,8 +52,7 @@ class ExecutorConstantSinking
Region &body = launch.body(); Region &body = launch.body();
visitUsedValuesDefinedAbove(body, [&](OpOperand *use) { visitUsedValuesDefinedAbove(body, [&](OpOperand *use) {
Value constant = use->get(); Value constant = use->get();
auto const_op = auto const_op = dyn_cast_or_null<TF::ConstOp>(constant.getDefiningOp());
dyn_cast_or_null<TF::ConstOp>(constant->getDefiningOp());
if (!const_op) return; if (!const_op) return;
// We found a constant, try to insert it in the map and re-use its // We found a constant, try to insert it in the map and re-use its
@ -62,13 +61,13 @@ class ExecutorConstantSinking
if (!map_entry.second) { if (!map_entry.second) {
// This constant has already been cloned into the region, reuse it. // This constant has already been cloned into the region, reuse it.
use->set(map_entry.first->getSecond().getResult()); use->set(map_entry.first->getSecond().getResult());
LLVM_DEBUG(llvm::dbgs() << "Re-use sunk constant " << *use->get() LLVM_DEBUG(llvm::dbgs() << "Re-use sunk constant " << use->get()
<< "\n in " << *use->get() << "\n"); << "\n in " << use->get() << "\n");
if (constant->use_empty()) const_op.erase(); if (constant.use_empty()) const_op.erase();
return; return;
} }
if (constant->hasOneUse()) { if (constant.hasOneUse()) {
LLVM_DEBUG(llvm::dbgs() << "Moved constant " << *constant << "\n"); LLVM_DEBUG(llvm::dbgs() << "Moved constant " << constant << "\n");
const_op.getOperation()->moveBefore(&body.begin()->front()); const_op.getOperation()->moveBefore(&body.begin()->front());
return; return;
} }
@ -76,8 +75,8 @@ class ExecutorConstantSinking
body.begin()->getOperations().insert(body.begin()->begin(), body.begin()->getOperations().insert(body.begin()->begin(),
map_entry.first->getSecond()); map_entry.first->getSecond());
use->set(map_entry.first->getSecond().getResult()); use->set(map_entry.first->getSecond().getResult());
LLVM_DEBUG(llvm::dbgs() << "Sunk cloned constant " << *use->get() LLVM_DEBUG(llvm::dbgs() << "Sunk cloned constant " << use->get()
<< "\n in " << *use->get() << "\n"); << "\n in " << use->get() << "\n");
}); });
}); });
} }

View File

@ -141,7 +141,7 @@ bool ShouldMoveOpAfterCluster(
const llvm::SmallSetVector<Operation*, 8>& preceding_users) { const llvm::SmallSetVector<Operation*, 8>& preceding_users) {
auto result = op->walk([&](Operation* op) { auto result = op->walk([&](Operation* op) {
for (Value operand : op->getOperands()) { for (Value operand : op->getOperands()) {
Operation* def = operand->getDefiningOp(); Operation* def = operand.getDefiningOp();
// Operands may not have a defining op (BlockArgument) or is from a // Operands may not have a defining op (BlockArgument) or is from a
// different block. // different block.
if (!def || def->getBlock() != block) continue; if (!def || def->getBlock() != block) continue;
@ -185,7 +185,7 @@ llvm::SmallVector<Value, 8> CollectClusterResults(
for (Operation* op : cluster_ops) { for (Operation* op : cluster_ops) {
for (Value result : op->getResults()) { for (Value result : op->getResults()) {
for (Operation* user : result->getUsers()) { for (Operation* user : result.getUsers()) {
// Check if user is not an op in the cluster. // Check if user is not an op in the cluster.
if (cluster_ops.count(block->findAncestorOpInBlock(*user)) == 0) { if (cluster_ops.count(block->findAncestorOpInBlock(*user)) == 0) {
results.push_back(result); results.push_back(result);
@ -206,7 +206,7 @@ tf_device::LaunchOp CreateLaunchOpForCluster(Operation* last_cluster_op,
OpBuilder builder(last_cluster_op); OpBuilder builder(last_cluster_op);
llvm::SmallVector<Type, 8> result_types; llvm::SmallVector<Type, 8> result_types;
for (Value result : results) result_types.push_back(result->getType()); for (Value result : results) result_types.push_back(result.getType());
// An empty string placeholder is used for the device as that will be later // An empty string placeholder is used for the device as that will be later
// populated with the device of the associated TPUReplicateMetadata op. // populated with the device of the associated TPUReplicateMetadata op.
@ -246,7 +246,7 @@ void UpdateLaunchOpResultExternalUses(tf_device::LaunchOp launch_op,
for (auto ret_vals : llvm::zip(results, launch_op.getResults())) { for (auto ret_vals : llvm::zip(results, launch_op.getResults())) {
Value old_ret = std::get<0>(ret_vals); Value old_ret = std::get<0>(ret_vals);
Value new_ret = std::get<1>(ret_vals); Value new_ret = std::get<1>(ret_vals);
for (auto& use : old_ret->getUses()) for (auto& use : old_ret.getUses())
if (!launch_op_block.findAncestorOpInBlock(*use.getOwner())) if (!launch_op_block.findAncestorOpInBlock(*use.getOwner()))
use.set(new_ret); use.set(new_ret);
} }
@ -307,7 +307,7 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op,
llvm::SmallSetVector<Operation*, 8> unique_replicated_input_ops; llvm::SmallSetVector<Operation*, 8> unique_replicated_input_ops;
mlir::visitUsedValuesDefinedAbove( mlir::visitUsedValuesDefinedAbove(
launch_op.body(), launch_op.body(), [&](mlir::OpOperand* operand) { launch_op.body(), launch_op.body(), [&](mlir::OpOperand* operand) {
Operation* def = operand->get()->getDefiningOp(); Operation* def = operand->get().getDefiningOp();
if (def && llvm::isa<TF::TPUReplicatedInputOp>(def)) if (def && llvm::isa<TF::TPUReplicatedInputOp>(def))
unique_replicated_input_ops.insert(def); unique_replicated_input_ops.insert(def);
}); });
@ -339,7 +339,7 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op,
for (auto result_and_idx : llvm::enumerate(launch_op.getResults())) { for (auto result_and_idx : llvm::enumerate(launch_op.getResults())) {
Value result = result_and_idx.value(); Value result = result_and_idx.value();
int idx = result_and_idx.index(); int idx = result_and_idx.index();
for (auto& use : result->getUses()) { for (auto& use : result.getUses()) {
Operation* def = use.getOwner(); Operation* def = use.getOwner();
if (!def || !llvm::isa<TF::TPUReplicatedOutputOp>(def)) if (!def || !llvm::isa<TF::TPUReplicatedOutputOp>(def))
return launch_op.emitError() return launch_op.emitError()
@ -470,7 +470,7 @@ void TPUClusterFormation::runOnFunction() {
// `tf_device.replicate` is created and replicated (1) operands/results are // `tf_device.replicate` is created and replicated (1) operands/results are
// untouched. // untouched.
if (op->getNumOperands() == 1 && op->getNumResults() == 1) if (op->getNumOperands() == 1 && op->getNumResults() == 1)
op->getResult(0)->replaceAllUsesWith(op->getOperand(0)); op->getResult(0).replaceAllUsesWith(op->getOperand(0));
// Leftover TPUReplicatedInput/TPUReplicatedOutput that are not of // Leftover TPUReplicatedInput/TPUReplicatedOutput that are not of
// `num_replicas` to 1. // `num_replicas` to 1.

View File

@ -60,9 +60,9 @@ llvm::SmallDenseMap<int32_t, int32_t> GetRemappedReplicatedInputIndices(
llvm::SmallDenseMap<int32_t, int32_t> remapped_indices; llvm::SmallDenseMap<int32_t, int32_t> remapped_indices;
for (auto operand_and_idx : llvm::enumerate(launch_func.getOperands())) for (auto operand_and_idx : llvm::enumerate(launch_func.getOperands()))
if (auto block_arg = operand_and_idx.value()->dyn_cast<BlockArgument>()) if (auto block_arg = operand_and_idx.value().dyn_cast<BlockArgument>())
if (block_arg->getOwner() == replicate_block) if (block_arg.getOwner() == replicate_block)
remapped_indices[block_arg->getArgNumber()] = operand_and_idx.index(); remapped_indices[block_arg.getArgNumber()] = operand_and_idx.index();
return remapped_indices; return remapped_indices;
} }

View File

@ -135,23 +135,23 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
// Find inputs that are variable reads. // Find inputs that are variable reads.
for (auto operand : llvm::enumerate(execute->getOpOperands())) { for (auto operand : llvm::enumerate(execute->getOpOperands())) {
infos.new_operand_values.push_back(operand.value().get()); infos.new_operand_values.push_back(operand.value().get());
if (!operand.value().get()->getDefiningOp()) continue; if (!operand.value().get().getDefiningOp()) continue;
auto read_op = llvm::dyn_cast<TF::ReadVariableOp>( auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(
operand.value().get()->getDefiningOp()); operand.value().get().getDefiningOp());
if (!read_op) continue; if (!read_op) continue;
auto resource = read_op.resource(); auto resource = read_op.resource();
if (check_device) { if (check_device) {
if (auto resource_op = resource->getDefiningOp()) { if (auto resource_op = resource.getDefiningOp()) {
auto resource_attr = resource_op->getAttr(kDeviceAttr); auto resource_attr = resource_op->getAttr(kDeviceAttr);
// Check device matching for the node defining the resource. // Check device matching for the node defining the resource.
if (!resource_attr || resource_attr != device_attr) continue; if (!resource_attr || resource_attr != device_attr) continue;
} else { } else {
auto resource_arg = resource->dyn_cast<BlockArgument>(); auto resource_arg = resource.dyn_cast<BlockArgument>();
assert(resource_arg); assert(resource_arg);
// Check device matching for the argument defining the resource. // Check device matching for the argument defining the resource.
auto resource_attr = func.getArgAttrOfType<mlir::StringAttr>( auto resource_attr = func.getArgAttrOfType<mlir::StringAttr>(
resource_arg->getArgNumber(), kFuncDeviceAttr); resource_arg.getArgNumber(), kFuncDeviceAttr);
if (!resource_attr || resource_attr != device_attr) continue; if (!resource_attr || resource_attr != device_attr) continue;
} }
} }
@ -222,9 +222,8 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
llvm::SmallVector<bool, 8> output_fused(execute->getNumResults(), false); llvm::SmallVector<bool, 8> output_fused(execute->getNumResults(), false);
for (int i = 0; i < execute->getNumResults(); ++i) { for (int i = 0; i < execute->getNumResults(); ++i) {
auto result = execute->getResult(i); auto result = execute->getResult(i);
if (!result->hasOneUse()) continue; if (!result.hasOneUse()) continue;
auto assign_op = auto assign_op = llvm::dyn_cast<TF::AssignVariableOp>(*result.user_begin());
llvm::dyn_cast<TF::AssignVariableOp>(*result->user_begin());
if (!assign_op) continue; if (!assign_op) continue;
auto resource = assign_op.resource(); auto resource = assign_op.resource();
auto it = infos.per_resource_info.find(resource); auto it = infos.per_resource_info.find(resource);
@ -330,7 +329,7 @@ void MergeForOneTPUExecute(Operation* execute, bool check_device,
// Replace the uses. // Replace the uses.
for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) { for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) {
if (infos.old_to_new_output_mapping[i] < 0) continue; if (infos.old_to_new_output_mapping[i] < 0) continue;
execute->getResult(i)->replaceAllUsesWith( execute->getResult(i).replaceAllUsesWith(
merged_execute.getResult(infos.old_to_new_output_mapping[i])); merged_execute.getResult(infos.old_to_new_output_mapping[i]));
} }
// Remove the assign ops. // Remove the assign ops.

View File

@ -457,7 +457,7 @@ LogicalResult Rewrite(
// the other ops that are intended to consume the compile result. // the other ops that are intended to consume the compile result.
Block* block = launch_func.getOperation()->getBlock(); Block* block = launch_func.getOperation()->getBlock();
for (auto compile_result_op : block->getOps<TF::TPUCompilationResultOp>()) for (auto compile_result_op : block->getOps<TF::TPUCompilationResultOp>())
compile_result_op.output()->replaceAllUsesWith(compile_op->getResult(0)); compile_result_op.output().replaceAllUsesWith(compile_op->getResult(0));
BuildTPUCompileSucceededAssertOp(compile_op, builder); BuildTPUCompileSucceededAssertOp(compile_op, builder);

View File

@ -97,11 +97,11 @@ void BreakUpIslands::runOnOperation() {
dups.clear(); dups.clear();
for (Value input : edges) { for (Value input : edges) {
dups.insert(input->getDefiningOp()); dups.insert(input.getDefiningOp());
} }
// Insert new control edges removing duplicates. // Insert new control edges removing duplicates.
for (Value value : llvm::reverse(edge.second)) { for (Value value : llvm::reverse(edge.second)) {
if (dups.insert(value->getDefiningOp()).second) edges.push_back(value); if (dups.insert(value.getDefiningOp()).second) edges.push_back(value);
} }
state.addOperands(edges); state.addOperands(edges);
Operation* new_op = builder.createOperation(state); Operation* new_op = builder.createOperation(state);
@ -160,7 +160,7 @@ IslandSourcesAndSinks FindSourcesAndSinksInIsland(
for (auto predecessor : predecessors) result.sinks.erase(predecessor); for (auto predecessor : predecessors) result.sinks.erase(predecessor);
bool has_in_island_operands = false; bool has_in_island_operands = false;
for (auto operand : sub_op.getOperands()) { for (auto operand : sub_op.getOperands()) {
auto defining_op = operand->getDefiningOp(); auto defining_op = operand.getDefiningOp();
if (!defining_op || defining_op->getParentOp() != island) continue; if (!defining_op || defining_op->getParentOp() != island) continue;
// Remove operands from sinks. // Remove operands from sinks.
result.sinks.erase(defining_op); result.sinks.erase(defining_op);
@ -190,16 +190,16 @@ void BreakUpIslands::BreakUpIsland(
// the island that defines that fetched value. // the island that defines that fetched value.
for (auto fetch : op.GetYield().fetches()) { for (auto fetch : op.GetYield().fetches()) {
// Ok, because there is no op to add control to (eg: function args). // Ok, because there is no op to add control to (eg: function args).
if (!fetch->getDefiningOp()) continue; if (!fetch.getDefiningOp()) continue;
if (fetch->getDefiningOp()->getParentOp() == op) { if (fetch.getDefiningOp()->getParentOp() == op) {
// OK, because it is the same island. // OK, because it is the same island.
} else if (auto island_op = llvm::dyn_cast<tf_executor::IslandOp>( } else if (auto island_op = llvm::dyn_cast<tf_executor::IslandOp>(
fetch->getDefiningOp())) { fetch.getDefiningOp())) {
island_control_inputs.push_back(island_op.control()); island_control_inputs.push_back(island_op.control());
} else { } else {
// TODO(parkers): Any defining op that has a control output can be handled // TODO(parkers): Any defining op that has a control output can be handled
// just like an island. // just like an island.
fetch->getDefiningOp()->emitError("Fetching non-island as dependency."); fetch.getDefiningOp()->emitError("Fetching non-island as dependency.");
return signalPassFailure(); return signalPassFailure();
} }
} }
@ -255,11 +255,11 @@ void BreakUpIslands::BreakUpIsland(
sink_island_controls.push_back(island.control()); sink_island_controls.push_back(island.control());
} }
assert(sink_island_controls.size() == 1); assert(sink_island_controls.size() == 1);
op.control()->replaceAllUsesWith(sink_island_controls[0]); op.control().replaceAllUsesWith(sink_island_controls[0]);
// All existing outputs need to add a control flow edge from // All existing outputs need to add a control flow edge from
// sink_island_controls[0]. // sink_island_controls[0].
for (Value out : op.outputs()) { for (Value out : op.outputs()) {
for (auto& use : out->getUses()) { for (auto& use : out.getUses()) {
Operation* owner = use.getOwner(); Operation* owner = use.getOwner();
if (auto island_op = if (auto island_op =
llvm::dyn_cast<tf_executor::IslandOp>(owner->getParentOp())) { llvm::dyn_cast<tf_executor::IslandOp>(owner->getParentOp())) {
@ -275,7 +275,7 @@ void BreakUpIslands::BreakUpIsland(
} }
} }
for (auto item : llvm::zip(op.outputs(), op.GetYield().fetches())) for (auto item : llvm::zip(op.outputs(), op.GetYield().fetches()))
std::get<0>(item)->replaceAllUsesWith(std::get<1>(item)); std::get<0>(item).replaceAllUsesWith(std::get<1>(item));
op.erase(); op.erase();
} }

View File

@ -70,7 +70,7 @@ tf_executor::IslandOp ControlToExecutorDialectConversion::CreateIslandForOp(
// Create a new region for the tf_executor.island body // Create a new region for the tf_executor.island body
SmallVector<Value, 8> operands; SmallVector<Value, 8> operands;
for (Value operand : op->getOperands()) for (Value operand : op->getOperands())
if (operand->getType().isa<tf_executor::ControlType>()) if (operand.getType().isa<tf_executor::ControlType>())
operands.push_back(operand); operands.push_back(operand);
SmallVector<Type, 8> types; SmallVector<Type, 8> types;
for (Type result_type : op->getResultTypes()) for (Type result_type : op->getResultTypes())
@ -155,7 +155,7 @@ void ControlToExecutorDialectConversion::runOnFunction() {
loc, types, operands, ArrayRef<NamedAttribute>{}); loc, types, operands, ArrayRef<NamedAttribute>{});
} else if (op.getName().getStringRef() == "_tf.NextIteration.source") { } else if (op.getName().getStringRef() == "_tf.NextIteration.source") {
replacement = builder.create<tf_executor::NextIterationSourceOp>( replacement = builder.create<tf_executor::NextIterationSourceOp>(
loc, op.getResult(0)->getType()); loc, op.getResult(0).getType());
// Record a mapping of the name to the nextiteration.source so that when // Record a mapping of the name to the nextiteration.source so that when
// we convert the sink we can get the token. // we convert the sink we can get the token.
StringAttr frame = op.getAttrOfType<StringAttr>("name"); StringAttr frame = op.getAttrOfType<StringAttr>("name");
@ -164,9 +164,9 @@ void ControlToExecutorDialectConversion::runOnFunction() {
cast<tf_executor::NextIterationSourceOp>(replacement); cast<tf_executor::NextIterationSourceOp>(replacement);
// Replace the results here since the _tf source does not produce a token // Replace the results here since the _tf source does not produce a token
// there isn't a mapping for the new result #1. // there isn't a mapping for the new result #1.
op.getResult(0)->replaceAllUsesWith(replacement->getResult(0)); op.getResult(0).replaceAllUsesWith(replacement->getResult(0));
for (int i : llvm::seq<int>(1, op.getNumResults())) for (int i : llvm::seq<int>(1, op.getNumResults()))
op.getResult(i)->replaceAllUsesWith(replacement->getResult(i + 1)); op.getResult(i).replaceAllUsesWith(replacement->getResult(i + 1));
replacement->setAttrs(op.getAttrList()); replacement->setAttrs(op.getAttrList());
op.erase(); op.erase();
continue; continue;
@ -202,7 +202,7 @@ void ControlToExecutorDialectConversion::runOnFunction() {
// Only the non-control operands are carried over, the island is handling // Only the non-control operands are carried over, the island is handling
// the control input. // the control input.
for (Value operand : op.getOperands()) for (Value operand : op.getOperands())
if (!operand->getType().isa<tf_executor::ControlType>()) if (!operand.getType().isa<tf_executor::ControlType>())
result.operands.push_back(operand); result.operands.push_back(operand);
// Add a result type for each non-control result we find // Add a result type for each non-control result we find
@ -232,7 +232,7 @@ void ControlToExecutorDialectConversion::runOnFunction() {
if (!isa<tf_executor::IslandOp>(replacement)) if (!isa<tf_executor::IslandOp>(replacement))
replacement->setAttrs(op.getAttrList()); replacement->setAttrs(op.getAttrList());
for (int i : llvm::seq<int>(0, op.getNumResults())) for (int i : llvm::seq<int>(0, op.getNumResults()))
op.getResult(i)->replaceAllUsesWith(replacement->getResult(i)); op.getResult(i).replaceAllUsesWith(replacement->getResult(i));
op.erase(); op.erase();
} }
} }

View File

@ -79,7 +79,7 @@ void ExecutorToControlDialectConversion::runOnFunction() {
for (auto ops_and_ret_vals : for (auto ops_and_ret_vals :
llvm::zip(graph.getResults(), fetch.getOperands())) llvm::zip(graph.getResults(), fetch.getOperands()))
std::get<0>(ops_and_ret_vals) std::get<0>(ops_and_ret_vals)
->replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); .replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
op.erase(); op.erase();
continue; continue;
} }
@ -106,7 +106,7 @@ void ExecutorToControlDialectConversion::runOnFunction() {
for (auto ops_and_ret_vals : for (auto ops_and_ret_vals :
llvm::zip(island.getResults(), wrapped_op.getOperands())) llvm::zip(island.getResults(), wrapped_op.getOperands()))
std::get<0>(ops_and_ret_vals) std::get<0>(ops_and_ret_vals)
->replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); .replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
break; break;
} }
// Add a leading _ off the name. // Add a leading _ off the name.
@ -141,7 +141,7 @@ void ExecutorToControlDialectConversion::runOnFunction() {
for (auto ops_and_ret_vals : for (auto ops_and_ret_vals :
llvm::zip(wrapped_op.getResults(), replacement->getResults())) llvm::zip(wrapped_op.getResults(), replacement->getResults()))
std::get<0>(ops_and_ret_vals) std::get<0>(ops_and_ret_vals)
->replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); .replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
ctl_sequence = replacement->getResult(replacement->getNumResults() - 1); ctl_sequence = replacement->getResult(replacement->getNumResults() - 1);
} }
@ -151,13 +151,13 @@ void ExecutorToControlDialectConversion::runOnFunction() {
// been rewritten from ops in island. Last op rewritten must logically // been rewritten from ops in island. Last op rewritten must logically
// carry // all the island control inputs, we can simply use it to // carry // all the island control inputs, we can simply use it to
// replace all uses of island's control output. // replace all uses of island's control output.
island.control()->replaceAllUsesWith(ctl_sequence); island.control().replaceAllUsesWith(ctl_sequence);
} else if (island.getNumOperands() > 0) { } else if (island.getNumOperands() > 0) {
// Getting here means island had an effectively empty body and there is // Getting here means island had an effectively empty body and there is
// just one control input. In this case, island's control output should // just one control input. In this case, island's control output should
// be replaced with the control input. // be replaced with the control input.
assert(island.getNumOperands() == 1); assert(island.getNumOperands() == 1);
island.control()->replaceAllUsesWith(island.getOperand(0)); island.control().replaceAllUsesWith(island.getOperand(0));
} }
op.erase(); op.erase();
@ -192,7 +192,7 @@ void ExecutorToControlDialectConversion::runOnFunction() {
// dialect. // dialect.
auto non_null_operands = llvm::make_filter_range( auto non_null_operands = llvm::make_filter_range(
op.getOperands(), op.getOperands(),
[](Value v) { return !v->getType().isa<tf_executor::TokenType>(); }); [](Value v) { return !v.getType().isa<tf_executor::TokenType>(); });
state.operands.append(non_null_operands.begin(), non_null_operands.end()); state.operands.append(non_null_operands.begin(), non_null_operands.end());
for (Type result_type : op.getResultTypes()) { for (Type result_type : op.getResultTypes()) {
// Filter out TokenType, they don't exist in the control dialect. // Filter out TokenType, they don't exist in the control dialect.
@ -212,14 +212,14 @@ void ExecutorToControlDialectConversion::runOnFunction() {
if (auto next_iteration = if (auto next_iteration =
dyn_cast<tf_executor::NextIterationSourceOp>(op)) { dyn_cast<tf_executor::NextIterationSourceOp>(op)) {
next_iteration.output()->replaceAllUsesWith(replacement->getResult(0)); next_iteration.output().replaceAllUsesWith(replacement->getResult(0));
next_iteration.token()->dropAllUses(); next_iteration.token().dropAllUses();
next_iteration.control()->replaceAllUsesWith(replacement->getResult(1)); next_iteration.control().replaceAllUsesWith(replacement->getResult(1));
} else { } else {
for (auto ops_and_ret_vals : for (auto ops_and_ret_vals :
llvm::zip(op.getResults(), replacement->getResults())) llvm::zip(op.getResults(), replacement->getResults()))
std::get<0>(ops_and_ret_vals) std::get<0>(ops_and_ret_vals)
->replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); .replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
} }
op.erase(); op.erase();
} }

View File

@ -236,7 +236,7 @@ std::string Exporter::UniqueName(Operation* op) {
StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode( StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
BlockArgument arg, unsigned index, llvm::StringRef name) { BlockArgument arg, unsigned index, llvm::StringRef name) {
auto func = arg->getParentRegion()->getParentOfType<mlir::FuncOp>(); auto func = arg.getParentRegion()->getParentOfType<mlir::FuncOp>();
auto node_def = absl::make_unique<NodeDef>(); auto node_def = absl::make_unique<NodeDef>();
if (!name.empty()) if (!name.empty())
@ -248,7 +248,7 @@ StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
DataType dtype; DataType dtype;
TF_RETURN_IF_ERROR(ConvertToDataType( TF_RETURN_IF_ERROR(ConvertToDataType(
arg->getType().cast<mlir::TensorType>().getElementType(), &dtype)); arg.getType().cast<mlir::TensorType>().getElementType(), &dtype));
AttrValue type_attr; AttrValue type_attr;
type_attr.set_type(dtype); type_attr.set_type(dtype);
(*node_def->mutable_attr())["T"] = type_attr; (*node_def->mutable_attr())["T"] = type_attr;
@ -286,7 +286,7 @@ StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode(
auto inst_op = inst->getOperand(index); auto inst_op = inst->getOperand(index);
DataType dtype; DataType dtype;
TF_RETURN_IF_ERROR(ConvertToDataType( TF_RETURN_IF_ERROR(ConvertToDataType(
inst_op->getType().cast<mlir::TensorType>().getElementType(), &dtype)); inst_op.getType().cast<mlir::TensorType>().getElementType(), &dtype));
AttrValue type_attr; AttrValue type_attr;
type_attr.set_type(dtype); type_attr.set_type(dtype);
(*node_def->mutable_attr())["T"] = type_attr; (*node_def->mutable_attr())["T"] = type_attr;
@ -298,8 +298,8 @@ StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode(
Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node, Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node,
unsigned dst_index) { unsigned dst_index) {
if (auto input_result = src->dyn_cast<mlir::OpResult>()) { if (auto input_result = src.dyn_cast<mlir::OpResult>()) {
auto* input_inst = input_result->getOwner(); auto* input_inst = input_result.getOwner();
// replaces the input node by the sink one if it is an NextIteration source: // replaces the input node by the sink one if it is an NextIteration source:
auto it = source_to_sink_.find(input_inst); auto it = source_to_sink_.find(input_inst);
if (it != source_to_sink_.end()) { if (it != source_to_sink_.end()) {
@ -308,16 +308,16 @@ Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node,
auto node_it = nodes_.find(input_inst); auto node_it = nodes_.find(input_inst);
TF_RET_CHECK(node_it != nodes_.end()) TF_RET_CHECK(node_it != nodes_.end())
<< "Use of OpResult encountered before def!"; << "Use of OpResult encountered before def!";
if (input_result->getType().isa<mlir::TFControlFlow::TFControlType>()) { if (input_result.getType().isa<mlir::TFControlFlow::TFControlType>()) {
graph_->AddControlEdge(node_it->second, dst_node); graph_->AddControlEdge(node_it->second, dst_node);
} else { } else {
graph_->AddEdge(node_it->second, input_result->getResultNumber(), graph_->AddEdge(node_it->second, input_result.getResultNumber(), dst_node,
dst_node, dst_index); dst_index);
} }
return Status::OK(); return Status::OK();
} }
auto input_arg = src->cast<BlockArgument>(); auto input_arg = src.cast<BlockArgument>();
auto input_node_it = args_.find(input_arg); auto input_node_it = args_.find(input_arg);
TF_RET_CHECK(input_node_it != args_.end()) TF_RET_CHECK(input_node_it != args_.end())
<< "Use of BlockArgument encounted before def!"; << "Use of BlockArgument encounted before def!";
@ -366,7 +366,7 @@ Status Exporter::AddInstructionNode(Operation* inst) {
} }
bool IsEntryFunctionArg(BlockArgument arg) { bool IsEntryFunctionArg(BlockArgument arg) {
return arg->getParentRegion()->getParentOfType<mlir::FuncOp>().getName() == return arg.getParentRegion()->getParentOfType<mlir::FuncOp>().getName() ==
"main"; "main";
} }
@ -387,21 +387,21 @@ Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index,
// is an input node. We recover the original input node and skip adding the // is an input node. We recover the original input node and skip adding the
// argument node. The new input node will be handled as normal in the // argument node. The new input node will be handled as normal in the
// following steps. // following steps.
if (!arg->hasOneUse()) { if (!arg.hasOneUse()) {
return errors::FailedPrecondition( return errors::FailedPrecondition(
"Arg in 'main' should only have one user."); "Arg in 'main' should only have one user.");
} }
auto* input = *arg->user_begin(); auto* input = *arg.user_begin();
auto input_name = input->getName().getStringRef(); auto input_name = input->getName().getStringRef();
input_name.consume_back(".input"); input_name.consume_back(".input");
mlir::OpBuilder builder(arg->getOwner()); mlir::OpBuilder builder(arg.getOwner());
auto loc = mlir::NameLoc::get(builder.getIdentifier(UniqueName(input)), auto loc = mlir::NameLoc::get(builder.getIdentifier(UniqueName(input)),
builder.getContext()); builder.getContext());
OperationState state(loc, input_name.str()); OperationState state(loc, input_name.str());
state.attributes.append(input->getAttrs().begin(), input->getAttrs().end()); state.attributes.append(input->getAttrs().begin(), input->getAttrs().end());
for (auto op : input->getOperands()) { for (auto op : input->getOperands()) {
// Skip the argument in the new operation. // Skip the argument in the new operation.
if (op->isa<BlockArgument>()) continue; if (op.isa<BlockArgument>()) continue;
state.operands.push_back(op); state.operands.push_back(op);
} }
state.types.append(input->getResultTypes().begin(), state.types.append(input->getResultTypes().begin(),
@ -419,7 +419,7 @@ Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index,
<< inst << " to an empty string."; << inst << " to an empty string.";
mapped_name.assign(input_mapped_name); mapped_name.assign(input_mapped_name);
for (int index : llvm::seq<int>(0, input->getNumResults())) { for (int index : llvm::seq<int>(0, input->getNumResults())) {
input->getResult(index)->replaceAllUsesWith(inst->getResult(index)); input->getResult(index).replaceAllUsesWith(inst->getResult(index));
} }
input->dropAllReferences(); input->dropAllReferences();
input->erase(); input->erase();
@ -524,7 +524,7 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
// the main graph did not have its _Retval nodes lifted into the functions // the main graph did not have its _Retval nodes lifted into the functions
// returns. // returns.
if (!graph_as_function) { if (!graph_as_function) {
auto defining_op = it.value()->getDefiningOp(); auto defining_op = it.value().getDefiningOp();
auto& mapped_name = exporter.op_to_name_[defining_op]; auto& mapped_name = exporter.op_to_name_[defining_op];
DCHECK(mapped_name.empty()) DCHECK(mapped_name.empty())
<< "Convert() attempted to change the op_to_name_ mapping for " << "Convert() attempted to change the op_to_name_ mapping for "
@ -541,7 +541,7 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
// Only assign user of argument the input name if the main graph did not // Only assign user of argument the input name if the main graph did not
// have its _Arg nodes lifted into the functions arguments. // have its _Arg nodes lifted into the functions arguments.
if (!graph_as_function) { if (!graph_as_function) {
auto first_user = *it.value()->user_begin(); auto first_user = *it.value().user_begin();
auto& mapped_name = exporter.op_to_name_[first_user]; auto& mapped_name = exporter.op_to_name_[first_user];
DCHECK(mapped_name.empty()) DCHECK(mapped_name.empty())
<< "Convert() attempted to change the op_to_name_ mapping for " << "Convert() attempted to change the op_to_name_ mapping for "
@ -556,7 +556,7 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
for (auto it : llvm::enumerate(block.getArguments())) { for (auto it : llvm::enumerate(block.getArguments())) {
int index = it.index(); int index = it.index();
auto arg = it.value(); auto arg = it.value();
mlir::Type type = arg->getType(); mlir::Type type = arg.getType();
if (!type.isa<mlir::TensorType>()) { if (!type.isa<mlir::TensorType>()) {
return errors::InvalidArgument( return errors::InvalidArgument(
"FuncOps arguments must have tensor types. Found ", "FuncOps arguments must have tensor types. Found ",

View File

@ -132,7 +132,7 @@ StatusOr<std::unique_ptr<NodeDef>> ConvertTFDialectOpToNodeDef(
mlir::OperationState result(inst->getLoc(), mlir::OperationState result(inst->getLoc(),
inst->getName().getStringRef().drop_front()); inst->getName().getStringRef().drop_front());
for (mlir::Value operand : inst->getOperands()) for (mlir::Value operand : inst->getOperands())
if (!operand->getType().isa<mlir::TFControlFlow::TFControlType>()) if (!operand.getType().isa<mlir::TFControlFlow::TFControlType>())
result.operands.push_back(operand); result.operands.push_back(operand);
// Add a result type for each non-control result we find // Add a result type for each non-control result we find

View File

@ -1192,9 +1192,9 @@ Status ImporterBase::ConvertFunctionArgAndRets(
// Collect mapping of OutputTensor to associated block arg. // Collect mapping of OutputTensor to associated block arg.
arg_nodes_to_values.try_emplace({arg_node.node, arg_node.index}, arg_def); arg_nodes_to_values.try_emplace({arg_node.node, arg_node.index}, arg_def);
island->getResult(0)->replaceAllUsesWith(arg_def); island->getResult(0).replaceAllUsesWith(arg_def);
// Erase control outputs from feed. // Erase control outputs from feed.
auto control_uses = island->getResult(1)->getUses(); auto control_uses = island->getResult(1).getUses();
for (auto& control_use : llvm::make_early_inc_range(control_uses)) for (auto& control_use : llvm::make_early_inc_range(control_uses))
control_use.getOwner()->eraseOperand(control_use.getOperandNumber()); control_use.getOwner()->eraseOperand(control_use.getOperandNumber());
@ -1389,7 +1389,7 @@ mlir::Operation* ImporterBase::createOperation(
builder_.getBlock()->begin()); builder_.getBlock()->begin());
auto source_op = auto source_op =
builder_at_begin.create<mlir::tf_executor::NextIterationSourceOp>( builder_at_begin.create<mlir::tf_executor::NextIterationSourceOp>(
loc, operands[0]->getType(), result.attributes); loc, operands[0].getType(), result.attributes);
return builder_.create<mlir::tf_executor::NextIterationSinkOp>( return builder_.create<mlir::tf_executor::NextIterationSinkOp>(
loc, source_op.token(), operands, result.attributes); loc, source_op.token(), operands, result.attributes);
} }
@ -1654,7 +1654,7 @@ Status ImporterBase::AddBackedges() {
Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst, Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
int dst_input) { int dst_input) {
// Get the NextIteration.Source operation from the token operand of the sink. // Get the NextIteration.Source operation from the token operand of the sink.
mlir::Operation* source = sink->getOperand(0)->getDefiningOp(); mlir::Operation* source = sink->getOperand(0).getDefiningOp();
// Adds the "source" to the operands of the dst by creating a new dst // Adds the "source" to the operands of the dst by creating a new dst
// operation. // operation.
@ -1680,7 +1680,7 @@ Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
// result of the new operation, and deletes the old operation. // result of the new operation, and deletes the old operation.
for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) { for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) {
auto new_output = new_dst->getResult(i); auto new_output = new_dst->getResult(i);
dst->getResult(i)->replaceAllUsesWith(new_output); dst->getResult(i).replaceAllUsesWith(new_output);
} }
dst->dropAllReferences(); dst->dropAllReferences();
dst->erase(); dst->erase();

View File

@ -222,12 +222,12 @@ static bool IsRefTypeControlOp(mlir::Operation* op) {
auto op_name = op_name_or_status.ConsumeValueOrDie(); auto op_name = op_name_or_status.ConsumeValueOrDie();
if (op_name.equals("NextIteration")) if (op_name.equals("NextIteration"))
return mlir::getElementTypeOrSelf(op->getOperand(0)->getType()) return mlir::getElementTypeOrSelf(op->getOperand(0).getType())
.isa<mlir::TF::TensorFlowRefType>(); .isa<mlir::TF::TensorFlowRefType>();
if (op_name.equals("Enter") || op_name.equals("Exit") || if (op_name.equals("Enter") || op_name.equals("Exit") ||
op_name.equals("Switch") || op_name.equals("Merge")) { op_name.equals("Switch") || op_name.equals("Merge")) {
return getElementTypeOrSelf(op->getResult(0)->getType()) return getElementTypeOrSelf(op->getResult(0).getType())
.isa<mlir::TF::TensorFlowRefType>(); .isa<mlir::TF::TensorFlowRefType>();
} }
return false; return false;

View File

@ -407,7 +407,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
} }
case HloOpcode::kWhile: { case HloOpcode::kWhile: {
auto op = func_builder->create<mlir::xla_hlo::WhileOp>( auto op = func_builder->create<mlir::xla_hlo::WhileOp>(
loc, operands[0]->getType(), operands[0]); loc, operands[0].getType(), operands[0]);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ImportComputation(instruction->while_condition(), &op.cond())); ImportComputation(instruction->while_condition(), &op.cond()));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(

View File

@ -175,7 +175,7 @@ void ConstOp::build(Builder* builder, OperationState& result, Attribute value) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) { OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
const auto output_type = getResult()->getType().cast<ShapedType>(); const auto output_type = getResult().getType().cast<ShapedType>();
const auto output_size = output_type.getNumElements(); const auto output_size = output_type.getNumElements();
const auto dimension = iota_dimension().getSExtValue(); const auto dimension = iota_dimension().getSExtValue();
const auto max_dim_size = output_type.getDimSize(dimension); const auto max_dim_size = output_type.getDimSize(dimension);
@ -204,15 +204,14 @@ OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
void AbsOp::build(Builder* builder, OperationState& result, Value operand) { void AbsOp::build(Builder* builder, OperationState& result, Value operand) {
auto shaped_type = operand->getType().cast<ShapedType>(); auto shaped_type = operand.getType().cast<ShapedType>();
Type new_type; Type new_type;
if (!shaped_type.getElementType().isa<ComplexType>()) { if (!shaped_type.getElementType().isa<ComplexType>()) {
new_type = operand->getType(); new_type = operand.getType();
} else if (shaped_type.hasRank()) { } else if (shaped_type.hasRank()) {
new_type = new_type = RankedTensorType::get(shaped_type.getShape(), operand.getType());
RankedTensorType::get(shaped_type.getShape(), operand->getType());
} else { } else {
new_type = UnrankedTensorType::get(operand->getType()); new_type = UnrankedTensorType::get(operand.getType());
} }
return AbsOp::build(builder, result, new_type, operand); return AbsOp::build(builder, result, new_type, operand);
@ -225,7 +224,7 @@ void AbsOp::build(Builder* builder, OperationState& result, Value operand) {
void ConvertOp::build(Builder* builder, OperationState& result, Value operand, void ConvertOp::build(Builder* builder, OperationState& result, Value operand,
Type result_element_ty) { Type result_element_ty) {
Type result_ty; Type result_ty;
Type operand_ty = operand->getType(); Type operand_ty = operand.getType();
if (auto ranked_ty = operand_ty.dyn_cast<RankedTensorType>()) { if (auto ranked_ty = operand_ty.dyn_cast<RankedTensorType>()) {
result_ty = RankedTensorType::get(ranked_ty.getShape(), result_element_ty); result_ty = RankedTensorType::get(ranked_ty.getShape(), result_element_ty);
} else { } else {
@ -235,7 +234,7 @@ void ConvertOp::build(Builder* builder, OperationState& result, Value operand,
} }
OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
if (getOperand()->getType() == getResult()->getType()) return getOperand(); if (getOperand().getType() == getResult().getType()) return getOperand();
// If the operand is constant, we can do the conversion now. // If the operand is constant, we can do the conversion now.
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) { if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
@ -252,7 +251,7 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
static LogicalResult Verify(GetTupleElementOp op) { static LogicalResult Verify(GetTupleElementOp op) {
auto indexVal = op.index().getZExtValue(); auto indexVal = op.index().getZExtValue();
auto operandType = op.getOperand()->getType().cast<TupleType>(); auto operandType = op.getOperand().getType().cast<TupleType>();
if (indexVal >= operandType.size()) { if (indexVal >= operandType.size()) {
return op.emitOpError( return op.emitOpError(
llvm::formatv("index {0} is out of bounds of operand with size {1}", llvm::formatv("index {0} is out of bounds of operand with size {1}",
@ -269,7 +268,7 @@ static LogicalResult Verify(GetTupleElementOp op) {
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) { OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
if (auto tupleOp = if (auto tupleOp =
dyn_cast_or_null<xla_hlo::TupleOp>(getOperand()->getDefiningOp())) { dyn_cast_or_null<xla_hlo::TupleOp>(getOperand().getDefiningOp())) {
return tupleOp.getOperand(index().getLimitedValue()); return tupleOp.getOperand(index().getLimitedValue());
} }
@ -305,9 +304,9 @@ static LogicalResult Verify(BroadcastOp op) {
"broadcast_sizes has rank {0} instead of rank 1", sizesRank)); "broadcast_sizes has rank {0} instead of rank 1", sizesRank));
} }
auto resultType = op.getResult()->getType().cast<RankedTensorType>(); auto resultType = op.getResult().getType().cast<RankedTensorType>();
auto resultRank = resultType.getRank(); auto resultRank = resultType.getRank();
auto operandType = op.operand()->getType().cast<RankedTensorType>(); auto operandType = op.operand().getType().cast<RankedTensorType>();
auto operandRank = operandType.getRank(); auto operandRank = operandType.getRank();
auto sizesSize = sizesType.getNumElements(); auto sizesSize = sizesType.getNumElements();
auto expectedRank = operandRank + sizesSize; auto expectedRank = operandRank + sizesSize;
@ -341,7 +340,7 @@ static LogicalResult Verify(BroadcastOp op) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult Verify(BroadcastInDimOp op) { static LogicalResult Verify(BroadcastInDimOp op) {
auto operandType = op.operand()->getType().cast<RankedTensorType>(); auto operandType = op.operand().getType().cast<RankedTensorType>();
auto operandRank = operandType.getRank(); auto operandRank = operandType.getRank();
if (!op.broadcast_dimensions()) { if (!op.broadcast_dimensions()) {
if (operandRank == 0) { if (operandRank == 0) {
@ -368,7 +367,7 @@ static LogicalResult Verify(BroadcastInDimOp op) {
dimensionsSize, operandRank)); dimensionsSize, operandRank));
} }
auto resultType = op.getResult()->getType().cast<RankedTensorType>(); auto resultType = op.getResult().getType().cast<RankedTensorType>();
auto resultRank = resultType.getRank(); auto resultRank = resultType.getRank();
if (resultRank < operandRank) { if (resultRank < operandRank) {
return op.emitOpError( return op.emitOpError(
@ -403,9 +402,9 @@ static LogicalResult Verify(BroadcastInDimOp op) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult Verify(ClampOp op) { static LogicalResult Verify(ClampOp op) {
auto operandType = op.operand()->getType().cast<RankedTensorType>(); auto operandType = op.operand().getType().cast<RankedTensorType>();
auto operandShape = operandType.getShape(); auto operandShape = operandType.getShape();
auto minType = op.min()->getType().cast<RankedTensorType>(); auto minType = op.min().getType().cast<RankedTensorType>();
auto minShape = minType.getShape(); auto minShape = minType.getShape();
if (minShape != operandShape && minType.getRank() != 0) { if (minShape != operandShape && minType.getRank() != 0) {
@ -415,7 +414,7 @@ static LogicalResult Verify(ClampOp op) {
llvm::make_range(operandShape.begin(), operandShape.end()))); llvm::make_range(operandShape.begin(), operandShape.end())));
} }
auto maxType = op.max()->getType().cast<RankedTensorType>(); auto maxType = op.max().getType().cast<RankedTensorType>();
auto maxShape = maxType.getShape(); auto maxShape = maxType.getShape();
if (maxShape != operandShape && maxType.getRank() != 0) { if (maxShape != operandShape && maxType.getRank() != 0) {
return op.emitOpError(llvm::formatv( return op.emitOpError(llvm::formatv(
@ -433,7 +432,7 @@ static LogicalResult Verify(ClampOp op) {
void ComplexOp::build(Builder* builder, OperationState& state, Value lhs, void ComplexOp::build(Builder* builder, OperationState& state, Value lhs,
Value rhs) { Value rhs) {
auto type = lhs->getType(); auto type = lhs.getType();
auto element_ty = ComplexType::get(getElementTypeOrSelf(type)); auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
Type result_ty; Type result_ty;
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) { if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
@ -449,9 +448,9 @@ void ComplexOp::build(Builder* builder, OperationState& state, Value lhs,
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
auto real_op = auto real_op =
dyn_cast_or_null<xla_hlo::RealOp>(getOperand(0)->getDefiningOp()); dyn_cast_or_null<xla_hlo::RealOp>(getOperand(0).getDefiningOp());
auto imag_op = auto imag_op =
dyn_cast_or_null<xla_hlo::ImagOp>(getOperand(1)->getDefiningOp()); dyn_cast_or_null<xla_hlo::ImagOp>(getOperand(1).getDefiningOp());
if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) { if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
return real_op.getOperand(); return real_op.getOperand();
} }
@ -477,12 +476,12 @@ Type CreateRealType(Type type) {
} // namespace } // namespace
void ImagOp::build(Builder* builder, OperationState& state, Value val) { void ImagOp::build(Builder* builder, OperationState& state, Value val) {
build(builder, state, CreateRealType(val->getType()), val); build(builder, state, CreateRealType(val.getType()), val);
} }
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
if (auto complex_op = if (auto complex_op =
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand()->getDefiningOp())) { dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand().getDefiningOp())) {
return complex_op.getOperand(1); return complex_op.getOperand(1);
} }
@ -490,12 +489,12 @@ OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
} }
void RealOp::build(Builder* builder, OperationState& state, Value val) { void RealOp::build(Builder* builder, OperationState& state, Value val) {
build(builder, state, CreateRealType(val->getType()), val); build(builder, state, CreateRealType(val.getType()), val);
} }
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) { OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
if (auto complex_op = if (auto complex_op =
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand()->getDefiningOp())) { dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand().getDefiningOp())) {
return complex_op.getOperand(0); return complex_op.getOperand(0);
} }
@ -512,12 +511,12 @@ OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
} }
static LogicalResult Verify(ConcatenateOp op) { static LogicalResult Verify(ConcatenateOp op) {
auto firstType = op.getOperand(0)->getType().cast<RankedTensorType>(); auto firstType = op.getOperand(0).getType().cast<RankedTensorType>();
auto firstShape = firstType.getShape(); auto firstShape = firstType.getShape();
int numOperands = op.getNumOperands(); int numOperands = op.getNumOperands();
for (int i = 1; i < numOperands; i++) { for (int i = 1; i < numOperands; i++) {
auto secondType = op.getOperand(i)->getType().cast<RankedTensorType>(); auto secondType = op.getOperand(i).getType().cast<RankedTensorType>();
if (firstType.getRank() != secondType.getRank()) { if (firstType.getRank() != secondType.getRank()) {
return op.emitOpError( return op.emitOpError(
@ -552,18 +551,18 @@ void DynamicSliceOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
if (getOperand()->getType() == getType()) { if (getOperand().getType() == getType()) {
return getOperand(); return getOperand();
} }
if (auto prev_op = if (auto prev_op =
dyn_cast_or_null<ReshapeOp>(getOperand()->getDefiningOp())) { dyn_cast_or_null<ReshapeOp>(getOperand().getDefiningOp())) {
setOperand(prev_op.getOperand()); setOperand(prev_op.getOperand());
return getResult(); return getResult();
} }
if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) { if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
return elements.reshape(getResult()->getType().cast<ShapedType>()); return elements.reshape(getResult().getType().cast<ShapedType>());
} }
return {}; return {};
@ -613,7 +612,7 @@ void ReduceOp::build(Builder* builder, OperationState& state,
for (Value operand : operands) { for (Value operand : operands) {
result_ty.push_back( result_ty.push_back(
GetReduceResultType(operand->getType(), dimensions, builder)); GetReduceResultType(operand.getType(), dimensions, builder));
} }
build(builder, state, result_ty, operands, init_values, dimensions); build(builder, state, result_ty, operands, init_values, dimensions);
} }
@ -645,8 +644,8 @@ static LogicalResult Verify(SelectOp op) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult Verify(PadOp op) { static LogicalResult Verify(PadOp op) {
auto input_type = op.operand()->getType().cast<RankedTensorType>(); auto input_type = op.operand().getType().cast<RankedTensorType>();
auto pad_type = op.padding_value()->getType().cast<RankedTensorType>(); auto pad_type = op.padding_value().getType().cast<RankedTensorType>();
if (pad_type.getRank() != 0) { if (pad_type.getRank() != 0) {
return op.emitOpError( return op.emitOpError(
@ -678,7 +677,7 @@ static LogicalResult Verify(PadOp op) {
auto input_shape = input_type.getShape(); auto input_shape = input_type.getShape();
auto output_shape = auto output_shape =
op.getResult()->getType().cast<RankedTensorType>().getShape(); op.getResult().getType().cast<RankedTensorType>().getShape();
if (input_shape.size() != output_shape.size()) { if (input_shape.size() != output_shape.size()) {
return op.emitOpError( return op.emitOpError(
llvm::formatv("operand rank ({0}) and result rank({0}) should match", llvm::formatv("operand rank ({0}) and result rank({0}) should match",
@ -757,15 +756,15 @@ static Type GetBroadcastType(Builder* builder, Type x, Type y,
} }
} // namespace } // namespace
#define BINARY_BUILDER(Op) \ #define BINARY_BUILDER(Op) \
void Op::build(Builder* builder, OperationState& result, Value left, \ void Op::build(Builder* builder, OperationState& result, Value left, \
Value right, DenseIntElementsAttr broadcast_dimensions) { \ Value right, DenseIntElementsAttr broadcast_dimensions) { \
auto type = GetBroadcastType(builder, left->getType().cast<ShapedType>(), \ auto type = GetBroadcastType(builder, left.getType().cast<ShapedType>(), \
right->getType().cast<ShapedType>(), \ right.getType().cast<ShapedType>(), \
getElementTypeOrSelf(right->getType()), \ getElementTypeOrSelf(right.getType()), \
broadcast_dimensions); \ broadcast_dimensions); \
return Op::build(builder, result, type, left, right, \ return Op::build(builder, result, type, left, right, \
broadcast_dimensions); \ broadcast_dimensions); \
} }
BINARY_BUILDER(AddOp); BINARY_BUILDER(AddOp);
@ -815,7 +814,7 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value operand,
DenseIntElementsAttr start_indices, DenseIntElementsAttr start_indices,
DenseIntElementsAttr limit_indices, DenseIntElementsAttr limit_indices,
DenseIntElementsAttr strides) { DenseIntElementsAttr strides) {
Type ty = operand->getType(); Type ty = operand.getType();
RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>(); RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
if (!ranked_ty) return ty; if (!ranked_ty) return ty;
int64_t rank = ranked_ty.getRank(); int64_t rank = ranked_ty.getRank();
@ -852,7 +851,7 @@ void SortOp::build(Builder* builder, OperationState& state, ValueRange operands,
SmallVector<Type, 2> element_types; SmallVector<Type, 2> element_types;
element_types.reserve(operands.size()); element_types.reserve(operands.size());
for (Value operand : operands) element_types.push_back(operand->getType()); for (Value operand : operands) element_types.push_back(operand.getType());
state.addTypes(builder->getTupleType(element_types)); state.addTypes(builder->getTupleType(element_types));
state.addRegion(); state.addRegion();
@ -864,14 +863,13 @@ static LogicalResult Verify(SortOp op) {
// TODO(antiagainst): verify partionally dynamic shapes // TODO(antiagainst): verify partionally dynamic shapes
if (llvm::all_of(operands, [](Value operand) { if (llvm::all_of(operands, [](Value operand) {
return operand->getType().cast<ShapedType>().hasRank(); return operand.getType().cast<ShapedType>().hasRank();
})) { })) {
ArrayRef<int64_t> input_shape = ArrayRef<int64_t> input_shape =
(*operands.begin())->getType().cast<ShapedType>().getShape(); (*operands.begin()).getType().cast<ShapedType>().getShape();
if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) { if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) {
return operand->getType().cast<ShapedType>().getShape() != return operand.getType().cast<ShapedType>().getShape() != input_shape;
input_shape;
})) }))
return op.emitOpError("requires all inputs to have the same dimensions"); return op.emitOpError("requires all inputs to have the same dimensions");
@ -889,10 +887,10 @@ static LogicalResult Verify(SortOp op) {
for (auto indexed_operand : llvm::enumerate(operands)) { for (auto indexed_operand : llvm::enumerate(operands)) {
int index = indexed_operand.index(); int index = indexed_operand.index();
Type element_type = Type element_type =
indexed_operand.value()->getType().cast<ShapedType>().getElementType(); indexed_operand.value().getType().cast<ShapedType>().getElementType();
Type tensor_type = RankedTensorType::get({}, element_type); Type tensor_type = RankedTensorType::get({}, element_type);
for (int i : {2 * index, 2 * index + 1}) { for (int i : {2 * index, 2 * index + 1}) {
Type arg_type = block.getArgument(i)->getType(); Type arg_type = block.getArgument(i).getType();
if (arg_type != tensor_type) if (arg_type != tensor_type)
return op.emitOpError("comparator block argument #") return op.emitOpError("comparator block argument #")
<< i << " should be of type " << tensor_type << " but got " << i << " should be of type " << tensor_type << " but got "
@ -926,7 +924,7 @@ static LogicalResult Verify(TransposeOp op) {
} }
auto permutationSize = permutationType.getNumElements(); auto permutationSize = permutationType.getNumElements();
auto operandType = op.operand()->getType().dyn_cast<RankedTensorType>(); auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
if (operandType) { if (operandType) {
auto operandRank = operandType.getRank(); auto operandRank = operandType.getRank();
if (operandRank != permutationSize) { if (operandRank != permutationSize) {
@ -936,7 +934,7 @@ static LogicalResult Verify(TransposeOp op) {
} }
} }
auto resultType = op.getResult()->getType().dyn_cast<RankedTensorType>(); auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
if (resultType) { if (resultType) {
auto resultRank = resultType.getRank(); auto resultRank = resultType.getRank();
if (resultRank != permutationSize) { if (resultRank != permutationSize) {
@ -972,14 +970,14 @@ static LogicalResult Verify(TransposeOp op) {
void GetTupleElementOp::build(Builder* builder, OperationState& result, void GetTupleElementOp::build(Builder* builder, OperationState& result,
Value tuple, int32_t index) { Value tuple, int32_t index) {
if (auto tuple_type = tuple->getType().dyn_cast<TupleType>()) { if (auto tuple_type = tuple.getType().dyn_cast<TupleType>()) {
auto element_type = tuple_type.getType(index); auto element_type = tuple_type.getType(index);
build(builder, result, element_type, tuple, build(builder, result, element_type, tuple,
builder->getI32IntegerAttr(index)); builder->getI32IntegerAttr(index));
return; return;
} }
build(builder, result, tuple->getType(), tuple, build(builder, result, tuple.getType(), tuple,
builder->getI32IntegerAttr(index)); builder->getI32IntegerAttr(index));
} }
@ -992,7 +990,7 @@ void TupleOp::build(Builder* builder, OperationState& result,
SmallVector<Type, 4> types; SmallVector<Type, 4> types;
types.reserve(values.size()); types.reserve(values.size());
for (auto val : values) { for (auto val : values) {
types.push_back(val->getType()); types.push_back(val.getType());
} }
build(builder, result, builder->getTupleType(types), values); build(builder, result, builder->getTupleType(types), values);
@ -1014,7 +1012,7 @@ void UnaryEinsumOp::getCanonicalizationPatterns(
void CompareOp::build(Builder* builder, OperationState& result, Value lhs, void CompareOp::build(Builder* builder, OperationState& result, Value lhs,
Value rhs, DenseIntElementsAttr broadcast_dimensions, Value rhs, DenseIntElementsAttr broadcast_dimensions,
StringAttr comparison_direction) { StringAttr comparison_direction) {
auto new_type = GetBroadcastType(builder, lhs->getType(), rhs->getType(), auto new_type = GetBroadcastType(builder, lhs.getType(), rhs.getType(),
builder->getI1Type(), broadcast_dimensions); builder->getI1Type(), broadcast_dimensions);
build(builder, result, new_type, lhs, rhs, broadcast_dimensions, build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
comparison_direction); comparison_direction);

View File

@ -23,8 +23,8 @@ namespace mlir {
namespace xla { namespace xla {
DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y) { DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y) {
TensorType xType = x->getType().dyn_cast<RankedTensorType>(); TensorType xType = x.getType().dyn_cast<RankedTensorType>();
TensorType yType = y->getType().dyn_cast<RankedTensorType>(); TensorType yType = y.getType().dyn_cast<RankedTensorType>();
if (xType == yType || !xType || !yType) return {}; if (xType == yType || !xType || !yType) return {};
// If the shapes have the same rank, then there is nothing to do. // If the shapes have the same rank, then there is nothing to do.

View File

@ -35,8 +35,8 @@ mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b,
/// Get a constant splat for the given value type. /// Get a constant splat for the given value type.
template <typename T> template <typename T>
static ElementsAttr getSplat(Builder* b, Value val, T constant) { static ElementsAttr getSplat(Builder* b, Value val, T constant) {
auto valType = val->getType().cast<TensorType>(); auto valType = val.getType().cast<TensorType>();
auto valElementType = getElementTypeOrSelf(val->getType()); auto valElementType = getElementTypeOrSelf(val.getType());
// Handle integer elements. // Handle integer elements.
Attribute elementAttr; Attribute elementAttr;

View File

@ -542,7 +542,7 @@ LogicalResult ExportXlaOp(OutfeedOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values; auto& value_map = *ctx.values;
value_map[op] = xla::OutfeedWithToken( value_map[op] = xla::OutfeedWithToken(
value_map[op.operand()], value_map[op.token()], value_map[op.operand()], value_map[op.token()],
xla::TypeToShape(op.operand()->getType()), op.outfeed_config()); xla::TypeToShape(op.operand().getType()), op.outfeed_config());
return success(); return success();
} }
@ -883,7 +883,7 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
std::vector<xla::Shape> arg_shapes; std::vector<xla::Shape> arg_shapes;
arg_shapes.reserve(bb.getNumArguments()); arg_shapes.reserve(bb.getNumArguments());
for (auto& arg : bb.getArguments()) for (auto& arg : bb.getArguments())
arg_shapes.push_back(xla::TypeToShape(arg->getType())); arg_shapes.push_back(xla::TypeToShape(arg.getType()));
xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes); xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
auto tuple = xla::Parameter(builder, 0, input_shape, "arg_tuple"); auto tuple = xla::Parameter(builder, 0, input_shape, "arg_tuple");
for (auto& it : llvm::enumerate(bb.getArguments())) { for (auto& it : llvm::enumerate(bb.getArguments())) {
@ -893,7 +893,7 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
for (auto& it : llvm::enumerate(bb.getArguments())) { for (auto& it : llvm::enumerate(bb.getArguments())) {
auto arg = it.value(); auto arg = it.value();
auto num = it.index(); auto num = it.index();
xla::Shape shape = xla::TypeToShape(arg->getType()); xla::Shape shape = xla::TypeToShape(arg.getType());
lowering[arg] = lowering[arg] =
xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num)); xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num));
} }
@ -1024,7 +1024,7 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
llvm::SmallDenseSet<int32_t, 4> used_shape_indices; llvm::SmallDenseSet<int32_t, 4> used_shape_indices;
auto arg_type = auto arg_type =
entry_func.getArgument(i)->getType().dyn_cast<RankedTensorType>(); entry_func.getArgument(i).getType().dyn_cast<RankedTensorType>();
for (auto shape_and_padding : llvm::enumerate(llvm::zip( for (auto shape_and_padding : llvm::enumerate(llvm::zip(
shape_indices.getValue(), padding_arg_indices.getValue()))) { shape_indices.getValue(), padding_arg_indices.getValue()))) {
const int element_index = shape_and_padding.index(); const int element_index = shape_and_padding.index();
@ -1059,7 +1059,7 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
kPaddingArgIndicesAttr, i, element_index, e, padding_arg_index)); kPaddingArgIndicesAttr, i, element_index, e, padding_arg_index));
Type padding_arg_type = Type padding_arg_type =
entry_func.getArgument(padding_arg_index)->getType(); entry_func.getArgument(padding_arg_index).getType();
if (auto tensor_type = padding_arg_type.dyn_cast<RankedTensorType>()) if (auto tensor_type = padding_arg_type.dyn_cast<RankedTensorType>())
if (tensor_type.getRank() != 0) if (tensor_type.getRank() != 0)
return entry_func.emitError() return entry_func.emitError()

View File

@ -29,7 +29,7 @@ def BuildSliceLimits : NativeCodeCall<
def BuildSliceStrides : NativeCodeCall< def BuildSliceStrides : NativeCodeCall<
"GetI64ElementsAttr(SmallVector<int64_t, 4>(" "GetI64ElementsAttr(SmallVector<int64_t, 4>("
"$0->getType().cast<RankedTensorType>().getRank(), 1), &$_builder)">; "$0.getType().cast<RankedTensorType>().getRank(), 1), &$_builder)">;
def DynamicSliceToSlice: Pat<(HLO_DynamicSliceOp HLO_Tensor:$input, def DynamicSliceToSlice: Pat<(HLO_DynamicSliceOp HLO_Tensor:$input,
(HLO_ConstOp I64ElementsAttr:$starting_indices), (HLO_ConstOp I64ElementsAttr:$starting_indices),

View File

@ -40,7 +40,7 @@ namespace {
constexpr StringRef kTempBufferAttr = "temp"; constexpr StringRef kTempBufferAttr = "temp";
Value GetTensorStoreOrReturnMemRef(Value value) { Value GetTensorStoreOrReturnMemRef(Value value) {
for (const auto& user : value->getUsers()) { for (const auto& user : value.getUsers()) {
if (auto tensor_store = dyn_cast<TensorStoreOp>(user)) { if (auto tensor_store = dyn_cast<TensorStoreOp>(user)) {
if (tensor_store.getOperand(0) == value) { if (tensor_store.getOperand(0) == value) {
return tensor_store.getOperand(1); return tensor_store.getOperand(1);
@ -57,8 +57,8 @@ Value GetTensorStoreOrReturnMemRef(Value value) {
} }
Operation* GetLastUse(Value value) { Operation* GetLastUse(Value value) {
Operation* last = value->getDefiningOp(); Operation* last = value.getDefiningOp();
for (auto& user : value->getUses()) { for (auto& user : value.getUses()) {
Operation* user_op = user.getOwner(); Operation* user_op = user.getOwner();
if (!user_op->isBeforeInBlock(last)) { if (!user_op->isBeforeInBlock(last)) {
last = user_op; last = user_op;
@ -69,7 +69,7 @@ Operation* GetLastUse(Value value) {
Value InsertAllocAndDealloc(Location loc, Value result, Value InsertAllocAndDealloc(Location loc, Value result,
ConversionPatternRewriter* rewriter) { ConversionPatternRewriter* rewriter) {
auto result_type = result->getType().dyn_cast<ShapedType>(); auto result_type = result.getType().dyn_cast<ShapedType>();
if (!result_type || !result_type.hasStaticShape()) { if (!result_type || !result_type.hasStaticShape()) {
emitError(loc, emitError(loc,
"tensor to buffer conversion expects statically shaped results"); "tensor to buffer conversion expects statically shaped results");
@ -79,7 +79,7 @@ Value InsertAllocAndDealloc(Location loc, Value result,
Operation* last = GetLastUse(result); Operation* last = GetLastUse(result);
Operation* op = result->getDefiningOp(); Operation* op = result.getDefiningOp();
OpBuilder allocBuilder(op); OpBuilder allocBuilder(op);
auto alloc = allocBuilder.create<AllocOp>(loc, memref_type); auto alloc = allocBuilder.create<AllocOp>(loc, memref_type);
alloc.setAttr(kTempBufferAttr, rewriter->getBoolAttr(true)); alloc.setAttr(kTempBufferAttr, rewriter->getBoolAttr(true));
@ -161,7 +161,7 @@ struct HloToLHloReduceConverter
int original_arg_count = entry_block.getNumArguments(); int original_arg_count = entry_block.getNumArguments();
for (int i = 0; i < original_arg_count; ++i) { for (int i = 0; i < original_arg_count; ++i) {
auto old_arg = entry_block.getArgument(i); auto old_arg = entry_block.getArgument(i);
auto old_type = old_arg->getType().cast<TensorType>(); auto old_type = old_arg.getType().cast<TensorType>();
auto new_type = auto new_type =
MemRefType::get(old_type.getShape(), old_type.getElementType()); MemRefType::get(old_type.getShape(), old_type.getElementType());
auto new_arg = entry_block.addArgument(new_type); auto new_arg = entry_block.addArgument(new_type);
@ -169,7 +169,7 @@ struct HloToLHloReduceConverter
} }
// Add an argument for the result. // Add an argument for the result.
entry_block.addArgument( entry_block.addArgument(
entry_block.getArgument(original_arg_count)->getType()); entry_block.getArgument(original_arg_count).getType());
// Remove the old arguments. // Remove the old arguments.
for (int i = original_arg_count - 1; i >= 0; --i) { for (int i = original_arg_count - 1; i >= 0; --i) {
entry_block.eraseArgument(i); entry_block.eraseArgument(i);

View File

@ -99,8 +99,8 @@ LogicalResult LowerConditionalOp(mlir::xla_hlo::ConditionalOp conditional_op) {
mapper, &builder))) mapper, &builder)))
return failure(); return failure();
tail_block->addArguments(conditional_op.getResult()->getType()); tail_block->addArguments(conditional_op.getResult().getType());
conditional_op.getResult()->replaceAllUsesWith(tail_block->getArgument(0)); conditional_op.getResult().replaceAllUsesWith(tail_block->getArgument(0));
op_inst->erase(); op_inst->erase();
return success(); return success();
@ -201,7 +201,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
// Erase the original while loop. // Erase the original while loop.
tail_block->addArgument(while_op.getType()); tail_block->addArgument(while_op.getType());
while_op.getResult()->replaceAllUsesWith(tail_block->getArgument(0)); while_op.getResult().replaceAllUsesWith(tail_block->getArgument(0));
op_inst->erase(); op_inst->erase();
return success(); return success();

View File

@ -242,7 +242,7 @@ static Value ApplyReduction(Location loc, Value input,
static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format, static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format,
Value input) { Value input) {
return b.getI64IntegerAttr( return b.getI64IntegerAttr(
getFeatureDimension(format, input->getType().cast<RankedTensorType>())); getFeatureDimension(format, input.getType().cast<RankedTensorType>()));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -254,7 +254,7 @@ static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format,
static DenseIntElementsAttr getBiasFeatureDimension(Builder &b, static DenseIntElementsAttr getBiasFeatureDimension(Builder &b,
StringAttr format, StringAttr format,
Value input) { Value input) {
auto inputType = input->getType().cast<RankedTensorType>(); auto inputType = input.getType().cast<RankedTensorType>();
size_t featureDim = getFeatureDimension(format, inputType); size_t featureDim = getFeatureDimension(format, inputType);
RankedTensorType type = RankedTensorType::get(1, b.getIntegerType(64)); RankedTensorType type = RankedTensorType::get(1, b.getIntegerType(64));
return DenseIntElementsAttr::get(type, featureDim); return DenseIntElementsAttr::get(type, featureDim);
@ -319,8 +319,8 @@ static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) {
// must be broadcasted with a size 1 tensor or another dynamic dimension. // must be broadcasted with a size 1 tensor or another dynamic dimension.
// Returns false on rankless. // Returns false on rankless.
static bool AreBroadcastCompatible(Value x, Value y) { static bool AreBroadcastCompatible(Value x, Value y) {
auto x_rankless = x->getType().dyn_cast<RankedTensorType>(); auto x_rankless = x.getType().dyn_cast<RankedTensorType>();
auto y_rankless = y->getType().dyn_cast<RankedTensorType>(); auto y_rankless = y.getType().dyn_cast<RankedTensorType>();
if (!x_rankless || !y_rankless) { if (!x_rankless || !y_rankless) {
return false; return false;
} }
@ -418,7 +418,7 @@ static void BuildArgMinMaxReductionBody(Type input_element_type,
static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices, static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices,
DenseIntElementsAttr slice_sizes) { DenseIntElementsAttr slice_sizes) {
auto input_ty = input->getType().dyn_cast<RankedTensorType>(); auto input_ty = input.getType().dyn_cast<RankedTensorType>();
int64_t input_rank = input_ty.getRank(); int64_t input_rank = input_ty.getRank();
ArrayRef<int64_t> input_shape = input_ty.getShape(); ArrayRef<int64_t> input_shape = input_ty.getShape();
DenseIntElementsAttr constant_start_indices; DenseIntElementsAttr constant_start_indices;
@ -465,7 +465,7 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
.cast<DenseIntElementsAttr>(); .cast<DenseIntElementsAttr>();
} }
auto input_ty = input->getType().dyn_cast<RankedTensorType>(); auto input_ty = input.getType().dyn_cast<RankedTensorType>();
int64_t input_rank = input_ty.getRank(); int64_t input_rank = input_ty.getRank();
ArrayRef<int64_t> input_shape = input_ty.getShape(); ArrayRef<int64_t> input_shape = input_ty.getShape();
SmallVector<int64_t, 4> normalized_sizes; SmallVector<int64_t, 4> normalized_sizes;
@ -574,9 +574,9 @@ class ConvertConv : public OpRewritePattern<OpT> {
std::string data_format = op.data_format().str(); std::string data_format = op.data_format().str();
if (!FormatFromString(data_format, &format)) return Pattern::matchFailure(); if (!FormatFromString(data_format, &format)) return Pattern::matchFailure();
auto input_ty = op.input()->getType().template dyn_cast<RankedTensorType>(); auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
auto filter_ty = auto filter_ty =
op.filter()->getType().template dyn_cast<RankedTensorType>(); op.filter().getType().template dyn_cast<RankedTensorType>();
auto result_ty = op.getType().template dyn_cast<RankedTensorType>(); auto result_ty = op.getType().template dyn_cast<RankedTensorType>();
// Input, filter and the result needs to have static shape for calculation // Input, filter and the result needs to have static shape for calculation
@ -698,10 +698,10 @@ class ConvertBF16FloorDivOp : public OpRewritePattern<TF::FloorDivOp> {
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto l = op.x(); auto l = op.x();
auto r = op.y(); auto r = op.y();
auto element_type = getElementTypeOrSelf(l->getType()); auto element_type = getElementTypeOrSelf(l.getType());
if (!element_type.isBF16()) return matchFailure(); if (!element_type.isBF16()) return matchFailure();
auto out_type = op.z()->getType().cast<TensorType>(); auto out_type = op.z().getType().cast<TensorType>();
l = rewriter.create<ConvertOp>(op.getLoc(), l, rewriter.getF32Type()); l = rewriter.create<ConvertOp>(op.getLoc(), l, rewriter.getF32Type());
r = rewriter.create<ConvertOp>(op.getLoc(), r, rewriter.getF32Type()); r = rewriter.create<ConvertOp>(op.getLoc(), r, rewriter.getF32Type());
@ -765,13 +765,13 @@ class ConvertFusedBatchNormGradBase
// activation shape needs to be static to convert negative indices in // activation shape needs to be static to convert negative indices in
// TensorFlow to absolute indices required by HLO. // TensorFlow to absolute indices required by HLO.
RankedTensorType act_type = RankedTensorType act_type =
act->getType().template dyn_cast<RankedTensorType>(); act.getType().template dyn_cast<RankedTensorType>();
if (!act_type) return Pattern::matchFailure(); if (!act_type) return Pattern::matchFailure();
Type act_ele_type = act_type.getElementType(); Type act_ele_type = act_type.getElementType();
// To support mixed precision, the statistics type, which maybe more // To support mixed precision, the statistics type, which maybe more
// precise than the input types, are used for this op. // precise than the input types, are used for this op.
Type kernel_type = Type kernel_type =
scale->getType().template cast<TensorType>().getElementType(); scale.getType().template cast<TensorType>().getElementType();
grad = rewriter.create<ConvertOp>(loc, grad, kernel_type); grad = rewriter.create<ConvertOp>(loc, grad, kernel_type);
act = rewriter.create<ConvertOp>(loc, act, kernel_type); act = rewriter.create<ConvertOp>(loc, act, kernel_type);
@ -787,7 +787,7 @@ class ConvertFusedBatchNormGradBase
Type feature_type = RankedTensorType::get( Type feature_type = RankedTensorType::get(
{GetDimSize(act_type, feature_dim)}, kernel_type); {GetDimSize(act_type, feature_dim)}, kernel_type);
Type result_type = TupleType::get( Type result_type = TupleType::get(
{act->getType(), feature_type, feature_type}, rewriter.getContext()); {act.getType(), feature_type, feature_type}, rewriter.getContext());
auto training_op = rewriter.create<BatchNormGradOp>( auto training_op = rewriter.create<BatchNormGradOp>(
loc, result_type, act, scale, mean, var, grad, op.epsilon(), loc, result_type, act, scale, mean, var, grad, op.epsilon(),
@ -870,10 +870,10 @@ class ConvertFusedBatchNormV3Op
auto feature_dim = auto feature_dim =
getFeatureDimensionAttr(rewriter, op.data_formatAttr(), op.x()); getFeatureDimensionAttr(rewriter, op.data_formatAttr(), op.x());
auto input_type_tensor = op.x()->getType().dyn_cast<TensorType>(); auto input_type_tensor = op.x().getType().dyn_cast<TensorType>();
auto input_element_type = input_type_tensor.getElementType(); auto input_element_type = input_type_tensor.getElementType();
auto scale_type_tensor = op.scale()->getType().dyn_cast<TensorType>(); auto scale_type_tensor = op.scale().getType().dyn_cast<TensorType>();
auto scale_element_type = scale_type_tensor.getElementType(); auto scale_element_type = scale_type_tensor.getElementType();
// TODO(b/69928690): Support mixed precision in the XLA batch // TODO(b/69928690): Support mixed precision in the XLA batch
@ -922,7 +922,7 @@ class ConvertFusedBatchNormV3Op
op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor)); op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor));
auto corrected_variance = rewriter.create<xla_hlo::MulOp>( auto corrected_variance = rewriter.create<xla_hlo::MulOp>(
op.getLoc(), batch_variance->getType(), batch_variance, op.getLoc(), batch_variance.getType(), batch_variance,
factor_const_op, /*DenseIntElementsAttr=*/DenseIntElementsAttr()); factor_const_op, /*DenseIntElementsAttr=*/DenseIntElementsAttr());
// Convert back to input type to stay aligned with expected output type // Convert back to input type to stay aligned with expected output type
@ -1016,12 +1016,12 @@ class ConvertMaxPoolOp : public OpRewritePattern<TF::MaxPoolOp> {
PatternMatchResult matchAndRewrite(TF::MaxPoolOp op, PatternMatchResult matchAndRewrite(TF::MaxPoolOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Type element_type = Type element_type =
op.input()->getType().cast<TensorType>().getElementType(); op.input().getType().cast<TensorType>().getElementType();
if (!element_type.isIntOrFloat()) return matchFailure(); if (!element_type.isIntOrFloat()) return matchFailure();
Location loc = op.getLoc(); Location loc = op.getLoc();
ConstOp init = GetMinValueForType(element_type, loc, &rewriter); ConstOp init = GetMinValueForType(element_type, loc, &rewriter);
auto input_ty = op.input()->getType().dyn_cast<RankedTensorType>(); auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
if (!input_ty) return matchFailure(); if (!input_ty) return matchFailure();
DenseIntElementsAttr paddings_attr = GetReduceWindowPadding( DenseIntElementsAttr paddings_attr = GetReduceWindowPadding(
input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter); input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
@ -1067,9 +1067,9 @@ class ConvertSigmoidOp : public OpRewritePattern<TF::SigmoidOp> {
auto scalar_one = rewriter.create<ConstOp>( auto scalar_one = rewriter.create<ConstOp>(
op.getLoc(), op.getLoc(),
rewriter.getFloatAttr(getElementTypeOrSelf(operand->getType()), 0.5)); rewriter.getFloatAttr(getElementTypeOrSelf(operand.getType()), 0.5));
auto shaped_type = operand->getType().cast<ShapedType>(); auto shaped_type = operand.getType().cast<ShapedType>();
auto constant_ones = rewriter.create<BroadcastOp>( auto constant_ones = rewriter.create<BroadcastOp>(
op.getLoc(), shaped_type, scalar_one, op.getLoc(), shaped_type, scalar_one,
DenseIntElementsAttr::get( DenseIntElementsAttr::get(
@ -1080,7 +1080,7 @@ class ConvertSigmoidOp : public OpRewritePattern<TF::SigmoidOp> {
auto scaled_input = rewriter.create<MulOp>( auto scaled_input = rewriter.create<MulOp>(
op.getLoc(), operand, constant_ones, DenseIntElementsAttr()); op.getLoc(), operand, constant_ones, DenseIntElementsAttr());
auto tanh_op = auto tanh_op =
rewriter.create<TanhOp>(op.getLoc(), operand->getType(), scaled_input); rewriter.create<TanhOp>(op.getLoc(), operand.getType(), scaled_input);
auto mul_op = auto mul_op =
rewriter.create<MulOp>(op.getLoc(), tanh_op, constant_ones, rewriter.create<MulOp>(op.getLoc(), tanh_op, constant_ones,
/*DenseIntElementsAttr=*/DenseIntElementsAttr()); /*DenseIntElementsAttr=*/DenseIntElementsAttr());
@ -1129,7 +1129,7 @@ class ConvertSoftmaxOp : public OpRewritePattern<OpTy> {
// Softmax converter requires ranked type because the XLA reduce ops used // Softmax converter requires ranked type because the XLA reduce ops used
// while lowering requires dimensions attribute to reduce along. // while lowering requires dimensions attribute to reduce along.
RankedTensorType type = logits->getType().dyn_cast<RankedTensorType>(); RankedTensorType type = logits.getType().dyn_cast<RankedTensorType>();
if (!type) return Pattern::matchFailure(); if (!type) return Pattern::matchFailure();
auto loc = op.getLoc(); auto loc = op.getLoc();
@ -1202,11 +1202,11 @@ class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> {
PatternMatchResult matchAndRewrite(TF::SizeOp op, PatternMatchResult matchAndRewrite(TF::SizeOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value input = op.input(); Value input = op.input();
auto input_ty = input->getType().dyn_cast<RankedTensorType>(); auto input_ty = input.getType().dyn_cast<RankedTensorType>();
if (!input_ty) return Pattern::matchFailure(); if (!input_ty) return Pattern::matchFailure();
const int64_t rank = input_ty.getRank(); const int64_t rank = input_ty.getRank();
auto result_type = op.getResult()->getType(); auto result_type = op.getResult().getType();
Operation *size = Operation *size =
GetScalarConstOfType(result_type.cast<TensorType>().getElementType(), GetScalarConstOfType(result_type.cast<TensorType>().getElementType(),
op.getLoc(), 1, &rewriter); op.getLoc(), 1, &rewriter);
@ -1264,7 +1264,7 @@ class ConvertSplitOp : public OpRewritePattern<TF::SplitOp> {
PatternMatchResult matchAndRewrite(TF::SplitOp op, PatternMatchResult matchAndRewrite(TF::SplitOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
// We can only split along static dimensions. // We can only split along static dimensions.
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>(); auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
if (!input_type) return matchFailure(); if (!input_type) return matchFailure();
// We can only match when the split dimension is a constant scalar. // We can only match when the split dimension is a constant scalar.
@ -1356,7 +1356,7 @@ class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
// We can only split along static dimensions. // We can only split along static dimensions.
// TODO(b/145731001): enhance to support dynamic-shaped inputs. // TODO(b/145731001): enhance to support dynamic-shaped inputs.
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>(); auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
if (!input_type) return matchFailure(); if (!input_type) return matchFailure();
// We can only match when the split dimension is a constant scalar. // We can only match when the split dimension is a constant scalar.
@ -1453,7 +1453,7 @@ class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
// //
// TODO(hinsu): Relax this constraint for ops without negative indices and // TODO(hinsu): Relax this constraint for ops without negative indices and
// strides. // strides.
auto input_ty = op.input()->getType().dyn_cast<RankedTensorType>(); auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
if (!input_ty || !input_ty.hasStaticShape()) return matchFailure(); if (!input_ty || !input_ty.hasStaticShape()) return matchFailure();
ArrayRef<int64_t> input_shape = input_ty.getShape(); ArrayRef<int64_t> input_shape = input_ty.getShape();
@ -1553,7 +1553,7 @@ class ConvertStridedSliceGradOp
return matchFailure(); return matchFailure();
Value grad = op.dy(); Value grad = op.dy();
Type element_type = grad->getType().cast<ShapedType>().getElementType(); Type element_type = grad.getType().cast<ShapedType>().getElementType();
// Perform reshape to undo any new/shrink axies done by strided slice. // Perform reshape to undo any new/shrink axies done by strided slice.
grad = rewriter.create<xla_hlo::ReshapeOp>( grad = rewriter.create<xla_hlo::ReshapeOp>(
@ -1593,7 +1593,7 @@ class ConvertStridedSliceGradOp
if (!dims_to_reverse.empty()) { if (!dims_to_reverse.empty()) {
grad = rewriter.create<xla_hlo::ReverseOp>( grad = rewriter.create<xla_hlo::ReverseOp>(
op.getLoc(), grad->getType(), grad, op.getLoc(), grad.getType(), grad,
GetI64ElementsAttr(dims_to_reverse, &rewriter)); GetI64ElementsAttr(dims_to_reverse, &rewriter));
} }
@ -1631,7 +1631,7 @@ class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> {
PatternMatchResult matchAndRewrite(TF::RangeOp op, PatternMatchResult matchAndRewrite(TF::RangeOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto result = op.getResult(); auto result = op.getResult();
auto result_type = result->getType(); auto result_type = result.getType();
if (!result_type.cast<ShapedType>().hasStaticShape()) { if (!result_type.cast<ShapedType>().hasStaticShape()) {
return matchFailure(); return matchFailure();
} }
@ -1663,7 +1663,7 @@ class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
// TODO(b/141785544): Update this to not require static shapes. // TODO(b/141785544): Update this to not require static shapes.
// Input shape needs to be static to convert negative indices in TensorFlow // Input shape needs to be static to convert negative indices in TensorFlow
// to absolute indices required by HLO. // to absolute indices required by HLO.
auto input_ty = op.input()->getType().template dyn_cast<RankedTensorType>(); auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
if (!input_ty) return this->matchFailure(); if (!input_ty) return this->matchFailure();
ArrayRef<int64_t> input_shape = input_ty.getShape(); ArrayRef<int64_t> input_shape = input_ty.getShape();
@ -1826,7 +1826,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> {
PatternMatchResult matchAndRewrite(OpTy op, PatternMatchResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
RankedTensorType input_type = RankedTensorType input_type =
op.input()->getType().template dyn_cast<RankedTensorType>(); op.input().getType().template dyn_cast<RankedTensorType>();
if (!input_type) { if (!input_type) {
return this->matchFailure(); return this->matchFailure();
} }
@ -1841,7 +1841,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> {
Derived::GetInitialValue(input_element_type, loc, rewriter); Derived::GetInitialValue(input_element_type, loc, rewriter);
RankedTensorType output_type = RankedTensorType output_type =
op.output()->getType().template dyn_cast<RankedTensorType>(); op.output().getType().template dyn_cast<RankedTensorType>();
if (!output_type) { if (!output_type) {
return this->matchFailure(); return this->matchFailure();
} }
@ -1918,9 +1918,9 @@ class ConvertTensorScatterUpdateOp
PatternMatchResult matchAndRewrite(TF::TensorScatterUpdateOp op, PatternMatchResult matchAndRewrite(TF::TensorScatterUpdateOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto tensor_ty = op.tensor()->getType().dyn_cast<RankedTensorType>(); auto tensor_ty = op.tensor().getType().dyn_cast<RankedTensorType>();
auto indices_ty = op.indices()->getType().dyn_cast<RankedTensorType>(); auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
auto updates_ty = op.updates()->getType().dyn_cast<RankedTensorType>(); auto updates_ty = op.updates().getType().dyn_cast<RankedTensorType>();
if (!tensor_ty || !indices_ty || !updates_ty) return matchFailure(); if (!tensor_ty || !indices_ty || !updates_ty) return matchFailure();
// Last dimension of the indices needs to known at compile time for // Last dimension of the indices needs to known at compile time for
@ -1977,7 +1977,7 @@ class ConvertTileOp : public OpRewritePattern<TF::TileOp> {
PatternMatchResult matchAndRewrite(TF::TileOp op, PatternMatchResult matchAndRewrite(TF::TileOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto input_ty = op.input()->getType().dyn_cast<RankedTensorType>(); auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
if (!input_ty || !input_ty.hasStaticShape()) return matchFailure(); if (!input_ty || !input_ty.hasStaticShape()) return matchFailure();
ArrayRef<int64_t> input_shape = input_ty.getShape(); ArrayRef<int64_t> input_shape = input_ty.getShape();
Type element_type = input_ty.getElementType(); Type element_type = input_ty.getElementType();
@ -2041,12 +2041,12 @@ class ConvertMaxPoolGradOp : public OpRewritePattern<TF::MaxPoolGradOp> {
Location loc = op.getLoc(); Location loc = op.getLoc();
Type element_type = Type element_type =
op.orig_input()->getType().cast<TensorType>().getElementType(); op.orig_input().getType().cast<TensorType>().getElementType();
// Compute paddings using the original input and kernel shape and strides. // Compute paddings using the original input and kernel shape and strides.
// Here, ReduceWindow op as used as the MaxPool op is lowered to the // Here, ReduceWindow op as used as the MaxPool op is lowered to the
// ReduceWindow op. // ReduceWindow op.
auto input_ty = op.orig_input()->getType().dyn_cast<RankedTensorType>(); auto input_ty = op.orig_input().getType().dyn_cast<RankedTensorType>();
if (!input_ty) return matchFailure(); if (!input_ty) return matchFailure();
DenseIntElementsAttr paddings_attr = GetReduceWindowPadding( DenseIntElementsAttr paddings_attr = GetReduceWindowPadding(
input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter); input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
@ -2099,11 +2099,11 @@ class ConvertConv2DBackpropInputOp
return Pattern::matchFailure(); return Pattern::matchFailure();
auto out_backprop_ty = auto out_backprop_ty =
op.out_backprop()->getType().dyn_cast<RankedTensorType>(); op.out_backprop().getType().dyn_cast<RankedTensorType>();
if (!out_backprop_ty || !out_backprop_ty.hasStaticShape()) if (!out_backprop_ty || !out_backprop_ty.hasStaticShape())
return matchFailure(); return matchFailure();
ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape(); ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape();
auto filter_ty = op.filter()->getType().dyn_cast<RankedTensorType>(); auto filter_ty = op.filter().getType().dyn_cast<RankedTensorType>();
if (!filter_ty || !filter_ty.hasStaticShape()) return matchFailure(); if (!filter_ty || !filter_ty.hasStaticShape()) return matchFailure();
ArrayRef<int64_t> filter_shape = filter_ty.getShape(); ArrayRef<int64_t> filter_shape = filter_ty.getShape();
int num_spatial_dims = 2; int num_spatial_dims = 2;
@ -2243,11 +2243,11 @@ class ConvertConv2DBackpropFilterOp
return Pattern::matchFailure(); return Pattern::matchFailure();
auto out_backprop_ty = auto out_backprop_ty =
op.out_backprop()->getType().dyn_cast<RankedTensorType>(); op.out_backprop().getType().dyn_cast<RankedTensorType>();
if (!out_backprop_ty || !out_backprop_ty.hasStaticShape()) if (!out_backprop_ty || !out_backprop_ty.hasStaticShape())
return matchFailure(); return matchFailure();
ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape(); ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape();
auto input_ty = op.input()->getType().dyn_cast<RankedTensorType>(); auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
if (!input_ty || !input_ty.hasStaticShape()) return matchFailure(); if (!input_ty || !input_ty.hasStaticShape()) return matchFailure();
ArrayRef<int64_t> input_shape = input_ty.getShape(); ArrayRef<int64_t> input_shape = input_ty.getShape();
@ -2432,7 +2432,7 @@ class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
PatternMatchResult matchAndRewrite(TF::OneHotOp op, PatternMatchResult matchAndRewrite(TF::OneHotOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto indices_ty = op.indices()->getType().dyn_cast<RankedTensorType>(); auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
if (!indices_ty || !indices_ty.hasStaticShape()) return matchFailure(); if (!indices_ty || !indices_ty.hasStaticShape()) return matchFailure();
ArrayRef<int64_t> indices_shape = indices_ty.getShape(); ArrayRef<int64_t> indices_shape = indices_ty.getShape();
Type element_type = indices_ty.getElementType(); Type element_type = indices_ty.getElementType();
@ -2522,7 +2522,7 @@ class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
// The last dimension of the input tensor's shape should be known so we can // The last dimension of the input tensor's shape should be known so we can
// have clamped end_indices for slices. // have clamped end_indices for slices.
TensorType input_type = op.input()->getType().cast<TensorType>(); TensorType input_type = op.input().getType().cast<TensorType>();
if (!input_type.hasRank()) return matchFailure(); if (!input_type.hasRank()) return matchFailure();
int64_t input_rank = input_type.getRank(); int64_t input_rank = input_type.getRank();
int64_t last_dim_index = input_rank - 1; int64_t last_dim_index = input_rank - 1;
@ -2587,7 +2587,7 @@ class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
PatternMatchResult matchAndRewrite(TF::UnpackOp op, PatternMatchResult matchAndRewrite(TF::UnpackOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto value_type = op.value()->getType().cast<RankedTensorType>(); auto value_type = op.value().getType().cast<RankedTensorType>();
if (!value_type) return matchFailure(); if (!value_type) return matchFailure();
int64_t value_rank = value_type.getRank(); int64_t value_rank = value_type.getRank();
@ -2645,12 +2645,12 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern<OpTy> {
PatternMatchResult matchAndRewrite(OpTy op, PatternMatchResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto data_type = op.data()->getType().template dyn_cast<RankedTensorType>(); auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
if (!data_type) return this->matchFailure(); if (!data_type) return this->matchFailure();
int64_t data_rank = data_type.getRank(); int64_t data_rank = data_type.getRank();
auto segment_ids_type = auto segment_ids_type =
op.segment_ids()->getType().template dyn_cast<RankedTensorType>(); op.segment_ids().getType().template dyn_cast<RankedTensorType>();
if (!segment_ids_type) return this->matchFailure(); if (!segment_ids_type) return this->matchFailure();
int64_t segment_ids_rank = segment_ids_type.getRank(); int64_t segment_ids_rank = segment_ids_type.getRank();

View File

@ -68,8 +68,8 @@ void Detuple(Value tuple, Operation::result_range replace, OpBuilder* builder) {
// De-tuple the results of the xla hlo conditional result. // De-tuple the results of the xla hlo conditional result.
for (auto result_it : llvm::enumerate(replace)) { for (auto result_it : llvm::enumerate(replace)) {
auto get_tuple_value = builder->create<xla_hlo::GetTupleElementOp>( auto get_tuple_value = builder->create<xla_hlo::GetTupleElementOp>(
result_it.value()->getLoc(), tuple, result_it.index()); result_it.value().getLoc(), tuple, result_it.index());
result_it.value()->replaceAllUsesWith(get_tuple_value); result_it.value().replaceAllUsesWith(get_tuple_value);
} }
} }

View File

@ -35,19 +35,19 @@ def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
def TrueBoolAttr : AttrConstraint<CPred<"$_self.getValue()">>; def TrueBoolAttr : AttrConstraint<CPred<"$_self.getValue()">>;
def CastValueToI64: NativeCodeCall< def CastValueToI64: NativeCodeCall<
"CastValueToI64($0->getLoc(), $1, &$_builder)">; "CastValueToI64($0.getLoc(), $1, &$_builder)">;
// Here, $0 is an ElementsAttr with exactly one element of type integer. $1 is // Here, $0 is an ElementsAttr with exactly one element of type integer. $1 is
// the corresponding value of ranked tensor type whose axis is referred in $0. // the corresponding value of ranked tensor type whose axis is referred in $0.
def GetHLOAxisFromTFAxis : NativeCodeCall< def GetHLOAxisFromTFAxis : NativeCodeCall<
"GetHLOAxisFromTFAxis(" "GetHLOAxisFromTFAxis("
"$0, $1->getType().cast<RankedTensorType>().getRank(), &$_builder)">; "$0, $1.getType().cast<RankedTensorType>().getRank(), &$_builder)">;
// Same as the above but with $1 of type operand_range from variadic TensorFlow // Same as the above but with $1 of type operand_range from variadic TensorFlow
// input. // input.
def GetHLOAxisFromTFAxisVariadic : NativeCodeCall< def GetHLOAxisFromTFAxisVariadic : NativeCodeCall<
"GetHLOAxisFromTFAxis(" "GetHLOAxisFromTFAxis("
"$0, (*$1.begin())->getType().cast<RankedTensorType>().getRank(), " "$0, (*$1.begin()).getType().cast<RankedTensorType>().getRank(), "
"&$_builder)">; "&$_builder)">;
def : Pattern< def : Pattern<
@ -251,10 +251,10 @@ def OneElementAttr
"Scalar ElementsAttr">; "Scalar ElementsAttr">;
def HasRankedFirstOperand def HasRankedFirstOperand
: Constraint<CPred<"(*$0.begin())->getType().isa<RankedTensorType>()">>; : Constraint<CPred<"(*$0.begin()).getType().isa<RankedTensorType>()">>;
def IsShapedTensor def IsShapedTensor
: Constraint<CPred<"$0->getType().isa<RankedTensorType>()">>; : Constraint<CPred<"$0.getType().isa<RankedTensorType>()">>;
// This pattern converts TensorFlow axis format to HLO axis format which // This pattern converts TensorFlow axis format to HLO axis format which
// doesn't wrap around like TensorFlow and is always positive. For this // doesn't wrap around like TensorFlow and is always positive. For this
@ -405,7 +405,7 @@ def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices,
// Ternary op patterns. // Ternary op patterns.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def BothTypesMatch : Constraint<CPred<"$0->getType() == $1->getType()">, def BothTypesMatch : Constraint<CPred<"$0.getType() == $1.getType()">,
"types must be equal">; "types must be equal">;
foreach src = [TF_SelectOp, TF_SelectV2Op] in foreach src = [TF_SelectOp, TF_SelectV2Op] in

View File

@ -47,8 +47,8 @@ struct CompareIConvert : public RewritePattern {
auto lhs = compare_op.lhs(); auto lhs = compare_op.lhs();
auto rhs = compare_op.rhs(); auto rhs = compare_op.rhs();
auto lhs_type = lhs->getType().cast<TensorType>(); auto lhs_type = lhs.getType().cast<TensorType>();
auto rhs_type = rhs->getType().cast<TensorType>(); auto rhs_type = rhs.getType().cast<TensorType>();
// Broadcasting not supported by this rewrite. // Broadcasting not supported by this rewrite.
if (lhs_type.getShape() != rhs_type.getShape()) return matchFailure(); if (lhs_type.getShape() != rhs_type.getShape()) return matchFailure();
@ -86,8 +86,8 @@ struct CompareFConvert : public RewritePattern {
auto lhs = compare_op.lhs(); auto lhs = compare_op.lhs();
auto rhs = compare_op.rhs(); auto rhs = compare_op.rhs();
auto lhs_type = lhs->getType().cast<TensorType>(); auto lhs_type = lhs.getType().cast<TensorType>();
auto rhs_type = rhs->getType().cast<TensorType>(); auto rhs_type = rhs.getType().cast<TensorType>();
// Broadcasting not supported by this rewrite. // Broadcasting not supported by this rewrite.
if (lhs_type.getShape() != rhs_type.getShape()) return matchFailure(); if (lhs_type.getShape() != rhs_type.getShape()) return matchFailure();

View File

@ -31,8 +31,8 @@ def : Pat<(HLO_ConstOp ElementsAttr:$value),
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def IsSameSizePred : CPred< def IsSameSizePred : CPred<
"$0->getType().cast<ShapedType>().getShape() " "$0.getType().cast<ShapedType>().getShape() "
"== $1->getType().cast<ShapedType>().getShape()">; "== $1.getType().cast<ShapedType>().getShape()">;
def IsSameSizeConstraint : Constraint<IsSameSizePred, "inputs are same size">; def IsSameSizeConstraint : Constraint<IsSameSizePred, "inputs are same size">;

View File

@ -39,8 +39,8 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOp> {
PatternRewriter& rewriter) const override { PatternRewriter& rewriter) const override {
const auto& lhs = op.lhs(); const auto& lhs = op.lhs();
const auto& rhs = op.rhs(); const auto& rhs = op.rhs();
const auto& lhs_type = lhs->getType().template cast<MemRefType>(); const auto& lhs_type = lhs.getType().template cast<MemRefType>();
const auto& rhs_type = rhs->getType().template cast<MemRefType>(); const auto& rhs_type = rhs.getType().template cast<MemRefType>();
const auto& element_type = lhs_type.getElementType(); const auto& element_type = lhs_type.getElementType();
if (lhs_type.getShape() != rhs_type.getShape()) { if (lhs_type.getShape() != rhs_type.getShape()) {

View File

@ -55,7 +55,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
// Only support 1d reductions for now. // Only support 1d reductions for now.
int64_t size = 0; int64_t size = 0;
for (auto result : reduce_op.out()) { for (auto result : reduce_op.out()) {
auto shaped_type = result->getType().dyn_cast<ShapedType>(); auto shaped_type = result.getType().dyn_cast<ShapedType>();
if (!shaped_type || shaped_type.getRank() != 1) { if (!shaped_type || shaped_type.getRank() != 1) {
return matchFailure(); return matchFailure();
} }
@ -71,7 +71,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
// Require all inputs to have the same shape. // Require all inputs to have the same shape.
int64_t reduce_dim_size = 0; int64_t reduce_dim_size = 0;
for (auto input : reduce_op.operands()) { for (auto input : reduce_op.operands()) {
auto shaped_type = input->getType().dyn_cast<ShapedType>(); auto shaped_type = input.getType().dyn_cast<ShapedType>();
if (!shaped_type || !shaped_type.hasStaticShape()) { if (!shaped_type || !shaped_type.hasStaticShape()) {
return matchFailure(); return matchFailure();
} }
@ -128,7 +128,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
auto output = mapping.lookup(*reduce_op.out().begin()); auto output = mapping.lookup(*reduce_op.out().begin());
// TODO(herhut) Move this to the SliceOp builder. // TODO(herhut) Move this to the SliceOp builder.
auto resType = MemRefType::get( auto resType = MemRefType::get(
llvm::None, output->getType().cast<MemRefType>().getElementType(), llvm::None, output.getType().cast<MemRefType>().getElementType(),
makeStridedLinearLayoutMap(llvm::None, makeStridedLinearLayoutMap(llvm::None,
MemRefType::getDynamicStrideOrOffset(), MemRefType::getDynamicStrideOrOffset(),
rewriter.getContext())); rewriter.getContext()));
@ -136,7 +136,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
loc, resType, output, ArrayRef<Value>{launch_op.getThreadIds().x}); loc, resType, output, ArrayRef<Value>{launch_op.getThreadIds().x});
llvm::SmallVector<Value, 4> indexings; llvm::SmallVector<Value, 4> indexings;
auto input_buffer = *reduce_op.operands().begin(); auto input_buffer = *reduce_op.operands().begin();
auto input_type = input_buffer->getType().cast<MemRefType>(); auto input_type = input_buffer.getType().cast<MemRefType>();
for (int64_t dim = 0; dim < input_type.getRank(); ++dim) { for (int64_t dim = 0; dim < input_type.getRank(); ++dim) {
indexings.push_back(dim == reducing_dimension indexings.push_back(dim == reducing_dimension
? loop.getInductionVar() ? loop.getInductionVar()

View File

@ -57,7 +57,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<LhloOp> {
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = lhlo_op.getLoc(); auto loc = lhlo_op.getLoc();
auto argType = auto argType =
lhlo_op.getOperand(0)->getType().template dyn_cast<ShapedType>(); lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
if (!argType || !argType.hasStaticShape()) { if (!argType || !argType.hasStaticShape()) {
emitError(loc, emitError(loc,
"lhlo to linalg conversion expects statically shaped args"); "lhlo to linalg conversion expects statically shaped args");
@ -73,7 +73,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<LhloOp> {
unsigned nloops = 0; unsigned nloops = 0;
int operandCount = args.size() - 1; int operandCount = args.size() - 1;
for (const auto& arg : llvm::enumerate(args)) { for (const auto& arg : llvm::enumerate(args)) {
auto memrefType = arg.value()->getType().dyn_cast<MemRefType>(); auto memrefType = arg.value().getType().dyn_cast<MemRefType>();
if (!memrefType) return ConversionPattern::matchFailure(); if (!memrefType) return ConversionPattern::matchFailure();
unsigned rank = memrefType.getRank(); unsigned rank = memrefType.getRank();
if (!rank || (nloops && nloops != rank)) { if (!rank || (nloops && nloops != rank)) {
@ -125,7 +125,7 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = lhlo_op.getLoc(); auto loc = lhlo_op.getLoc();
auto argType = auto argType =
lhlo_op.getOperand(0)->getType().template dyn_cast<ShapedType>(); lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
if (!argType || !argType.getElementType().isIntOrFloat() || if (!argType || !argType.getElementType().isIntOrFloat() ||
(argType.getRank() != 0)) { (argType.getRank() != 0)) {
return ConversionPattern::matchFailure(); return ConversionPattern::matchFailure();
@ -151,9 +151,9 @@ class BroadcastInDimConverter : public OpConversionPattern<BroadcastInDimOp> {
BroadcastInDimOp broadcastOp, ArrayRef<Value> args, BroadcastInDimOp broadcastOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto operandMemrefType = auto operandMemrefType =
broadcastOp.operand()->getType().dyn_cast<MemRefType>(); broadcastOp.operand().getType().dyn_cast<MemRefType>();
auto resultMemrefType = auto resultMemrefType =
broadcastOp.output()->getType().dyn_cast<MemRefType>(); broadcastOp.output().getType().dyn_cast<MemRefType>();
if (!operandMemrefType || !resultMemrefType) return matchFailure(); if (!operandMemrefType || !resultMemrefType) return matchFailure();
auto broadcastDims = broadcastOp.broadcast_dimensions(); auto broadcastDims = broadcastOp.broadcast_dimensions();
if (!broadcastDims.hasValue()) return matchFailure(); if (!broadcastDims.hasValue()) return matchFailure();
@ -253,7 +253,7 @@ class IotaConverter : public OpConversionPattern<IotaOp> {
IotaOp iotaOp, ArrayRef<Value> args, IotaOp iotaOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto resultMemrefType = auto resultMemrefType =
iotaOp.getOperand()->getType().dyn_cast<MemRefType>(); iotaOp.getOperand().getType().dyn_cast<MemRefType>();
if (!resultMemrefType) return matchFailure(); if (!resultMemrefType) return matchFailure();
auto resultElementType = resultMemrefType.getElementType(); auto resultElementType = resultMemrefType.getElementType();

View File

@ -49,7 +49,7 @@ Value TransposeReshape(Value arg, mlir::Location loc,
llvm::ArrayRef<int64_t> right_dims, llvm::ArrayRef<int64_t> right_dims,
llvm::ArrayRef<int64_t> arg_shape, llvm::ArrayRef<int64_t> arg_shape,
PatternRewriter *rewriter) { PatternRewriter *rewriter) {
auto element_type = mlir::getElementTypeOrSelf(arg->getType()); auto element_type = mlir::getElementTypeOrSelf(arg.getType());
int64_t left_size = 1; int64_t left_size = 1;
for (auto dim : left_dims) { for (auto dim : left_dims) {
@ -94,7 +94,7 @@ Value TransposeReshape(Value arg, mlir::Location loc,
Value ProcessDotArg(Value arg, mlir::Location loc, Value ProcessDotArg(Value arg, mlir::Location loc,
ElementsAttr contract_dims_attr, bool outer_dims_first, ElementsAttr contract_dims_attr, bool outer_dims_first,
PatternRewriter *rewriter) { PatternRewriter *rewriter) {
auto shape = arg->getType().cast<mlir::ShapedType>().getShape(); auto shape = arg.getType().cast<mlir::ShapedType>().getShape();
llvm::SmallVector<bool, 5> is_outer_dim; llvm::SmallVector<bool, 5> is_outer_dim;
is_outer_dim.resize(shape.size(), true); is_outer_dim.resize(shape.size(), true);
@ -154,8 +154,8 @@ struct GeneralDotConvert
/*outer_dims_first=*/false, &rewriter); /*outer_dims_first=*/false, &rewriter);
// Dot resulting shape. // Dot resulting shape.
auto lhs_shape = lhs->getType().cast<mlir::ShapedType>().getShape(); auto lhs_shape = lhs.getType().cast<mlir::ShapedType>().getShape();
auto rhs_shape = rhs->getType().cast<mlir::ShapedType>().getShape(); auto rhs_shape = rhs.getType().cast<mlir::ShapedType>().getShape();
auto new_dot_type = auto new_dot_type =
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type); RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);

View File

@ -61,7 +61,7 @@ using ScalarIOp = typename ScalarOp<LHLO_BinaryOp>::IOp;
template <typename LhloOp> template <typename LhloOp>
Operation* MapLhloOpToStdScalarOp(LhloOp lhlo_op, ArrayRef<Type> result_types, Operation* MapLhloOpToStdScalarOp(LhloOp lhlo_op, ArrayRef<Type> result_types,
ArrayRef<Value> block_args, OpBuilder b) { ArrayRef<Value> block_args, OpBuilder b) {
Type element_type = block_args.front()->getType(); Type element_type = block_args.front().getType();
if (element_type.isa<IntegerType>()) { if (element_type.isa<IntegerType>()) {
return b.template create<ScalarIOp<LhloOp>>(lhlo_op.getLoc(), result_types, return b.template create<ScalarIOp<LhloOp>>(lhlo_op.getLoc(), result_types,
block_args, mlir::None); block_args, mlir::None);
@ -79,7 +79,7 @@ inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
ArrayRef<Value> block_args, OpBuilder b) { ArrayRef<Value> block_args, OpBuilder b) {
const auto& lhs = block_args[0]; const auto& lhs = block_args[0];
const auto& rhs = block_args[1]; const auto& rhs = block_args[1];
Type element_type = lhs->getType(); Type element_type = lhs.getType();
if (element_type.isa<IntegerType>()) { if (element_type.isa<IntegerType>()) {
auto lhs_gt_rhs = b.create<ScalarIOp<CompareOp>>( auto lhs_gt_rhs = b.create<ScalarIOp<CompareOp>>(
lhlo_op.getLoc(), CmpIPredicate::sgt, lhs, rhs); lhlo_op.getLoc(), CmpIPredicate::sgt, lhs, rhs);
@ -99,7 +99,7 @@ inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
ArrayRef<Value> block_args, OpBuilder b) { ArrayRef<Value> block_args, OpBuilder b) {
const auto& lhs = block_args[0]; const auto& lhs = block_args[0];
const auto& rhs = block_args[1]; const auto& rhs = block_args[1];
Type element_type = lhs->getType(); Type element_type = lhs.getType();
if (element_type.isa<IntegerType>()) { if (element_type.isa<IntegerType>()) {
auto lhs_lt_rhs = b.create<ScalarIOp<CompareOp>>( auto lhs_lt_rhs = b.create<ScalarIOp<CompareOp>>(
lhlo_op.getLoc(), CmpIPredicate::slt, lhs, rhs); lhlo_op.getLoc(), CmpIPredicate::slt, lhs, rhs);
@ -117,7 +117,7 @@ template <>
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::AndOp>( inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
xla_lhlo::AndOp lhlo_op, ArrayRef<Type> result_types, xla_lhlo::AndOp lhlo_op, ArrayRef<Type> result_types,
ArrayRef<Value> block_args, OpBuilder b) { ArrayRef<Value> block_args, OpBuilder b) {
Type element_type = block_args.front()->getType(); Type element_type = block_args.front().getType();
return element_type.isa<IntegerType>() return element_type.isa<IntegerType>()
? b.create<::mlir::AndOp>(lhlo_op.getLoc(), result_types, ? b.create<::mlir::AndOp>(lhlo_op.getLoc(), result_types,
block_args, mlir::None) block_args, mlir::None)
@ -153,7 +153,7 @@ inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::CompareOp>(
ArrayRef<Value> block_args, OpBuilder b) { ArrayRef<Value> block_args, OpBuilder b) {
const auto& lhs = block_args[0]; const auto& lhs = block_args[0];
const auto& rhs = block_args[1]; const auto& rhs = block_args[1];
Type element_type = lhs->getType(); Type element_type = lhs.getType();
if (element_type.isa<IntegerType>()) { if (element_type.isa<IntegerType>()) {
Optional<CmpIPredicate> predicate = Optional<CmpIPredicate> predicate =
getIntCmpPredicate(lhlo_op.comparison_direction()); getIntCmpPredicate(lhlo_op.comparison_direction());
@ -181,7 +181,7 @@ template <>
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>( inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
xla_lhlo::ExpOp lhlo_op, ArrayRef<Type> result_types, xla_lhlo::ExpOp lhlo_op, ArrayRef<Type> result_types,
ArrayRef<Value> block_args, OpBuilder b) { ArrayRef<Value> block_args, OpBuilder b) {
Type element_type = block_args.front()->getType(); Type element_type = block_args.front().getType();
return element_type.isa<FloatType>() return element_type.isa<FloatType>()
? b.create<::mlir::ExpOp>(lhlo_op.getLoc(), result_types, ? b.create<::mlir::ExpOp>(lhlo_op.getLoc(), result_types,
block_args, mlir::None) block_args, mlir::None)

View File

@ -257,7 +257,7 @@ mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size,
} }
for (mlir::IROperand& use : for (mlir::IROperand& use :
llvm::make_early_inc_range(loop.getInductionVar()->getUses())) { llvm::make_early_inc_range(loop.getInductionVar().getUses())) {
mlir::Operation* owner = use.getOwner(); mlir::Operation* owner = use.getOwner();
BoundAffineMap affine_map = GetBoundAffineMapFrom(owner); BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
unsigned new_dim = affine_map.operands.size(); unsigned new_dim = affine_map.operands.size();
@ -330,7 +330,7 @@ mlir::Operation* HoistAndFix(llvm::iplist<mlir::Operation>::iterator begin_op,
indvars.push_back(ancestor.getInductionVar()); indvars.push_back(ancestor.getInductionVar());
} }
for (mlir::IROperand& use : for (mlir::IROperand& use :
llvm::make_early_inc_range(alloc.getResult()->getUses())) { llvm::make_early_inc_range(alloc.getResult().getUses())) {
mlir::Operation* owner = use.getOwner(); mlir::Operation* owner = use.getOwner();
BoundAffineMap affine_map = GetBoundAffineMapFrom(owner); BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
affine_map.operands.insert(affine_map.operands.begin(), indvars.begin(), affine_map.operands.insert(affine_map.operands.begin(), indvars.begin(),

View File

@ -108,7 +108,7 @@ struct SingleTripLoopRemoval
: public mlir::FunctionPass<SingleTripLoopRemoval> { : public mlir::FunctionPass<SingleTripLoopRemoval> {
void runOnFunction() override { void runOnFunction() override {
auto getConstantValue = [](mlir::Value value) -> llvm::Optional<int64_t> { auto getConstantValue = [](mlir::Value value) -> llvm::Optional<int64_t> {
auto definingOp = value->getDefiningOp(); auto definingOp = value.getDefiningOp();
if (!definingOp) return llvm::None; if (!definingOp) return llvm::None;
auto constantOp = llvm::dyn_cast<mlir::ConstantOp>(definingOp); auto constantOp = llvm::dyn_cast<mlir::ConstantOp>(definingOp);
if (!constantOp) return llvm::None; if (!constantOp) return llvm::None;
@ -180,9 +180,9 @@ struct StoreForwardingPass : mlir::FunctionPass<StoreForwardingPass> {
// Recursively checks defining ops until finds AllocOp. Return either AllocOp // Recursively checks defining ops until finds AllocOp. Return either AllocOp
// if it is found or nullptr. // if it is found or nullptr.
mlir::Operation* SearchAllocOp(mlir::Value memref) { mlir::Operation* SearchAllocOp(mlir::Value memref) {
mlir::Operation* defOp = memref->getDefiningOp(); mlir::Operation* defOp = memref.getDefiningOp();
while (auto subviewOp = mlir::dyn_cast_or_null<mlir::SubViewOp>(defOp)) { while (auto subviewOp = mlir::dyn_cast_or_null<mlir::SubViewOp>(defOp)) {
defOp = subviewOp.source()->getDefiningOp(); defOp = subviewOp.source().getDefiningOp();
} }
if (auto allocOp = mlir::dyn_cast_or_null<mlir::AllocOp>(defOp)) { if (auto allocOp = mlir::dyn_cast_or_null<mlir::AllocOp>(defOp)) {
return allocOp.getOperation(); return allocOp.getOperation();
@ -211,7 +211,7 @@ struct StoreForwardingPass : mlir::FunctionPass<StoreForwardingPass> {
struct DeadTempBufferRemoval : mlir::FunctionPass<DeadTempBufferRemoval> { struct DeadTempBufferRemoval : mlir::FunctionPass<DeadTempBufferRemoval> {
bool operationConsideredDead(mlir::Operation* op) { bool operationConsideredDead(mlir::Operation* op) {
for (auto result : op->getResults()) { for (auto result : op->getResults()) {
if (!llvm::all_of(result->getUsers(), [&](mlir::Operation* op) { if (!llvm::all_of(result.getUsers(), [&](mlir::Operation* op) {
// Store and Dealloc is OK. // Store and Dealloc is OK.
if (llvm::isa<mlir::StoreOp>(op) || if (llvm::isa<mlir::StoreOp>(op) ||
llvm::isa<mlir::DeallocOp>(op)) { llvm::isa<mlir::DeallocOp>(op)) {
@ -235,7 +235,7 @@ struct DeadTempBufferRemoval : mlir::FunctionPass<DeadTempBufferRemoval> {
void recursiveErase(mlir::Operation* op) { void recursiveErase(mlir::Operation* op) {
for (auto result : op->getResults()) { for (auto result : op->getResults()) {
for (auto user : llvm::make_early_inc_range(result->getUsers())) { for (auto user : llvm::make_early_inc_range(result.getUsers())) {
recursiveErase(user); recursiveErase(user);
} }
} }

View File

@ -197,19 +197,19 @@ static absl::optional<int64> getLaunchBound(const mlir::gpu::KernelDim3& dim) {
op->emitError() << "bound " << name << " is not constant"; op->emitError() << "bound " << name << " is not constant";
return absl::nullopt; return absl::nullopt;
}; };
auto y_op = dim.y->getDefiningOp(); auto y_op = dim.y.getDefiningOp();
auto dim_y = get_constant(y_op, "y"); auto dim_y = get_constant(y_op, "y");
if (!dim_y.has_value() || dim_y.value() != 1) { if (!dim_y.has_value() || dim_y.value() != 1) {
y_op->emitError() << "bound 'y' is not constant 1"; y_op->emitError() << "bound 'y' is not constant 1";
return absl::nullopt; return absl::nullopt;
} }
auto z_op = dim.z->getDefiningOp(); auto z_op = dim.z.getDefiningOp();
auto dim_z = get_constant(z_op, "z"); auto dim_z = get_constant(z_op, "z");
if (!dim_z.has_value() || dim_z.value() != 1) { if (!dim_z.has_value() || dim_z.value() != 1) {
z_op->emitError() << "bound 'z' is not constant 1"; z_op->emitError() << "bound 'z' is not constant 1";
return absl::nullopt; return absl::nullopt;
} }
return get_constant(dim.x->getDefiningOp(), "x"); return get_constant(dim.x.getDefiningOp(), "x");
} }
using OperandToValueMap = using OperandToValueMap =
@ -224,7 +224,7 @@ static StatusOr<std::vector<const HloInstruction*>> ComputeOperandToValueMap(
for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands(); for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands();
++kernel_index) { ++kernel_index) {
auto launchop_operand = auto launchop_operand =
launchOp.getKernelOperand(kernel_index)->dyn_cast<BlockArgument>(); launchOp.getKernelOperand(kernel_index).dyn_cast<BlockArgument>();
if (!launchop_operand) { if (!launchop_operand) {
launchOp.emitError("argument to kernel is not a function input"); launchOp.emitError("argument to kernel is not a function input");
has_failed = true; has_failed = true;
@ -233,7 +233,7 @@ static StatusOr<std::vector<const HloInstruction*>> ComputeOperandToValueMap(
// host_index is the argument position to the surrounding function that // host_index is the argument position to the surrounding function that
// contains the launch. This index corresponds to HLO operand indices // contains the launch. This index corresponds to HLO operand indices
// by construction. // by construction.
auto host_index = launchop_operand->getArgNumber(); auto host_index = launchop_operand.getArgNumber();
// The trailing argument to the outer function are the results. // The trailing argument to the outer function are the results.
auto operand = auto operand =
(host_index < operands.size()) ? operands[host_index] : instr; (host_index < operands.size()) ? operands[host_index] : instr;
@ -304,7 +304,7 @@ Status InsertBufferLoadPreduleIntoKernel(
// { baseptr, dataptr, offset, shape_vect, stride_vect } // { baseptr, dataptr, offset, shape_vect, stride_vect }
// where shape_vect and stride_vect are integer vectors with length // where shape_vect and stride_vect are integer vectors with length
// matching the rank of the tensor. // matching the rank of the tensor.
auto target_type = value->getType().cast<LLVMType>(); auto target_type = value.getType().cast<LLVMType>();
auto struct_type = target_type.getPointerElementTy(); auto struct_type = target_type.getPointerElementTy();
auto descPtr = auto descPtr =
builder.create<mlir::LLVM::AllocaOp>(loc, target_type, one, 0); builder.create<mlir::LLVM::AllocaOp>(loc, target_type, one, 0);
@ -367,7 +367,7 @@ Status InsertBufferLoadPreduleIntoKernel(
} }
} }
// Now we can use the descriptor instead of the original argument. // Now we can use the descriptor instead of the original argument.
value->replaceAllUsesWith(descPtr); value.replaceAllUsesWith(descPtr);
} }
} }