NFC: Remove usages of Value::operator* and Value::operator-> now that Value is properly value-typed.
These were temporary methods used to simplify the transition. PiperOrigin-RevId: 287902391 Change-Id: I249fc049e4f1b762a23939b8e8b79110103df4aa
This commit is contained in:
parent
b4fd6a5963
commit
043abbdf86
@ -217,7 +217,7 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b,
|
|||||||
// min/max stats is just for comments, so ignore it.
|
// min/max stats is just for comments, so ignore it.
|
||||||
if (!tensor.quantization || IsQuantized(tensor)) return nullptr;
|
if (!tensor.quantization || IsQuantized(tensor)) return nullptr;
|
||||||
// If the result isn't float and unquantizable, the min/max is ignored.
|
// If the result isn't float and unquantizable, the min/max is ignored.
|
||||||
if (!res->getType()
|
if (!res.getType()
|
||||||
.cast<mlir::ShapedType>()
|
.cast<mlir::ShapedType>()
|
||||||
.getElementType()
|
.getElementType()
|
||||||
.isa<mlir::FloatType>()) {
|
.isa<mlir::FloatType>()) {
|
||||||
|
@ -233,17 +233,17 @@ static bool IsConst(Operation* op) {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
static bool HasValidTFLiteType(Value value, T& error_handler) {
|
static bool HasValidTFLiteType(Value value, T& error_handler) {
|
||||||
// None type is allowed to represent unspecified operands.
|
// None type is allowed to represent unspecified operands.
|
||||||
if (value->getType().isa<NoneType>()) return true;
|
if (value.getType().isa<NoneType>()) return true;
|
||||||
|
|
||||||
auto type = value->getType().dyn_cast<TensorType>();
|
auto type = value.getType().dyn_cast<TensorType>();
|
||||||
if (!type) {
|
if (!type) {
|
||||||
if (auto op = value->getDefiningOp()) {
|
if (auto op = value.getDefiningOp()) {
|
||||||
error_handler.emitError()
|
error_handler.emitError()
|
||||||
<< '\'' << op << "' should produce value of tensor type instead of "
|
<< '\'' << op << "' should produce value of tensor type instead of "
|
||||||
<< value->getType();
|
<< value.getType();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
error_handler.emitError("expected tensor type, got ") << value->getType();
|
error_handler.emitError("expected tensor type, got ") << value.getType();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -282,7 +282,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
|
|||||||
|
|
||||||
for (auto arg : bb.getArguments()) {
|
for (auto arg : bb.getArguments()) {
|
||||||
if (!HasValidTFLiteType(arg, fn))
|
if (!HasValidTFLiteType(arg, fn))
|
||||||
return fn.emitError("invalid TFLite type: ") << arg->getType(), false;
|
return fn.emitError("invalid TFLite type: ") << arg.getType(), false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify that all operations except the terminator have exactly one
|
// Verify that all operations except the terminator have exactly one
|
||||||
@ -292,7 +292,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) {
|
|||||||
|
|
||||||
for (auto result : inst.getResults()) {
|
for (auto result : inst.getResults()) {
|
||||||
if (!HasValidTFLiteType(result, inst))
|
if (!HasValidTFLiteType(result, inst))
|
||||||
return fn.emitError("invalid TFLite type: ") << result->getType(),
|
return fn.emitError("invalid TFLite type: ") << result.getType(),
|
||||||
false;
|
false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -504,7 +504,7 @@ Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
|
|||||||
|
|
||||||
Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
||||||
Value value, const std::string& name, unsigned buffer_idx) {
|
Value value, const std::string& name, unsigned buffer_idx) {
|
||||||
auto type = value->getType().cast<TensorType>();
|
auto type = value.getType().cast<TensorType>();
|
||||||
|
|
||||||
// TFLite requires tensor shape only for the inputs and constants.
|
// TFLite requires tensor shape only for the inputs and constants.
|
||||||
// However, we output all known shapes for better round-tripping
|
// However, we output all known shapes for better round-tripping
|
||||||
@ -516,7 +516,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
|||||||
|
|
||||||
if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range))
|
if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range))
|
||||||
return mlir::emitError(
|
return mlir::emitError(
|
||||||
value->getLoc(),
|
value.getLoc(),
|
||||||
"result shape dimensions out of 32 bit int type range");
|
"result shape dimensions out of 32 bit int type range");
|
||||||
|
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
@ -528,7 +528,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
|||||||
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
if (mlir::failed(check_shape(shape_ref))) return llvm::None;
|
||||||
|
|
||||||
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
|
||||||
} else if (auto* inst = value->getDefiningOp()) {
|
} else if (auto* inst = value.getDefiningOp()) {
|
||||||
if (IsConst(inst)) {
|
if (IsConst(inst)) {
|
||||||
// Const op can have a result of dynamic shaped type (e.g. due to constant
|
// Const op can have a result of dynamic shaped type (e.g. due to constant
|
||||||
// folding), but we can still derive the shape of a constant tensor for
|
// folding), but we can still derive the shape of a constant tensor for
|
||||||
@ -571,7 +571,7 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
|
|||||||
// marked as a stateful. If so, set the tensor's is_variable as true
|
// marked as a stateful. If so, set the tensor's is_variable as true
|
||||||
// This is v1 ref variable semantics in the TFLite runtime.
|
// This is v1 ref variable semantics in the TFLite runtime.
|
||||||
bool is_variable = false;
|
bool is_variable = false;
|
||||||
for (auto& use : value->getUses()) {
|
for (auto& use : value.getUses()) {
|
||||||
is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber());
|
is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber());
|
||||||
if (is_variable) {
|
if (is_variable) {
|
||||||
break;
|
break;
|
||||||
@ -923,7 +923,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
|||||||
// on failure.
|
// on failure.
|
||||||
auto build_tensor_and_buffer = [&](Value value, const std::string& name) {
|
auto build_tensor_and_buffer = [&](Value value, const std::string& name) {
|
||||||
// NoneType represents optional and may be skipped here.
|
// NoneType represents optional and may be skipped here.
|
||||||
if (value->getType().isa<NoneType>()) {
|
if (value.getType().isa<NoneType>()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -936,7 +936,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
|||||||
// make the Buffer empty apart from setting the buffer_idx=0 in the Tensor.
|
// make the Buffer empty apart from setting the buffer_idx=0 in the Tensor.
|
||||||
// This does not seem to affect runtime behavior for RNN/LSTM, but would be
|
// This does not seem to affect runtime behavior for RNN/LSTM, but would be
|
||||||
// good for reducing memory footprint.
|
// good for reducing memory footprint.
|
||||||
if (auto* inst = value->getDefiningOp()) {
|
if (auto* inst = value.getDefiningOp()) {
|
||||||
auto buffer_or = BuildBuffer(inst);
|
auto buffer_or = BuildBuffer(inst);
|
||||||
if (!buffer_or) return false;
|
if (!buffer_or) return false;
|
||||||
buffers_.push_back(*buffer_or);
|
buffers_.push_back(*buffer_or);
|
||||||
@ -976,7 +976,7 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(FuncOp fn) {
|
|||||||
std::vector<int32_t> operands;
|
std::vector<int32_t> operands;
|
||||||
operands.reserve(inst.getNumOperands());
|
operands.reserve(inst.getNumOperands());
|
||||||
for (auto operand : inst.getOperands()) {
|
for (auto operand : inst.getOperands()) {
|
||||||
if (operand->getType().isa<NoneType>())
|
if (operand.getType().isa<NoneType>())
|
||||||
operands.push_back(kTfLiteOptionalTensor);
|
operands.push_back(kTfLiteOptionalTensor);
|
||||||
else
|
else
|
||||||
operands.push_back(tensor_index_map.lookup(operand));
|
operands.push_back(tensor_index_map.lookup(operand));
|
||||||
|
@ -304,11 +304,11 @@ Attribute ConstFoldUnaryOp(Type result_type, Attribute operand,
|
|||||||
void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs,
|
void buildComparisonBinOp(Builder *builder, OperationState &result, Value lhs,
|
||||||
Value rhs) {
|
Value rhs) {
|
||||||
auto result_type =
|
auto result_type =
|
||||||
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
|
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
|
||||||
if (!result_type)
|
if (!result_type)
|
||||||
emitError(result.location)
|
emitError(result.location)
|
||||||
<< "non-broadcastable operands: " << lhs->getType() << " and "
|
<< "non-broadcastable operands: " << lhs.getType() << " and "
|
||||||
<< rhs->getType();
|
<< rhs.getType();
|
||||||
result.addOperands({lhs, rhs});
|
result.addOperands({lhs, rhs});
|
||||||
// Comparison binary ops always return i1 tensor.
|
// Comparison binary ops always return i1 tensor.
|
||||||
if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) {
|
if (auto shaped_type = result_type.dyn_cast<RankedTensorType>()) {
|
||||||
@ -324,12 +324,12 @@ void buildFusedBroadcastableBinOp(Builder *builder, OperationState &result,
|
|||||||
Value lhs, Value rhs,
|
Value lhs, Value rhs,
|
||||||
StringAttr fused_activation_function) {
|
StringAttr fused_activation_function) {
|
||||||
auto result_type =
|
auto result_type =
|
||||||
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
|
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
|
||||||
|
|
||||||
if (!result_type)
|
if (!result_type)
|
||||||
emitError(result.location)
|
emitError(result.location)
|
||||||
<< "non-broadcastable operands: " << lhs->getType() << " and "
|
<< "non-broadcastable operands: " << lhs.getType() << " and "
|
||||||
<< rhs->getType();
|
<< rhs.getType();
|
||||||
|
|
||||||
result.addOperands({lhs, rhs});
|
result.addOperands({lhs, rhs});
|
||||||
result.addAttribute("fused_activation_function", fused_activation_function);
|
result.addAttribute("fused_activation_function", fused_activation_function);
|
||||||
@ -358,7 +358,7 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
int64_t GetConcatenationOpAxis(ConcatenationOp op) {
|
int64_t GetConcatenationOpAxis(ConcatenationOp op) {
|
||||||
auto output_type = op.output()->getType().cast<RankedTensorType>();
|
auto output_type = op.output().getType().cast<RankedTensorType>();
|
||||||
int64_t axis = op.axis().getSExtValue();
|
int64_t axis = op.axis().getSExtValue();
|
||||||
if (axis < 0) axis += output_type.getRank();
|
if (axis < 0) axis += output_type.getRank();
|
||||||
return axis;
|
return axis;
|
||||||
@ -452,7 +452,7 @@ LogicalResult VerifyConcatenationOpTypes(Operation *op,
|
|||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult Verify(ConcatenationOp op) {
|
LogicalResult Verify(ConcatenationOp op) {
|
||||||
auto output_type = op.output()->getType().dyn_cast<RankedTensorType>();
|
auto output_type = op.output().getType().dyn_cast<RankedTensorType>();
|
||||||
|
|
||||||
// If the output type is unranked, there is nothing else to be verified.
|
// If the output type is unranked, there is nothing else to be verified.
|
||||||
if (!output_type) return success();
|
if (!output_type) return success();
|
||||||
@ -463,7 +463,7 @@ LogicalResult Verify(ConcatenationOp op) {
|
|||||||
|
|
||||||
SmallVector<TensorType, 4> operand_types;
|
SmallVector<TensorType, 4> operand_types;
|
||||||
for (Value operand : op.values())
|
for (Value operand : op.values())
|
||||||
operand_types.push_back(operand->getType().cast<TensorType>());
|
operand_types.push_back(operand.getType().cast<TensorType>());
|
||||||
|
|
||||||
return VerifyConcatenationOpTypes(op.getOperation(), output_type,
|
return VerifyConcatenationOpTypes(op.getOperation(), output_type,
|
||||||
operand_types, axis);
|
operand_types, axis);
|
||||||
@ -520,7 +520,7 @@ DenseElementsAttr ConstFoldConcatenateOpDense(ArrayRef<Attribute> operands,
|
|||||||
|
|
||||||
OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
|
||||||
if (fused_activation_function() == "NONE") {
|
if (fused_activation_function() == "NONE") {
|
||||||
if (auto output_type = output()->getType().dyn_cast<RankedTensorType>()) {
|
if (auto output_type = output().getType().dyn_cast<RankedTensorType>()) {
|
||||||
const int64_t axis = GetConcatenationOpAxis(*this);
|
const int64_t axis = GetConcatenationOpAxis(*this);
|
||||||
if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis))
|
if (IsConcatenationOpConstFoldable(*this, operands, output_type, axis))
|
||||||
return ConstFoldConcatenateOpDense(operands, output_type, axis);
|
return ConstFoldConcatenateOpDense(operands, output_type, axis);
|
||||||
@ -530,7 +530,7 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
// Remove all empty values.
|
// Remove all empty values.
|
||||||
SmallVector<Value, 4> non_empty_values;
|
SmallVector<Value, 4> non_empty_values;
|
||||||
for (Value value : this->values()) {
|
for (Value value : this->values()) {
|
||||||
const auto shaped_type = value->getType().cast<ShapedType>();
|
const auto shaped_type = value.getType().cast<ShapedType>();
|
||||||
if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) {
|
if (shaped_type.hasStaticShape() && shaped_type.getNumElements() == 0) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -559,8 +559,8 @@ OpFoldResult ConcatenationOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult Verify(FullyConnectedOp op) {
|
LogicalResult Verify(FullyConnectedOp op) {
|
||||||
ShapedType input_type = op.input()->getType().cast<ShapedType>();
|
ShapedType input_type = op.input().getType().cast<ShapedType>();
|
||||||
ShapedType filter_type = op.filter()->getType().cast<ShapedType>();
|
ShapedType filter_type = op.filter().getType().cast<ShapedType>();
|
||||||
if (filter_type.hasRank() && filter_type.getRank() != 2) {
|
if (filter_type.hasRank() && filter_type.getRank() != 2) {
|
||||||
return op.emitOpError("expect 2d filter, got ") << filter_type;
|
return op.emitOpError("expect 2d filter, got ") << filter_type;
|
||||||
}
|
}
|
||||||
@ -582,7 +582,7 @@ LogicalResult Verify(FullyConnectedOp op) {
|
|||||||
// format.
|
// format.
|
||||||
if (op.weights_format() == "DEFAULT") {
|
if (op.weights_format() == "DEFAULT") {
|
||||||
ShapedType output_type =
|
ShapedType output_type =
|
||||||
(*op.output().begin())->getType().cast<ShapedType>();
|
(*op.output().begin()).getType().cast<ShapedType>();
|
||||||
if (!output_type.hasStaticShape()) {
|
if (!output_type.hasStaticShape()) {
|
||||||
return mlir::success();
|
return mlir::success();
|
||||||
}
|
}
|
||||||
@ -610,8 +610,8 @@ LogicalResult Verify(FullyConnectedOp op) {
|
|||||||
|
|
||||||
static void BuildGatherOp(Builder *builder, OperationState &result,
|
static void BuildGatherOp(Builder *builder, OperationState &result,
|
||||||
Value params, Value indices, IntegerAttr axis) {
|
Value params, Value indices, IntegerAttr axis) {
|
||||||
auto params_type = params->getType().cast<TensorType>();
|
auto params_type = params.getType().cast<TensorType>();
|
||||||
auto indices_type = indices->getType().cast<TensorType>();
|
auto indices_type = indices.getType().cast<TensorType>();
|
||||||
|
|
||||||
// If params/indices is unranked, then output is unranked.
|
// If params/indices is unranked, then output is unranked.
|
||||||
if (!params_type.hasRank() || !indices_type.hasRank())
|
if (!params_type.hasRank() || !indices_type.hasRank())
|
||||||
@ -705,7 +705,7 @@ static LogicalResult Verify(PackOp op) {
|
|||||||
return op.emitOpError("input count should match 'values_count' attribute");
|
return op.emitOpError("input count should match 'values_count' attribute");
|
||||||
|
|
||||||
Value operand0 = op.getOperand(0);
|
Value operand0 = op.getOperand(0);
|
||||||
auto input_type = operand0->getType().cast<ShapedType>();
|
auto input_type = operand0.getType().cast<ShapedType>();
|
||||||
|
|
||||||
// Check axis bounds.
|
// Check axis bounds.
|
||||||
if (input_type.hasRank()) {
|
if (input_type.hasRank()) {
|
||||||
@ -718,7 +718,7 @@ static LogicalResult Verify(PackOp op) {
|
|||||||
// Make sure all inputs have the same shape and element type.
|
// Make sure all inputs have the same shape and element type.
|
||||||
// TODO(rahulsp): Simplify once b/135032064 is fixed.
|
// TODO(rahulsp): Simplify once b/135032064 is fixed.
|
||||||
for (Value operand : op.getOperands()) {
|
for (Value operand : op.getOperands()) {
|
||||||
auto other_type = operand->getType().cast<ShapedType>();
|
auto other_type = operand.getType().cast<ShapedType>();
|
||||||
if (input_type != other_type)
|
if (input_type != other_type)
|
||||||
return op.emitOpError("operands should be of the same type. got ")
|
return op.emitOpError("operands should be of the same type. got ")
|
||||||
<< input_type << ", " << other_type;
|
<< input_type << ", " << other_type;
|
||||||
@ -732,9 +732,9 @@ static LogicalResult Verify(PackOp op) {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static LogicalResult Verify(PReluOp op) {
|
static LogicalResult Verify(PReluOp op) {
|
||||||
auto input_type = op.input()->getType().cast<ShapedType>();
|
auto input_type = op.input().getType().cast<ShapedType>();
|
||||||
auto alpha_type = op.alpha()->getType().cast<ShapedType>();
|
auto alpha_type = op.alpha().getType().cast<ShapedType>();
|
||||||
auto output_type = op.output()->getType().cast<ShapedType>();
|
auto output_type = op.output().getType().cast<ShapedType>();
|
||||||
|
|
||||||
if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) {
|
if (input_type.hasStaticShape() && alpha_type.hasStaticShape()) {
|
||||||
if (input_type.getRank() != alpha_type.getRank() + 1) {
|
if (input_type.getRank() != alpha_type.getRank() + 1) {
|
||||||
@ -783,13 +783,13 @@ struct RemoveAdjacentReshape : public RewritePattern {
|
|||||||
|
|
||||||
PatternMatchResult match(Operation *op) const override {
|
PatternMatchResult match(Operation *op) const override {
|
||||||
auto thisOp = cast<ReshapeOp>(op);
|
auto thisOp = cast<ReshapeOp>(op);
|
||||||
auto prevOp = thisOp.getOperand(0)->getDefiningOp();
|
auto prevOp = thisOp.getOperand(0).getDefiningOp();
|
||||||
return isa_and_nonnull<ReshapeOp>(prevOp) ? matchSuccess() : matchFailure();
|
return isa_and_nonnull<ReshapeOp>(prevOp) ? matchSuccess() : matchFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||||
auto thisOp = cast<ReshapeOp>(op);
|
auto thisOp = cast<ReshapeOp>(op);
|
||||||
auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0)->getDefiningOp());
|
auto prevOp = cast<ReshapeOp>(thisOp.getOperand(0).getDefiningOp());
|
||||||
|
|
||||||
// Replace
|
// Replace
|
||||||
// %1 = "tfl.reshape"(%0, %shape0)
|
// %1 = "tfl.reshape"(%0, %shape0)
|
||||||
@ -807,7 +807,7 @@ struct RemoveAdjacentReshape : public RewritePattern {
|
|||||||
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||||
// Remove identity reshape with both static result and input shape.
|
// Remove identity reshape with both static result and input shape.
|
||||||
auto result_type = getType().cast<ShapedType>();
|
auto result_type = getType().cast<ShapedType>();
|
||||||
auto input_type = getOperand(0)->getType().cast<ShapedType>();
|
auto input_type = getOperand(0).getType().cast<ShapedType>();
|
||||||
if (result_type.hasStaticShape() && result_type == input_type) {
|
if (result_type.hasStaticShape() && result_type == input_type) {
|
||||||
return getOperand(0);
|
return getOperand(0);
|
||||||
}
|
}
|
||||||
@ -865,7 +865,7 @@ struct RemoveRedundantUnpackPack : public RewritePattern {
|
|||||||
PatternMatchResult matchAndRewrite(Operation *op,
|
PatternMatchResult matchAndRewrite(Operation *op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
TFL::PackOp pack_op = cast<TFL::PackOp>(op);
|
TFL::PackOp pack_op = cast<TFL::PackOp>(op);
|
||||||
Operation *first_input = pack_op.getOperand(0)->getDefiningOp();
|
Operation *first_input = pack_op.getOperand(0).getDefiningOp();
|
||||||
if (!first_input) return matchFailure();
|
if (!first_input) return matchFailure();
|
||||||
auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input);
|
auto input_unpack_op = dyn_cast_or_null<TFL::UnpackOp>(first_input);
|
||||||
if (!input_unpack_op) return matchFailure();
|
if (!input_unpack_op) return matchFailure();
|
||||||
@ -905,9 +905,9 @@ void PackOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static LogicalResult Verify(SliceOp op) {
|
static LogicalResult Verify(SliceOp op) {
|
||||||
auto input_type = op.input()->getType().cast<ShapedType>();
|
auto input_type = op.input().getType().cast<ShapedType>();
|
||||||
auto begin_type = op.begin()->getType().cast<ShapedType>();
|
auto begin_type = op.begin().getType().cast<ShapedType>();
|
||||||
auto size_type = op.size()->getType().cast<ShapedType>();
|
auto size_type = op.size().getType().cast<ShapedType>();
|
||||||
if (input_type.hasStaticShape() && begin_type.hasStaticShape() &&
|
if (input_type.hasStaticShape() && begin_type.hasStaticShape() &&
|
||||||
size_type.hasStaticShape()) {
|
size_type.hasStaticShape()) {
|
||||||
if (input_type.getRank() != begin_type.getNumElements()) {
|
if (input_type.getRank() != begin_type.getNumElements()) {
|
||||||
@ -995,7 +995,7 @@ static void BuildTopKOp(Builder *builder, OperationState &result, Value input,
|
|||||||
// TODO(jpienaar): This should use a helper function.
|
// TODO(jpienaar): This should use a helper function.
|
||||||
const_k = cst.getValue<IntegerAttr>({}).getValue().getSExtValue();
|
const_k = cst.getValue<IntegerAttr>({}).getValue().getSExtValue();
|
||||||
|
|
||||||
auto val_type = input->getType().cast<TensorType>();
|
auto val_type = input.getType().cast<TensorType>();
|
||||||
// If value is unranked, then so is results.
|
// If value is unranked, then so is results.
|
||||||
if (!val_type.hasRank())
|
if (!val_type.hasRank())
|
||||||
return TFL::TopKV2Op::build(
|
return TFL::TopKV2Op::build(
|
||||||
@ -1035,7 +1035,7 @@ struct DropFakeQuant : public RewritePattern {
|
|||||||
// If all the users of this op have valid "minmax" attributes, it is matched
|
// If all the users of this op have valid "minmax" attributes, it is matched
|
||||||
// and can be removed.
|
// and can be removed.
|
||||||
auto fakeQuantOp = cast<FakeQuantOp>(op);
|
auto fakeQuantOp = cast<FakeQuantOp>(op);
|
||||||
for (auto *operand : fakeQuantOp.getResult()->getUsers())
|
for (auto *operand : fakeQuantOp.getResult().getUsers())
|
||||||
if (!HasValidMinMaxAttribute(operand)) return matchFailure();
|
if (!HasValidMinMaxAttribute(operand)) return matchFailure();
|
||||||
|
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
@ -1102,7 +1102,7 @@ static LogicalResult VerifySplitOpOutputTypes(
|
|||||||
for (int64_t i = 0; i < num_splits; ++i) {
|
for (int64_t i = 0; i < num_splits; ++i) {
|
||||||
auto expected_output_type = get_expected_output_type(i);
|
auto expected_output_type = get_expected_output_type(i);
|
||||||
Value output = op->getResult(i);
|
Value output = op->getResult(i);
|
||||||
auto output_type = output->getType().dyn_cast<RankedTensorType>();
|
auto output_type = output.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!output_type || output_type != expected_output_type)
|
if (!output_type || output_type != expected_output_type)
|
||||||
return op->emitOpError()
|
return op->emitOpError()
|
||||||
<< "output #" << i << " should be " << expected_output_type;
|
<< "output #" << i << " should be " << expected_output_type;
|
||||||
@ -1121,7 +1121,7 @@ static LogicalResult Verify(SplitOp op) {
|
|||||||
if (!split_dim_opt) return success();
|
if (!split_dim_opt) return success();
|
||||||
|
|
||||||
// If 'input' is not a ranked tensor, there are no other checks.
|
// If 'input' is not a ranked tensor, there are no other checks.
|
||||||
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
|
auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!input_type) return success();
|
if (!input_type) return success();
|
||||||
|
|
||||||
int64_t split_dim = split_dim_opt.getValue();
|
int64_t split_dim = split_dim_opt.getValue();
|
||||||
@ -1157,7 +1157,7 @@ static LogicalResult Verify(SplitVOp op) {
|
|||||||
if (!split_dim_opt) return success();
|
if (!split_dim_opt) return success();
|
||||||
|
|
||||||
// If 'input' is not a ranked tensor, there are no other checks.
|
// If 'input' is not a ranked tensor, there are no other checks.
|
||||||
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
|
auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!input_type) return success();
|
if (!input_type) return success();
|
||||||
|
|
||||||
int64_t split_dim = split_dim_opt.getValue();
|
int64_t split_dim = split_dim_opt.getValue();
|
||||||
@ -1177,8 +1177,7 @@ static LogicalResult Verify(SplitVOp op) {
|
|||||||
return success();
|
return success();
|
||||||
|
|
||||||
if (size_splits_attr.getNumElements() != num_splits) {
|
if (size_splits_attr.getNumElements() != num_splits) {
|
||||||
auto size_splits_type =
|
auto size_splits_type = op.size_splits().getType().cast<RankedTensorType>();
|
||||||
op.size_splits()->getType().cast<RankedTensorType>();
|
|
||||||
RankedTensorType expected_size_splits_type =
|
RankedTensorType expected_size_splits_type =
|
||||||
RankedTensorType::get({num_splits}, size_splits_type.getElementType());
|
RankedTensorType::get({num_splits}, size_splits_type.getElementType());
|
||||||
return op.emitOpError("'size_splits' should be ")
|
return op.emitOpError("'size_splits' should be ")
|
||||||
@ -1414,7 +1413,7 @@ OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Also fold if `input` has a known rank.
|
// Also fold if `input` has a known rank.
|
||||||
auto input_type = input()->getType().cast<ShapedType>();
|
auto input_type = input().getType().cast<ShapedType>();
|
||||||
// Do not fold if rank is zero because the TFLite converter doesn't
|
// Do not fold if rank is zero because the TFLite converter doesn't
|
||||||
// distinguish between unranked input and scalar input due to b/138865275.
|
// distinguish between unranked input and scalar input due to b/138865275.
|
||||||
// TODO(b/138865275): Remove `input_type.getRank() != 0` in the following
|
// TODO(b/138865275): Remove `input_type.getRank() != 0` in the following
|
||||||
@ -1445,18 +1444,18 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
static void BuildSelectV2Op(Builder *builder, OperationState &result,
|
static void BuildSelectV2Op(Builder *builder, OperationState &result,
|
||||||
Value cond, Value x, Value y) {
|
Value cond, Value x, Value y) {
|
||||||
auto operand_type =
|
auto operand_type =
|
||||||
OpTrait::util::getBroadcastedType(x->getType(), y->getType());
|
OpTrait::util::getBroadcastedType(x.getType(), y.getType());
|
||||||
|
|
||||||
if (!operand_type)
|
if (!operand_type)
|
||||||
emitError(result.location) << "non-broadcastable operands: " << x->getType()
|
emitError(result.location) << "non-broadcastable operands: " << x.getType()
|
||||||
<< " and " << y->getType();
|
<< " and " << y.getType();
|
||||||
|
|
||||||
bool has_static_cond_shape = false;
|
bool has_static_cond_shape = false;
|
||||||
bool has_static_operand_shape = false;
|
bool has_static_operand_shape = false;
|
||||||
ArrayRef<int64_t> cond_shape;
|
ArrayRef<int64_t> cond_shape;
|
||||||
ArrayRef<int64_t> operand_shape;
|
ArrayRef<int64_t> operand_shape;
|
||||||
|
|
||||||
if (auto shaped_type = cond->getType().dyn_cast<ShapedType>()) {
|
if (auto shaped_type = cond.getType().dyn_cast<ShapedType>()) {
|
||||||
if (shaped_type.hasStaticShape()) {
|
if (shaped_type.hasStaticShape()) {
|
||||||
has_static_cond_shape = true;
|
has_static_cond_shape = true;
|
||||||
cond_shape = shaped_type.getShape();
|
cond_shape = shaped_type.getShape();
|
||||||
@ -1474,12 +1473,12 @@ static void BuildSelectV2Op(Builder *builder, OperationState &result,
|
|||||||
!OpTrait::util::getBroadcastedShape(cond_shape, operand_shape,
|
!OpTrait::util::getBroadcastedShape(cond_shape, operand_shape,
|
||||||
broadcastedShape)) {
|
broadcastedShape)) {
|
||||||
emitError(result.location) << "non-broadcastable operands: " << operand_type
|
emitError(result.location) << "non-broadcastable operands: " << operand_type
|
||||||
<< " and " << cond->getType();
|
<< " and " << cond.getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
result.addOperands({cond, x, y});
|
result.addOperands({cond, x, y});
|
||||||
|
|
||||||
auto elementType = x->getType().dyn_cast<ShapedType>().getElementType();
|
auto elementType = x.getType().dyn_cast<ShapedType>().getElementType();
|
||||||
if (has_static_cond_shape && has_static_operand_shape) {
|
if (has_static_cond_shape && has_static_operand_shape) {
|
||||||
result.types.push_back(
|
result.types.push_back(
|
||||||
RankedTensorType::get(broadcastedShape, elementType));
|
RankedTensorType::get(broadcastedShape, elementType));
|
||||||
@ -1571,9 +1570,8 @@ OpFoldResult RangeOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static LogicalResult Verify(TransposeConvOp op) {
|
static LogicalResult Verify(TransposeConvOp op) {
|
||||||
ShapedType output_type = op.output()->getType().cast<ShapedType>();
|
ShapedType output_type = op.output().getType().cast<ShapedType>();
|
||||||
ShapedType output_shape_type =
|
ShapedType output_shape_type = op.output_shape().getType().cast<ShapedType>();
|
||||||
op.output_shape()->getType().cast<ShapedType>();
|
|
||||||
if (output_type.hasRank() && output_shape_type.hasStaticShape()) {
|
if (output_type.hasRank() && output_shape_type.hasStaticShape()) {
|
||||||
if (output_type.getRank() != output_shape_type.getDimSize(0)) {
|
if (output_type.getRank() != output_shape_type.getDimSize(0)) {
|
||||||
return op.emitOpError(llvm::formatv(
|
return op.emitOpError(llvm::formatv(
|
||||||
@ -1679,9 +1677,9 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult Verify(TransposeOp op) {
|
static LogicalResult Verify(TransposeOp op) {
|
||||||
auto input_type = op.x()->getType().cast<ShapedType>();
|
auto input_type = op.x().getType().cast<ShapedType>();
|
||||||
auto perm_type = op.perm()->getType().cast<ShapedType>();
|
auto perm_type = op.perm().getType().cast<ShapedType>();
|
||||||
auto output_type = op.y()->getType().cast<ShapedType>();
|
auto output_type = op.y().getType().cast<ShapedType>();
|
||||||
if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
|
if (input_type.hasStaticShape() && perm_type.hasStaticShape()) {
|
||||||
if (perm_type.getNumElements() != input_type.getRank()) {
|
if (perm_type.getNumElements() != input_type.getRank()) {
|
||||||
return op.emitOpError(
|
return op.emitOpError(
|
||||||
|
@ -135,7 +135,7 @@ def TFL_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TFL_Int32Or64]>;
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
class TFL_OperandIsUnrankedPred<int n> :
|
class TFL_OperandIsUnrankedPred<int n> :
|
||||||
CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">;
|
CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">;
|
||||||
|
|
||||||
// TODO: Some of these could be generalized and/or moved to more general
|
// TODO: Some of these could be generalized and/or moved to more general
|
||||||
// location.
|
// location.
|
||||||
@ -144,38 +144,38 @@ class TFL_OperandHasRank<int n, int m> :
|
|||||||
PredOpTrait<"operand " # n # " is " # m # "-D",
|
PredOpTrait<"operand " # n # " is " # m # "-D",
|
||||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||||
CPred<"$_op.getOperand(" # n #
|
CPred<"$_op.getOperand(" # n #
|
||||||
")->getType().cast<ShapedType>().getRank() == " # m>]>>;
|
").getType().cast<ShapedType>().getRank() == " # m>]>>;
|
||||||
|
|
||||||
// Returns true if the n-th operand is ranked and has rank dim.
|
// Returns true if the n-th operand is ranked and has rank dim.
|
||||||
class TFL_OperandHasKnownRank<int n, int dim> : And<[
|
class TFL_OperandHasKnownRank<int n, int dim> : And<[
|
||||||
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">,
|
CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
|
||||||
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() == "
|
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() == "
|
||||||
# dim>]>;
|
# dim>]>;
|
||||||
|
|
||||||
// True if operand n is ranked and has a rank > dim.
|
// True if operand n is ranked and has a rank > dim.
|
||||||
class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[
|
class TFL_OperandIsRankedAndHasDimPred<int n, int dim> : And<[
|
||||||
CPred<"$_op.getOperand(" # n # ")->getType().isa<RankedTensorType>()">,
|
CPred<"$_op.getOperand(" # n # ").getType().isa<RankedTensorType>()">,
|
||||||
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>().getRank() > "
|
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>().getRank() > "
|
||||||
# dim>]>;
|
# dim>]>;
|
||||||
|
|
||||||
class TFL_OperandDimEquals<int n, int dim, int size> : And<[
|
class TFL_OperandDimEquals<int n, int dim, int size> : And<[
|
||||||
TFL_OperandIsRankedAndHasDimPred<n, dim>,
|
TFL_OperandIsRankedAndHasDimPred<n, dim>,
|
||||||
CPred<"$_op.getOperand(" # n # ")->getType().cast<ShapedType>()"
|
CPred<"$_op.getOperand(" # n # ").getType().cast<ShapedType>()"
|
||||||
".getShape()[" # dim # " ] == " # size>]>;
|
".getShape()[" # dim # " ] == " # size>]>;
|
||||||
|
|
||||||
// Returns true if the n-th operand has unknown rank or at least rank m.
|
// Returns true if the n-th operand has unknown rank or at least rank m.
|
||||||
class TFL_OperandHasAtleastRank<int n, int m> :
|
class TFL_OperandHasAtleastRank<int n, int m> :
|
||||||
PredOpTrait<"operand " # n # " is " # m # "-D",
|
PredOpTrait<"operand " # n # " is " # m # "-D",
|
||||||
Or<[CPred<"$_op.getOperand(" # n # ")->getType().isa<UnrankedTensorType>()">,
|
Or<[CPred<"$_op.getOperand(" # n # ").getType().isa<UnrankedTensorType>()">,
|
||||||
CPred<"$_op.getOperand(" # n #
|
CPred<"$_op.getOperand(" # n #
|
||||||
")->getType().cast<ShapedType>().getRank() >= " # m>]>>;
|
").getType().cast<ShapedType>().getRank() >= " # m>]>>;
|
||||||
|
|
||||||
class TFL_OperandRankEquals1DimOfOperand<int x, int y> :
|
class TFL_OperandRankEquals1DimOfOperand<int x, int y> :
|
||||||
PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size",
|
PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size",
|
||||||
CPred<"$_op.getOperand(" # x #
|
CPred<"$_op.getOperand(" # x #
|
||||||
")->getType().cast<ShapedType>().getRank() == "
|
").getType().cast<ShapedType>().getRank() == "
|
||||||
"$_op.getOperand(" # y #
|
"$_op.getOperand(" # y #
|
||||||
")->getType().cast<ShapedType>().getShape()[0]">>;
|
").getType().cast<ShapedType>().getShape()[0]">>;
|
||||||
|
|
||||||
class TFL_Operand0DOr1ElementTensor<int x> :
|
class TFL_Operand0DOr1ElementTensor<int x> :
|
||||||
PredOpTrait<"operand #" # x # " is an 0-d tensor or 1-d tensor w/ 1 element",
|
PredOpTrait<"operand #" # x # " is an 0-d tensor or 1-d tensor w/ 1 element",
|
||||||
@ -195,7 +195,7 @@ class TFL_OperandHasRankLessThan<int n, int m> :
|
|||||||
PredOpTrait<"operand " # n # " is maximum " # m # "-D",
|
PredOpTrait<"operand " # n # " is maximum " # m # "-D",
|
||||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||||
CPred<"$_op.getOperand(" # n #
|
CPred<"$_op.getOperand(" # n #
|
||||||
")->getType().cast<ShapedType>().getRank() <= " # m>]>>;
|
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
|
||||||
|
|
||||||
// This is a quantization-aware version of TCresVTEtIsSameAsOp
|
// This is a quantization-aware version of TCresVTEtIsSameAsOp
|
||||||
class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
|
class TFL_TCresVTEtIsSameAsOp<int i, int j> : And<[
|
||||||
@ -227,7 +227,7 @@ def TFL_BroadcastableBinaryBuilder : OpBuilder<
|
|||||||
"Builder *builder, OperationState &result, Value lhs, Value rhs",
|
"Builder *builder, OperationState &result, Value lhs, Value rhs",
|
||||||
[{
|
[{
|
||||||
auto resultType =
|
auto resultType =
|
||||||
OpTrait::util::getBroadcastedType(lhs->getType(), rhs->getType());
|
OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType());
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
mlir::emitError(result.location, "non-broadcastable operands");
|
mlir::emitError(result.location, "non-broadcastable operands");
|
||||||
result.addOperands({lhs, rhs});
|
result.addOperands({lhs, rhs});
|
||||||
@ -471,7 +471,7 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> {
|
|||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
|
|
||||||
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
|
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
|
||||||
return getResult()->getType().cast<TensorType>().getElementType().
|
return getResult().getType().cast<TensorType>().getElementType().
|
||||||
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
|
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
|
||||||
tflite::TensorType_INT32;
|
tflite::TensorType_INT32;
|
||||||
}]>;
|
}]>;
|
||||||
@ -500,7 +500,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
|
|||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
|
|
||||||
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
|
DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{
|
||||||
return getResult()->getType().cast<TensorType>().getElementType().
|
return getResult().getType().cast<TensorType>().getElementType().
|
||||||
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
|
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
|
||||||
tflite::TensorType_INT32;
|
tflite::TensorType_INT32;
|
||||||
}]>;
|
}]>;
|
||||||
@ -1996,7 +1996,7 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect]> {
|
|||||||
let results = (outs AnyTensor:$output);
|
let results = (outs AnyTensor:$output);
|
||||||
|
|
||||||
DerivedTypeAttr out_type = DerivedTypeAttr<[{
|
DerivedTypeAttr out_type = DerivedTypeAttr<[{
|
||||||
return getResult()->getType().cast<TensorType>().getElementType();
|
return getResult().getType().cast<TensorType>().getElementType();
|
||||||
}]>;
|
}]>;
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
@ -2083,7 +2083,7 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect,
|
|||||||
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
|
let builders = [OpBuilder<"Builder *builder, OperationState &result, "
|
||||||
"Value condition, Value x, Value y",
|
"Value condition, Value x, Value y",
|
||||||
[{
|
[{
|
||||||
auto resultType = x->getType();
|
auto resultType = x.getType();
|
||||||
result.addOperands({condition, x, y});
|
result.addOperands({condition, x, y});
|
||||||
result.types.push_back(resultType);
|
result.types.push_back(resultType);
|
||||||
}]>];
|
}]>];
|
||||||
@ -2733,7 +2733,7 @@ in the unique output `y`. In other words:
|
|||||||
);
|
);
|
||||||
|
|
||||||
DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{
|
DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{
|
||||||
return getResult(1)->getType().cast<TensorType>().getElementType().
|
return getResult(1).getType().cast<TensorType>().getElementType().
|
||||||
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
|
cast<IntegerType>().getWidth() > 32 ? tflite::TensorType_INT64 :
|
||||||
tflite::TensorType_INT32;
|
tflite::TensorType_INT32;
|
||||||
}]>;
|
}]>;
|
||||||
|
@ -78,8 +78,8 @@ class ImportQuantStatsPass : public FunctionPass<ImportQuantStatsPass> {
|
|||||||
bool IsQuantizableResult(Operation *op, int index) {
|
bool IsQuantizableResult(Operation *op, int index) {
|
||||||
if (index < 0 || index >= op->getNumResults()) return false;
|
if (index < 0 || index >= op->getNumResults()) return false;
|
||||||
Value res = op->getResult(index);
|
Value res = op->getResult(index);
|
||||||
return res->getType().isa<ShapedType>() &&
|
return res.getType().isa<ShapedType>() &&
|
||||||
res->getType().cast<ShapedType>().getElementType().isa<FloatType>();
|
res.getType().cast<ShapedType>().getElementType().isa<FloatType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// A method to retrieve the name for the given op.
|
// A method to retrieve the name for the given op.
|
||||||
@ -123,7 +123,7 @@ void ImportQuantStatsPass::InsertStatsOpAtResult(OpBuilder b, Value res,
|
|||||||
IntegerAttr axis) {
|
IntegerAttr axis) {
|
||||||
auto stats_op = b.create<quant::StatisticsOp>(b.getUnknownLoc(), res,
|
auto stats_op = b.create<quant::StatisticsOp>(b.getUnknownLoc(), res,
|
||||||
layer_stats, axis_stats, axis);
|
layer_stats, axis_stats, axis);
|
||||||
res->replaceAllUsesWith(stats_op);
|
res.replaceAllUsesWith(stats_op);
|
||||||
stats_op.getOperation()->replaceUsesOfWith(stats_op, res);
|
stats_op.getOperation()->replaceUsesOfWith(stats_op, res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,14 +146,14 @@ class QuantizationDriver {
|
|||||||
|
|
||||||
// Adds all the users of index-th result of op to the work list.
|
// Adds all the users of index-th result of op to the work list.
|
||||||
void AddUserToList(Operation *op, int index) {
|
void AddUserToList(Operation *op, int index) {
|
||||||
for (auto *user : op->getResult(index)->getUsers()) {
|
for (auto *user : op->getResult(index).getUsers()) {
|
||||||
work_list_.push_back(user);
|
work_list_.push_back(user);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adds the defining op of index-th operand of op to the work list.
|
// Adds the defining op of index-th operand of op to the work list.
|
||||||
void AddOperandToList(Operation *op, int index) {
|
void AddOperandToList(Operation *op, int index) {
|
||||||
if (auto *inst = op->getOperand(index)->getDefiningOp()) {
|
if (auto *inst = op->getOperand(index).getDefiningOp()) {
|
||||||
work_list_.push_back(inst);
|
work_list_.push_back(inst);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -248,7 +248,7 @@ class QuantizationDriver {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
QuantParams params =
|
QuantParams params =
|
||||||
quant::QuantizedType::getQuantizedElementType(in->getType());
|
quant::QuantizedType::getQuantizedElementType(in.getType());
|
||||||
bool immutable = !EmptyParams(params);
|
bool immutable = !EmptyParams(params);
|
||||||
int next_state_index = states_.size();
|
int next_state_index = states_.size();
|
||||||
states_.push_back({params, immutable});
|
states_.push_back({params, immutable});
|
||||||
@ -338,7 +338,7 @@ bool QuantizationDriver::IsQuantized(Operation *op) {
|
|||||||
int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
|
int QuantizationDriver::InitializeState(Operation *op, int index, Value val,
|
||||||
bool as_result) {
|
bool as_result) {
|
||||||
QuantParams params =
|
QuantParams params =
|
||||||
quant::QuantizedType::getQuantizedElementType(val->getType());
|
quant::QuantizedType::getQuantizedElementType(val.getType());
|
||||||
bool immutable = !EmptyParams(params);
|
bool immutable = !EmptyParams(params);
|
||||||
int next_state_index = states_.size();
|
int next_state_index = states_.size();
|
||||||
states_.push_back({params, immutable});
|
states_.push_back({params, immutable});
|
||||||
@ -447,13 +447,13 @@ void QuantizationDriver::QuantizeOpResult(Operation *op, int index,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void QuantizationDriver::QuantizeArg(BlockArgument arg, QuantParams params) {
|
void QuantizationDriver::QuantizeArg(BlockArgument arg, QuantParams params) {
|
||||||
builder_.setInsertionPointToStart(arg->getOwner());
|
builder_.setInsertionPointToStart(arg.getOwner());
|
||||||
QuantizeValue(arg, params, builder_.getUnknownLoc());
|
QuantizeValue(arg, params, builder_.getUnknownLoc());
|
||||||
}
|
}
|
||||||
|
|
||||||
void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
|
void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
Type expressed_type = value->getType();
|
Type expressed_type = value.getType();
|
||||||
Type new_type = params.castFromExpressedType(expressed_type);
|
Type new_type = params.castFromExpressedType(expressed_type);
|
||||||
// This value isn't an expressed type (float), skip.
|
// This value isn't an expressed type (float), skip.
|
||||||
if (!new_type) return;
|
if (!new_type) return;
|
||||||
@ -465,7 +465,7 @@ void QuantizationDriver::QuantizeValue(Value value, QuantParams params,
|
|||||||
quantize.output());
|
quantize.output());
|
||||||
// `original_result` has a use to `quantize`, so this will replace that use
|
// `original_result` has a use to `quantize`, so this will replace that use
|
||||||
// by the result of `dequantize`. Remember to reset that use afterwards
|
// by the result of `dequantize`. Remember to reset that use afterwards
|
||||||
value->replaceAllUsesWith(dequantize);
|
value.replaceAllUsesWith(dequantize);
|
||||||
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
|
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -475,7 +475,7 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
|
|||||||
builder_.setInsertionPointAfter(op);
|
builder_.setInsertionPointAfter(op);
|
||||||
Value value = op->getResult(index);
|
Value value = op->getResult(index);
|
||||||
if (state->pos == RequantizeState::ON_OUTPUT) {
|
if (state->pos == RequantizeState::ON_OUTPUT) {
|
||||||
Operation *user = value->getUses().begin().getUser();
|
Operation *user = value.getUses().begin().getUser();
|
||||||
if (llvm::isa<TFL::QuantizeOp>(user)) {
|
if (llvm::isa<TFL::QuantizeOp>(user)) {
|
||||||
// The requantize op is inserted between `quantize` and `dequantize` ops.
|
// The requantize op is inserted between `quantize` and `dequantize` ops.
|
||||||
value = user->getResult(0);
|
value = user->getResult(0);
|
||||||
@ -488,12 +488,12 @@ void QuantizationDriver::RequantizeOpResult(Operation *op, int index,
|
|||||||
void QuantizationDriver::RequantizeArg(BlockArgument arg,
|
void QuantizationDriver::RequantizeArg(BlockArgument arg,
|
||||||
RequantizeState *state) {
|
RequantizeState *state) {
|
||||||
Value value = arg;
|
Value value = arg;
|
||||||
builder_.setInsertionPointToStart(arg->getOwner());
|
builder_.setInsertionPointToStart(arg.getOwner());
|
||||||
if (value->hasOneUse()) {
|
if (value.hasOneUse()) {
|
||||||
auto user = value->use_begin().getUser();
|
auto user = value.use_begin().getUser();
|
||||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||||
value = q.output();
|
value = q.output();
|
||||||
builder_.setInsertionPoint(arg->getOwner(), ++Block::iterator(user));
|
builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
RequantizeValue(value, state, builder_.getUnknownLoc());
|
RequantizeValue(value, state, builder_.getUnknownLoc());
|
||||||
@ -503,13 +503,13 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
|
|||||||
Location loc) {
|
Location loc) {
|
||||||
Type new_type;
|
Type new_type;
|
||||||
if (state->pos == RequantizeState::ON_INPUT) {
|
if (state->pos == RequantizeState::ON_INPUT) {
|
||||||
Type expressed_type = value->getType();
|
Type expressed_type = value.getType();
|
||||||
// The value needs to be requantized. A Quantize op will be created to use
|
// The value needs to be requantized. A Quantize op will be created to use
|
||||||
// it as the operand and replace its uses.
|
// it as the operand and replace its uses.
|
||||||
new_type = state->params.castFromExpressedType(expressed_type);
|
new_type = state->params.castFromExpressedType(expressed_type);
|
||||||
} else {
|
} else {
|
||||||
Type expressed_type =
|
Type expressed_type =
|
||||||
quant::QuantizedType::castToExpressedType(value->getType());
|
quant::QuantizedType::castToExpressedType(value.getType());
|
||||||
if (!expressed_type) return;
|
if (!expressed_type) return;
|
||||||
|
|
||||||
// The value needs to be requantized. A Quantize op will be created to use
|
// The value needs to be requantized. A Quantize op will be created to use
|
||||||
@ -522,7 +522,7 @@ void QuantizationDriver::RequantizeValue(Value value, RequantizeState *state,
|
|||||||
TypeAttr type_attr = TypeAttr::get(new_type);
|
TypeAttr type_attr = TypeAttr::get(new_type);
|
||||||
auto requantize_op =
|
auto requantize_op =
|
||||||
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
|
builder_.create<TFL::QuantizeOp>(loc, new_type, value, type_attr);
|
||||||
value->replaceAllUsesWith(requantize_op);
|
value.replaceAllUsesWith(requantize_op);
|
||||||
requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
|
requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -603,7 +603,7 @@ void QuantizationDriver::PreprocessConstantOps() {
|
|||||||
Value value = cst.getResult();
|
Value value = cst.getResult();
|
||||||
SmallVector<std::pair<Operation *, int>, 4> bias_users;
|
SmallVector<std::pair<Operation *, int>, 4> bias_users;
|
||||||
bool used_as_weight = false;
|
bool used_as_weight = false;
|
||||||
for (auto &use : value->getUses()) {
|
for (auto &use : value.getUses()) {
|
||||||
auto spec = GetQuantSpec(use.getOwner());
|
auto spec = GetQuantSpec(use.getOwner());
|
||||||
auto biases = spec->biases_params;
|
auto biases = spec->biases_params;
|
||||||
Operation *user = use.getOwner();
|
Operation *user = use.getOwner();
|
||||||
@ -649,8 +649,8 @@ void QuantizationDriver::SetupAllStates() {
|
|||||||
args_.push_back(arg);
|
args_.push_back(arg);
|
||||||
Value value = arg;
|
Value value = arg;
|
||||||
// If the argument is quantized, it should only has one user.
|
// If the argument is quantized, it should only has one user.
|
||||||
if (arg->hasOneUse()) {
|
if (arg.hasOneUse()) {
|
||||||
auto user = value->use_begin().getUser();
|
auto user = value.use_begin().getUser();
|
||||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||||
value = q.output();
|
value = q.output();
|
||||||
}
|
}
|
||||||
@ -666,7 +666,7 @@ void QuantizationDriver::SetupAllStates() {
|
|||||||
|
|
||||||
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
|
||||||
auto operand = op->getOperand(i);
|
auto operand = op->getOperand(i);
|
||||||
if (auto *inst = operand->getDefiningOp()) {
|
if (auto *inst = operand.getDefiningOp()) {
|
||||||
// If the operand comes from a tfl.dequantize op, we use the quantized
|
// If the operand comes from a tfl.dequantize op, we use the quantized
|
||||||
// input of this tfl.dequantize op to set the state.
|
// input of this tfl.dequantize op to set the state.
|
||||||
if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
|
if (auto dq = llvm::dyn_cast<TFL::DequantizeOp>(inst)) {
|
||||||
@ -681,8 +681,8 @@ void QuantizationDriver::SetupAllStates() {
|
|||||||
// If the result has been quantized, it should only be used by a
|
// If the result has been quantized, it should only be used by a
|
||||||
// tfl.quantize op. For this case, we uses the quantized result to
|
// tfl.quantize op. For this case, we uses the quantized result to
|
||||||
// create the state and mark it immutable.
|
// create the state and mark it immutable.
|
||||||
if (result->hasOneUse()) {
|
if (result.hasOneUse()) {
|
||||||
auto user = result->use_begin().getUser();
|
auto user = result.use_begin().getUser();
|
||||||
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
if (auto q = llvm::dyn_cast<TFL::QuantizeOp>(user)) {
|
||||||
result = q.output();
|
result = q.output();
|
||||||
}
|
}
|
||||||
|
@ -70,7 +70,7 @@ class FixedResultUniformScale {
|
|||||||
QuantizedType GetResultQuantizedType(int index) {
|
QuantizedType GetResultQuantizedType(int index) {
|
||||||
auto op = this->getOperation();
|
auto op = this->getOperation();
|
||||||
auto result_type =
|
auto result_type =
|
||||||
op->getResult(index)->getType().template cast<TensorType>();
|
op->getResult(index).getType().template cast<TensorType>();
|
||||||
Builder builder(op->getContext());
|
Builder builder(op->getContext());
|
||||||
IntegerType storage_type = builder.getIntegerType(BitWidth);
|
IntegerType storage_type = builder.getIntegerType(BitWidth);
|
||||||
const double scale = static_cast<double>(ScaleMantissa) *
|
const double scale = static_cast<double>(ScaleMantissa) *
|
||||||
|
@ -367,7 +367,7 @@ ElementsAttr Quantize(Attribute real_value, Type tensor_type) {
|
|||||||
static bool PreferResultScale(Operation* op) {
|
static bool PreferResultScale(Operation* op) {
|
||||||
int float_operands = 0;
|
int float_operands = 0;
|
||||||
for (auto operand : op->getOperands()) {
|
for (auto operand : op->getOperands()) {
|
||||||
if (auto operand_type = operand->getType().dyn_cast<ShapedType>()) {
|
if (auto operand_type = operand.getType().dyn_cast<ShapedType>()) {
|
||||||
if (operand_type.getElementType().isa<FloatType>()) {
|
if (operand_type.getElementType().isa<FloatType>()) {
|
||||||
if (float_operands++ > 1) return true;
|
if (float_operands++ > 1) return true;
|
||||||
}
|
}
|
||||||
@ -400,22 +400,22 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
|||||||
quant::StatisticsOp stats_op = all_stats_ops.back();
|
quant::StatisticsOp stats_op = all_stats_ops.back();
|
||||||
all_stats_ops.pop_back();
|
all_stats_ops.pop_back();
|
||||||
|
|
||||||
if (auto def = stats_op.arg()->getDefiningOp()) {
|
if (auto def = stats_op.arg().getDefiningOp()) {
|
||||||
if (IsStatsRedundant(def, op_quant_spec_getter)) {
|
if (IsStatsRedundant(def, op_quant_spec_getter)) {
|
||||||
redundant_stats_ops.insert(stats_op);
|
redundant_stats_ops.insert(stats_op);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto user : stats_op.getResult()->getUsers()) {
|
for (auto user : stats_op.getResult().getUsers()) {
|
||||||
// We don't propagate this parameter down if it has multiple operands.
|
// We don't propagate this parameter down if it has multiple operands.
|
||||||
// We want to use the result parameter scales instead.
|
// We want to use the result parameter scales instead.
|
||||||
|
|
||||||
if (user->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
|
if (user->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
|
||||||
!PreferResultScale(user)) {
|
!PreferResultScale(user)) {
|
||||||
for (Value res : user->getResults()) {
|
for (Value res : user->getResults()) {
|
||||||
if (res->hasOneUse()) {
|
if (res.hasOneUse()) {
|
||||||
if (auto next_stats = llvm::dyn_cast<quant::StatisticsOp>(
|
if (auto next_stats = llvm::dyn_cast<quant::StatisticsOp>(
|
||||||
*res->getUsers().begin())) {
|
*res.getUsers().begin())) {
|
||||||
// quantization parameters can be propagated to next_stats
|
// quantization parameters can be propagated to next_stats
|
||||||
redundant_stats_ops.insert(next_stats);
|
redundant_stats_ops.insert(next_stats);
|
||||||
// add next_stats to the work list so propagation can
|
// add next_stats to the work list so propagation can
|
||||||
@ -440,12 +440,12 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
|||||||
quant::StatisticsOp stats_op = all_stats_ops.back();
|
quant::StatisticsOp stats_op = all_stats_ops.back();
|
||||||
all_stats_ops.pop_back();
|
all_stats_ops.pop_back();
|
||||||
|
|
||||||
if (auto def = stats_op.arg()->getDefiningOp()) {
|
if (auto def = stats_op.arg().getDefiningOp()) {
|
||||||
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
|
if (def->hasTrait<OpTrait::quant::SameOperandsAndResultsScale>() &&
|
||||||
PreferResultScale(def)) {
|
PreferResultScale(def)) {
|
||||||
for (auto input : def->getOperands()) {
|
for (auto input : def->getOperands()) {
|
||||||
if (auto next_stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(
|
if (auto next_stats = llvm::dyn_cast_or_null<quant::StatisticsOp>(
|
||||||
input->getDefiningOp())) {
|
input.getDefiningOp())) {
|
||||||
redundant_stats_ops.insert(next_stats);
|
redundant_stats_ops.insert(next_stats);
|
||||||
all_stats_ops.push_back(next_stats);
|
all_stats_ops.push_back(next_stats);
|
||||||
}
|
}
|
||||||
@ -458,7 +458,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func,
|
|||||||
for (auto it : redundant_stats_ops) {
|
for (auto it : redundant_stats_ops) {
|
||||||
if (!llvm::isa<quant::StatisticsOp>(it)) return true;
|
if (!llvm::isa<quant::StatisticsOp>(it)) return true;
|
||||||
auto stats_op = llvm::cast<quant::StatisticsOp>(it);
|
auto stats_op = llvm::cast<quant::StatisticsOp>(it);
|
||||||
stats_op.getResult()->replaceAllUsesWith(stats_op.arg());
|
stats_op.getResult().replaceAllUsesWith(stats_op.arg());
|
||||||
stats_op.erase();
|
stats_op.erase();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -116,7 +116,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quant::StatisticsOp> {
|
|||||||
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg(),
|
auto q = rewriter.create<Q>(op.getLoc(), result_type, op.arg(),
|
||||||
TypeAttr::get(result_type));
|
TypeAttr::get(result_type));
|
||||||
auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
|
auto dq = rewriter.create<DQ>(op.getLoc(), op.getType(), q);
|
||||||
op.getResult()->replaceAllUsesWith(dq);
|
op.getResult().replaceAllUsesWith(dq);
|
||||||
q.getOperation()->replaceUsesOfWith(dq, op.arg());
|
q.getOperation()->replaceUsesOfWith(dq, op.arg());
|
||||||
op.erase();
|
op.erase();
|
||||||
|
|
||||||
@ -162,7 +162,7 @@ struct QuantizationPattern : public RewritePattern {
|
|||||||
return matchFailure();
|
return matchFailure();
|
||||||
}
|
}
|
||||||
Value quantized_value = op->getResult(0);
|
Value quantized_value = op->getResult(0);
|
||||||
for (Operation* quantized_op : quantized_value->getUsers()) {
|
for (Operation* quantized_op : quantized_value.getUsers()) {
|
||||||
// If it is requantize op, we shouldn't rewrite this op.
|
// If it is requantize op, we shouldn't rewrite this op.
|
||||||
if (llvm::isa<Q>(quantized_op) || llvm::isa<DQ>(quantized_op)) {
|
if (llvm::isa<Q>(quantized_op) || llvm::isa<DQ>(quantized_op)) {
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
@ -179,14 +179,14 @@ struct QuantizationPattern : public RewritePattern {
|
|||||||
SmallVector<Value, 4> inputs;
|
SmallVector<Value, 4> inputs;
|
||||||
inputs.reserve(quantized_op->getNumOperands());
|
inputs.reserve(quantized_op->getNumOperands());
|
||||||
for (auto operand : quantized_op->getOperands()) {
|
for (auto operand : quantized_op->getOperands()) {
|
||||||
Type operand_type = operand->getType();
|
Type operand_type = operand.getType();
|
||||||
if (operand_type.isa<NoneType>()) {
|
if (operand_type.isa<NoneType>()) {
|
||||||
inputs.push_back(operand);
|
inputs.push_back(operand);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto ele_type = operand->getType().cast<TensorType>().getElementType();
|
auto ele_type = operand.getType().cast<TensorType>().getElementType();
|
||||||
if (auto op_inst = dyn_cast_or_null<DQ>(operand->getDefiningOp())) {
|
if (auto op_inst = dyn_cast_or_null<DQ>(operand.getDefiningOp())) {
|
||||||
inputs.push_back(op_inst.input());
|
inputs.push_back(op_inst.input());
|
||||||
} else if (ele_type.isa<IntegerType>()) {
|
} else if (ele_type.isa<IntegerType>()) {
|
||||||
// If the operand is an integer tensor, then it doesn't require the
|
// If the operand is an integer tensor, then it doesn't require the
|
||||||
@ -207,7 +207,7 @@ struct QuantizationPattern : public RewritePattern {
|
|||||||
for (auto enumerated_result :
|
for (auto enumerated_result :
|
||||||
llvm::enumerate(quantized_op->getResults())) {
|
llvm::enumerate(quantized_op->getResults())) {
|
||||||
Value result = enumerated_result.value();
|
Value result = enumerated_result.value();
|
||||||
Type result_type = result->getType();
|
Type result_type = result.getType();
|
||||||
// Add this to the test coverage once we create test ops with none type
|
// Add this to the test coverage once we create test ops with none type
|
||||||
// results.
|
// results.
|
||||||
if (result_type.isa<NoneType>()) {
|
if (result_type.isa<NoneType>()) {
|
||||||
@ -216,20 +216,20 @@ struct QuantizationPattern : public RewritePattern {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Type result_ele_type =
|
Type result_ele_type =
|
||||||
result->getType().cast<TensorType>().getElementType();
|
result.getType().cast<TensorType>().getElementType();
|
||||||
// If the user is the Quantize op, it must be the only user.
|
// If the user is the Quantize op, it must be the only user.
|
||||||
if (result->hasOneUse() && llvm::isa<Q>(*result->user_begin())) {
|
if (result.hasOneUse() && llvm::isa<Q>(*result.user_begin())) {
|
||||||
auto user = llvm::cast<Q>(*result->user_begin());
|
auto user = llvm::cast<Q>(*result.user_begin());
|
||||||
outputs_replaced.insert({user.output(), enumerated_result.index()});
|
outputs_replaced.insert({user.output(), enumerated_result.index()});
|
||||||
output_types.push_back(user.getType());
|
output_types.push_back(user.getType());
|
||||||
} else if (result_ele_type.template isa<IntegerType>()) {
|
} else if (result_ele_type.template isa<IntegerType>()) {
|
||||||
// If the result is an integer tensor, then it doesn't require the
|
// If the result is an integer tensor, then it doesn't require the
|
||||||
// D op in the pattern.
|
// D op in the pattern.
|
||||||
outputs_replaced.insert({result, enumerated_result.index()});
|
outputs_replaced.insert({result, enumerated_result.index()});
|
||||||
output_types.push_back(result->getType());
|
output_types.push_back(result.getType());
|
||||||
} else if (static_cast<const ConcretTy*>(this)->AllowHybridResult()) {
|
} else if (static_cast<const ConcretTy*>(this)->AllowHybridResult()) {
|
||||||
outputs_replaced.insert({result, enumerated_result.index()});
|
outputs_replaced.insert({result, enumerated_result.index()});
|
||||||
output_types.push_back(result->getType());
|
output_types.push_back(result.getType());
|
||||||
} else {
|
} else {
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
}
|
}
|
||||||
@ -241,7 +241,7 @@ struct QuantizationPattern : public RewritePattern {
|
|||||||
output_types, quantized_op->getAttrs());
|
output_types, quantized_op->getAttrs());
|
||||||
Operation* new_op = rewriter.createOperation(new_state);
|
Operation* new_op = rewriter.createOperation(new_state);
|
||||||
for (auto output : outputs_replaced) {
|
for (auto output : outputs_replaced) {
|
||||||
output.getFirst()->replaceAllUsesWith(
|
output.getFirst().replaceAllUsesWith(
|
||||||
new_op->getResult(output.getSecond()));
|
new_op->getResult(output.getSecond()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -252,7 +252,7 @@ struct QuantizationPattern : public RewritePattern {
|
|||||||
// For constant operands, the floating-point constant is duplicated in
|
// For constant operands, the floating-point constant is duplicated in
|
||||||
// case it is quantized.
|
// case it is quantized.
|
||||||
for (int i = 0, e = new_op->getNumOperands(); i != e; ++i) {
|
for (int i = 0, e = new_op->getNumOperands(); i != e; ++i) {
|
||||||
auto def = new_op->getOperand(i)->getDefiningOp();
|
auto def = new_op->getOperand(i).getDefiningOp();
|
||||||
if (auto q = llvm::dyn_cast_or_null<Q>(def)) {
|
if (auto q = llvm::dyn_cast_or_null<Q>(def)) {
|
||||||
DenseFPElementsAttr attr;
|
DenseFPElementsAttr attr;
|
||||||
if (!matchPattern(q.input(), m_Constant(&attr))) {
|
if (!matchPattern(q.input(), m_Constant(&attr))) {
|
||||||
@ -265,7 +265,7 @@ struct QuantizationPattern : public RewritePattern {
|
|||||||
|
|
||||||
for (int i = 0, e = new_op->getNumResults(); i != e; ++i) {
|
for (int i = 0, e = new_op->getNumResults(); i != e; ++i) {
|
||||||
if (!quantized_op->getResult(i)
|
if (!quantized_op->getResult(i)
|
||||||
->getType()
|
.getType()
|
||||||
.cast<ShapedType>()
|
.cast<ShapedType>()
|
||||||
.getElementType()
|
.getElementType()
|
||||||
.isa<FloatType>()) {
|
.isa<FloatType>()) {
|
||||||
@ -283,13 +283,13 @@ struct QuantizationPattern : public RewritePattern {
|
|||||||
// Find the Dequantize/Dequantize users of the new op results, and
|
// Find the Dequantize/Dequantize users of the new op results, and
|
||||||
// replace the usage. Then all the floating-point ops are connected.
|
// replace the usage. Then all the floating-point ops are connected.
|
||||||
// N.B. the return op will use this floating-point result.
|
// N.B. the return op will use this floating-point result.
|
||||||
for (auto user : new_op->getResult(i)->getUsers()) {
|
for (auto user : new_op->getResult(i).getUsers()) {
|
||||||
// Skip the Requantize op, and we know it has a single user.
|
// Skip the Requantize op, and we know it has a single user.
|
||||||
if (llvm::isa<Q>(user)) {
|
if (llvm::isa<Q>(user)) {
|
||||||
user = *user->getResult(0)->getUsers().begin();
|
user = *user->getResult(0).getUsers().begin();
|
||||||
}
|
}
|
||||||
if (auto dequantize = llvm::dyn_cast<DQ>(user)) {
|
if (auto dequantize = llvm::dyn_cast<DQ>(user)) {
|
||||||
dequantize.getResult()->replaceAllUsesWith(
|
dequantize.getResult().replaceAllUsesWith(
|
||||||
quantized_op->getResult(i));
|
quantized_op->getResult(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -316,7 +316,7 @@ struct ConvertUnsignedToSigned : public OpRewritePattern<Q> {
|
|||||||
|
|
||||||
PatternMatchResult matchAndRewrite(Q op,
|
PatternMatchResult matchAndRewrite(Q op,
|
||||||
PatternRewriter& rewriter) const override {
|
PatternRewriter& rewriter) const override {
|
||||||
Type output_type = op.output()->getType();
|
Type output_type = op.output().getType();
|
||||||
auto qtype = QType::getQuantizedElementType(output_type);
|
auto qtype = QType::getQuantizedElementType(output_type);
|
||||||
if (!qtype || qtype.isSigned()) return this->matchFailure();
|
if (!qtype || qtype.isSigned()) return this->matchFailure();
|
||||||
|
|
||||||
|
@ -103,7 +103,7 @@ static int PrintFunctionResultMapping(const std::string &result,
|
|||||||
i = 0;
|
i = 0;
|
||||||
for (auto output : *subgraph->outputs()) {
|
for (auto output : *subgraph->outputs()) {
|
||||||
print_buffer(*subgraph, i, output, [&](int i) {
|
print_buffer(*subgraph, i, output, [&](int i) {
|
||||||
return terminator ? terminator->getOperand(i)->getLoc() : unknown_loc;
|
return terminator ? terminator->getOperand(i).getLoc() : unknown_loc;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -205,7 +205,7 @@ struct OphintCompositeOp {
|
|||||||
Operation* current_identity_op = operand.ops.begin()->second;
|
Operation* current_identity_op = operand.ops.begin()->second;
|
||||||
Value input = current_identity_op->getOperand(0);
|
Value input = current_identity_op->getOperand(0);
|
||||||
RankedTensorType input_type =
|
RankedTensorType input_type =
|
||||||
input->getType().cast<RankedTensorType>();
|
input.getType().cast<RankedTensorType>();
|
||||||
// The Reshape will be {1, (original_shape)}
|
// The Reshape will be {1, (original_shape)}
|
||||||
SmallVector<int64_t, 4> reshape_op_shape;
|
SmallVector<int64_t, 4> reshape_op_shape;
|
||||||
reshape_op_shape.push_back(1);
|
reshape_op_shape.push_back(1);
|
||||||
@ -242,13 +242,13 @@ struct OphintCompositeOp {
|
|||||||
}
|
}
|
||||||
// Find the first op that consumes the last value of the aggregated
|
// Find the first op that consumes the last value of the aggregated
|
||||||
// inputs.
|
// inputs.
|
||||||
Operation* first_use = *(packed_input_consumers.back()->user_begin());
|
Operation* first_use = *(packed_input_consumers.back().user_begin());
|
||||||
// The pack reshape will be {N, (original_shape)}
|
// The pack reshape will be {N, (original_shape)}
|
||||||
SmallVector<int64_t, 4> pack_shape;
|
SmallVector<int64_t, 4> pack_shape;
|
||||||
pack_shape.push_back(pack_input_operands.size());
|
pack_shape.push_back(pack_input_operands.size());
|
||||||
RankedTensorType type = operand.ops.at(0)
|
RankedTensorType type = operand.ops.at(0)
|
||||||
->getResult(0)
|
->getResult(0)
|
||||||
->getType()
|
.getType()
|
||||||
.cast<RankedTensorType>();
|
.cast<RankedTensorType>();
|
||||||
for (const auto& dim : type.getShape()) {
|
for (const auto& dim : type.getShape()) {
|
||||||
pack_shape.push_back(dim);
|
pack_shape.push_back(dim);
|
||||||
@ -290,7 +290,7 @@ struct OphintCompositeOp {
|
|||||||
const int output_numer = operand.ops.size();
|
const int output_numer = operand.ops.size();
|
||||||
Value first_output = operand.ops.at(0)->getOperand(0);
|
Value first_output = operand.ops.at(0)->getOperand(0);
|
||||||
RankedTensorType first_output_type =
|
RankedTensorType first_output_type =
|
||||||
first_output->getType().cast<RankedTensorType>();
|
first_output.getType().cast<RankedTensorType>();
|
||||||
// The aggregated output shape will be {N, original_shape}.
|
// The aggregated output shape will be {N, original_shape}.
|
||||||
SmallVector<int64_t, 4> shape;
|
SmallVector<int64_t, 4> shape;
|
||||||
shape.push_back(output_numer);
|
shape.push_back(output_numer);
|
||||||
@ -302,10 +302,10 @@ struct OphintCompositeOp {
|
|||||||
} else if (operand.aggregation == kStrategyLast) {
|
} else if (operand.aggregation == kStrategyLast) {
|
||||||
Value last_output =
|
Value last_output =
|
||||||
operand.ops.at(operand.ops.size() - 1)->getOperand(0);
|
operand.ops.at(operand.ops.size() - 1)->getOperand(0);
|
||||||
aggregated_output_types[kv.first] = last_output->getType();
|
aggregated_output_types[kv.first] = last_output.getType();
|
||||||
} else {
|
} else {
|
||||||
Value first_output = operand.ops.at(0)->getOperand(0);
|
Value first_output = operand.ops.at(0)->getOperand(0);
|
||||||
aggregated_output_types[kv.first] = first_output->getType();
|
aggregated_output_types[kv.first] = first_output.getType();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return aggregated_output_types;
|
return aggregated_output_types;
|
||||||
@ -329,7 +329,7 @@ struct OphintCompositeOp {
|
|||||||
Operation* first_output = operand.ops.at(0);
|
Operation* first_output = operand.ops.at(0);
|
||||||
Location insert_loc = first_output->getLoc();
|
Location insert_loc = first_output->getLoc();
|
||||||
SmallVector<Type, 4> unpack_output_types(
|
SmallVector<Type, 4> unpack_output_types(
|
||||||
output_number, first_output->getOperand(0)->getType());
|
output_number, first_output->getOperand(0).getType());
|
||||||
|
|
||||||
builder->setInsertionPoint(first_output);
|
builder->setInsertionPoint(first_output);
|
||||||
Operation* unpack_op = builder->create<TFL::UnpackOp>(
|
Operation* unpack_op = builder->create<TFL::UnpackOp>(
|
||||||
@ -404,7 +404,7 @@ void PreprocessTopoSortGraph(
|
|||||||
// should only count as one.
|
// should only count as one.
|
||||||
llvm::DenseSet<Operation*> input_ops;
|
llvm::DenseSet<Operation*> input_ops;
|
||||||
for (int i = 0; i < op.getNumOperands(); ++i) {
|
for (int i = 0; i < op.getNumOperands(); ++i) {
|
||||||
Operation* input_op = op.getOperand(i)->getDefiningOp();
|
Operation* input_op = op.getOperand(i).getDefiningOp();
|
||||||
if (input_op) input_ops.insert(input_op);
|
if (input_op) input_ops.insert(input_op);
|
||||||
}
|
}
|
||||||
if (input_ops.empty()) {
|
if (input_ops.empty()) {
|
||||||
@ -515,7 +515,7 @@ Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type,
|
|||||||
SmallVector<int, 4> input_indexes;
|
SmallVector<int, 4> input_indexes;
|
||||||
for (const auto& kv : inputs) {
|
for (const auto& kv : inputs) {
|
||||||
Value input = kv.second;
|
Value input = kv.second;
|
||||||
input_types.push_back(input->getType());
|
input_types.push_back(input.getType());
|
||||||
input_values.push_back(input);
|
input_values.push_back(input);
|
||||||
input_indexes.push_back(kv.first);
|
input_indexes.push_back(kv.first);
|
||||||
}
|
}
|
||||||
@ -589,7 +589,7 @@ llvm::DenseSet<Operation*> BfsForReachableOps(ArrayRef<Operation*> input_ops) {
|
|||||||
std::queue<Operation*> ops_queue;
|
std::queue<Operation*> ops_queue;
|
||||||
for (auto& input_op : input_ops) {
|
for (auto& input_op : input_ops) {
|
||||||
for (Value value : input_op->getOperands()) {
|
for (Value value : input_op->getOperands()) {
|
||||||
Operation* op = value->getDefiningOp();
|
Operation* op = value.getDefiningOp();
|
||||||
if (op != nullptr) ops_queue.push(op);
|
if (op != nullptr) ops_queue.push(op);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -599,7 +599,7 @@ llvm::DenseSet<Operation*> BfsForReachableOps(ArrayRef<Operation*> input_ops) {
|
|||||||
ops_queue.pop();
|
ops_queue.pop();
|
||||||
reachable_ops.insert(current_op);
|
reachable_ops.insert(current_op);
|
||||||
for (Value value : current_op->getOperands()) {
|
for (Value value : current_op->getOperands()) {
|
||||||
Operation* upstream_op = value->getDefiningOp();
|
Operation* upstream_op = value.getDefiningOp();
|
||||||
// Not visited, put it into the queue.
|
// Not visited, put it into the queue.
|
||||||
if (upstream_op != nullptr &&
|
if (upstream_op != nullptr &&
|
||||||
!llvm::is_contained(reachable_ops, upstream_op)) {
|
!llvm::is_contained(reachable_ops, upstream_op)) {
|
||||||
@ -642,7 +642,7 @@ LogicalResult ConvertOphintToStub(StringRef stub_name,
|
|||||||
aggregated_inputs, aggregated_output_types, builder, module_op);
|
aggregated_inputs, aggregated_output_types, builder, module_op);
|
||||||
|
|
||||||
for (const auto& kv : aggregated_inputs) {
|
for (const auto& kv : aggregated_inputs) {
|
||||||
Operation* op = kv.second->getDefiningOp();
|
Operation* op = kv.second.getDefiningOp();
|
||||||
if (op == nullptr) return failure();
|
if (op == nullptr) return failure();
|
||||||
op->moveBefore(fused_op);
|
op->moveBefore(fused_op);
|
||||||
}
|
}
|
||||||
|
@ -103,7 +103,7 @@ LogicalResult BuildUnidirectionalSequenceRnnOp(FuncOp composite_func_op,
|
|||||||
Value hidden_state = call_op.getOperand(4);
|
Value hidden_state = call_op.getOperand(4);
|
||||||
|
|
||||||
// Build Output.
|
// Build Output.
|
||||||
auto output_type = call_op.getResult(0)->getType();
|
auto output_type = call_op.getResult(0).getType();
|
||||||
|
|
||||||
// Currently, ophinted RNN only supports time_major = True.
|
// Currently, ophinted RNN only supports time_major = True.
|
||||||
const bool time_major = true;
|
const bool time_major = true;
|
||||||
@ -170,11 +170,11 @@ LogicalResult BuildUnidirectionalSequenceLSTMOp(FuncOp composite_func_op,
|
|||||||
for (int i = 0; i < call_op.getNumResults() - 1; ++i) {
|
for (int i = 0; i < call_op.getNumResults() - 1; ++i) {
|
||||||
// This one should not be used.
|
// This one should not be used.
|
||||||
Value unused_output = call_op.getResult(i);
|
Value unused_output = call_op.getResult(i);
|
||||||
if (!unused_output->use_empty()) return failure();
|
if (!unused_output.use_empty()) return failure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
output_types.push_back(
|
output_types.push_back(
|
||||||
call_op.getResult(call_op.getNumResults() - 1)->getType());
|
call_op.getResult(call_op.getNumResults() - 1).getType());
|
||||||
|
|
||||||
// Prepare attributes.
|
// Prepare attributes.
|
||||||
SmallVector<NamedAttribute, 4> attributes;
|
SmallVector<NamedAttribute, 4> attributes;
|
||||||
@ -207,10 +207,10 @@ LogicalResult ConvertTfLiteFusedOpIfAvailable(StringRef func_name,
|
|||||||
composite_func_op, call_op, builder, &fused_op);
|
composite_func_op, call_op, builder, &fused_op);
|
||||||
if (failed(build_fused_op_result)) return build_fused_op_result;
|
if (failed(build_fused_op_result)) return build_fused_op_result;
|
||||||
Value call_output = call_op.getResult(call_op.getNumResults() - 1);
|
Value call_output = call_op.getResult(call_op.getNumResults() - 1);
|
||||||
if (call_output->getType() != fused_op->getResult(0)->getType()) {
|
if (call_output.getType() != fused_op->getResult(0).getType()) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
call_output->replaceAllUsesWith(fused_op->getResult(0));
|
call_output.replaceAllUsesWith(fused_op->getResult(0));
|
||||||
} else { // If we support more fused op, we should add the conversion here.
|
} else { // If we support more fused op, we should add the conversion here.
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
|
|||||||
// Use the tensor type information from $0 and convert min $1, max $2 and
|
// Use the tensor type information from $0 and convert min $1, max $2 and
|
||||||
// numBits $3 and narrowRange $4 to a QuantizedType.
|
// numBits $3 and narrowRange $4 to a QuantizedType.
|
||||||
def ConvertToQuantTypeFromAttrs : NativeCodeCall<
|
def ConvertToQuantTypeFromAttrs : NativeCodeCall<
|
||||||
"GetQuantizedTypeAttr($_builder, $0->getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">;
|
"GetQuantizedTypeAttr($_builder, $0.getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">;
|
||||||
|
|
||||||
// Converts an integer attribute $0 to 32-bit with builder.
|
// Converts an integer attribute $0 to 32-bit with builder.
|
||||||
def convertIntAttrTo32Bit : NativeCodeCall<
|
def convertIntAttrTo32Bit : NativeCodeCall<
|
||||||
@ -50,7 +50,7 @@ def ExtractSingleElementAsInteger : NativeCodeCall<
|
|||||||
"ExtractSingleElementAsInteger($_self.cast<ElementsAttr>())">;
|
"ExtractSingleElementAsInteger($_self.cast<ElementsAttr>())">;
|
||||||
|
|
||||||
// Checks whether the given operation has static shapes and same shapes of all inputs.
|
// Checks whether the given operation has static shapes and same shapes of all inputs.
|
||||||
def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0->getDefiningOp())">;
|
def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">;
|
||||||
def HasSameStaticShapes : Constraint<HasSameStaticShapesPred, "op must have static same input shapes">;
|
def HasSameStaticShapes : Constraint<HasSameStaticShapesPred, "op must have static same input shapes">;
|
||||||
def HasNotSameStaticShapes : Constraint<Neg<HasSameStaticShapesPred>, "op must have not static same input shapes">;
|
def HasNotSameStaticShapes : Constraint<Neg<HasSameStaticShapesPred>, "op must have not static same input shapes">;
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ bool HasSameStaticShapes(Operation* op) {
|
|||||||
int index = 0;
|
int index = 0;
|
||||||
ArrayRef<int64_t> shape;
|
ArrayRef<int64_t> shape;
|
||||||
for (Value value : values) {
|
for (Value value : values) {
|
||||||
auto shaped_type = value->getType().dyn_cast<ShapedType>();
|
auto shaped_type = value.getType().dyn_cast<ShapedType>();
|
||||||
if (!shaped_type && !shaped_type.hasStaticShape()) {
|
if (!shaped_type && !shaped_type.hasStaticShape()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -122,7 +122,7 @@ PatternMatchResult ConvertTFConcatOp::matchAndRewrite(
|
|||||||
auto tf_concat_op = cast<TF::ConcatOp>(op);
|
auto tf_concat_op = cast<TF::ConcatOp>(op);
|
||||||
|
|
||||||
auto values = tf_concat_op.values();
|
auto values = tf_concat_op.values();
|
||||||
auto output_type = tf_concat_op.output()->getType();
|
auto output_type = tf_concat_op.output().getType();
|
||||||
// Extract axis attribute from constant concat_dims tensor
|
// Extract axis attribute from constant concat_dims tensor
|
||||||
ElementsAttr axis;
|
ElementsAttr axis;
|
||||||
if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis)))
|
if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis)))
|
||||||
@ -141,7 +141,7 @@ PatternMatchResult ConvertTFConcatV2Op::matchAndRewrite(
|
|||||||
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
|
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
|
||||||
|
|
||||||
auto values = tf_concat_op.values();
|
auto values = tf_concat_op.values();
|
||||||
auto output_type = tf_concat_op.output()->getType();
|
auto output_type = tf_concat_op.output().getType();
|
||||||
// Extract axis attribute from constant axis tensor
|
// Extract axis attribute from constant axis tensor
|
||||||
ElementsAttr axis;
|
ElementsAttr axis;
|
||||||
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis)))
|
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis)))
|
||||||
@ -167,7 +167,7 @@ PatternMatchResult ConvertTFMatMulOp::matchAndRewrite(
|
|||||||
if (tf_matmul_op.transpose_a()) return matchFailure();
|
if (tf_matmul_op.transpose_a()) return matchFailure();
|
||||||
if (!tf_matmul_op.transpose_b()) return matchFailure();
|
if (!tf_matmul_op.transpose_b()) return matchFailure();
|
||||||
|
|
||||||
Type output_type = tf_matmul_op.getResult()->getType();
|
Type output_type = tf_matmul_op.getResult().getType();
|
||||||
// TODO(jpienaar): Follow up post shuffle discussion.
|
// TODO(jpienaar): Follow up post shuffle discussion.
|
||||||
auto no_input = rewriter.create<ConstantOp>(
|
auto no_input = rewriter.create<ConstantOp>(
|
||||||
op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
|
op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr());
|
||||||
@ -184,7 +184,7 @@ PatternMatchResult ConvertTFPackOp::matchAndRewrite(
|
|||||||
auto tf_pack_op = cast<TF::PackOp>(op);
|
auto tf_pack_op = cast<TF::PackOp>(op);
|
||||||
|
|
||||||
SmallVector<Value, 4> values(tf_pack_op.values());
|
SmallVector<Value, 4> values(tf_pack_op.values());
|
||||||
auto output_type = tf_pack_op.output()->getType();
|
auto output_type = tf_pack_op.output().getType();
|
||||||
auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N());
|
auto values_count = rewriter.getI32IntegerAttr(tf_pack_op.N());
|
||||||
// Axis can be negative.
|
// Axis can be negative.
|
||||||
auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis().getSExtValue());
|
auto axis = rewriter.getI32IntegerAttr(tf_pack_op.axis().getSExtValue());
|
||||||
@ -201,7 +201,7 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
|
|||||||
auto input = tf_reshape_op.tensor();
|
auto input = tf_reshape_op.tensor();
|
||||||
auto shape = tf_reshape_op.shape();
|
auto shape = tf_reshape_op.shape();
|
||||||
|
|
||||||
ShapedType shape_type = shape->getType().cast<ShapedType>();
|
ShapedType shape_type = shape.getType().cast<ShapedType>();
|
||||||
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
|
// The tfl reshape's #2 operand needs to i32 tensor type, so we have to cast.
|
||||||
if (!shape_type.getElementType().isInteger(32)) {
|
if (!shape_type.getElementType().isInteger(32)) {
|
||||||
auto new_shape = shape_type.getShape();
|
auto new_shape = shape_type.getShape();
|
||||||
@ -213,7 +213,7 @@ PatternMatchResult ConvertTFReshapeOp::matchAndRewrite(
|
|||||||
rewriter.getBoolAttr(false))
|
rewriter.getBoolAttr(false))
|
||||||
.y();
|
.y();
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output()->getType(),
|
rewriter.replaceOpWithNewOp<ReshapeOp>(op, tf_reshape_op.output().getType(),
|
||||||
input, shape);
|
input, shape);
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
@ -222,7 +222,7 @@ PatternMatchResult ConvertTFSplitOp::matchAndRewrite(
|
|||||||
Operation* op, PatternRewriter& rewriter) const {
|
Operation* op, PatternRewriter& rewriter) const {
|
||||||
auto tf_split_op = cast<TF::SplitOp>(op);
|
auto tf_split_op = cast<TF::SplitOp>(op);
|
||||||
|
|
||||||
auto output_types = functional::map([](Value v) { return v->getType(); },
|
auto output_types = functional::map([](Value v) { return v.getType(); },
|
||||||
tf_split_op.output());
|
tf_split_op.output());
|
||||||
// Number of splits cannot be negative.
|
// Number of splits cannot be negative.
|
||||||
auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split());
|
auto num_split = rewriter.getI32IntegerAttr(tf_split_op.num_split());
|
||||||
@ -237,7 +237,7 @@ PatternMatchResult ConvertTFSplitVOp::matchAndRewrite(
|
|||||||
Operation* op, PatternRewriter& rewriter) const {
|
Operation* op, PatternRewriter& rewriter) const {
|
||||||
auto tf_splitv_op = cast<TF::SplitVOp>(op);
|
auto tf_splitv_op = cast<TF::SplitVOp>(op);
|
||||||
|
|
||||||
auto output_types = functional::map([](Value v) { return v->getType(); },
|
auto output_types = functional::map([](Value v) { return v.getType(); },
|
||||||
tf_splitv_op.output());
|
tf_splitv_op.output());
|
||||||
// Number of splits cannot be negative.
|
// Number of splits cannot be negative.
|
||||||
auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split());
|
auto num_split = rewriter.getI32IntegerAttr(tf_splitv_op.num_split());
|
||||||
@ -254,7 +254,7 @@ Value PadStridedSliceAttributeArray(Operation* op, PatternRewriter& rewriter,
|
|||||||
DenseIntElementsAttr dense_elem_attr;
|
DenseIntElementsAttr dense_elem_attr;
|
||||||
SmallVector<int32_t, 8> padded_val;
|
SmallVector<int32_t, 8> padded_val;
|
||||||
|
|
||||||
auto ranked_attr_type = attribute->getType().dyn_cast<RankedTensorType>();
|
auto ranked_attr_type = attribute.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!ranked_attr_type ||
|
if (!ranked_attr_type ||
|
||||||
!matchPattern(attribute, m_Constant(&dense_elem_attr))) {
|
!matchPattern(attribute, m_Constant(&dense_elem_attr))) {
|
||||||
// If the input attribute is neither ranked type nor constant, we
|
// If the input attribute is neither ranked type nor constant, we
|
||||||
@ -280,14 +280,14 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
|
|||||||
Operation* op, PatternRewriter& rewriter) const {
|
Operation* op, PatternRewriter& rewriter) const {
|
||||||
auto tf_strided_slice_op = cast<TF::StridedSliceOp>(op);
|
auto tf_strided_slice_op = cast<TF::StridedSliceOp>(op);
|
||||||
auto ranked_input_type =
|
auto ranked_input_type =
|
||||||
tf_strided_slice_op.input()->getType().dyn_cast<RankedTensorType>();
|
tf_strided_slice_op.input().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!ranked_input_type) {
|
if (!ranked_input_type) {
|
||||||
// If input is not a ranked tensor, we can't deduce the padding dimensions
|
// If input is not a ranked tensor, we can't deduce the padding dimensions
|
||||||
// from it, so we just do a plain conversion here.
|
// from it, so we just do a plain conversion here.
|
||||||
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
|
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
|
||||||
op, tf_strided_slice_op.output()->getType(),
|
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
|
||||||
tf_strided_slice_op.input(), tf_strided_slice_op.begin(),
|
tf_strided_slice_op.begin(), tf_strided_slice_op.end(),
|
||||||
tf_strided_slice_op.end(), tf_strided_slice_op.strides(),
|
tf_strided_slice_op.strides(),
|
||||||
rewriter.getI32IntegerAttr(
|
rewriter.getI32IntegerAttr(
|
||||||
tf_strided_slice_op.begin_mask().getSExtValue()),
|
tf_strided_slice_op.begin_mask().getSExtValue()),
|
||||||
rewriter.getI32IntegerAttr(
|
rewriter.getI32IntegerAttr(
|
||||||
@ -318,7 +318,7 @@ PatternMatchResult ConvertTFStridedSliceOp::matchAndRewrite(
|
|||||||
Value padded_strides = PadStridedSliceAttributeArray(
|
Value padded_strides = PadStridedSliceAttributeArray(
|
||||||
op, rewriter, tf_strided_slice_op.strides(), strides_pad_val, nullptr);
|
op, rewriter, tf_strided_slice_op.strides(), strides_pad_val, nullptr);
|
||||||
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
|
rewriter.replaceOpWithNewOp<TFL::StridedSliceOp>(
|
||||||
op, tf_strided_slice_op.output()->getType(), tf_strided_slice_op.input(),
|
op, tf_strided_slice_op.output().getType(), tf_strided_slice_op.input(),
|
||||||
padded_begin, padded_end, padded_strides,
|
padded_begin, padded_end, padded_strides,
|
||||||
rewriter.getI32IntegerAttr(begin_mask),
|
rewriter.getI32IntegerAttr(begin_mask),
|
||||||
rewriter.getI32IntegerAttr(end_mask),
|
rewriter.getI32IntegerAttr(end_mask),
|
||||||
@ -336,7 +336,7 @@ PatternMatchResult ConvertTFUnpackOp::matchAndRewrite(
|
|||||||
auto tf_unpack_op = cast<TF::UnpackOp>(op);
|
auto tf_unpack_op = cast<TF::UnpackOp>(op);
|
||||||
|
|
||||||
auto input = tf_unpack_op.value();
|
auto input = tf_unpack_op.value();
|
||||||
auto output_types = functional::map([](Value v) { return v->getType(); },
|
auto output_types = functional::map([](Value v) { return v.getType(); },
|
||||||
tf_unpack_op.output());
|
tf_unpack_op.output());
|
||||||
auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num());
|
auto num = rewriter.getI32IntegerAttr(tf_unpack_op.num());
|
||||||
// Axis can be negative.
|
// Axis can be negative.
|
||||||
@ -360,7 +360,7 @@ bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) {
|
|||||||
if (tf_matrix_diag_v2_or_v3_op.getNumOperands() != 5) return false;
|
if (tf_matrix_diag_v2_or_v3_op.getNumOperands() != 5) return false;
|
||||||
|
|
||||||
auto input = tf_matrix_diag_v2_or_v3_op.diagonal();
|
auto input = tf_matrix_diag_v2_or_v3_op.diagonal();
|
||||||
auto output_type = tf_matrix_diag_v2_or_v3_op.output()->getType();
|
auto output_type = tf_matrix_diag_v2_or_v3_op.output().getType();
|
||||||
|
|
||||||
// Extract k constant tensor and check value = 0.
|
// Extract k constant tensor and check value = 0.
|
||||||
ElementsAttr k;
|
ElementsAttr k;
|
||||||
@ -500,7 +500,7 @@ PatternMatchResult ConvertTFReciprocalOp::matchAndRewrite(
|
|||||||
|
|
||||||
auto status_or_const_op = CreateConstOpWithSingleValue(
|
auto status_or_const_op = CreateConstOpWithSingleValue(
|
||||||
&rewriter, op->getLoc(),
|
&rewriter, op->getLoc(),
|
||||||
tf_reciprocal_op.x()->getType().cast<ShapedType>(), 1);
|
tf_reciprocal_op.x().getType().cast<ShapedType>(), 1);
|
||||||
if (!status_or_const_op.ok()) {
|
if (!status_or_const_op.ok()) {
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
}
|
}
|
||||||
|
@ -71,7 +71,7 @@ struct LoadQuantizationRecipe : public FunctionPass<LoadQuantizationRecipe> {
|
|||||||
|
|
||||||
void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
|
void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
|
||||||
Type expressed_type =
|
Type expressed_type =
|
||||||
lstm.input()->getType().cast<ShapedType>().getElementType();
|
lstm.input().getType().cast<ShapedType>().getElementType();
|
||||||
Type int8_storage_type = builder->getIntegerType(8);
|
Type int8_storage_type = builder->getIntegerType(8);
|
||||||
Type int16_storage_type = builder->getIntegerType(16);
|
Type int16_storage_type = builder->getIntegerType(16);
|
||||||
auto flag = quant::QuantizationFlags::FlagValue::Signed;
|
auto flag = quant::QuantizationFlags::FlagValue::Signed;
|
||||||
@ -88,8 +88,8 @@ void LoadQuantizationRecipe::Initialize(LSTMOp lstm, OpBuilder* builder) {
|
|||||||
auto any_int16 = quant::AnyQuantizedType::get(
|
auto any_int16 = quant::AnyQuantizedType::get(
|
||||||
flag, int16_storage_type, expressed_type, int16_min, int16_max);
|
flag, int16_storage_type, expressed_type, int16_min, int16_max);
|
||||||
|
|
||||||
int8 = any_int8.castFromExpressedType(lstm.input()->getType());
|
int8 = any_int8.castFromExpressedType(lstm.input().getType());
|
||||||
int16 = any_int16.castFromExpressedType(lstm.input()->getType());
|
int16 = any_int16.castFromExpressedType(lstm.input().getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value in,
|
Operation* LoadQuantizationRecipe::CreateLayerNorm(Location loc, Value in,
|
||||||
|
@ -196,13 +196,13 @@ struct ConvertTensorListSetItem : public ConversionPattern {
|
|||||||
// Calculate `index` + 1, which is used to generate the start position for
|
// Calculate `index` + 1, which is used to generate the start position for
|
||||||
// the second slice op.
|
// the second slice op.
|
||||||
auto suffix_start =
|
auto suffix_start =
|
||||||
rewriter.create<TF::AddOp>(loc, index->getType(), index,
|
rewriter.create<TF::AddOp>(loc, index.getType(), index,
|
||||||
CreateI32SplatConst(loc, &rewriter, {}, 1));
|
CreateI32SplatConst(loc, &rewriter, {}, 1));
|
||||||
|
|
||||||
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
|
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
|
||||||
loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero);
|
loc, RankedTensorType::get({1}, shape_dtype), item_rank, scalar_zero);
|
||||||
// Create two slice ops.
|
// Create two slice ops.
|
||||||
Type element_type = input->getType().cast<TensorType>().getElementType();
|
Type element_type = input.getType().cast<TensorType>().getElementType();
|
||||||
UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
|
UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type);
|
||||||
Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
|
Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1);
|
||||||
TF::SliceOp slice1 =
|
TF::SliceOp slice1 =
|
||||||
@ -225,7 +225,7 @@ struct ConvertTensorListSetItem : public ConversionPattern {
|
|||||||
|
|
||||||
// Concatenate three parts together to generate the final result.
|
// Concatenate three parts together to generate the final result.
|
||||||
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
|
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
|
||||||
op, input->getType(), scalar_zero,
|
op, input.getType(), scalar_zero,
|
||||||
ArrayRef<Value>({slice1, expanded_item, slice2}));
|
ArrayRef<Value>({slice1, expanded_item, slice2}));
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
@ -264,7 +264,7 @@ struct ConvertTensorListInitOp : public ConversionPattern {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Value element_shape = operands[0];
|
Value element_shape = operands[0];
|
||||||
Type shape_dtype = getElementTypeOrSelf(element_shape->getType());
|
Type shape_dtype = getElementTypeOrSelf(element_shape.getType());
|
||||||
|
|
||||||
DenseIntElementsAttr dense_elem_attr;
|
DenseIntElementsAttr dense_elem_attr;
|
||||||
if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
|
if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
|
||||||
@ -297,11 +297,10 @@ struct ConvertTensorListInitOp : public ConversionPattern {
|
|||||||
new_element_shape_values.push_back(dim_value);
|
new_element_shape_values.push_back(dim_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto attr =
|
auto attr = DenseIntElementsAttr::get(
|
||||||
DenseIntElementsAttr::get(element_shape->getType().cast<ShapedType>(),
|
element_shape.getType().cast<ShapedType>(), new_element_shape_values);
|
||||||
new_element_shape_values);
|
|
||||||
auto new_element_shape = rewriter.create<ConstantOp>(
|
auto new_element_shape = rewriter.create<ConstantOp>(
|
||||||
op.getLoc(), element_shape->getType(), attr);
|
op.getLoc(), element_shape.getType(), attr);
|
||||||
element_shape = new_element_shape;
|
element_shape = new_element_shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -355,7 +354,7 @@ struct ConvertTensorListReserve
|
|||||||
Value GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value> operands,
|
Value GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value> operands,
|
||||||
PatternRewriter *rewriter) const override {
|
PatternRewriter *rewriter) const override {
|
||||||
Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
|
Value scalar_zero = CreateI32SplatConst(op.getLoc(), rewriter, {}, 0);
|
||||||
Type shape_dtype = getElementTypeOrSelf(op.element_shape()->getType());
|
Type shape_dtype = getElementTypeOrSelf(op.element_shape().getType());
|
||||||
Value num_elements = operands[1];
|
Value num_elements = operands[1];
|
||||||
return rewriter->create<TF::ExpandDimsOp>(
|
return rewriter->create<TF::ExpandDimsOp>(
|
||||||
op.getLoc(), RankedTensorType::get({1}, shape_dtype), num_elements,
|
op.getLoc(), RankedTensorType::get({1}, shape_dtype), num_elements,
|
||||||
@ -392,14 +391,14 @@ struct ConvertTensorListPushBack : public ConversionPattern {
|
|||||||
// Expand the shape of the item so that it will have rank same as the input
|
// Expand the shape of the item so that it will have rank same as the input
|
||||||
// tensor and it is compatible for the Concat Op.
|
// tensor and it is compatible for the Concat Op.
|
||||||
Type expanded_item_type =
|
Type expanded_item_type =
|
||||||
PrependLeadingDimIfRanked(1, item->getType(), &rewriter);
|
PrependLeadingDimIfRanked(1, item.getType(), &rewriter);
|
||||||
Value scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0);
|
Value scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0);
|
||||||
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
|
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
|
||||||
op->getLoc(), expanded_item_type, item, scalar_zero);
|
op->getLoc(), expanded_item_type, item, scalar_zero);
|
||||||
|
|
||||||
Type elem_type = getElementTypeOrSelf(item);
|
Type elem_type = getElementTypeOrSelf(item);
|
||||||
auto handle_dtype =
|
auto handle_dtype =
|
||||||
getElementTypeOrSelf(push_back_op.output_handle()->getType())
|
getElementTypeOrSelf(push_back_op.output_handle().getType())
|
||||||
.cast<TF::VariantType>();
|
.cast<TF::VariantType>();
|
||||||
Type result_type =
|
Type result_type =
|
||||||
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
|
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
|
||||||
@ -446,7 +445,7 @@ struct ConvertTensorListResize : public ConversionPattern {
|
|||||||
// Infer result type of this op based on TF's shape inference result.
|
// Infer result type of this op based on TF's shape inference result.
|
||||||
Type elem_type = getElementTypeOrSelf(input_handle);
|
Type elem_type = getElementTypeOrSelf(input_handle);
|
||||||
auto handle_dtype =
|
auto handle_dtype =
|
||||||
getElementTypeOrSelf(resize_op.output_handle()->getType())
|
getElementTypeOrSelf(resize_op.output_handle().getType())
|
||||||
.cast<TF::VariantType>();
|
.cast<TF::VariantType>();
|
||||||
Type result_type =
|
Type result_type =
|
||||||
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
|
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
|
||||||
@ -463,8 +462,8 @@ struct ConvertTensorListResize : public ConversionPattern {
|
|||||||
auto input_shape = rewriter.create<TF::ShapeOp>(
|
auto input_shape = rewriter.create<TF::ShapeOp>(
|
||||||
loc, RankedTensorType::get({-1}, shape_dtype), input_handle);
|
loc, RankedTensorType::get({-1}, shape_dtype), input_handle);
|
||||||
|
|
||||||
Type branch_args_type[] = {input_handle->getType(), input_shape.getType(),
|
Type branch_args_type[] = {input_handle.getType(), input_shape.getType(),
|
||||||
size_diff.getType(), size->getType()};
|
size_diff.getType(), size.getType()};
|
||||||
Type branch_result_type[] = {result_type};
|
Type branch_result_type[] = {result_type};
|
||||||
auto func_type = FunctionType::get(branch_args_type, branch_result_type,
|
auto func_type = FunctionType::get(branch_args_type, branch_result_type,
|
||||||
rewriter.getContext());
|
rewriter.getContext());
|
||||||
@ -524,7 +523,7 @@ struct ConvertTensorListResize : public ConversionPattern {
|
|||||||
loc, RankedTensorType::get({-1}, shape_dtype), input_shape, slice_start,
|
loc, RankedTensorType::get({-1}, shape_dtype), input_shape, slice_start,
|
||||||
slice_size);
|
slice_size);
|
||||||
auto extended_part = rewriter->create<TF::TensorListReserveOp>(
|
auto extended_part = rewriter->create<TF::TensorListReserveOp>(
|
||||||
loc, resize_op.output_handle()->getType(), elem_shape, size_diff);
|
loc, resize_op.output_handle().getType(), elem_shape, size_diff);
|
||||||
// `ConcatOp` expects non-variant-typed input. Insert a
|
// `ConcatOp` expects non-variant-typed input. Insert a
|
||||||
// `TensorListStackOp` here to convert type from variant to non-variant.
|
// `TensorListStackOp` here to convert type from variant to non-variant.
|
||||||
// Note that we are using the same `result_type` for both the
|
// Note that we are using the same `result_type` for both the
|
||||||
@ -627,7 +626,7 @@ struct ConvertTensorListStack : public ConversionPattern {
|
|||||||
// trivial Reshape op (that doesn't actually change the input's shape) and
|
// trivial Reshape op (that doesn't actually change the input's shape) and
|
||||||
// also populate the shape info to the op result. The shape of the
|
// also populate the shape info to the op result. The shape of the
|
||||||
// tensorlist is inferred from `num_elements` and `element_shape`.
|
// tensorlist is inferred from `num_elements` and `element_shape`.
|
||||||
auto ranked_type = element_shape->getType().dyn_cast<RankedTensorType>();
|
auto ranked_type = element_shape.getType().dyn_cast<RankedTensorType>();
|
||||||
DenseIntElementsAttr dense_elem_attr;
|
DenseIntElementsAttr dense_elem_attr;
|
||||||
if ((ranked_type && ranked_type.getRank() == 0) ||
|
if ((ranked_type && ranked_type.getRank() == 0) ||
|
||||||
!matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
|
!matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
|
||||||
@ -659,7 +658,7 @@ struct ConvertIdentity : public ConversionPattern {
|
|||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto op = llvm::cast<TF::IdentityOp>(operation);
|
auto op = llvm::cast<TF::IdentityOp>(operation);
|
||||||
Value input = operands[0];
|
Value input = operands[0];
|
||||||
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input->getType(), operands,
|
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands,
|
||||||
op.getAttrs());
|
op.getAttrs());
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
@ -687,7 +686,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
|||||||
Type arg_type = func_type.getInput(i);
|
Type arg_type = func_type.getInput(i);
|
||||||
if (getElementTypeOrSelf(arg_type).isa<TF::VariantType>()) {
|
if (getElementTypeOrSelf(arg_type).isa<TF::VariantType>()) {
|
||||||
arg_type = UnrankedTensorType::get(
|
arg_type = UnrankedTensorType::get(
|
||||||
getElementTypeOrSelf(op.getOperand(i)->getType()));
|
getElementTypeOrSelf(op.getOperand(i).getType()));
|
||||||
}
|
}
|
||||||
updated_argument_types.push_back(arg_type);
|
updated_argument_types.push_back(arg_type);
|
||||||
}
|
}
|
||||||
@ -703,7 +702,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
|||||||
// from the corresponding input operand. This is correct because while
|
// from the corresponding input operand. This is correct because while
|
||||||
// body's inputs and results have the same type.
|
// body's inputs and results have the same type.
|
||||||
result_type = UnrankedTensorType::get(
|
result_type = UnrankedTensorType::get(
|
||||||
getElementTypeOrSelf(op.getOperand(i)->getType()));
|
getElementTypeOrSelf(op.getOperand(i).getType()));
|
||||||
}
|
}
|
||||||
updated_result_types.push_back(result_type);
|
updated_result_types.push_back(result_type);
|
||||||
}
|
}
|
||||||
@ -717,7 +716,7 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
|||||||
// Change the argument type for the first block.
|
// Change the argument type for the first block.
|
||||||
Block &body_first_bb = func.front();
|
Block &body_first_bb = func.front();
|
||||||
for (int i = 0; i < body_first_bb.getNumArguments(); ++i) {
|
for (int i = 0; i < body_first_bb.getNumArguments(); ++i) {
|
||||||
body_first_bb.getArgument(i)->setType(updated_argument_types[i]);
|
body_first_bb.getArgument(i).setType(updated_argument_types[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
@ -735,12 +734,12 @@ struct ConvertWhile : public ConversionPattern {
|
|||||||
llvm::SmallVector<Type, 8> result_types;
|
llvm::SmallVector<Type, 8> result_types;
|
||||||
result_types.reserve(op.getNumOperands());
|
result_types.reserve(op.getNumOperands());
|
||||||
for (int i = 0, e = operands.size(); i != e; ++i) {
|
for (int i = 0, e = operands.size(); i != e; ++i) {
|
||||||
Type result_ty = op.getResult(i)->getType();
|
Type result_ty = op.getResult(i).getType();
|
||||||
|
|
||||||
// If we notice the result type is a DT_VARIANT, we change the
|
// If we notice the result type is a DT_VARIANT, we change the
|
||||||
// corresponding result type to unranked tensor type.
|
// corresponding result type to unranked tensor type.
|
||||||
if (getElementTypeOrSelf(result_ty).isa<TF::VariantType>()) {
|
if (getElementTypeOrSelf(result_ty).isa<TF::VariantType>()) {
|
||||||
Type element_ty = getElementTypeOrSelf(operands[i]->getType());
|
Type element_ty = getElementTypeOrSelf(operands[i].getType());
|
||||||
result_ty = UnrankedTensorType::get(element_ty);
|
result_ty = UnrankedTensorType::get(element_ty);
|
||||||
}
|
}
|
||||||
result_types.push_back(result_ty);
|
result_types.push_back(result_ty);
|
||||||
|
@ -51,15 +51,15 @@ namespace TFL {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
|
bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) {
|
||||||
if (sq_op->getType().cast<ShapedType>().getRank() - 1 ==
|
if (sq_op.getType().cast<ShapedType>().getRank() - 1 ==
|
||||||
*axis.getValues<int>().begin() ||
|
*axis.getValues<int>().begin() ||
|
||||||
*axis.getValues<int>().begin() == -1) {
|
*axis.getValues<int>().begin() == -1) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (sq_op->getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
|
if (sq_op.getType().cast<ShapedType>().getRank() != axis.getNumElements()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto shape = sq_op->getType().cast<ShapedType>();
|
auto shape = sq_op.getType().cast<ShapedType>();
|
||||||
SmallVector<int, 4> elems{axis.getValues<int>().begin(),
|
SmallVector<int, 4> elems{axis.getValues<int>().begin(),
|
||||||
axis.getValues<int>().end()};
|
axis.getValues<int>().end()};
|
||||||
for (int i = 0; i < shape.getRank(); ++i) {
|
for (int i = 0; i < shape.getRank(); ++i) {
|
||||||
@ -143,7 +143,7 @@ ElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) {
|
|||||||
// Returns shape of a ranked tensor.
|
// Returns shape of a ranked tensor.
|
||||||
// Precondition: output_val's is ranked tensor.
|
// Precondition: output_val's is ranked tensor.
|
||||||
DenseElementsAttr GetShape(Value output_val) {
|
DenseElementsAttr GetShape(Value output_val) {
|
||||||
auto output_type = output_val->getType().cast<RankedTensorType>();
|
auto output_type = output_val.getType().cast<RankedTensorType>();
|
||||||
auto shape_vector = output_type.getShape();
|
auto shape_vector = output_type.getShape();
|
||||||
std::vector<int32_t> shape(shape_vector.size());
|
std::vector<int32_t> shape(shape_vector.size());
|
||||||
for (int i = 0; i < shape_vector.size(); ++i) {
|
for (int i = 0; i < shape_vector.size(); ++i) {
|
||||||
@ -152,7 +152,7 @@ DenseElementsAttr GetShape(Value output_val) {
|
|||||||
return mlir::DenseElementsAttr::get(
|
return mlir::DenseElementsAttr::get(
|
||||||
RankedTensorType::get(
|
RankedTensorType::get(
|
||||||
{static_cast<int>(shape.size())},
|
{static_cast<int>(shape.size())},
|
||||||
mlir::IntegerType::get(32, output_val->getContext())),
|
mlir::IntegerType::get(32, output_val.getContext())),
|
||||||
llvm::makeArrayRef(shape));
|
llvm::makeArrayRef(shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -173,13 +173,13 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern<TFL::AddOp> {
|
|||||||
|
|
||||||
// Fully Connected.
|
// Fully Connected.
|
||||||
auto fc_op =
|
auto fc_op =
|
||||||
dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs()->getDefiningOp());
|
dyn_cast_or_null<TFL::FullyConnectedOp>(add_op.lhs().getDefiningOp());
|
||||||
if (!fc_op) return matchFailure();
|
if (!fc_op) return matchFailure();
|
||||||
|
|
||||||
Value filter = fc_op.filter();
|
Value filter = fc_op.filter();
|
||||||
Value bias = fc_op.bias();
|
Value bias = fc_op.bias();
|
||||||
ElementsAttr bias_value;
|
ElementsAttr bias_value;
|
||||||
const bool is_none_bias = bias->getType().isa<NoneType>();
|
const bool is_none_bias = bias.getType().isa<NoneType>();
|
||||||
if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
|
if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value)))
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
|
if (fc_op.fused_activation_function() != "NONE") return matchFailure();
|
||||||
@ -213,7 +213,7 @@ struct FuseFullyConnectedAndRelu : public OpRewritePattern<TFL::ReluOp> {
|
|||||||
|
|
||||||
PatternMatchResult matchAndRewrite(TFL::ReluOp relu_op,
|
PatternMatchResult matchAndRewrite(TFL::ReluOp relu_op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Operation *input = relu_op.getOperand()->getDefiningOp();
|
Operation *input = relu_op.getOperand().getDefiningOp();
|
||||||
if (!isa_and_nonnull<FullyConnectedOp>(input)) return matchFailure();
|
if (!isa_and_nonnull<FullyConnectedOp>(input)) return matchFailure();
|
||||||
auto fully_connected_op = cast<FullyConnectedOp>(input);
|
auto fully_connected_op = cast<FullyConnectedOp>(input);
|
||||||
if (fully_connected_op.fused_activation_function() != "NONE")
|
if (fully_connected_op.fused_activation_function() != "NONE")
|
||||||
@ -247,13 +247,13 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
|||||||
|
|
||||||
// Fully Connected.
|
// Fully Connected.
|
||||||
auto fc_op =
|
auto fc_op =
|
||||||
dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs()->getDefiningOp());
|
dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs().getDefiningOp());
|
||||||
if (!fc_op) return matchFailure();
|
if (!fc_op) return matchFailure();
|
||||||
Value filter = fc_op.filter();
|
Value filter = fc_op.filter();
|
||||||
Value bias = fc_op.bias();
|
Value bias = fc_op.bias();
|
||||||
ElementsAttr cst_tmp;
|
ElementsAttr cst_tmp;
|
||||||
if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure();
|
if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure();
|
||||||
if (!bias->getType().isa<NoneType>() &&
|
if (!bias.getType().isa<NoneType>() &&
|
||||||
!matchPattern(bias, m_Constant(&cst_tmp)))
|
!matchPattern(bias, m_Constant(&cst_tmp)))
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
if (fc_op.fused_activation_function().equals("None")) return matchFailure();
|
if (fc_op.fused_activation_function().equals("None")) return matchFailure();
|
||||||
@ -262,7 +262,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
|||||||
// filter input. We only support broadcasting the operand along the depth
|
// filter input. We only support broadcasting the operand along the depth
|
||||||
// dimension, when the operand's depth is 1.
|
// dimension, when the operand's depth is 1.
|
||||||
Value new_const_val = constant_val;
|
Value new_const_val = constant_val;
|
||||||
if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter->getType())) {
|
if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter.getType())) {
|
||||||
auto original_shape = cst.getType().getShape();
|
auto original_shape = cst.getType().getShape();
|
||||||
llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(),
|
llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(),
|
||||||
original_shape.end());
|
original_shape.end());
|
||||||
@ -270,7 +270,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
|||||||
auto new_cst = cst.reshape(RankedTensorType::get(
|
auto new_cst = cst.reshape(RankedTensorType::get(
|
||||||
normalized_shape, cst.getType().getElementType()));
|
normalized_shape, cst.getType().getElementType()));
|
||||||
Type new_type = new_cst.getType();
|
Type new_type = new_cst.getType();
|
||||||
if (!IsBroadcastableElementsAttrAndType(new_type, filter->getType())) {
|
if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
}
|
}
|
||||||
auto new_op =
|
auto new_op =
|
||||||
@ -285,7 +285,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
|
|||||||
auto new_filter =
|
auto new_filter =
|
||||||
rewriter.create<TF::MulOp>(loc, filter, new_const_val).z();
|
rewriter.create<TF::MulOp>(loc, filter, new_const_val).z();
|
||||||
// If bias isn't None, it needs to be multiplied as well.
|
// If bias isn't None, it needs to be multiplied as well.
|
||||||
if (!bias->getType().isa<NoneType>()) {
|
if (!bias.getType().isa<NoneType>()) {
|
||||||
bias = rewriter.create<TF::MulOp>(loc, bias, constant_val).z();
|
bias = rewriter.create<TF::MulOp>(loc, bias, constant_val).z();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -311,7 +311,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
|||||||
PatternMatchResult matchAndRewrite(AffineOpType fc_op,
|
PatternMatchResult matchAndRewrite(AffineOpType fc_op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
// Binary op.
|
// Binary op.
|
||||||
Operation *binary_op = fc_op.input()->getDefiningOp();
|
Operation *binary_op = fc_op.input().getDefiningOp();
|
||||||
if (!binary_op || binary_op->getNumOperands() != 2)
|
if (!binary_op || binary_op->getNumOperands() != 2)
|
||||||
return this->matchFailure();
|
return this->matchFailure();
|
||||||
// We only handle the cases the RHS is a scalar.
|
// We only handle the cases the RHS is a scalar.
|
||||||
@ -330,15 +330,15 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
|||||||
DenseFPElementsAttr filter_cst, bias_cst;
|
DenseFPElementsAttr filter_cst, bias_cst;
|
||||||
if (!matchPattern(filter, m_Constant(&filter_cst))) {
|
if (!matchPattern(filter, m_Constant(&filter_cst))) {
|
||||||
// The filter maybe quantized, then we should set it to the real constant.
|
// The filter maybe quantized, then we should set it to the real constant.
|
||||||
auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter->getDefiningOp());
|
auto dq = llvm::dyn_cast_or_null<DequantizeOp>(filter.getDefiningOp());
|
||||||
if (!dq) return this->matchFailure();
|
if (!dq) return this->matchFailure();
|
||||||
auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input()->getDefiningOp());
|
auto q = llvm::dyn_cast_or_null<QuantizeOp>(dq.input().getDefiningOp());
|
||||||
if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) {
|
if (!q || !matchPattern(q.input(), m_Constant(&filter_cst))) {
|
||||||
return this->matchFailure();
|
return this->matchFailure();
|
||||||
}
|
}
|
||||||
filter = q.input();
|
filter = q.input();
|
||||||
}
|
}
|
||||||
if (!bias->getType().isa<NoneType>() &&
|
if (!bias.getType().isa<NoneType>() &&
|
||||||
!matchPattern(bias, m_Constant(&bias_cst)))
|
!matchPattern(bias, m_Constant(&bias_cst)))
|
||||||
return this->matchFailure();
|
return this->matchFailure();
|
||||||
ShapedType filter_type = filter_cst.getType();
|
ShapedType filter_type = filter_cst.getType();
|
||||||
@ -362,7 +362,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
|||||||
// The new bias should be a 1-D tensor with length equals to the bias
|
// The new bias should be a 1-D tensor with length equals to the bias
|
||||||
// dimension of the weight.
|
// dimension of the weight.
|
||||||
SmallVector<APFloat, 4> new_bias_values;
|
SmallVector<APFloat, 4> new_bias_values;
|
||||||
if (bias->getType().isa<NoneType>()) { // none bias, a list of zeros
|
if (bias.getType().isa<NoneType>()) { // none bias, a list of zeros
|
||||||
new_bias_values.resize(bias_size, APFloat(0.0));
|
new_bias_values.resize(bias_size, APFloat(0.0));
|
||||||
} else if (bias_cst.getNumElements() == 1) { // scalar bias, broadcast it
|
} else if (bias_cst.getNumElements() == 1) { // scalar bias, broadcast it
|
||||||
new_bias_values.resize(bias_size, *bias_cst.float_value_begin());
|
new_bias_values.resize(bias_size, *bias_cst.float_value_begin());
|
||||||
@ -401,12 +401,12 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
|||||||
// We recreate the constant op in case it is shared by the other ops. This
|
// We recreate the constant op in case it is shared by the other ops. This
|
||||||
// might increase the model size.
|
// might increase the model size.
|
||||||
auto new_filter_op = rewriter.create<ConstOp>(
|
auto new_filter_op = rewriter.create<ConstOp>(
|
||||||
fc_op.getLoc(), filter->getType(), new_filter);
|
fc_op.getLoc(), filter.getType(), new_filter);
|
||||||
fc_op.setOperand(0, binary_op->getOperand(0));
|
fc_op.setOperand(0, binary_op->getOperand(0));
|
||||||
if (fc_op.filter() != filter) {
|
if (fc_op.filter() != filter) {
|
||||||
// This filter goes through quantize and dequantize ops. Then we just
|
// This filter goes through quantize and dequantize ops. Then we just
|
||||||
// need to update the weight to the quantize op.
|
// need to update the weight to the quantize op.
|
||||||
filter->replaceAllUsesWith(new_filter_op);
|
filter.replaceAllUsesWith(new_filter_op);
|
||||||
} else {
|
} else {
|
||||||
// This filter doesn't go through quantize and dequantize ops, Then
|
// This filter doesn't go through quantize and dequantize ops, Then
|
||||||
// we update the weight of the affine op directly.
|
// we update the weight of the affine op directly.
|
||||||
|
@ -55,7 +55,7 @@ foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
|||||||
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
|
defm : FuseActFnIntoConvOpPat<actFnPair[0], actFnPair[1]>;
|
||||||
|
|
||||||
// Checks if the value has only one user.
|
// Checks if the value has only one user.
|
||||||
def HasOneUse : Constraint<CPred<"$0->hasOneUse()">>;
|
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||||
|
|
||||||
// If we see a binary op (add, sub) op adding a constant value to a convolution
|
// If we see a binary op (add, sub) op adding a constant value to a convolution
|
||||||
// op with constant bias, we can fuse the binary op into the convolution op by
|
// op with constant bias, we can fuse the binary op into the convolution op by
|
||||||
@ -161,7 +161,7 @@ def EqualOperands : Constraint<CPred<"$0 == $1">>;
|
|||||||
|
|
||||||
// Checks if the operand has rank == n
|
// Checks if the operand has rank == n
|
||||||
class OperandHasRank<int n> : Constraint<
|
class OperandHasRank<int n> : Constraint<
|
||||||
CPred<"$0->getType().cast<ShapedType>().getRank() == " # n>>;
|
CPred<"$0.getType().cast<ShapedType>().getRank() == " # n>>;
|
||||||
|
|
||||||
// Matching HardSwish
|
// Matching HardSwish
|
||||||
def : Pat<
|
def : Pat<
|
||||||
@ -256,7 +256,7 @@ foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]]
|
|||||||
in defm : L2NormalizePatterns<L2NormalizePairs[0], L2NormalizePairs[1]>;
|
in defm : L2NormalizePatterns<L2NormalizePairs[0], L2NormalizePairs[1]>;
|
||||||
|
|
||||||
def AreBroadcastableTypes : Constraint<CPred<
|
def AreBroadcastableTypes : Constraint<CPred<
|
||||||
"TFL::IsBroadcastableElementsAttrAndType($0->getType(), $1->getType())">>;
|
"TFL::IsBroadcastableElementsAttrAndType($0.getType(), $1.getType())">>;
|
||||||
|
|
||||||
// Pattern for skipping Tile if it is mainly for broadcasting and the
|
// Pattern for skipping Tile if it is mainly for broadcasting and the
|
||||||
// Op is already supporting broadcasting.
|
// Op is already supporting broadcasting.
|
||||||
|
@ -71,29 +71,29 @@ void RemoveQuantizationAdaptorOps(FuncOp func) {
|
|||||||
|
|
||||||
auto remove_quantize_op = [&](QuantizeOp quantize_op) {
|
auto remove_quantize_op = [&](QuantizeOp quantize_op) {
|
||||||
auto quantize_output = quantize_op.output();
|
auto quantize_output = quantize_op.output();
|
||||||
auto quantize_type = quantize_output->getType();
|
auto quantize_type = quantize_output.getType();
|
||||||
input_types.push_back(quantize_type);
|
input_types.push_back(quantize_type);
|
||||||
auto new_arg = bb.addArgument(quantize_type);
|
auto new_arg = bb.addArgument(quantize_type);
|
||||||
quantize_output->replaceAllUsesWith(new_arg);
|
quantize_output.replaceAllUsesWith(new_arg);
|
||||||
quantize_op.erase();
|
quantize_op.erase();
|
||||||
arg->dropAllUses();
|
arg.dropAllUses();
|
||||||
bb.eraseArgument(0);
|
bb.eraseArgument(0);
|
||||||
};
|
};
|
||||||
|
|
||||||
// This is looking for a pattern: arg -> tfl.quantize
|
// This is looking for a pattern: arg -> tfl.quantize
|
||||||
if (arg->hasOneUse() && llvm::isa<QuantizeOp>(*arg->user_begin())) {
|
if (arg.hasOneUse() && llvm::isa<QuantizeOp>(*arg.user_begin())) {
|
||||||
auto quantize_op = llvm::cast<QuantizeOp>(*arg->user_begin());
|
auto quantize_op = llvm::cast<QuantizeOp>(*arg.user_begin());
|
||||||
remove_quantize_op(quantize_op);
|
remove_quantize_op(quantize_op);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make a copy of current argument and append it to the end of the list if
|
// Make a copy of current argument and append it to the end of the list if
|
||||||
// the pattern isn't found.
|
// the pattern isn't found.
|
||||||
Type arg_type = arg->getType();
|
Type arg_type = arg.getType();
|
||||||
input_types.push_back(arg_type);
|
input_types.push_back(arg_type);
|
||||||
auto new_arg = bb.addArgument(arg_type);
|
auto new_arg = bb.addArgument(arg_type);
|
||||||
arg->replaceAllUsesWith(new_arg);
|
arg.replaceAllUsesWith(new_arg);
|
||||||
arg->dropAllUses();
|
arg.dropAllUses();
|
||||||
bb.eraseArgument(0);
|
bb.eraseArgument(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -103,15 +103,15 @@ void RemoveQuantizationAdaptorOps(FuncOp func) {
|
|||||||
output_types.reserve(num_return_operands);
|
output_types.reserve(num_return_operands);
|
||||||
for (int i = 0; i != num_return_operands; ++i) {
|
for (int i = 0; i != num_return_operands; ++i) {
|
||||||
auto returned_value = terminator->getOperand(i);
|
auto returned_value = terminator->getOperand(i);
|
||||||
Operation* returned_op = returned_value->getDefiningOp();
|
Operation* returned_op = returned_value.getDefiningOp();
|
||||||
if (returned_op && llvm::isa<DequantizeOp>(returned_op)) {
|
if (returned_op && llvm::isa<DequantizeOp>(returned_op)) {
|
||||||
auto dequantize_op = llvm::cast<DequantizeOp>(returned_op);
|
auto dequantize_op = llvm::cast<DequantizeOp>(returned_op);
|
||||||
Value dequantized_result = dequantize_op.input();
|
Value dequantized_result = dequantize_op.input();
|
||||||
output_types.push_back(dequantized_result->getType());
|
output_types.push_back(dequantized_result.getType());
|
||||||
terminator->setOperand(i, dequantized_result);
|
terminator->setOperand(i, dequantized_result);
|
||||||
returned_op->erase();
|
returned_op->erase();
|
||||||
} else {
|
} else {
|
||||||
output_types.push_back(returned_value->getType());
|
output_types.push_back(returned_value.getType());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto new_func_type = builder.getFunctionType(input_types, output_types);
|
auto new_func_type = builder.getFunctionType(input_types, output_types);
|
||||||
|
@ -135,10 +135,10 @@ def : Pat<(TF_ReshapeOp
|
|||||||
// Casts result type of $1 to a quantized type by using the quantization
|
// Casts result type of $1 to a quantized type by using the quantization
|
||||||
// parameters from the type in $0.
|
// parameters from the type in $0.
|
||||||
class UpdateShapeWithAxis<int i> : NativeCodeCall<
|
class UpdateShapeWithAxis<int i> : NativeCodeCall<
|
||||||
"CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1->getType(), " # i # ")">;
|
"CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1.getType(), " # i # ")">;
|
||||||
|
|
||||||
class UsedBy<string op> : Constraint<
|
class UsedBy<string op> : Constraint<
|
||||||
CPred<"llvm::isa<mlir::TFL::" # op # "Op>(*$0->getUsers().begin())">>;
|
CPred<"llvm::isa<mlir::TFL::" # op # "Op>(*$0.getUsers().begin())">>;
|
||||||
|
|
||||||
// When the op is passing-through, the output types of the quantized ops need
|
// When the op is passing-through, the output types of the quantized ops need
|
||||||
// to be updated as well. Since the quantize op manages its own type by the
|
// to be updated as well. Since the quantize op manages its own type by the
|
||||||
|
@ -153,7 +153,7 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
|
|||||||
params);
|
params);
|
||||||
auto dq_op =
|
auto dq_op =
|
||||||
builder.create<TFL::DequantizeOp>(loc, input_type, q_op.output());
|
builder.create<TFL::DequantizeOp>(loc, input_type, q_op.output());
|
||||||
arg->replaceAllUsesWith(dq_op.output());
|
arg.replaceAllUsesWith(dq_op.output());
|
||||||
q_op.setOperand(arg);
|
q_op.setOperand(arg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -161,8 +161,8 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
|
|||||||
|
|
||||||
for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
|
for (int i = 0, e = func.getNumArguments(); i != e; ++i) {
|
||||||
BlockArgument arg = func.getArgument(i);
|
BlockArgument arg = func.getArgument(i);
|
||||||
auto* arg_block = arg->getOwner();
|
auto* arg_block = arg.getOwner();
|
||||||
add_quantize_op(arg->getLoc(), arg->getType(), arg_block,
|
add_quantize_op(arg.getLoc(), arg.getType(), arg_block,
|
||||||
std::next(arg_block->begin(), i), arg, i);
|
std::next(arg_block->begin(), i), arg, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,7 +115,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
|||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
// We don't want to insert quantize/dequantize if the quantize op exists.
|
// We don't want to insert quantize/dequantize if the quantize op exists.
|
||||||
auto res = tf_op.outputs();
|
auto res = tf_op.outputs();
|
||||||
if (!res->hasOneUse() || isa<QuantizeOp>(*res->user_begin()))
|
if (!res.hasOneUse() || isa<QuantizeOp>(*res.user_begin()))
|
||||||
return this->matchFailure();
|
return this->matchFailure();
|
||||||
|
|
||||||
// Extract the min/max constant values from the operands. We also consider
|
// Extract the min/max constant values from the operands. We also consider
|
||||||
@ -123,9 +123,9 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
|||||||
// constants and the tf.FakeQuantWithMinMaxVarsOp.
|
// constants and the tf.FakeQuantWithMinMaxVarsOp.
|
||||||
Value min = tf_op.min(), max = tf_op.max();
|
Value min = tf_op.min(), max = tf_op.max();
|
||||||
DenseFPElementsAttr min_value, max_value;
|
DenseFPElementsAttr min_value, max_value;
|
||||||
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min->getDefiningOp()))
|
if (auto id1 = dyn_cast_or_null<TF::IdentityOp>(min.getDefiningOp()))
|
||||||
min = id1.input();
|
min = id1.input();
|
||||||
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max->getDefiningOp()))
|
if (auto id2 = dyn_cast_or_null<TF::IdentityOp>(max.getDefiningOp()))
|
||||||
max = id2.input();
|
max = id2.input();
|
||||||
if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
|
if (!matchPattern(min, m_Constant(&min_value))) return this->matchFailure();
|
||||||
if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
|
if (!matchPattern(max, m_Constant(&max_value))) return this->matchFailure();
|
||||||
@ -133,7 +133,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
|||||||
int quant_dim = -1;
|
int quant_dim = -1;
|
||||||
if (PerAxis) {
|
if (PerAxis) {
|
||||||
// This is a special case that the quant_dim is the last dimensions.
|
// This is a special case that the quant_dim is the last dimensions.
|
||||||
quant_dim = res->getType().template cast<ShapedType>().getRank() - 1;
|
quant_dim = res.getType().template cast<ShapedType>().getRank() - 1;
|
||||||
}
|
}
|
||||||
// Use the min/max from the operands and the num_bits and narrow_range
|
// Use the min/max from the operands and the num_bits and narrow_range
|
||||||
// attribute to create the quantization parameter for the new quantize op.
|
// attribute to create the quantization parameter for the new quantize op.
|
||||||
@ -155,7 +155,7 @@ struct InsertTFLQuantOpsAfterTFFakeQuantOp
|
|||||||
tf_op.getLoc(), qtype.getValue(), value, qtype);
|
tf_op.getLoc(), qtype.getValue(), value, qtype);
|
||||||
auto dequantize = rewriter.create<TFL::DequantizeOp>(
|
auto dequantize = rewriter.create<TFL::DequantizeOp>(
|
||||||
tf_op.getLoc(), res_type, quantize.output());
|
tf_op.getLoc(), res_type, quantize.output());
|
||||||
value->replaceAllUsesWith(dequantize);
|
value.replaceAllUsesWith(dequantize);
|
||||||
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
|
quantize.getOperation()->replaceUsesOfWith(dequantize, value);
|
||||||
|
|
||||||
return this->matchSuccess();
|
return this->matchSuccess();
|
||||||
@ -240,7 +240,7 @@ struct ConvertTFConvOp : public RewritePattern {
|
|||||||
// that we can extract info from the shape (e.g., for constructing bias
|
// that we can extract info from the shape (e.g., for constructing bias
|
||||||
// tensor, for setting depth_multiplier attribute, etc.).
|
// tensor, for setting depth_multiplier attribute, etc.).
|
||||||
auto filter_type =
|
auto filter_type =
|
||||||
tf_op.filter()->getType().template dyn_cast<RankedTensorType>();
|
tf_op.filter().getType().template dyn_cast<RankedTensorType>();
|
||||||
if (filter_type && filter_type.getRank() == 4)
|
if (filter_type && filter_type.getRank() == 4)
|
||||||
return matchSuccess(std::move(state));
|
return matchSuccess(std::move(state));
|
||||||
|
|
||||||
@ -262,7 +262,7 @@ struct ConvertTFConvOp : public RewritePattern {
|
|||||||
|
|
||||||
// Get a splat zero tensor with the expected dimension for the bias tensor
|
// Get a splat zero tensor with the expected dimension for the bias tensor
|
||||||
auto filter = tf_op.filter();
|
auto filter = tf_op.filter();
|
||||||
auto filter_type = filter->getType().template cast<RankedTensorType>();
|
auto filter_type = filter.getType().template cast<RankedTensorType>();
|
||||||
auto elem_type = filter_type.getElementType();
|
auto elem_type = filter_type.getElementType();
|
||||||
auto bias_dim = static_cast<const ConcreteType *>(this)->getBiasDim(
|
auto bias_dim = static_cast<const ConcreteType *>(this)->getBiasDim(
|
||||||
filter_type.getShape());
|
filter_type.getShape());
|
||||||
@ -323,7 +323,7 @@ class ConvertTFConv2D : public ConvertTFConvOp<ConvertTFConv2D, TF::Conv2DOp> {
|
|||||||
auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr);
|
auto perm_op = rewriter.create<TF::ConstOp>(loc, perm_type, perm_attr);
|
||||||
|
|
||||||
// Create tensor type for the transpose result.
|
// Create tensor type for the transpose result.
|
||||||
auto filter_type = filter->getType().cast<RankedTensorType>();
|
auto filter_type = filter.getType().cast<RankedTensorType>();
|
||||||
auto result_shape = functional::map(
|
auto result_shape = functional::map(
|
||||||
[filter_type](int64_t dim) { return filter_type.getDimSize(dim); },
|
[filter_type](int64_t dim) { return filter_type.getDimSize(dim); },
|
||||||
perm);
|
perm);
|
||||||
@ -356,7 +356,7 @@ class ConvertTFDepthwiseConv2dNative
|
|||||||
// have a corresponding 'depth_multiplier' attribute; the multiplier is the
|
// have a corresponding 'depth_multiplier' attribute; the multiplier is the
|
||||||
// fourth dimension in the 4-D filter tensor. We query the multiplier from
|
// fourth dimension in the 4-D filter tensor. We query the multiplier from
|
||||||
// tf.DepthwiseConv2dNative and set it as the attribute value accordingly.
|
// tf.DepthwiseConv2dNative and set it as the attribute value accordingly.
|
||||||
auto multiplier = filter->getType().cast<RankedTensorType>().getDimSize(3);
|
auto multiplier = filter.getType().cast<RankedTensorType>().getDimSize(3);
|
||||||
|
|
||||||
filter = legalizeFilter(rewriter, loc, filter);
|
filter = legalizeFilter(rewriter, loc, filter);
|
||||||
return rewriter.create<TFL::DepthwiseConv2DOp>(
|
return rewriter.create<TFL::DepthwiseConv2DOp>(
|
||||||
@ -380,7 +380,7 @@ class ConvertTFDepthwiseConv2dNative
|
|||||||
/// RankedTensorType.
|
/// RankedTensorType.
|
||||||
Value legalizeFilter(PatternRewriter &rewriter, Location loc,
|
Value legalizeFilter(PatternRewriter &rewriter, Location loc,
|
||||||
Value filter) const {
|
Value filter) const {
|
||||||
auto filter_type = filter->getType().cast<RankedTensorType>();
|
auto filter_type = filter.getType().cast<RankedTensorType>();
|
||||||
auto filterShape = filter_type.getShape();
|
auto filterShape = filter_type.getShape();
|
||||||
SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1],
|
SmallVector<int64_t, 4> result_shape = {1, filterShape[0], filterShape[1],
|
||||||
filterShape[2] * filterShape[3]};
|
filterShape[2] * filterShape[3]};
|
||||||
@ -432,11 +432,11 @@ struct ConvertTFStridedSlice : public RewritePattern {
|
|||||||
// Insert a new reshape op.
|
// Insert a new reshape op.
|
||||||
Value original_input = strided_slice_op.input();
|
Value original_input = strided_slice_op.input();
|
||||||
RankedTensorType original_input_type =
|
RankedTensorType original_input_type =
|
||||||
original_input->getType().cast<RankedTensorType>();
|
original_input.getType().cast<RankedTensorType>();
|
||||||
const ArrayRef<int64_t> &original_input_shape =
|
const ArrayRef<int64_t> &original_input_shape =
|
||||||
original_input_type.getShape();
|
original_input_type.getShape();
|
||||||
RankedTensorType begin_type =
|
RankedTensorType begin_type =
|
||||||
strided_slice_op.begin()->getType().cast<RankedTensorType>();
|
strided_slice_op.begin().getType().cast<RankedTensorType>();
|
||||||
const int dim_size = begin_type.getShape()[0];
|
const int dim_size = begin_type.getShape()[0];
|
||||||
SmallVector<int64_t, 4> new_shape;
|
SmallVector<int64_t, 4> new_shape;
|
||||||
int mask = 1;
|
int mask = 1;
|
||||||
|
@ -83,7 +83,7 @@ LogicalResult DuplicateValueIfNeeded(Operation* op,
|
|||||||
// We can only clone the constant op at this point.
|
// We can only clone the constant op at this point.
|
||||||
// Since all ops have been legalized to tflite ops, so we only care about
|
// Since all ops have been legalized to tflite ops, so we only care about
|
||||||
// ConstOp or QConstOp or mlir constant op/
|
// ConstOp or QConstOp or mlir constant op/
|
||||||
Operation* input_op = operand->getDefiningOp();
|
Operation* input_op = operand.getDefiningOp();
|
||||||
if (input_op == nullptr) return failure();
|
if (input_op == nullptr) return failure();
|
||||||
|
|
||||||
Attribute attr;
|
Attribute attr;
|
||||||
|
@ -83,7 +83,7 @@ TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
|
|||||||
template <typename BatchMatMulOpType>
|
template <typename BatchMatMulOpType>
|
||||||
std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
|
std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
|
||||||
Value value, int batch_size, Location loc, PatternRewriter& rewriter) {
|
Value value, int batch_size, Location loc, PatternRewriter& rewriter) {
|
||||||
RankedTensorType tensorType = value->getType().cast<RankedTensorType>();
|
RankedTensorType tensorType = value.getType().cast<RankedTensorType>();
|
||||||
Type element_type = tensorType.getElementType();
|
Type element_type = tensorType.getElementType();
|
||||||
|
|
||||||
int rank = tensorType.getShape().size();
|
int rank = tensorType.getShape().size();
|
||||||
@ -127,7 +127,7 @@ std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
|
|||||||
template <typename BatchMatMulOpType>
|
template <typename BatchMatMulOpType>
|
||||||
TF::TransposeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createTransposeOp(
|
TF::TransposeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createTransposeOp(
|
||||||
Value value, Location loc, PatternRewriter& rewriter) {
|
Value value, Location loc, PatternRewriter& rewriter) {
|
||||||
auto value_type = value->getType().cast<RankedTensorType>();
|
auto value_type = value.getType().cast<RankedTensorType>();
|
||||||
auto shape = value_type.getShape();
|
auto shape = value_type.getShape();
|
||||||
int dims = shape.size();
|
int dims = shape.size();
|
||||||
|
|
||||||
@ -197,17 +197,17 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
|
|||||||
Value input_lhs = op.x();
|
Value input_lhs = op.x();
|
||||||
Value input_rhs = op.y();
|
Value input_rhs = op.y();
|
||||||
|
|
||||||
if (!input_lhs->getType().isa<RankedTensorType>()) {
|
if (!input_lhs.getType().isa<RankedTensorType>()) {
|
||||||
// LHS must be a ranked tensor type
|
// LHS must be a ranked tensor type
|
||||||
return this->matchFailure();
|
return this->matchFailure();
|
||||||
}
|
}
|
||||||
if (!input_rhs->getType().isa<RankedTensorType>()) {
|
if (!input_rhs.getType().isa<RankedTensorType>()) {
|
||||||
// RHS must be a ranked tensor type
|
// RHS must be a ranked tensor type
|
||||||
return this->matchFailure();
|
return this->matchFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto lhs_type = input_lhs->getType().cast<RankedTensorType>();
|
auto lhs_type = input_lhs.getType().cast<RankedTensorType>();
|
||||||
auto rhs_type = input_rhs->getType().cast<RankedTensorType>();
|
auto rhs_type = input_rhs.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
auto element_type = lhs_type.getElementType();
|
auto element_type = lhs_type.getElementType();
|
||||||
|
|
||||||
@ -233,7 +233,7 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
|
|||||||
if (op.adj_x()) {
|
if (op.adj_x()) {
|
||||||
input_lhs = createTransposeOp(input_lhs, loc, rewriter);
|
input_lhs = createTransposeOp(input_lhs, loc, rewriter);
|
||||||
|
|
||||||
lhs_type = input_lhs->getType().cast<RankedTensorType>();
|
lhs_type = input_lhs.getType().cast<RankedTensorType>();
|
||||||
lhs_shape = lhs_type.getShape();
|
lhs_shape = lhs_type.getShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -241,7 +241,7 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
|
|||||||
if (op.adj_y()) {
|
if (op.adj_y()) {
|
||||||
input_rhs = createTransposeOp(input_rhs, loc, rewriter);
|
input_rhs = createTransposeOp(input_rhs, loc, rewriter);
|
||||||
|
|
||||||
rhs_type = input_rhs->getType().cast<RankedTensorType>();
|
rhs_type = input_rhs.getType().cast<RankedTensorType>();
|
||||||
rhs_shape = rhs_type.getShape();
|
rhs_shape = rhs_type.getShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
|
|||||||
}
|
}
|
||||||
|
|
||||||
ArrayRef<int64_t> GetRankedTensorShape(Value value) {
|
ArrayRef<int64_t> GetRankedTensorShape(Value value) {
|
||||||
return value->getType().cast<RankedTensorType>().getShape();
|
return value.getType().cast<RankedTensorType>().getShape();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value SliceRankedTensor(OpBuilder* builder, Value input,
|
Value SliceRankedTensor(OpBuilder* builder, Value input,
|
||||||
@ -120,7 +120,7 @@ Value SliceRankedTensor(OpBuilder* builder, Value input,
|
|||||||
location,
|
location,
|
||||||
RankedTensorType::get(
|
RankedTensorType::get(
|
||||||
size_values,
|
size_values,
|
||||||
input->getType().cast<RankedTensorType>().getElementType()),
|
input.getType().cast<RankedTensorType>().getElementType()),
|
||||||
input, slice_i2c_begin, slice_i2c_size);
|
input, slice_i2c_begin, slice_i2c_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -327,8 +327,7 @@ void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() {
|
|||||||
SmallVector<int64_t, 2> output_shape{1, -1};
|
SmallVector<int64_t, 2> output_shape{1, -1};
|
||||||
auto input_types = fused_func_op_.getType().getInputs();
|
auto input_types = fused_func_op_.getType().getInputs();
|
||||||
auto output_type = mlir::RankedTensorType::get(
|
auto output_type = mlir::RankedTensorType::get(
|
||||||
output_shape,
|
output_shape, input_.getType().cast<RankedTensorType>().getElementType());
|
||||||
input_->getType().cast<RankedTensorType>().getElementType());
|
|
||||||
fused_func_op_.setType(mlir::FunctionType::get(input_types, output_type,
|
fused_func_op_.setType(mlir::FunctionType::get(input_types, output_type,
|
||||||
fused_func_op_.getContext()));
|
fused_func_op_.getContext()));
|
||||||
}
|
}
|
||||||
@ -351,8 +350,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
|
|||||||
// Create the fused LSTM op.
|
// Create the fused LSTM op.
|
||||||
SmallVector<int64_t, 2> output_shape = {1, n_output_};
|
SmallVector<int64_t, 2> output_shape = {1, n_output_};
|
||||||
auto result_type = mlir::RankedTensorType::get(
|
auto result_type = mlir::RankedTensorType::get(
|
||||||
output_shape,
|
output_shape, input_.getType().cast<RankedTensorType>().getElementType());
|
||||||
input_->getType().cast<RankedTensorType>().getElementType());
|
|
||||||
lstm_ = builder_.create<mlir::TFL::LSTMOp>(
|
lstm_ = builder_.create<mlir::TFL::LSTMOp>(
|
||||||
fused_func_op_.getLoc(), result_type, input_, input2input_, input2forget_,
|
fused_func_op_.getLoc(), result_type, input_, input2input_, input2forget_,
|
||||||
input2cell_, input2output_, rec2input_, rec2forget_, rec2cell_,
|
input2cell_, input2output_, rec2input_, rec2forget_, rec2cell_,
|
||||||
@ -371,7 +369,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
|
|||||||
SmallVector<int64_t, 2> func_output_shape = {1, -1};
|
SmallVector<int64_t, 2> func_output_shape = {1, -1};
|
||||||
auto func_result_type = mlir::RankedTensorType::get(
|
auto func_result_type = mlir::RankedTensorType::get(
|
||||||
func_output_shape,
|
func_output_shape,
|
||||||
input_->getType().cast<RankedTensorType>().getElementType());
|
input_.getType().cast<RankedTensorType>().getElementType());
|
||||||
|
|
||||||
auto tensor_cast = builder_.create<mlir::TensorCastOp>(
|
auto tensor_cast = builder_.create<mlir::TensorCastOp>(
|
||||||
fused_func_op_.getLoc(), lstm_.getResult(), func_result_type);
|
fused_func_op_.getLoc(), lstm_.getResult(), func_result_type);
|
||||||
@ -426,7 +424,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() {
|
|||||||
bias_ = fused_func_op_.getArgument(2);
|
bias_ = fused_func_op_.getArgument(2);
|
||||||
|
|
||||||
weight_ = fused_func_op_.getArgument(1);
|
weight_ = fused_func_op_.getArgument(1);
|
||||||
weight_type_ = weight_->getType().cast<RankedTensorType>();
|
weight_type_ = weight_.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
if (weight_type_.getRank() != 2) {
|
if (weight_type_.getRank() != 2) {
|
||||||
return fused_func_op_.emitError() << "The weight tensor was not of rank 2";
|
return fused_func_op_.emitError() << "The weight tensor was not of rank 2";
|
||||||
@ -440,7 +438,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() {
|
|||||||
n_cell_ = weight_type_.getDimSize(1) / num_gates_;
|
n_cell_ = weight_type_.getDimSize(1) / num_gates_;
|
||||||
|
|
||||||
projection_ = fused_func_op_.getArgument(3);
|
projection_ = fused_func_op_.getArgument(3);
|
||||||
projection_type_ = projection_->getType().cast<RankedTensorType>();
|
projection_type_ = projection_.getType().cast<RankedTensorType>();
|
||||||
if (projection_type_.getRank() != 2) {
|
if (projection_type_.getRank() != 2) {
|
||||||
n_output_ = n_cell_;
|
n_output_ = n_cell_;
|
||||||
} else {
|
} else {
|
||||||
@ -467,8 +465,7 @@ LogicalResult ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::Initialize() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
layer_norm_scale_ = fused_func_op_.getArgument(4);
|
layer_norm_scale_ = fused_func_op_.getArgument(4);
|
||||||
layer_norm_scale_type_ =
|
layer_norm_scale_type_ = layer_norm_scale_.getType().cast<RankedTensorType>();
|
||||||
layer_norm_scale_->getType().cast<RankedTensorType>();
|
|
||||||
if (layer_norm_scale_type_.getRank() != 1) {
|
if (layer_norm_scale_type_.getRank() != 1) {
|
||||||
return fused_func_op_.emitError()
|
return fused_func_op_.emitError()
|
||||||
<< "The layer_norm_scale tensor was not of rank 1";
|
<< "The layer_norm_scale tensor was not of rank 1";
|
||||||
|
@ -128,22 +128,20 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) {
|
|||||||
|
|
||||||
auto transpose_op = fused_lstm_func_.getBody().front().begin();
|
auto transpose_op = fused_lstm_func_.getBody().front().begin();
|
||||||
transpose_op++;
|
transpose_op++;
|
||||||
EXPECT_EQ(transpose_op->getOperand(0)
|
|
||||||
->getType()
|
|
||||||
.cast<RankedTensorType>()
|
|
||||||
.getDimSize(0),
|
|
||||||
3);
|
|
||||||
EXPECT_EQ(transpose_op->getOperand(0)
|
|
||||||
->getType()
|
|
||||||
.cast<RankedTensorType>()
|
|
||||||
.getDimSize(1),
|
|
||||||
12);
|
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(
|
||||||
transpose_op->getResult(0)->getType().cast<RankedTensorType>().getDimSize(
|
transpose_op->getOperand(0).getType().cast<RankedTensorType>().getDimSize(
|
||||||
|
0),
|
||||||
|
3);
|
||||||
|
EXPECT_EQ(
|
||||||
|
transpose_op->getOperand(0).getType().cast<RankedTensorType>().getDimSize(
|
||||||
|
1),
|
||||||
|
12);
|
||||||
|
EXPECT_EQ(
|
||||||
|
transpose_op->getResult(0).getType().cast<RankedTensorType>().getDimSize(
|
||||||
0),
|
0),
|
||||||
12);
|
12);
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(
|
||||||
transpose_op->getResult(0)->getType().cast<RankedTensorType>().getDimSize(
|
transpose_op->getResult(0).getType().cast<RankedTensorType>().getDimSize(
|
||||||
1),
|
1),
|
||||||
3);
|
3);
|
||||||
|
|
||||||
@ -156,12 +154,12 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) {
|
|||||||
EXPECT_EQ(it->getNumOperands(), 24);
|
EXPECT_EQ(it->getNumOperands(), 24);
|
||||||
EXPECT_EQ(it->getNumResults(), 1);
|
EXPECT_EQ(it->getNumResults(), 1);
|
||||||
// cifg = false, so input2input is not None.
|
// cifg = false, so input2input is not None.
|
||||||
EXPECT_FALSE(it->getOperand(1)->getType().isa<NoneType>());
|
EXPECT_FALSE(it->getOperand(1).getType().isa<NoneType>());
|
||||||
// input layer norm is None
|
// input layer norm is None
|
||||||
EXPECT_TRUE(it->getOperand(20)->getType().isa<NoneType>());
|
EXPECT_TRUE(it->getOperand(20).getType().isa<NoneType>());
|
||||||
// proj_bias is F32
|
// proj_bias is F32
|
||||||
EXPECT_TRUE(it->getOperand(17)
|
EXPECT_TRUE(it->getOperand(17)
|
||||||
->getType()
|
.getType()
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType()
|
.getElementType()
|
||||||
.isF32());
|
.isF32());
|
||||||
@ -169,7 +167,7 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) {
|
|||||||
// output gate bias is 0 since it is out of bounds of the bias tensor, so
|
// output gate bias is 0 since it is out of bounds of the bias tensor, so
|
||||||
// we set its value as a const tensor of specified size and value 0.
|
// we set its value as a const tensor of specified size and value 0.
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
mlir::cast<mlir::ConstantOp>(it->getOpOperand(15).get()->getDefiningOp())
|
mlir::cast<mlir::ConstantOp>(it->getOpOperand(15).get().getDefiningOp())
|
||||||
.getValue()
|
.getValue()
|
||||||
.cast<ElementsAttr>()
|
.cast<ElementsAttr>()
|
||||||
.getValue<FloatAttr>(0)
|
.getValue<FloatAttr>(0)
|
||||||
@ -209,7 +207,7 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimpleToFusedLSTMCoupleInputForget) {
|
|||||||
EXPECT_EQ(it->getNumOperands(), 24);
|
EXPECT_EQ(it->getNumOperands(), 24);
|
||||||
EXPECT_EQ(it->getNumResults(), 1);
|
EXPECT_EQ(it->getNumResults(), 1);
|
||||||
// cifg = true, so input2input is None.
|
// cifg = true, so input2input is None.
|
||||||
EXPECT_TRUE(it->getOperand(1)->getType().isa<NoneType>());
|
EXPECT_TRUE(it->getOperand(1).getType().isa<NoneType>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) {
|
TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) {
|
||||||
@ -235,15 +233,15 @@ TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) {
|
|||||||
EXPECT_EQ(it->getNumOperands(), 24);
|
EXPECT_EQ(it->getNumOperands(), 24);
|
||||||
EXPECT_EQ(it->getNumResults(), 1);
|
EXPECT_EQ(it->getNumResults(), 1);
|
||||||
// cifg = false, so input2input is not None.
|
// cifg = false, so input2input is not None.
|
||||||
EXPECT_FALSE(it->getOperand(1)->getType().isa<NoneType>());
|
EXPECT_FALSE(it->getOperand(1).getType().isa<NoneType>());
|
||||||
|
|
||||||
// input layer norm
|
// input layer norm
|
||||||
EXPECT_FALSE(it->getOperand(20)->getType().isa<NoneType>());
|
EXPECT_FALSE(it->getOperand(20).getType().isa<NoneType>());
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(
|
||||||
it->getOperand(20)->getType().cast<RankedTensorType>().getShape().size(),
|
it->getOperand(20).getType().cast<RankedTensorType>().getShape().size(),
|
||||||
1);
|
1);
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(it->getOperand(20).getType().cast<RankedTensorType>().getDimSize(0),
|
||||||
it->getOperand(20)->getType().cast<RankedTensorType>().getDimSize(0), 3);
|
3);
|
||||||
|
|
||||||
EXPECT_EQ(fused_ln_lstm_func_.getType().getNumResults(), 1);
|
EXPECT_EQ(fused_ln_lstm_func_.getType().getNumResults(), 1);
|
||||||
auto output_types = fused_ln_lstm_func_.getType().getResults();
|
auto output_types = fused_ln_lstm_func_.getType().getResults();
|
||||||
|
@ -52,7 +52,7 @@ bool TFIntListIsAllOnes(const ArrayAttr &attr);
|
|||||||
// Returns true iff the given value is a float tensor.
|
// Returns true iff the given value is a float tensor.
|
||||||
// is "DT_FLOAT".
|
// is "DT_FLOAT".
|
||||||
inline bool TFTypeIsFloatTensor(Value value) {
|
inline bool TFTypeIsFloatTensor(Value value) {
|
||||||
auto tensorType = value->getType().dyn_cast<TensorType>();
|
auto tensorType = value.getType().dyn_cast<TensorType>();
|
||||||
if (!tensorType) return false;
|
if (!tensorType) return false;
|
||||||
return tensorType.getElementType().isa<FloatType>();
|
return tensorType.getElementType().isa<FloatType>();
|
||||||
}
|
}
|
||||||
|
@ -149,17 +149,17 @@ std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) {
|
|||||||
return op->getName().getStringRef();
|
return op->getName().getStringRef();
|
||||||
}
|
}
|
||||||
auto val = op_or_val.dyn_cast<mlir::Value>();
|
auto val = op_or_val.dyn_cast<mlir::Value>();
|
||||||
auto name_from_loc = GetNameFromLoc(val->getLoc());
|
auto name_from_loc = GetNameFromLoc(val.getLoc());
|
||||||
if (!name_from_loc.empty()) return name_from_loc;
|
if (!name_from_loc.empty()) return name_from_loc;
|
||||||
// If the location is none of the expected types, then simply use name
|
// If the location is none of the expected types, then simply use name
|
||||||
// generated using the op type. Follow TF convention and append the result
|
// generated using the op type. Follow TF convention and append the result
|
||||||
// index unless 0.
|
// index unless 0.
|
||||||
if (auto result = val->dyn_cast<mlir::OpResult>()) {
|
if (auto result = val.dyn_cast<mlir::OpResult>()) {
|
||||||
if (result->getResultNumber() > 0)
|
if (result.getResultNumber() > 0)
|
||||||
return llvm::formatv("{0}:{1}",
|
return llvm::formatv("{0}:{1}",
|
||||||
result->getOwner()->getName().getStringRef(),
|
result.getOwner()->getName().getStringRef(),
|
||||||
result->getResultNumber());
|
result.getResultNumber());
|
||||||
return result->getOwner()->getName().getStringRef();
|
return result.getOwner()->getName().getStringRef();
|
||||||
}
|
}
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
@ -84,17 +84,17 @@ int64_t FindPassthroughArgumentForReturnValue(int64_t return_index,
|
|||||||
FuncOp func_op) {
|
FuncOp func_op) {
|
||||||
auto value =
|
auto value =
|
||||||
func_op.getBody().front().getTerminator()->getOperand(return_index);
|
func_op.getBody().front().getTerminator()->getOperand(return_index);
|
||||||
assert(mlir::getElementTypeOrSelf(value->getType()).isa<TF::ResourceType>());
|
assert(mlir::getElementTypeOrSelf(value.getType()).isa<TF::ResourceType>());
|
||||||
int64_t arg_index = -1;
|
int64_t arg_index = -1;
|
||||||
auto try_parse_arg_index = [&arg_index](Value v) {
|
auto try_parse_arg_index = [&arg_index](Value v) {
|
||||||
auto resource_arg = v->dyn_cast<BlockArgument>();
|
auto resource_arg = v.dyn_cast<BlockArgument>();
|
||||||
if (resource_arg) arg_index = resource_arg->getArgNumber();
|
if (resource_arg) arg_index = resource_arg.getArgNumber();
|
||||||
return arg_index;
|
return arg_index;
|
||||||
};
|
};
|
||||||
while (try_parse_arg_index(value) == -1) {
|
while (try_parse_arg_index(value) == -1) {
|
||||||
auto op = value->getDefiningOp();
|
auto op = value.getDefiningOp();
|
||||||
assert(op);
|
assert(op);
|
||||||
int64_t res_num = value->cast<OpResult>()->getResultNumber();
|
int64_t res_num = value.cast<OpResult>().getResultNumber();
|
||||||
if (auto graph = llvm::dyn_cast<tf_executor::GraphOp>(op)) {
|
if (auto graph = llvm::dyn_cast<tf_executor::GraphOp>(op)) {
|
||||||
value = graph.GetFetch().getOperand(res_num);
|
value = graph.GetFetch().getOperand(res_num);
|
||||||
} else if (auto island = llvm::dyn_cast<tf_executor::IslandOp>(op)) {
|
} else if (auto island = llvm::dyn_cast<tf_executor::IslandOp>(op)) {
|
||||||
@ -126,13 +126,13 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
|
|||||||
// Before having that, we assume function arguments do not alias each other.
|
// Before having that, we assume function arguments do not alias each other.
|
||||||
int64_t next_unique_id = 0;
|
int64_t next_unique_id = 0;
|
||||||
for (auto arg : func_op.getArguments()) {
|
for (auto arg : func_op.getArguments()) {
|
||||||
if (!mlir::getElementTypeOrSelf(arg->getType()).isa<TF::ResourceType>())
|
if (!mlir::getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>())
|
||||||
continue;
|
continue;
|
||||||
resource_value_to_ids_[arg].insert(next_unique_id++);
|
resource_value_to_ids_[arg].insert(next_unique_id++);
|
||||||
}
|
}
|
||||||
llvm::StringMap<int64_t> var_handle_name_id_map;
|
llvm::StringMap<int64_t> var_handle_name_id_map;
|
||||||
auto forward_input_to_output = [&](Value operand, Value result) {
|
auto forward_input_to_output = [&](Value operand, Value result) {
|
||||||
if (!mlir::getElementTypeOrSelf(result->getType()).isa<TF::ResourceType>())
|
if (!mlir::getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>())
|
||||||
return;
|
return;
|
||||||
auto& result_ids = resource_value_to_ids_[result];
|
auto& result_ids = resource_value_to_ids_[result];
|
||||||
auto operand_it = resource_value_to_ids_.find(operand);
|
auto operand_it = resource_value_to_ids_.find(operand);
|
||||||
@ -161,8 +161,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
|
|||||||
// analysis. Inside that block, we can still treat its block arguments as
|
// analysis. Inside that block, we can still treat its block arguments as
|
||||||
// different resources.
|
// different resources.
|
||||||
for (auto arg : replicate.GetBody().getArguments()) {
|
for (auto arg : replicate.GetBody().getArguments()) {
|
||||||
if (mlir::getElementTypeOrSelf(arg->getType())
|
if (mlir::getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) {
|
||||||
.isa<TF::ResourceType>()) {
|
|
||||||
resource_value_to_ids_[arg].insert(next_unique_id++);
|
resource_value_to_ids_[arg].insert(next_unique_id++);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -171,7 +170,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
|
|||||||
// If a result is a passthrough of the body input, use the corresponding
|
// If a result is a passthrough of the body input, use the corresponding
|
||||||
// operand's resource IDs.
|
// operand's resource IDs.
|
||||||
for (auto result : llvm::enumerate(while_op.getResults())) {
|
for (auto result : llvm::enumerate(while_op.getResults())) {
|
||||||
if (!mlir::getElementTypeOrSelf(result.value()->getType())
|
if (!mlir::getElementTypeOrSelf(result.value().getType())
|
||||||
.isa<TF::ResourceType>()) {
|
.isa<TF::ResourceType>()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -192,7 +191,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
|
|||||||
// If a result is a passthrough of both branches' inputs, merge the
|
// If a result is a passthrough of both branches' inputs, merge the
|
||||||
// resource IDs of corresponding operands for the two inputs.
|
// resource IDs of corresponding operands for the two inputs.
|
||||||
for (auto result : llvm::enumerate(if_op.getResults())) {
|
for (auto result : llvm::enumerate(if_op.getResults())) {
|
||||||
if (!mlir::getElementTypeOrSelf(result.value()->getType())
|
if (!mlir::getElementTypeOrSelf(result.value().getType())
|
||||||
.isa<TF::ResourceType>()) {
|
.isa<TF::ResourceType>()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -211,7 +210,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (auto result : op->getResults()) {
|
for (auto result : op->getResults()) {
|
||||||
if (!mlir::getElementTypeOrSelf(result->getType())
|
if (!mlir::getElementTypeOrSelf(result.getType())
|
||||||
.isa<TF::ResourceType>())
|
.isa<TF::ResourceType>())
|
||||||
continue;
|
continue;
|
||||||
resource_value_to_ids_[result].insert(kUnknownResourceId);
|
resource_value_to_ids_[result].insert(kUnknownResourceId);
|
||||||
@ -253,14 +252,14 @@ llvm::SmallDenseSet<int64_t, 8> FindAccessedResources(
|
|||||||
llvm::SmallDenseSet<int64_t, 8> resources;
|
llvm::SmallDenseSet<int64_t, 8> resources;
|
||||||
|
|
||||||
for (auto operand : op->getOperands()) {
|
for (auto operand : op->getOperands()) {
|
||||||
if (!mlir::getElementTypeOrSelf(operand->getType()).isa<TF::ResourceType>())
|
if (!mlir::getElementTypeOrSelf(operand.getType()).isa<TF::ResourceType>())
|
||||||
continue;
|
continue;
|
||||||
if (alias_analysis.IsUnknownResource(operand)) return UnknownResourceSet();
|
if (alias_analysis.IsUnknownResource(operand)) return UnknownResourceSet();
|
||||||
const auto& ids = alias_analysis.GetResourceUniqueIds(operand);
|
const auto& ids = alias_analysis.GetResourceUniqueIds(operand);
|
||||||
resources.insert(ids.begin(), ids.end());
|
resources.insert(ids.begin(), ids.end());
|
||||||
}
|
}
|
||||||
for (auto result : op->getResults()) {
|
for (auto result : op->getResults()) {
|
||||||
if (!mlir::getElementTypeOrSelf(result->getType()).isa<TF::ResourceType>())
|
if (!mlir::getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>())
|
||||||
continue;
|
continue;
|
||||||
if (alias_analysis.IsUnknownResource(result)) return UnknownResourceSet();
|
if (alias_analysis.IsUnknownResource(result)) return UnknownResourceSet();
|
||||||
const auto& ids = alias_analysis.GetResourceUniqueIds(result);
|
const auto& ids = alias_analysis.GetResourceUniqueIds(result);
|
||||||
|
@ -184,11 +184,11 @@ void Print(ReplicateOp op, OpAsmPrinter* p) {
|
|||||||
*p << '(';
|
*p << '(';
|
||||||
Block& block = op.body().front();
|
Block& block = op.body().front();
|
||||||
interleaveComma(block.getArguments(), *p, [&](BlockArgument arg) {
|
interleaveComma(block.getArguments(), *p, [&](BlockArgument arg) {
|
||||||
const int block_arg_num = arg->getArgNumber();
|
const int block_arg_num = arg.getArgNumber();
|
||||||
*p << '[';
|
*p << '[';
|
||||||
p->printOperands(std::next(op.operand_begin(), block_arg_num * n),
|
p->printOperands(std::next(op.operand_begin(), block_arg_num * n),
|
||||||
std::next(op.operand_begin(), (block_arg_num + 1) * n));
|
std::next(op.operand_begin(), (block_arg_num + 1) * n));
|
||||||
*p << "] as " << *arg << ": " << arg->getType();
|
*p << "] as " << arg << ": " << arg.getType();
|
||||||
});
|
});
|
||||||
*p << ')';
|
*p << ')';
|
||||||
}
|
}
|
||||||
@ -229,13 +229,13 @@ LogicalResult Verify(ReplicateOp op) {
|
|||||||
|
|
||||||
// Check replicated input types match block argument types.
|
// Check replicated input types match block argument types.
|
||||||
for (auto block_arg : block.getArguments()) {
|
for (auto block_arg : block.getArguments()) {
|
||||||
Type block_arg_type = block_arg->getType();
|
Type block_arg_type = block_arg.getType();
|
||||||
for (int i = n * block_arg->getArgNumber(), e = i + n; i < e; ++i)
|
for (int i = n * block_arg.getArgNumber(), e = i + n; i < e; ++i)
|
||||||
if (failed(VerifyCompatibleTypes(block_arg_type,
|
if (failed(VerifyCompatibleTypes(block_arg_type,
|
||||||
op.getOperand(i)->getType())))
|
op.getOperand(i).getType())))
|
||||||
return op.emitOpError()
|
return op.emitOpError()
|
||||||
<< "incompatible types for operand " << i
|
<< "incompatible types for operand " << i
|
||||||
<< " and block argument " << block_arg->getArgNumber();
|
<< " and block argument " << block_arg.getArgNumber();
|
||||||
}
|
}
|
||||||
|
|
||||||
Operation& terminator = block.back();
|
Operation& terminator = block.back();
|
||||||
@ -282,7 +282,7 @@ void BuildReplicateOp(
|
|||||||
DCHECK_EQ(llvm::size(replicated_input.first), n);
|
DCHECK_EQ(llvm::size(replicated_input.first), n);
|
||||||
for (auto input : replicated_input.first) {
|
for (auto input : replicated_input.first) {
|
||||||
DCHECK(succeeded(
|
DCHECK(succeeded(
|
||||||
VerifyCompatibleTypes(input->getType(), replicated_input.second)));
|
VerifyCompatibleTypes(input.getType(), replicated_input.second)));
|
||||||
state->addOperands(input);
|
state->addOperands(input);
|
||||||
}
|
}
|
||||||
block.addArgument(replicated_input.second);
|
block.addArgument(replicated_input.second);
|
||||||
|
@ -167,7 +167,7 @@ namespace {
|
|||||||
LogicalResult VerifyControlOperandsAfterAllData(Operation *op) {
|
LogicalResult VerifyControlOperandsAfterAllData(Operation *op) {
|
||||||
bool found_control = false;
|
bool found_control = false;
|
||||||
for (int operand_idx : llvm::seq<int>(0, op->getNumOperands())) {
|
for (int operand_idx : llvm::seq<int>(0, op->getNumOperands())) {
|
||||||
if (op->getOperand(operand_idx)->getType().isa<ControlType>()) {
|
if (op->getOperand(operand_idx).getType().isa<ControlType>()) {
|
||||||
found_control = true;
|
found_control = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -218,7 +218,7 @@ LogicalResult Verify(GraphOp graph) {
|
|||||||
for (int i : llvm::seq<int>(0, fetch.getNumOperands())) {
|
for (int i : llvm::seq<int>(0, fetch.getNumOperands())) {
|
||||||
Value operand = fetch.getOperand(i);
|
Value operand = fetch.getOperand(i);
|
||||||
// Break out of the loop at the first control operand encountered.
|
// Break out of the loop at the first control operand encountered.
|
||||||
if (operand->getType().isa<ControlType>()) {
|
if (operand.getType().isa<ControlType>()) {
|
||||||
if (i != graph.getNumResults())
|
if (i != graph.getNumResults())
|
||||||
return fetch.emitOpError()
|
return fetch.emitOpError()
|
||||||
<< "operand #" << i
|
<< "operand #" << i
|
||||||
@ -228,7 +228,7 @@ LogicalResult Verify(GraphOp graph) {
|
|||||||
if (i >= graph.getNumResults())
|
if (i >= graph.getNumResults())
|
||||||
return fetch.emitOpError()
|
return fetch.emitOpError()
|
||||||
<< "operand #" << i << " does not have a graph results to bind";
|
<< "operand #" << i << " does not have a graph results to bind";
|
||||||
if (graph.getResult(i)->getType() != operand->getType())
|
if (graph.getResult(i).getType() != operand.getType())
|
||||||
return fetch.emitOpError()
|
return fetch.emitOpError()
|
||||||
<< "operand #" << i << " type mismatch graph results";
|
<< "operand #" << i << " type mismatch graph results";
|
||||||
}
|
}
|
||||||
@ -331,8 +331,8 @@ LogicalResult Verify(IslandOp island) {
|
|||||||
<< "has " << yield.getNumOperands()
|
<< "has " << yield.getNumOperands()
|
||||||
<< " operand, but island returns " << result_count;
|
<< " operand, but island returns " << result_count;
|
||||||
for (int operand_idx : llvm::seq<int>(0, yield.getNumOperands())) {
|
for (int operand_idx : llvm::seq<int>(0, yield.getNumOperands())) {
|
||||||
if (island.getResult(operand_idx)->getType() !=
|
if (island.getResult(operand_idx).getType() !=
|
||||||
yield.getOperand(operand_idx)->getType())
|
yield.getOperand(operand_idx).getType())
|
||||||
return yield.emitOpError()
|
return yield.emitOpError()
|
||||||
<< "operand #" << operand_idx << " type mismatch island results";
|
<< "operand #" << operand_idx << " type mismatch island results";
|
||||||
}
|
}
|
||||||
@ -340,7 +340,7 @@ LogicalResult Verify(IslandOp island) {
|
|||||||
// Check that there aren't any control results other than the last one.
|
// Check that there aren't any control results other than the last one.
|
||||||
Type control_type = ControlType::get(island.getContext());
|
Type control_type = ControlType::get(island.getContext());
|
||||||
for (int operand_idx : llvm::seq<int>(0, island.getNumResults() - 1)) {
|
for (int operand_idx : llvm::seq<int>(0, island.getNumResults() - 1)) {
|
||||||
if (island.getResult(operand_idx)->getType() == control_type)
|
if (island.getResult(operand_idx).getType() == control_type)
|
||||||
return yield.emitOpError()
|
return yield.emitOpError()
|
||||||
<< "unexpected control type for operand #" << operand_idx;
|
<< "unexpected control type for operand #" << operand_idx;
|
||||||
}
|
}
|
||||||
@ -503,12 +503,12 @@ ParseResult ParseSwitchOp(OpAsmParser &parser, OperationState &result) {
|
|||||||
void Print(SwitchOp switch_op, OpAsmPrinter &p) {
|
void Print(SwitchOp switch_op, OpAsmPrinter &p) {
|
||||||
p << switch_op.getOperationName() << ' ';
|
p << switch_op.getOperationName() << ' ';
|
||||||
p.printOperands(switch_op.getOperands());
|
p.printOperands(switch_op.getOperands());
|
||||||
Type data_operand_ty = switch_op.data()->getType();
|
Type data_operand_ty = switch_op.data().getType();
|
||||||
// If the types aren't perfectly matching, print the functional type syntax
|
// If the types aren't perfectly matching, print the functional type syntax
|
||||||
// else print the shorter single type.
|
// else print the shorter single type.
|
||||||
p << " : ";
|
p << " : ";
|
||||||
if (switch_op.trueOutput()->getType() != data_operand_ty ||
|
if (switch_op.trueOutput().getType() != data_operand_ty ||
|
||||||
switch_op.falseOutput()->getType() != data_operand_ty) {
|
switch_op.falseOutput().getType() != data_operand_ty) {
|
||||||
p.printFunctionalType(switch_op.getOperation());
|
p.printFunctionalType(switch_op.getOperation());
|
||||||
} else {
|
} else {
|
||||||
p << switch_op.getType(0);
|
p << switch_op.getType(0);
|
||||||
@ -535,12 +535,12 @@ LogicalResult Verify(SwitchNOp switchn) {
|
|||||||
<< "expect `num_outs` (" << num_outs.getInt() << ") results but got "
|
<< "expect `num_outs` (" << num_outs.getInt() << ") results but got "
|
||||||
<< (switchn.getNumResults() - 1);
|
<< (switchn.getNumResults() - 1);
|
||||||
|
|
||||||
auto operand0_type = switchn.getOperand(0)->getType();
|
auto operand0_type = switchn.getOperand(0).getType();
|
||||||
for (Value result : switchn.outputs())
|
for (Value result : switchn.outputs())
|
||||||
if (operand0_type != result->getType())
|
if (operand0_type != result.getType())
|
||||||
return switchn.emitOpError()
|
return switchn.emitOpError()
|
||||||
<< "type mismatch between data operand and result: "
|
<< "type mismatch between data operand and result: "
|
||||||
<< operand0_type << " vs " << result->getType();
|
<< operand0_type << " vs " << result.getType();
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@ -616,12 +616,12 @@ LogicalResult Verify(MergeOp merge) {
|
|||||||
if (!merge.getNumOperands())
|
if (!merge.getNumOperands())
|
||||||
return merge.emitOpError() << "expects at least one operand";
|
return merge.emitOpError() << "expects at least one operand";
|
||||||
|
|
||||||
Type data_type = merge.getOperand(0)->getType();
|
Type data_type = merge.getOperand(0).getType();
|
||||||
if (data_type.isa<ControlType>())
|
if (data_type.isa<ControlType>())
|
||||||
return merge.emitOpError() << "expects a non-control input";
|
return merge.emitOpError() << "expects a non-control input";
|
||||||
|
|
||||||
// Check that each operand can be individually broadcasted to the output type.
|
// Check that each operand can be individually broadcasted to the output type.
|
||||||
Type output_type = merge.output()->getType();
|
Type output_type = merge.output().getType();
|
||||||
TensorType output_tensor_ty = output_type.dyn_cast<TensorType>();
|
TensorType output_tensor_ty = output_type.dyn_cast<TensorType>();
|
||||||
if (!output_tensor_ty) {
|
if (!output_tensor_ty) {
|
||||||
return merge.emitOpError()
|
return merge.emitOpError()
|
||||||
@ -666,7 +666,7 @@ void Print(MergeOp merge, OpAsmPrinter &p) {
|
|||||||
bool use_short_form = true;
|
bool use_short_form = true;
|
||||||
int num_data_operands = 0;
|
int num_data_operands = 0;
|
||||||
|
|
||||||
Type output_type = merge.output()->getType();
|
Type output_type = merge.output().getType();
|
||||||
for (Type operand_type : merge.getOperandTypes()) {
|
for (Type operand_type : merge.getOperandTypes()) {
|
||||||
if (operand_type.isa<ControlType>()) break;
|
if (operand_type.isa<ControlType>()) break;
|
||||||
num_data_operands++;
|
num_data_operands++;
|
||||||
@ -750,7 +750,7 @@ void Print(EnterOp enter, OpAsmPrinter &p) {
|
|||||||
// If the types aren't perfectly matching, print the functional type syntax
|
// If the types aren't perfectly matching, print the functional type syntax
|
||||||
// else print the shorter single type.
|
// else print the shorter single type.
|
||||||
p << " : ";
|
p << " : ";
|
||||||
if (enter.data()->getType() != enter.output()->getType()) {
|
if (enter.data().getType() != enter.output().getType()) {
|
||||||
p.printFunctionalType(enter.getOperation());
|
p.printFunctionalType(enter.getOperation());
|
||||||
} else {
|
} else {
|
||||||
p << enter.getType(0);
|
p << enter.getType(0);
|
||||||
@ -825,9 +825,9 @@ namespace {
|
|||||||
|
|
||||||
LogicalResult Verify(NextIterationSourceOp source) {
|
LogicalResult Verify(NextIterationSourceOp source) {
|
||||||
Value token = source.token();
|
Value token = source.token();
|
||||||
if (!token->hasOneUse())
|
if (!token.hasOneUse())
|
||||||
return source.emitOpError() << "expects a single user for produced token";
|
return source.emitOpError() << "expects a single user for produced token";
|
||||||
if (!isa<NextIterationSinkOp>(*token->user_begin()))
|
if (!isa<NextIterationSinkOp>(*token.user_begin()))
|
||||||
return source.emitOpError() << "token should be consumed by a sink op";
|
return source.emitOpError() << "token should be consumed by a sink op";
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@ -859,7 +859,7 @@ namespace {
|
|||||||
|
|
||||||
LogicalResult Verify(NextIterationSinkOp sink) {
|
LogicalResult Verify(NextIterationSinkOp sink) {
|
||||||
Value token = sink.token();
|
Value token = sink.token();
|
||||||
Operation *definingOp = token->getDefiningOp();
|
Operation *definingOp = token.getDefiningOp();
|
||||||
if (!definingOp)
|
if (!definingOp)
|
||||||
return sink.emitOpError() << "expects a token directly produced by a "
|
return sink.emitOpError() << "expects a token directly produced by a "
|
||||||
"tf_executor.NextIteration.Source op: ";
|
"tf_executor.NextIteration.Source op: ";
|
||||||
@ -867,11 +867,11 @@ LogicalResult Verify(NextIterationSinkOp sink) {
|
|||||||
if (!source)
|
if (!source)
|
||||||
return sink.emitOpError() << "expects a token produced by a "
|
return sink.emitOpError() << "expects a token produced by a "
|
||||||
"tf_executor.NextIteration.Source op: ";
|
"tf_executor.NextIteration.Source op: ";
|
||||||
if (source.output()->getType() != sink.input()->getType())
|
if (source.output().getType() != sink.input().getType())
|
||||||
return sink.emitOpError()
|
return sink.emitOpError()
|
||||||
<< "input type " << sink.input()->getType()
|
<< "input type " << sink.input().getType()
|
||||||
<< " mismatch the tf_executor.NextIteration.Source output type: "
|
<< " mismatch the tf_executor.NextIteration.Source output type: "
|
||||||
<< source.output()->getType();
|
<< source.output().getType();
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -880,7 +880,7 @@ void Print(NextIterationSinkOp next_iteration, OpAsmPrinter &p) {
|
|||||||
p.printOperand(next_iteration.getOperand(0));
|
p.printOperand(next_iteration.getOperand(0));
|
||||||
p << "] ";
|
p << "] ";
|
||||||
p.printOperands(llvm::drop_begin(next_iteration.getOperands(), 1));
|
p.printOperands(llvm::drop_begin(next_iteration.getOperands(), 1));
|
||||||
p << " : " << next_iteration.getOperand(1)->getType();
|
p << " : " << next_iteration.getOperand(1).getType();
|
||||||
p.printOptionalAttrDict(next_iteration.getAttrs());
|
p.printOptionalAttrDict(next_iteration.getAttrs());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -980,11 +980,11 @@ void Print(LoopCondOp loop_cond, OpAsmPrinter &p) {
|
|||||||
p.printOperands(loop_cond.getOperands());
|
p.printOperands(loop_cond.getOperands());
|
||||||
|
|
||||||
// If the types aren't matching (broadcast), print the functional type syntax.
|
// If the types aren't matching (broadcast), print the functional type syntax.
|
||||||
if (loop_cond.input()->getType() != loop_cond.output()->getType()) {
|
if (loop_cond.input().getType() != loop_cond.output().getType()) {
|
||||||
p << " : ";
|
p << " : ";
|
||||||
p.printFunctionalType(loop_cond.getOperation());
|
p.printFunctionalType(loop_cond.getOperation());
|
||||||
} else {
|
} else {
|
||||||
p << " : " << loop_cond.input()->getType();
|
p << " : " << loop_cond.input().getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
p.printOptionalAttrDict(loop_cond.getAttrs());
|
p.printOptionalAttrDict(loop_cond.getAttrs());
|
||||||
@ -1090,15 +1090,15 @@ struct HoistInnerOpsSingleIslandGraph : public OpRewritePattern<GraphOp> {
|
|||||||
llvm::SmallVector<Value, 8> new_rets;
|
llvm::SmallVector<Value, 8> new_rets;
|
||||||
for (Value operand : fetch_op.fetches()) {
|
for (Value operand : fetch_op.fetches()) {
|
||||||
// Control results should not be propagated out.
|
// Control results should not be propagated out.
|
||||||
if (operand->getType().isa<ControlType>()) break;
|
if (operand.getType().isa<ControlType>()) break;
|
||||||
|
|
||||||
if (operand->getDefiningOp() != island_op) {
|
if (operand.getDefiningOp() != island_op) {
|
||||||
// Operand is not from island, simply propagate it out.
|
// Operand is not from island, simply propagate it out.
|
||||||
new_rets.push_back(operand);
|
new_rets.push_back(operand);
|
||||||
} else {
|
} else {
|
||||||
// Lookup yield operand in island for inner op result.
|
// Lookup yield operand in island for inner op result.
|
||||||
auto result = operand->cast<OpResult>();
|
auto result = operand.cast<OpResult>();
|
||||||
new_rets.push_back(yield_op.getOperand(result->getResultNumber()));
|
new_rets.push_back(yield_op.getOperand(result.getResultNumber()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1138,7 +1138,7 @@ struct DropEmptyIslandNoOperandNoDataResult
|
|||||||
!HasSingleOpInBlock<YieldOp>(&op.GetBody()))
|
!HasSingleOpInBlock<YieldOp>(&op.GetBody()))
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
|
|
||||||
for (auto &use : llvm::make_early_inc_range(op.control()->getUses()))
|
for (auto &use : llvm::make_early_inc_range(op.control().getUses()))
|
||||||
use.getOwner()->eraseOperand(use.getOperandNumber());
|
use.getOwner()->eraseOperand(use.getOperandNumber());
|
||||||
|
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
@ -1158,7 +1158,7 @@ struct DropEmptyIslandNoOperandOneDataResult
|
|||||||
PatternMatchResult matchAndRewrite(IslandOp op,
|
PatternMatchResult matchAndRewrite(IslandOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
if (op.getNumOperands() != 0 || op.getNumResults() != 2 ||
|
if (op.getNumOperands() != 0 || op.getNumResults() != 2 ||
|
||||||
!op.control()->use_empty() ||
|
!op.control().use_empty() ||
|
||||||
!HasSingleOpInBlock<YieldOp>(&op.GetBody()))
|
!HasSingleOpInBlock<YieldOp>(&op.GetBody()))
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
|
|
||||||
@ -1193,7 +1193,7 @@ struct DropEmptyControlTrigger : public OpRewritePattern<ControlTriggerOp> {
|
|||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
if (op.getNumOperands() != 0) return matchFailure();
|
if (op.getNumOperands() != 0) return matchFailure();
|
||||||
|
|
||||||
for (auto &use : llvm::make_early_inc_range(op.control()->getUses()))
|
for (auto &use : llvm::make_early_inc_range(op.control().getUses()))
|
||||||
use.getOwner()->eraseOperand(use.getOperandNumber());
|
use.getOwner()->eraseOperand(use.getOperandNumber());
|
||||||
|
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
@ -460,7 +460,7 @@ def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source",
|
|||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
NextIterationSinkOp GetSink() {
|
NextIterationSinkOp GetSink() {
|
||||||
return cast<NextIterationSinkOp>(*token()->user_begin());
|
return cast<NextIterationSinkOp>(*token().user_begin());
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
@ -302,7 +302,7 @@ class WithBroadcastableBinOpBuilder {
|
|||||||
"Builder *builder, OperationState &result, Value x, Value y",
|
"Builder *builder, OperationState &result, Value x, Value y",
|
||||||
[{
|
[{
|
||||||
auto resultType =
|
auto resultType =
|
||||||
OpTrait::util::getBroadcastedType(x->getType(), y->getType());
|
OpTrait::util::getBroadcastedType(x.getType(), y.getType());
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
mlir::emitError(result.location, "non-broadcastable operands");
|
mlir::emitError(result.location, "non-broadcastable operands");
|
||||||
return build(builder, result, resultType, x, y);
|
return build(builder, result, resultType, x, y);
|
||||||
@ -317,14 +317,14 @@ class WithBroadcastableCmpOpBuilder {
|
|||||||
"Builder *builder, OperationState &result, Value x, Value y",
|
"Builder *builder, OperationState &result, Value x, Value y",
|
||||||
[{
|
[{
|
||||||
Type resultType;
|
Type resultType;
|
||||||
if (x->getType().isa<UnrankedTensorType>() ||
|
if (x.getType().isa<UnrankedTensorType>() ||
|
||||||
y->getType().isa<UnrankedTensorType>()) {
|
y.getType().isa<UnrankedTensorType>()) {
|
||||||
resultType = UnrankedTensorType::get(builder->getI1Type());
|
resultType = UnrankedTensorType::get(builder->getI1Type());
|
||||||
} else {
|
} else {
|
||||||
SmallVector<int64_t, 4> resultShape;
|
SmallVector<int64_t, 4> resultShape;
|
||||||
if (!OpTrait::util::getBroadcastedShape(
|
if (!OpTrait::util::getBroadcastedShape(
|
||||||
x->getType().cast<ShapedType>().getShape(),
|
x.getType().cast<ShapedType>().getShape(),
|
||||||
y->getType().cast<ShapedType>().getShape(), resultShape)) {
|
y.getType().cast<ShapedType>().getShape(), resultShape)) {
|
||||||
mlir::emitError(result.location,
|
mlir::emitError(result.location,
|
||||||
"operands have no broadcastable shapes");
|
"operands have no broadcastable shapes");
|
||||||
}
|
}
|
||||||
|
@ -77,7 +77,7 @@ static RankedTensorType GetRankedTensorTypeForOperand(Value operand) {
|
|||||||
if (matchPattern(operand, m_Constant(&attr))) {
|
if (matchPattern(operand, m_Constant(&attr))) {
|
||||||
return attr.getType().dyn_cast<RankedTensorType>();
|
return attr.getType().dyn_cast<RankedTensorType>();
|
||||||
}
|
}
|
||||||
return operand->getType().dyn_cast<RankedTensorType>();
|
return operand.getType().dyn_cast<RankedTensorType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if the given `value` is of ranked float tensor type with the
|
// Returns true if the given `value` is of ranked float tensor type with the
|
||||||
@ -161,7 +161,7 @@ static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
|
|||||||
static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x,
|
static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x,
|
||||||
Value y, BoolAttr incompatible_shape_error) {
|
Value y, BoolAttr incompatible_shape_error) {
|
||||||
auto result_type =
|
auto result_type =
|
||||||
OpTrait::util::getBroadcastedType(x->getType(), y->getType());
|
OpTrait::util::getBroadcastedType(x.getType(), y.getType());
|
||||||
if (!result_type) {
|
if (!result_type) {
|
||||||
if (incompatible_shape_error.getValue()) {
|
if (incompatible_shape_error.getValue()) {
|
||||||
mlir::emitError(loc, "non-broadcastable operands");
|
mlir::emitError(loc, "non-broadcastable operands");
|
||||||
@ -187,7 +187,7 @@ static int64_t GetDimForAxis(int64_t axis, int64_t rank) {
|
|||||||
// inference functions.
|
// inference functions.
|
||||||
static Type InferReductionOpType(Value input, Value reduction_indices,
|
static Type InferReductionOpType(Value input, Value reduction_indices,
|
||||||
BoolAttr keep_dims, Builder *builder) {
|
BoolAttr keep_dims, Builder *builder) {
|
||||||
Type input_ty = input->getType();
|
Type input_ty = input.getType();
|
||||||
Type element_ty = getElementTypeOrSelf(input_ty);
|
Type element_ty = getElementTypeOrSelf(input_ty);
|
||||||
|
|
||||||
// Output type is unranked if input type is not ranked.
|
// Output type is unranked if input type is not ranked.
|
||||||
@ -330,12 +330,12 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||||||
// Verifies an reduction op's `input` and reduction `dims`.
|
// Verifies an reduction op's `input` and reduction `dims`.
|
||||||
static LogicalResult VerifyReductionInputAndDims(Value input, Value dims,
|
static LogicalResult VerifyReductionInputAndDims(Value input, Value dims,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
auto dims_type = dims->getType().dyn_cast<RankedTensorType>();
|
auto dims_type = dims.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!dims_type) return success();
|
if (!dims_type) return success();
|
||||||
if (dims_type.getRank() > 1)
|
if (dims_type.getRank() > 1)
|
||||||
return emitError(loc, "dimensions can only be 0D or 1D tensor");
|
return emitError(loc, "dimensions can only be 0D or 1D tensor");
|
||||||
|
|
||||||
auto input_type = input->getType().dyn_cast<RankedTensorType>();
|
auto input_type = input.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!input_type) return success();
|
if (!input_type) return success();
|
||||||
int64_t rank = input_type.getRank();
|
int64_t rank = input_type.getRank();
|
||||||
|
|
||||||
@ -441,9 +441,8 @@ static LogicalResult Verify(BiasAddOp op) {
|
|||||||
if (!IsOfRankOrUnranked(op.bias(), 1))
|
if (!IsOfRankOrUnranked(op.bias(), 1))
|
||||||
return op.emitOpError("requires bias operand to have rank exactly one");
|
return op.emitOpError("requires bias operand to have rank exactly one");
|
||||||
|
|
||||||
RankedTensorType value_ty =
|
RankedTensorType value_ty = op.value().getType().dyn_cast<RankedTensorType>();
|
||||||
op.value()->getType().dyn_cast<RankedTensorType>();
|
RankedTensorType bias_ty = op.bias().getType().dyn_cast<RankedTensorType>();
|
||||||
RankedTensorType bias_ty = op.bias()->getType().dyn_cast<RankedTensorType>();
|
|
||||||
if (!bias_ty || !value_ty) return success();
|
if (!bias_ty || !value_ty) return success();
|
||||||
|
|
||||||
// TODO(hinsu): Leverage tensor_format.h utility in TensorFlow to compute
|
// TODO(hinsu): Leverage tensor_format.h utility in TensorFlow to compute
|
||||||
@ -552,7 +551,7 @@ static LogicalResult Verify(ConcatOffsetOp op) {
|
|||||||
<< "requires sizes of shapes and offsets to be the same, got sizes "
|
<< "requires sizes of shapes and offsets to be the same, got sizes "
|
||||||
<< op.shape().size() << " and " << op.offset().size();
|
<< op.shape().size() << " and " << op.offset().size();
|
||||||
|
|
||||||
auto ranked_dim = op.concat_dim()->getType().dyn_cast<RankedTensorType>();
|
auto ranked_dim = op.concat_dim().getType().dyn_cast<RankedTensorType>();
|
||||||
if (ranked_dim && ranked_dim.getRank() != 0)
|
if (ranked_dim && ranked_dim.getRank() != 0)
|
||||||
return op.emitOpError()
|
return op.emitOpError()
|
||||||
<< "requires concat_dim to be a scalar, got tensor of rank "
|
<< "requires concat_dim to be a scalar, got tensor of rank "
|
||||||
@ -565,11 +564,11 @@ static LogicalResult Verify(ConcatOffsetOp op) {
|
|||||||
Value offset = std::get<1>(shape_offset_idx.value());
|
Value offset = std::get<1>(shape_offset_idx.value());
|
||||||
const size_t idx = shape_offset_idx.index();
|
const size_t idx = shape_offset_idx.index();
|
||||||
|
|
||||||
if (failed(verifyCompatibleShape(shape->getType(), offset->getType())))
|
if (failed(verifyCompatibleShape(shape.getType(), offset.getType())))
|
||||||
return op.emitOpError() << "requires operand and result " << idx
|
return op.emitOpError() << "requires operand and result " << idx
|
||||||
<< " to have compatible shapes";
|
<< " to have compatible shapes";
|
||||||
|
|
||||||
auto ranked_shape = shape->getType().dyn_cast<RankedTensorType>();
|
auto ranked_shape = shape.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!ranked_shape) continue;
|
if (!ranked_shape) continue;
|
||||||
|
|
||||||
if (ranked_shape.getRank() != 1)
|
if (ranked_shape.getRank() != 1)
|
||||||
@ -786,7 +785,7 @@ static LogicalResult Verify(OpT op) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int64_t input_channels = -1;
|
int64_t input_channels = -1;
|
||||||
if (auto ty = op.input()->getType().template dyn_cast<RankedTensorType>()) {
|
if (auto ty = op.input().getType().template dyn_cast<RankedTensorType>()) {
|
||||||
std::string data_format = op.data_format().str();
|
std::string data_format = op.data_format().str();
|
||||||
tensorflow::TensorFormat format;
|
tensorflow::TensorFormat format;
|
||||||
auto is_valid = FormatFromString(data_format, &format);
|
auto is_valid = FormatFromString(data_format, &format);
|
||||||
@ -796,7 +795,7 @@ static LogicalResult Verify(OpT op) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int64_t filter_channels = -1;
|
int64_t filter_channels = -1;
|
||||||
if (auto ty = op.filter()->getType().template dyn_cast<RankedTensorType>()) {
|
if (auto ty = op.filter().getType().template dyn_cast<RankedTensorType>()) {
|
||||||
int idx = tensorflow::GetFilterTensorInputChannelsDimIndex(
|
int idx = tensorflow::GetFilterTensorInputChannelsDimIndex(
|
||||||
num_dims, tensorflow::FORMAT_HWIO);
|
num_dims, tensorflow::FORMAT_HWIO);
|
||||||
filter_channels = ty.getDimSize(idx);
|
filter_channels = ty.getDimSize(idx);
|
||||||
@ -876,8 +875,8 @@ static LogicalResult Verify(DynamicStitchOp op) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Value data = std::get<1>(it);
|
Value data = std::get<1>(it);
|
||||||
RankedTensorType index_ty = index->getType().dyn_cast<RankedTensorType>();
|
RankedTensorType index_ty = index.getType().dyn_cast<RankedTensorType>();
|
||||||
RankedTensorType data_ty = data->getType().dyn_cast<RankedTensorType>();
|
RankedTensorType data_ty = data.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!index_ty || !data_ty) continue;
|
if (!index_ty || !data_ty) continue;
|
||||||
|
|
||||||
int64_t index_rank = index_ty.getRank();
|
int64_t index_rank = index_ty.getRank();
|
||||||
@ -993,10 +992,10 @@ void EqualOp::build(Builder *builder, OperationState &result, Value x, Value y,
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
Type InferExpandDimsOpType(Value input, Value dim) {
|
Type InferExpandDimsOpType(Value input, Value dim) {
|
||||||
Type element_ty = input->getType().cast<TensorType>().getElementType();
|
Type element_ty = input.getType().cast<TensorType>().getElementType();
|
||||||
auto unranked_ty = UnrankedTensorType::get(element_ty);
|
auto unranked_ty = UnrankedTensorType::get(element_ty);
|
||||||
|
|
||||||
auto input_ty = input->getType().dyn_cast<RankedTensorType>();
|
auto input_ty = input.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!input_ty) return unranked_ty;
|
if (!input_ty) return unranked_ty;
|
||||||
|
|
||||||
DenseIntElementsAttr dim_attr;
|
DenseIntElementsAttr dim_attr;
|
||||||
@ -1076,14 +1075,14 @@ static LogicalResult Verify(FakeQuantWithMinMaxVarsPerChannelOp op) {
|
|||||||
|
|
||||||
Value inputs = op.inputs();
|
Value inputs = op.inputs();
|
||||||
if (!HasRankAtLeast(inputs, 1) ||
|
if (!HasRankAtLeast(inputs, 1) ||
|
||||||
inputs->getType().isa<UnrankedTensorType>()) {
|
inputs.getType().isa<UnrankedTensorType>()) {
|
||||||
return op.emitError("requires inputs to be at least 1d float tensor");
|
return op.emitError("requires inputs to be at least 1d float tensor");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto inputsType = inputs->getType().cast<ShapedType>();
|
auto inputsType = inputs.getType().cast<ShapedType>();
|
||||||
int depth = inputsType.getDimSize(inputsType.getRank() - 1);
|
int depth = inputsType.getDimSize(inputsType.getRank() - 1);
|
||||||
if (op.min()->getType().cast<ShapedType>().getDimSize(0) != depth ||
|
if (op.min().getType().cast<ShapedType>().getDimSize(0) != depth ||
|
||||||
op.max()->getType().cast<ShapedType>().getDimSize(0) != depth) {
|
op.max().getType().cast<ShapedType>().getDimSize(0) != depth) {
|
||||||
return op.emitOpError(
|
return op.emitOpError(
|
||||||
"requires min and max to have same size as last dimension of inputs");
|
"requires min and max to have same size as last dimension of inputs");
|
||||||
}
|
}
|
||||||
@ -1139,7 +1138,7 @@ static LogicalResult Verify(FusedBatchNormOp op) {
|
|||||||
|
|
||||||
static LogicalResult Verify(GatherV2Op op) {
|
static LogicalResult Verify(GatherV2Op op) {
|
||||||
int64_t batch_dims = op.batch_dims().getSExtValue();
|
int64_t batch_dims = op.batch_dims().getSExtValue();
|
||||||
if (auto ty = op.indices()->getType().dyn_cast<RankedTensorType>()) {
|
if (auto ty = op.indices().getType().dyn_cast<RankedTensorType>()) {
|
||||||
int64_t rank = ty.getRank();
|
int64_t rank = ty.getRank();
|
||||||
if (batch_dims > rank || batch_dims < -rank)
|
if (batch_dims > rank || batch_dims < -rank)
|
||||||
return op.emitOpError()
|
return op.emitOpError()
|
||||||
@ -1154,7 +1153,7 @@ static LogicalResult Verify(GatherV2Op op) {
|
|||||||
DenseIntElementsAttr axis_attr;
|
DenseIntElementsAttr axis_attr;
|
||||||
if (matchPattern(op.axis(), m_Constant(&axis_attr))) {
|
if (matchPattern(op.axis(), m_Constant(&axis_attr))) {
|
||||||
int64_t axis = (*axis_attr.begin()).getSExtValue();
|
int64_t axis = (*axis_attr.begin()).getSExtValue();
|
||||||
if (auto ty = op.params()->getType().dyn_cast<RankedTensorType>()) {
|
if (auto ty = op.params().getType().dyn_cast<RankedTensorType>()) {
|
||||||
int64_t rank = ty.getRank();
|
int64_t rank = ty.getRank();
|
||||||
if (axis >= rank || axis < -rank)
|
if (axis >= rank || axis < -rank)
|
||||||
return op.emitOpError() << "axis (" << axis << ") must be in range ["
|
return op.emitOpError() << "axis (" << axis << ") must be in range ["
|
||||||
@ -1197,7 +1196,7 @@ static LogicalResult Verify(IfOp op) {
|
|||||||
" inputs");
|
" inputs");
|
||||||
|
|
||||||
for (unsigned i = 0; i < expectedNumInputs; ++i) {
|
for (unsigned i = 0; i < expectedNumInputs; ++i) {
|
||||||
auto operandType = op.getOperand(i + 1)->getType().cast<TensorType>();
|
auto operandType = op.getOperand(i + 1).getType().cast<TensorType>();
|
||||||
auto thenInputType = thenFuncType.getInput(i).cast<TensorType>();
|
auto thenInputType = thenFuncType.getInput(i).cast<TensorType>();
|
||||||
if (!AreCastCompatible(operandType, thenInputType))
|
if (!AreCastCompatible(operandType, thenInputType))
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
@ -1228,7 +1227,7 @@ static LogicalResult Verify(IfOp op) {
|
|||||||
" results");
|
" results");
|
||||||
|
|
||||||
for (unsigned i = 0; i < expectedNumResults; ++i) {
|
for (unsigned i = 0; i < expectedNumResults; ++i) {
|
||||||
auto resultType = op.getResult(i)->getType().cast<TensorType>();
|
auto resultType = op.getResult(i).getType().cast<TensorType>();
|
||||||
auto thenResultType = thenFuncType.getResult(i).cast<TensorType>();
|
auto thenResultType = thenFuncType.getResult(i).cast<TensorType>();
|
||||||
if (!AreCastCompatible(thenResultType, resultType))
|
if (!AreCastCompatible(thenResultType, resultType))
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
@ -1364,7 +1363,7 @@ void NotEqualOp::build(Builder *builder, OperationState &result, Value x,
|
|||||||
static LogicalResult Verify(OneHotOp op) {
|
static LogicalResult Verify(OneHotOp op) {
|
||||||
int64_t axis = op.axis().getSExtValue();
|
int64_t axis = op.axis().getSExtValue();
|
||||||
|
|
||||||
auto indices_ty = op.indices()->getType().dyn_cast<RankedTensorType>();
|
auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
|
||||||
if (indices_ty &&
|
if (indices_ty &&
|
||||||
!(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) {
|
!(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) {
|
||||||
return op.emitOpError()
|
return op.emitOpError()
|
||||||
@ -1403,11 +1402,11 @@ static LogicalResult Verify(OneHotOp op) {
|
|||||||
static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value,
|
static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value,
|
||||||
Value off_value, IntegerAttr axis) {
|
Value off_value, IntegerAttr axis) {
|
||||||
int64_t axis_val = axis.getInt();
|
int64_t axis_val = axis.getInt();
|
||||||
Type element_ty = on_value->getType().cast<TensorType>().getElementType();
|
Type element_ty = on_value.getType().cast<TensorType>().getElementType();
|
||||||
auto unranked_ty = UnrankedTensorType::get(element_ty);
|
auto unranked_ty = UnrankedTensorType::get(element_ty);
|
||||||
if (axis_val < -1) return unranked_ty;
|
if (axis_val < -1) return unranked_ty;
|
||||||
|
|
||||||
auto indices_ty = indices->getType().dyn_cast<RankedTensorType>();
|
auto indices_ty = indices.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!indices_ty) return unranked_ty;
|
if (!indices_ty) return unranked_ty;
|
||||||
|
|
||||||
auto shape = llvm::to_vector<2>(indices_ty.getShape());
|
auto shape = llvm::to_vector<2>(indices_ty.getShape());
|
||||||
@ -1446,7 +1445,7 @@ static LogicalResult Verify(PackOp op) {
|
|||||||
|
|
||||||
int64_t inputs_rank = -1;
|
int64_t inputs_rank = -1;
|
||||||
for (Value value : values) {
|
for (Value value : values) {
|
||||||
if (auto ty = value->getType().dyn_cast<RankedTensorType>()) {
|
if (auto ty = value.getType().dyn_cast<RankedTensorType>()) {
|
||||||
// Exit early as input types are verified to be compatible so all ranked
|
// Exit early as input types are verified to be compatible so all ranked
|
||||||
// tensors have the same rank.
|
// tensors have the same rank.
|
||||||
inputs_rank = ty.getRank();
|
inputs_rank = ty.getRank();
|
||||||
@ -1548,8 +1547,8 @@ static LogicalResult Verify(RandomUniformOp op) {
|
|||||||
|
|
||||||
void RangeOp::build(Builder *builder, OperationState &result, Value start,
|
void RangeOp::build(Builder *builder, OperationState &result, Value start,
|
||||||
Value limit, Value delta) {
|
Value limit, Value delta) {
|
||||||
assert(start->getType() == limit->getType());
|
assert(start.getType() == limit.getType());
|
||||||
assert(start->getType() == delta->getType());
|
assert(start.getType() == delta.getType());
|
||||||
DenseIntElementsAttr start_val;
|
DenseIntElementsAttr start_val;
|
||||||
DenseIntElementsAttr limit_val;
|
DenseIntElementsAttr limit_val;
|
||||||
DenseIntElementsAttr delta_val;
|
DenseIntElementsAttr delta_val;
|
||||||
@ -1563,13 +1562,13 @@ void RangeOp::build(Builder *builder, OperationState &result, Value start,
|
|||||||
builder, result,
|
builder, result,
|
||||||
RankedTensorType::get(
|
RankedTensorType::get(
|
||||||
size.getSExtValue(),
|
size.getSExtValue(),
|
||||||
start->getType().cast<TensorType>().getElementType()),
|
start.getType().cast<TensorType>().getElementType()),
|
||||||
start, limit, delta);
|
start, limit, delta);
|
||||||
}
|
}
|
||||||
return RangeOp::build(
|
return RangeOp::build(
|
||||||
builder, result,
|
builder, result,
|
||||||
RankedTensorType::get(
|
RankedTensorType::get(
|
||||||
{-1}, start->getType().cast<TensorType>().getElementType()),
|
{-1}, start.getType().cast<TensorType>().getElementType()),
|
||||||
start, limit, delta);
|
start, limit, delta);
|
||||||
}
|
}
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1598,17 +1597,17 @@ void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||||||
// TODO(b/128020684): Verify the rank of the output and change to use
|
// TODO(b/128020684): Verify the rank of the output and change to use
|
||||||
// m_Constant.
|
// m_Constant.
|
||||||
static LogicalResult Verify(ReshapeOp op) {
|
static LogicalResult Verify(ReshapeOp op) {
|
||||||
auto shapeType = op.shape()->getType().cast<TensorType>();
|
auto shapeType = op.shape().getType().cast<TensorType>();
|
||||||
if (!shapeType.hasRank()) return success();
|
if (!shapeType.hasRank()) return success();
|
||||||
if (shapeType.getRank() != 1)
|
if (shapeType.getRank() != 1)
|
||||||
return op.emitOpError("shape must be 1D tensor");
|
return op.emitOpError("shape must be 1D tensor");
|
||||||
auto rankByShape = shapeType.getShape()[0];
|
auto rankByShape = shapeType.getShape()[0];
|
||||||
auto typeOfTensor = op.tensor()->getType().cast<TensorType>();
|
auto typeOfTensor = op.tensor().getType().cast<TensorType>();
|
||||||
// No compile time verification for unknown sized shape.
|
// No compile time verification for unknown sized shape.
|
||||||
if (rankByShape == -1 || !typeOfTensor.hasStaticShape()) return success();
|
if (rankByShape == -1 || !typeOfTensor.hasStaticShape()) return success();
|
||||||
// Check values if constant shape. No compiling time verification for
|
// Check values if constant shape. No compiling time verification for
|
||||||
// non-constant shape.
|
// non-constant shape.
|
||||||
auto *shapeOp = op.shape()->getDefiningOp();
|
auto *shapeOp = op.shape().getDefiningOp();
|
||||||
if (!shapeOp) return success();
|
if (!shapeOp) return success();
|
||||||
Attribute shapeCst;
|
Attribute shapeCst;
|
||||||
if (auto shapeStdOp = dyn_cast<ConstantOp>(shapeOp)) {
|
if (auto shapeStdOp = dyn_cast<ConstantOp>(shapeOp)) {
|
||||||
@ -1662,7 +1661,7 @@ static LogicalResult Verify(ReshapeOp op) {
|
|||||||
|
|
||||||
void ReshapeOp::build(Builder *builder, OperationState &result, Value tensor,
|
void ReshapeOp::build(Builder *builder, OperationState &result, Value tensor,
|
||||||
Value shape) {
|
Value shape) {
|
||||||
auto ttype = tensor->getType().cast<ShapedType>();
|
auto ttype = tensor.getType().cast<ShapedType>();
|
||||||
auto etype = ttype.getElementType();
|
auto etype = ttype.getElementType();
|
||||||
|
|
||||||
auto unranked = [builder, etype, &result, shape, tensor]() {
|
auto unranked = [builder, etype, &result, shape, tensor]() {
|
||||||
@ -1723,14 +1722,14 @@ void ReshapeOp::build(Builder *builder, OperationState &result, Value tensor,
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static Type InferSelectV2OpType(Value condition, Value e, Value t) {
|
static Type InferSelectV2OpType(Value condition, Value e, Value t) {
|
||||||
Type element_ty = e->getType().cast<TensorType>().getElementType();
|
Type element_ty = e.getType().cast<TensorType>().getElementType();
|
||||||
auto unranked_ty = UnrankedTensorType::get(element_ty);
|
auto unranked_ty = UnrankedTensorType::get(element_ty);
|
||||||
|
|
||||||
Type broadcasted_ty =
|
Type broadcasted_ty =
|
||||||
OpTrait::util::getBroadcastedType(e->getType(), t->getType());
|
OpTrait::util::getBroadcastedType(e.getType(), t.getType());
|
||||||
if (!broadcasted_ty) return unranked_ty;
|
if (!broadcasted_ty) return unranked_ty;
|
||||||
|
|
||||||
auto cond_ranked_ty = condition->getType().dyn_cast<RankedTensorType>();
|
auto cond_ranked_ty = condition.getType().dyn_cast<RankedTensorType>();
|
||||||
auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast<RankedTensorType>();
|
auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast<RankedTensorType>();
|
||||||
if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty;
|
if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty;
|
||||||
|
|
||||||
@ -1791,7 +1790,7 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
|
|||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
static LogicalResult Verify(ShapeOp op) {
|
static LogicalResult Verify(ShapeOp op) {
|
||||||
return VerifyShapeOperandAndResult(op, op.input()->getType(), op.getType());
|
return VerifyShapeOperandAndResult(op, op.input().getType(), op.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts shape of the given type to attribute if it is of ranked tensor type.
|
// Converts shape of the given type to attribute if it is of ranked tensor type.
|
||||||
@ -1816,12 +1815,12 @@ static Attribute ConvertShapeToAttr(Type input_ty, int out_width) {
|
|||||||
OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) {
|
||||||
int width =
|
int width =
|
||||||
getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
|
getType().cast<ShapedType>().getElementType().getIntOrFloatBitWidth();
|
||||||
return ConvertShapeToAttr(getOperand()->getType(), width);
|
return ConvertShapeToAttr(getOperand().getType(), width);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShapeOp::build(Builder *builder, OperationState &result, Value input,
|
void ShapeOp::build(Builder *builder, OperationState &result, Value input,
|
||||||
BoolAttr use32Bit) {
|
BoolAttr use32Bit) {
|
||||||
auto rankedTensorType = input->getType().dyn_cast<RankedTensorType>();
|
auto rankedTensorType = input.getType().dyn_cast<RankedTensorType>();
|
||||||
int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1;
|
int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1;
|
||||||
auto out_type = use32Bit.getValue() ? builder->getIntegerType(32)
|
auto out_type = use32Bit.getValue() ? builder->getIntegerType(32)
|
||||||
: builder->getIntegerType(64);
|
: builder->getIntegerType(64);
|
||||||
@ -1846,7 +1845,7 @@ static LogicalResult Verify(ShapeNOp op) {
|
|||||||
|
|
||||||
for (auto i : llvm::seq<uint64_t>(0, num_tensors)) {
|
for (auto i : llvm::seq<uint64_t>(0, num_tensors)) {
|
||||||
auto verification = VerifyShapeOperandAndResult(
|
auto verification = VerifyShapeOperandAndResult(
|
||||||
op, op.getOperand(i)->getType(), op.getResult(i)->getType(), i);
|
op, op.getOperand(i).getType(), op.getResult(i).getType(), i);
|
||||||
if (failed(verification)) return verification;
|
if (failed(verification)) return verification;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1919,7 +1918,7 @@ static LogicalResult Verify(SliceOp op) {
|
|||||||
" same number of elements";
|
" same number of elements";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto input_ty = op.input()->getType().dyn_cast<RankedTensorType>();
|
auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
|
||||||
if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) {
|
if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) {
|
||||||
return op.emitOpError() << "requires number of elements in begin and size"
|
return op.emitOpError() << "requires number of elements in begin and size"
|
||||||
"are equal to input rank";
|
"are equal to input rank";
|
||||||
@ -1973,7 +1972,7 @@ static LogicalResult Verify(SoftmaxOp op) {
|
|||||||
//
|
//
|
||||||
static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) {
|
static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) {
|
||||||
auto broadcasted_ty = OpTrait::util::getBroadcastedType(
|
auto broadcasted_ty = OpTrait::util::getBroadcastedType(
|
||||||
op.features()->getType(), op.labels()->getType())
|
op.features().getType(), op.labels().getType())
|
||||||
.dyn_cast_or_null<ShapedType>();
|
.dyn_cast_or_null<ShapedType>();
|
||||||
if (!broadcasted_ty ||
|
if (!broadcasted_ty ||
|
||||||
(broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2))
|
(broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2))
|
||||||
@ -1994,8 +1993,8 @@ static LogicalResult Verify(SparseSoftmaxCrossEntropyWithLogitsOp op) {
|
|||||||
if (!IsOfRankOrUnranked(op.labels(), 1)) {
|
if (!IsOfRankOrUnranked(op.labels(), 1)) {
|
||||||
return op.emitOpError("requires labels operand of rank one");
|
return op.emitOpError("requires labels operand of rank one");
|
||||||
}
|
}
|
||||||
auto features_ty = op.features()->getType().dyn_cast<RankedTensorType>();
|
auto features_ty = op.features().getType().dyn_cast<RankedTensorType>();
|
||||||
auto labels_ty = op.labels()->getType().dyn_cast<RankedTensorType>();
|
auto labels_ty = op.labels().getType().dyn_cast<RankedTensorType>();
|
||||||
if (features_ty && labels_ty) {
|
if (features_ty && labels_ty) {
|
||||||
int64_t features_batches = features_ty.getDimSize(0);
|
int64_t features_batches = features_ty.getDimSize(0);
|
||||||
int64_t labels_batches = labels_ty.getDimSize(0);
|
int64_t labels_batches = labels_ty.getDimSize(0);
|
||||||
@ -2020,7 +2019,7 @@ LogicalResult VerifySplitInputAndSplitDim(Op op, Optional<int64_t> *dim_index) {
|
|||||||
*dim_index = llvm::None;
|
*dim_index = llvm::None;
|
||||||
|
|
||||||
Value split_dim = op.split_dim();
|
Value split_dim = op.split_dim();
|
||||||
if (auto split_dim_type = split_dim->getType().dyn_cast<RankedTensorType>())
|
if (auto split_dim_type = split_dim.getType().dyn_cast<RankedTensorType>())
|
||||||
if (split_dim_type.getRank() != 0)
|
if (split_dim_type.getRank() != 0)
|
||||||
return op.emitOpError(
|
return op.emitOpError(
|
||||||
"split dimension should be an integer scalar tensor");
|
"split dimension should be an integer scalar tensor");
|
||||||
@ -2028,7 +2027,7 @@ LogicalResult VerifySplitInputAndSplitDim(Op op, Optional<int64_t> *dim_index) {
|
|||||||
// We can perform further verification if the input tensor to be split has
|
// We can perform further verification if the input tensor to be split has
|
||||||
// known rank and the split dimension tensor is a constant.
|
// known rank and the split dimension tensor is a constant.
|
||||||
|
|
||||||
auto input_type = op.value()->getType().template dyn_cast<RankedTensorType>();
|
auto input_type = op.value().getType().template dyn_cast<RankedTensorType>();
|
||||||
if (!input_type) return success();
|
if (!input_type) return success();
|
||||||
|
|
||||||
int64_t input_rank = input_type.getRank();
|
int64_t input_rank = input_type.getRank();
|
||||||
@ -2057,7 +2056,7 @@ static LogicalResult Verify(SplitOp op) {
|
|||||||
if (!dim_index) return success();
|
if (!dim_index) return success();
|
||||||
|
|
||||||
int64_t input_dim_size =
|
int64_t input_dim_size =
|
||||||
op.value()->getType().cast<RankedTensorType>().getDimSize(*dim_index);
|
op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
|
||||||
if (input_dim_size == ShapedType::kDynamicSize) return success();
|
if (input_dim_size == ShapedType::kDynamicSize) return success();
|
||||||
|
|
||||||
if (input_dim_size % op.getNumResults() != 0)
|
if (input_dim_size % op.getNumResults() != 0)
|
||||||
@ -2073,7 +2072,7 @@ static LogicalResult Verify(SplitOp op) {
|
|||||||
|
|
||||||
static LogicalResult Verify(SplitVOp op) {
|
static LogicalResult Verify(SplitVOp op) {
|
||||||
auto split_sizes_type =
|
auto split_sizes_type =
|
||||||
op.size_splits()->getType().dyn_cast<RankedTensorType>();
|
op.size_splits().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!split_sizes_type) return success();
|
if (!split_sizes_type) return success();
|
||||||
|
|
||||||
if (split_sizes_type.getRank() != 1 ||
|
if (split_sizes_type.getRank() != 1 ||
|
||||||
@ -2086,7 +2085,7 @@ static LogicalResult Verify(SplitVOp op) {
|
|||||||
if (!dim_index) return success();
|
if (!dim_index) return success();
|
||||||
|
|
||||||
int64_t input_dim_size =
|
int64_t input_dim_size =
|
||||||
op.value()->getType().cast<RankedTensorType>().getDimSize(*dim_index);
|
op.value().getType().cast<RankedTensorType>().getDimSize(*dim_index);
|
||||||
if (input_dim_size == ShapedType::kDynamicSize) return success();
|
if (input_dim_size == ShapedType::kDynamicSize) return success();
|
||||||
|
|
||||||
// If split sizes come from a constant, they must sum to the dimension size
|
// If split sizes come from a constant, they must sum to the dimension size
|
||||||
@ -2178,7 +2177,7 @@ static LogicalResult VerifyStridedSliceBase(OpTy op) {
|
|||||||
int64_t expected_size = -1;
|
int64_t expected_size = -1;
|
||||||
|
|
||||||
for (Value val : {op.begin(), op.end(), op.strides()}) {
|
for (Value val : {op.begin(), op.end(), op.strides()}) {
|
||||||
auto operand_ty = val->getType().dyn_cast<ShapedType>();
|
auto operand_ty = val.getType().dyn_cast<ShapedType>();
|
||||||
if (!operand_ty || !operand_ty.hasStaticShape()) {
|
if (!operand_ty || !operand_ty.hasStaticShape()) {
|
||||||
// TensorFlow constant ops may have non-static shape because the shape is
|
// TensorFlow constant ops may have non-static shape because the shape is
|
||||||
// not propagated during constant folding. If the defining op for this
|
// not propagated during constant folding. If the defining op for this
|
||||||
@ -2336,7 +2335,7 @@ bool StridedSliceOp::GetSlicedBoundRanges(
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static LogicalResult Verify(StridedSliceGradOp op) {
|
static LogicalResult Verify(StridedSliceGradOp op) {
|
||||||
auto shape_type = op.shape()->getType().dyn_cast<RankedTensorType>();
|
auto shape_type = op.shape().getType().dyn_cast<RankedTensorType>();
|
||||||
if (shape_type && shape_type.getRank() != 1)
|
if (shape_type && shape_type.getRank() != 1)
|
||||||
return op.emitOpError("'shape' operand must be 1D tensor, but got ")
|
return op.emitOpError("'shape' operand must be 1D tensor, but got ")
|
||||||
<< shape_type.getRank() << "D tensor";
|
<< shape_type.getRank() << "D tensor";
|
||||||
@ -2433,8 +2432,8 @@ static LogicalResult Verify(TensorScatterUpdateOp op) {
|
|||||||
return op.emitOpError(
|
return op.emitOpError(
|
||||||
"requires updates operand to have at least 1 dimension");
|
"requires updates operand to have at least 1 dimension");
|
||||||
|
|
||||||
auto tensor_ty = op.tensor()->getType().dyn_cast<RankedTensorType>();
|
auto tensor_ty = op.tensor().getType().dyn_cast<RankedTensorType>();
|
||||||
auto indices_ty = op.indices()->getType().dyn_cast<RankedTensorType>();
|
auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!tensor_ty || !indices_ty) return success();
|
if (!tensor_ty || !indices_ty) return success();
|
||||||
|
|
||||||
int64_t num_index_dims = indices_ty.getShape().back();
|
int64_t num_index_dims = indices_ty.getShape().back();
|
||||||
@ -2478,7 +2477,7 @@ static LogicalResult Verify(TransposeOp op) {
|
|||||||
// TODO(jpienaar): perm could be optional too.
|
// TODO(jpienaar): perm could be optional too.
|
||||||
void TransposeOp::build(Builder *builder, OperationState &result, Value x,
|
void TransposeOp::build(Builder *builder, OperationState &result, Value x,
|
||||||
Value perm) {
|
Value perm) {
|
||||||
auto x_type = x->getType().cast<TensorType>();
|
auto x_type = x.getType().cast<TensorType>();
|
||||||
// If value is unranked, then so is results.
|
// If value is unranked, then so is results.
|
||||||
if (!x_type.hasRank())
|
if (!x_type.hasRank())
|
||||||
return TransposeOp::build(builder, result,
|
return TransposeOp::build(builder, result,
|
||||||
@ -2509,7 +2508,7 @@ void TransposeOp::build(Builder *builder, OperationState &result, Value x,
|
|||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
||||||
auto const_perm = dyn_cast_or_null<TF::ConstOp>(perm()->getDefiningOp());
|
auto const_perm = dyn_cast_or_null<TF::ConstOp>(perm().getDefiningOp());
|
||||||
|
|
||||||
if (!const_perm) {
|
if (!const_perm) {
|
||||||
return {};
|
return {};
|
||||||
@ -2541,7 +2540,7 @@ void TruncateDivOp::getCanonicalizationPatterns(
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static LogicalResult Verify(UnpackOp op) {
|
static LogicalResult Verify(UnpackOp op) {
|
||||||
auto value_type = op.value()->getType().dyn_cast<RankedTensorType>();
|
auto value_type = op.value().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!value_type) return success();
|
if (!value_type) return success();
|
||||||
|
|
||||||
int64_t value_rank = value_type.getRank();
|
int64_t value_rank = value_type.getRank();
|
||||||
@ -2569,9 +2568,9 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
|
|||||||
if (!HasRankAtMost(op.num_segments(), 0))
|
if (!HasRankAtMost(op.num_segments(), 0))
|
||||||
return op.emitOpError("number of segments should be a 0-D tensor");
|
return op.emitOpError("number of segments should be a 0-D tensor");
|
||||||
|
|
||||||
auto data_type = op.data()->getType().template dyn_cast<RankedTensorType>();
|
auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
|
||||||
auto segment_ids_type =
|
auto segment_ids_type =
|
||||||
op.segment_ids()->getType().template dyn_cast<RankedTensorType>();
|
op.segment_ids().getType().template dyn_cast<RankedTensorType>();
|
||||||
if (data_type && segment_ids_type) {
|
if (data_type && segment_ids_type) {
|
||||||
if (data_type.getRank() < segment_ids_type.getRank())
|
if (data_type.getRank() < segment_ids_type.getRank())
|
||||||
return op.emitOpError(
|
return op.emitOpError(
|
||||||
@ -2609,7 +2608,7 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) {
|
|||||||
|
|
||||||
static LogicalResult Verify(VariableShapeOp op) {
|
static LogicalResult Verify(VariableShapeOp op) {
|
||||||
auto resource_operand_type = op.input()
|
auto resource_operand_type = op.input()
|
||||||
->getType()
|
.getType()
|
||||||
.cast<TensorType>()
|
.cast<TensorType>()
|
||||||
.getElementType()
|
.getElementType()
|
||||||
.cast<TF::ResourceType>();
|
.cast<TF::ResourceType>();
|
||||||
@ -2763,7 +2762,7 @@ struct TFInlinerInterface : public DialectInlinerInterface {
|
|||||||
Operation *materializeCallConversion(OpBuilder &builder, Value input,
|
Operation *materializeCallConversion(OpBuilder &builder, Value input,
|
||||||
Type result_type,
|
Type result_type,
|
||||||
Location conversion_loc) const final {
|
Location conversion_loc) const final {
|
||||||
if (!result_type.isa<TensorType>() || !input->getType().isa<TensorType>())
|
if (!result_type.isa<TensorType>() || !input.getType().isa<TensorType>())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return builder.create<TF::CastOp>(conversion_loc, result_type, input,
|
return builder.create<TF::CastOp>(conversion_loc, result_type, input,
|
||||||
/*truncate=*/builder.getBoolAttr(false));
|
/*truncate=*/builder.getBoolAttr(false));
|
||||||
|
@ -57,7 +57,7 @@ class TF_TensorListInitOp<string mnemonic> : TF_Op<mnemonic, [NoSideEffect]> {
|
|||||||
// Returns data type of the result handle. Returned type contains type of
|
// Returns data type of the result handle. Returned type contains type of
|
||||||
// the TensorList element as a subtype.
|
// the TensorList element as a subtype.
|
||||||
VariantType handle_dtype() {
|
VariantType handle_dtype() {
|
||||||
return getElementTypeOrSelf(handle()->getType()).cast<TF::VariantType>();
|
return getElementTypeOrSelf(handle().getType()).cast<TF::VariantType>();
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
@ -47,7 +47,7 @@ class OperandsSameAsResultsTypeOrRef
|
|||||||
LogicalResult shapeMatch = impl::verifySameOperandsAndResultShape(op);
|
LogicalResult shapeMatch = impl::verifySameOperandsAndResultShape(op);
|
||||||
if (failed(shapeMatch)) return shapeMatch;
|
if (failed(shapeMatch)) return shapeMatch;
|
||||||
|
|
||||||
auto type = getElementTypeOrSelf(op->getResult(0)->getType());
|
auto type = getElementTypeOrSelf(op->getResult(0).getType());
|
||||||
|
|
||||||
// Verify that the first result type is same as the rest of the results.
|
// Verify that the first result type is same as the rest of the results.
|
||||||
// We skip the comparison against itself.
|
// We skip the comparison against itself.
|
||||||
|
@ -23,7 +23,7 @@ def SingleResultAndOperandHaveSameElementType : Constraint<
|
|||||||
CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>;
|
CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>;
|
||||||
|
|
||||||
def SingleResultAndOperandHaveSameType : Constraint<
|
def SingleResultAndOperandHaveSameType : Constraint<
|
||||||
CPred<"$0->getType() == $1->getType()">>;
|
CPred<"$0.getType() == $1.getType()">>;
|
||||||
|
|
||||||
def IsRank2Tensor : Type<HasAnyRankOfPred<[2]>, "Rank 2 tensor">;
|
def IsRank2Tensor : Type<HasAnyRankOfPred<[2]>, "Rank 2 tensor">;
|
||||||
|
|
||||||
|
@ -70,9 +70,9 @@ StringRef GetDevice(Operation* op) {
|
|||||||
bool CanMergeIntoCluster(const Cluster& c, Operation* to_merge) {
|
bool CanMergeIntoCluster(const Cluster& c, Operation* to_merge) {
|
||||||
return llvm::all_of(to_merge->getOperands(), [&](Value operand) {
|
return llvm::all_of(to_merge->getOperands(), [&](Value operand) {
|
||||||
// Block arguments.
|
// Block arguments.
|
||||||
if (operand->isa<BlockArgument>()) return true;
|
if (operand.isa<BlockArgument>()) return true;
|
||||||
|
|
||||||
Operation* defining_op = operand->getDefiningOp();
|
Operation* defining_op = operand.getDefiningOp();
|
||||||
|
|
||||||
// Operand produced by other islands.
|
// Operand produced by other islands.
|
||||||
if (defining_op->getBlock() != c.ops.front()->getBlock()) return true;
|
if (defining_op->getBlock() != c.ops.front()->getBlock()) return true;
|
||||||
@ -100,7 +100,7 @@ void ReplaceLiveOutExternalUses(llvm::ArrayRef<Value> live_outs,
|
|||||||
Region* launch_op_region = &launch_op.body();
|
Region* launch_op_region = &launch_op.body();
|
||||||
for (const auto& p : llvm::zip(live_outs, launch_op.getResults())) {
|
for (const auto& p : llvm::zip(live_outs, launch_op.getResults())) {
|
||||||
Value from = std::get<0>(p);
|
Value from = std::get<0>(p);
|
||||||
for (auto& use : from->getUses()) {
|
for (auto& use : from.getUses()) {
|
||||||
if (launch_op_region->isAncestor(use.getOwner()->getParentRegion()))
|
if (launch_op_region->isAncestor(use.getOwner()->getParentRegion()))
|
||||||
continue;
|
continue;
|
||||||
use.set(std::get<1>(p));
|
use.set(std::get<1>(p));
|
||||||
@ -116,7 +116,7 @@ void GetLiveOuts(Region* region, llvm::SmallVectorImpl<Value>* live_outs) {
|
|||||||
for (Value v : op.getResults()) {
|
for (Value v : op.getResults()) {
|
||||||
// A value is live-out if any of its users are not inside value producer's
|
// A value is live-out if any of its users are not inside value producer's
|
||||||
// region.
|
// region.
|
||||||
bool is_live_out = llvm::any_of(v->getUsers(), [&](Operation* user) {
|
bool is_live_out = llvm::any_of(v.getUsers(), [&](Operation* user) {
|
||||||
return !region->isAncestor(user->getParentRegion());
|
return !region->isAncestor(user->getParentRegion());
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -158,7 +158,7 @@ void BuildLaunchForCluster(const Cluster& c, OpBuilder* builder) {
|
|||||||
llvm::SmallVector<Type, 4> live_out_types;
|
llvm::SmallVector<Type, 4> live_out_types;
|
||||||
live_out_types.reserve(live_outs.size());
|
live_out_types.reserve(live_outs.size());
|
||||||
for (Value v : live_outs) {
|
for (Value v : live_outs) {
|
||||||
live_out_types.emplace_back(v->getType());
|
live_out_types.emplace_back(v.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
tf_device::LaunchOp launch_op = builder->create<tf_device::LaunchOp>(
|
tf_device::LaunchOp launch_op = builder->create<tf_device::LaunchOp>(
|
||||||
|
@ -56,7 +56,7 @@ FuncOp BuildFunction(StringRef device, llvm::ArrayRef<Value> live_ins,
|
|||||||
OpBuilder* builder) {
|
OpBuilder* builder) {
|
||||||
llvm::SmallVector<Type, 4> operand_types;
|
llvm::SmallVector<Type, 4> operand_types;
|
||||||
operand_types.reserve(live_ins.size());
|
operand_types.reserve(live_ins.size());
|
||||||
for (Value v : live_ins) operand_types.emplace_back(v->getType());
|
for (Value v : live_ins) operand_types.emplace_back(v.getType());
|
||||||
|
|
||||||
llvm::SmallVector<Type, 4> result_types(launch_op.getResultTypes());
|
llvm::SmallVector<Type, 4> result_types(launch_op.getResultTypes());
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ def CreateTFReadVariableOp: NativeCodeCall<
|
|||||||
"$_builder.create<TF::ReadVariableOp>("
|
"$_builder.create<TF::ReadVariableOp>("
|
||||||
" $0.getLoc(),"
|
" $0.getLoc(),"
|
||||||
" UnrankedTensorType::get("
|
" UnrankedTensorType::get("
|
||||||
" $1->getType().cast<TensorType>().getElementType()),"
|
" $1.getType().cast<TensorType>().getElementType()),"
|
||||||
" $2)"
|
" $2)"
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ llvm::Optional<IslandOp> GetOperandCandidateToMergeWith(IslandOp island) {
|
|||||||
|
|
||||||
// Check island control operands.
|
// Check island control operands.
|
||||||
for (Value input : island.controlInputs()) {
|
for (Value input : island.controlInputs()) {
|
||||||
Operation* def = input->getDefiningOp();
|
Operation* def = input.getDefiningOp();
|
||||||
DCHECK_EQ(def->getParentOp(), graph_op);
|
DCHECK_EQ(def->getParentOp(), graph_op);
|
||||||
if (!candidate || candidate->isBeforeInBlock(def)) candidate = def;
|
if (!candidate || candidate->isBeforeInBlock(def)) candidate = def;
|
||||||
}
|
}
|
||||||
@ -79,7 +79,7 @@ llvm::Optional<IslandOp> GetOperandCandidateToMergeWith(IslandOp island) {
|
|||||||
// Check island data operands.
|
// Check island data operands.
|
||||||
island.walk([graph_op, &candidate](Operation* op) {
|
island.walk([graph_op, &candidate](Operation* op) {
|
||||||
for (Value input : op->getOperands()) {
|
for (Value input : op->getOperands()) {
|
||||||
Operation* def = input->getDefiningOp();
|
Operation* def = input.getDefiningOp();
|
||||||
if (!def || def->getParentOp() != graph_op) continue;
|
if (!def || def->getParentOp() != graph_op) continue;
|
||||||
if (!candidate || candidate->isBeforeInBlock(def)) candidate = def;
|
if (!candidate || candidate->isBeforeInBlock(def)) candidate = def;
|
||||||
}
|
}
|
||||||
@ -99,7 +99,7 @@ llvm::Optional<IslandOp> GetResultCandidateToMergeWith(IslandOp island) {
|
|||||||
Operation* candidate = nullptr;
|
Operation* candidate = nullptr;
|
||||||
|
|
||||||
// Check island control results.
|
// Check island control results.
|
||||||
for (Operation* user : island.control()->getUsers()) {
|
for (Operation* user : island.control().getUsers()) {
|
||||||
DCHECK_EQ(user->getParentOp(), graph_op);
|
DCHECK_EQ(user->getParentOp(), graph_op);
|
||||||
if (!candidate || user->isBeforeInBlock(candidate)) candidate = user;
|
if (!candidate || user->isBeforeInBlock(candidate)) candidate = user;
|
||||||
}
|
}
|
||||||
@ -107,7 +107,7 @@ llvm::Optional<IslandOp> GetResultCandidateToMergeWith(IslandOp island) {
|
|||||||
// Check island data results.
|
// Check island data results.
|
||||||
Block& graph_body = llvm::cast<GraphOp>(graph_op).GetBody();
|
Block& graph_body = llvm::cast<GraphOp>(graph_op).GetBody();
|
||||||
for (Value result : island.outputs()) {
|
for (Value result : island.outputs()) {
|
||||||
for (Operation* user : result->getUsers()) {
|
for (Operation* user : result.getUsers()) {
|
||||||
Operation* def = graph_body.findAncestorOpInBlock(*user);
|
Operation* def = graph_body.findAncestorOpInBlock(*user);
|
||||||
DCHECK_NE(def, nullptr);
|
DCHECK_NE(def, nullptr);
|
||||||
if (!candidate || def->isBeforeInBlock(candidate)) candidate = def;
|
if (!candidate || def->isBeforeInBlock(candidate)) candidate = def;
|
||||||
@ -147,7 +147,7 @@ llvm::SmallVector<IslandResult, 8> GetNewIslandResultsAndForwardResults(
|
|||||||
bool result_captured = false;
|
bool result_captured = false;
|
||||||
Value inner_op_result = std::get<0>(ret_vals);
|
Value inner_op_result = std::get<0>(ret_vals);
|
||||||
Value island_result = std::get<1>(ret_vals);
|
Value island_result = std::get<1>(ret_vals);
|
||||||
for (auto& use : llvm::make_early_inc_range(island_result->getUses())) {
|
for (auto& use : llvm::make_early_inc_range(island_result.getUses())) {
|
||||||
if (child_body.findAncestorOpInBlock(*use.getOwner())) {
|
if (child_body.findAncestorOpInBlock(*use.getOwner())) {
|
||||||
// Forward result from inner op.
|
// Forward result from inner op.
|
||||||
use.set(inner_op_result);
|
use.set(inner_op_result);
|
||||||
@ -162,7 +162,7 @@ llvm::SmallVector<IslandResult, 8> GetNewIslandResultsAndForwardResults(
|
|||||||
llvm::zip(child.GetYield().getOperands(), child.outputs())) {
|
llvm::zip(child.GetYield().getOperands(), child.outputs())) {
|
||||||
Value inner_op_result = std::get<0>(ret_vals);
|
Value inner_op_result = std::get<0>(ret_vals);
|
||||||
Value island_result = std::get<1>(ret_vals);
|
Value island_result = std::get<1>(ret_vals);
|
||||||
if (!island_result->use_empty()) {
|
if (!island_result.use_empty()) {
|
||||||
results.emplace_back(inner_op_result, island_result);
|
results.emplace_back(inner_op_result, island_result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -178,7 +178,7 @@ IslandOp CreateNewIsland(IslandOp parent, IslandOp child,
|
|||||||
// Collect types from results.
|
// Collect types from results.
|
||||||
llvm::SmallVector<Type, 8> result_types;
|
llvm::SmallVector<Type, 8> result_types;
|
||||||
for (const auto& result : results)
|
for (const auto& result : results)
|
||||||
result_types.push_back(result.inner_op_result->getType());
|
result_types.push_back(result.inner_op_result.getType());
|
||||||
|
|
||||||
// IslandOps always have a control result.
|
// IslandOps always have a control result.
|
||||||
result_types.push_back(ControlType::get(parent.getContext()));
|
result_types.push_back(ControlType::get(parent.getContext()));
|
||||||
@ -201,7 +201,7 @@ YieldOp CreateNewIslandYieldOp(IslandOp new_island,
|
|||||||
const auto& old_result = std::get<0>(ret_vals);
|
const auto& old_result = std::get<0>(ret_vals);
|
||||||
|
|
||||||
// Replace original island result with new island result.
|
// Replace original island result with new island result.
|
||||||
old_result.island_result->replaceAllUsesWith(std::get<1>(ret_vals));
|
old_result.island_result.replaceAllUsesWith(std::get<1>(ret_vals));
|
||||||
|
|
||||||
// Add associated inner op result to operands of the YieldOp.
|
// Add associated inner op result to operands of the YieldOp.
|
||||||
yield_operands.push_back(old_result.inner_op_result);
|
yield_operands.push_back(old_result.inner_op_result);
|
||||||
@ -249,8 +249,8 @@ void MergeIslands(IslandOp parent, IslandOp child, IslandType insert_position) {
|
|||||||
MoveInnerOpsToNewIsland(parent, child, new_yield_op.getOperation());
|
MoveInnerOpsToNewIsland(parent, child, new_yield_op.getOperation());
|
||||||
|
|
||||||
// Update control inputs to point to the new merged island.
|
// Update control inputs to point to the new merged island.
|
||||||
child.control()->replaceAllUsesWith(new_island.control());
|
child.control().replaceAllUsesWith(new_island.control());
|
||||||
parent.control()->replaceAllUsesWith(new_island.control());
|
parent.control().replaceAllUsesWith(new_island.control());
|
||||||
|
|
||||||
// Remove merged islands.
|
// Remove merged islands.
|
||||||
child.erase();
|
child.erase();
|
||||||
@ -291,11 +291,11 @@ void InsertDummyIslandForFetch(FetchOp fetch) {
|
|||||||
llvm::SmallVector<Type, 4> data_types;
|
llvm::SmallVector<Type, 4> data_types;
|
||||||
llvm::SmallVector<Value, 4> control_fetches;
|
llvm::SmallVector<Value, 4> control_fetches;
|
||||||
for (auto value : fetch.fetches()) {
|
for (auto value : fetch.fetches()) {
|
||||||
if (value->getType().isa<ControlType>()) {
|
if (value.getType().isa<ControlType>()) {
|
||||||
control_fetches.push_back(value);
|
control_fetches.push_back(value);
|
||||||
} else {
|
} else {
|
||||||
data_fetches.push_back(value);
|
data_fetches.push_back(value);
|
||||||
data_types.push_back(value->getType());
|
data_types.push_back(value.getType());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto island = OpBuilder(fetch).create<IslandOp>(
|
auto island = OpBuilder(fetch).create<IslandOp>(
|
||||||
|
@ -66,12 +66,12 @@ class SwitchFoldPass : public mlir::FunctionPass<SwitchFoldPass> {
|
|||||||
|
|
||||||
// Returns the defining op for a value looking through islands.
|
// Returns the defining op for a value looking through islands.
|
||||||
static Operation* GetDefiningOp(Value val) {
|
static Operation* GetDefiningOp(Value val) {
|
||||||
Operation* op = val->getDefiningOp();
|
Operation* op = val.getDefiningOp();
|
||||||
auto island_op = dyn_cast<tf_executor::IslandOp>(op);
|
auto island_op = dyn_cast<tf_executor::IslandOp>(op);
|
||||||
if (!island_op) return op;
|
if (!island_op) return op;
|
||||||
auto yield_op = island_op.GetYield();
|
auto yield_op = island_op.GetYield();
|
||||||
auto index = val->cast<mlir::OpResult>()->getResultNumber();
|
auto index = val.cast<mlir::OpResult>().getResultNumber();
|
||||||
return yield_op.getOperand(index)->getDefiningOp();
|
return yield_op.getOperand(index).getDefiningOp();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns either the value or input to an IdentityOp.
|
// Returns either the value or input to an IdentityOp.
|
||||||
@ -114,7 +114,7 @@ class DeadQueue {
|
|||||||
// feeding into the Merge then we could have a null value here.
|
// feeding into the Merge then we could have a null value here.
|
||||||
count = 0;
|
count = 0;
|
||||||
for (auto operand : op->getOperands()) {
|
for (auto operand : op->getOperands()) {
|
||||||
if (operand && !operand->getType().isa<tf_executor::ControlType>())
|
if (operand && !operand.getType().isa<tf_executor::ControlType>())
|
||||||
++count;
|
++count;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -125,8 +125,8 @@ class DeadQueue {
|
|||||||
|
|
||||||
// Enqueue users of a value.
|
// Enqueue users of a value.
|
||||||
void EnqueueUsers(Value val) {
|
void EnqueueUsers(Value val) {
|
||||||
for (auto user : val->getUsers()) {
|
for (auto user : val.getUsers()) {
|
||||||
Enqueue(user, val->getType().isa<tf_executor::ControlType>());
|
Enqueue(user, val.getType().isa<tf_executor::ControlType>());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -189,7 +189,7 @@ static void MatchSwitchFoldOps(tf_executor::SwitchOp switch_op,
|
|||||||
bool taken = pred.getSplatValue<bool>();
|
bool taken = pred.getSplatValue<bool>();
|
||||||
Value dead = taken ? switch_op.falseOutput() : switch_op.trueOutput();
|
Value dead = taken ? switch_op.falseOutput() : switch_op.trueOutput();
|
||||||
Value live = !taken ? switch_op.falseOutput() : switch_op.trueOutput();
|
Value live = !taken ? switch_op.falseOutput() : switch_op.trueOutput();
|
||||||
live->replaceAllUsesWith(switch_op.data());
|
live.replaceAllUsesWith(switch_op.data());
|
||||||
queue->EnqueueUsers(dead);
|
queue->EnqueueUsers(dead);
|
||||||
|
|
||||||
// Delete switch op.
|
// Delete switch op.
|
||||||
@ -218,7 +218,7 @@ static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) {
|
|||||||
Value operand = e.value();
|
Value operand = e.value();
|
||||||
if (!operand) continue;
|
if (!operand) continue;
|
||||||
// Skip control operands.
|
// Skip control operands.
|
||||||
if (operand->getType().isa<tf_executor::ControlType>()) break;
|
if (operand.getType().isa<tf_executor::ControlType>()) break;
|
||||||
if (val != nullptr) {
|
if (val != nullptr) {
|
||||||
return merge->emitOpError("multiple valid inputs post switch folding");
|
return merge->emitOpError("multiple valid inputs post switch folding");
|
||||||
}
|
}
|
||||||
@ -226,26 +226,26 @@ static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) {
|
|||||||
index = e.index();
|
index = e.index();
|
||||||
}
|
}
|
||||||
assert(val != nullptr && "merge node should have been deleted");
|
assert(val != nullptr && "merge node should have been deleted");
|
||||||
merge_op.output()->replaceAllUsesWith(val);
|
merge_op.output().replaceAllUsesWith(val);
|
||||||
|
|
||||||
// Build and insert value_index only if needed.
|
// Build and insert value_index only if needed.
|
||||||
if (!merge_op.value_index()->use_empty()) {
|
if (!merge_op.value_index().use_empty()) {
|
||||||
merge_op.value_index()->replaceAllUsesWith(
|
merge_op.value_index().replaceAllUsesWith(
|
||||||
build_index(merge->getLoc(), index));
|
build_index(merge->getLoc(), index));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Propagate control dependencies if used.
|
// Propagate control dependencies if used.
|
||||||
if (!merge_op.control()->use_empty()) {
|
if (!merge_op.control().use_empty()) {
|
||||||
// Change control dependencies from the merge to being on the parent of
|
// Change control dependencies from the merge to being on the parent of
|
||||||
// the value being propagated.
|
// the value being propagated.
|
||||||
auto def_op = val->getDefiningOp();
|
auto def_op = val.getDefiningOp();
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
auto exec_dialect =
|
auto exec_dialect =
|
||||||
function.getContext()->getRegisteredDialect("tf_executor");
|
function.getContext()->getRegisteredDialect("tf_executor");
|
||||||
assert(def_op->getDialect() == exec_dialect &&
|
assert(def_op->getDialect() == exec_dialect &&
|
||||||
"unable to forward control dependencies");
|
"unable to forward control dependencies");
|
||||||
#endif
|
#endif
|
||||||
merge_op.control()->replaceAllUsesWith(
|
merge_op.control().replaceAllUsesWith(
|
||||||
def_op->getResult(def_op->getNumResults() - 1));
|
def_op->getResult(def_op->getNumResults() - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ static Value LowerCondition(Location loc, Value value, OpBuilder* builder) {
|
|||||||
// FIXME: This is almost all wrong, but is a placeholder to unblock the one
|
// FIXME: This is almost all wrong, but is a placeholder to unblock the one
|
||||||
// testcases, later patches will build on this once I build the right infra to
|
// testcases, later patches will build on this once I build the right infra to
|
||||||
// support it.
|
// support it.
|
||||||
TensorType type = value->getType().cast<TensorType>();
|
TensorType type = value.getType().cast<TensorType>();
|
||||||
if (!type.hasRank() || type.getRank() != 0 ||
|
if (!type.hasRank() || type.getRank() != 0 ||
|
||||||
!type.getElementType().isInteger(1)) {
|
!type.getElementType().isInteger(1)) {
|
||||||
return emitError(loc, "only supports zero-D bool tensors now"), nullptr;
|
return emitError(loc, "only supports zero-D bool tensors now"), nullptr;
|
||||||
@ -79,7 +79,7 @@ static Operation* CallFn(Location loc, const std::function<Value(int)>& get_arg,
|
|||||||
for (int i = 0; i < num_operands; ++i) {
|
for (int i = 0; i < num_operands; ++i) {
|
||||||
Value val = get_arg(i);
|
Value val = get_arg(i);
|
||||||
Type expected = fn_type.getInput(i);
|
Type expected = fn_type.getInput(i);
|
||||||
if (val->getType() != expected) {
|
if (val.getType() != expected) {
|
||||||
val =
|
val =
|
||||||
builder->create<TF::CastOp>(loc, expected, val,
|
builder->create<TF::CastOp>(loc, expected, val,
|
||||||
/*Truncate=*/builder->getBoolAttr(false));
|
/*Truncate=*/builder->getBoolAttr(false));
|
||||||
@ -102,8 +102,8 @@ static llvm::SmallVector<Value, 4> PrepareValsForJump(
|
|||||||
result.reserve(num_vals);
|
result.reserve(num_vals);
|
||||||
for (int i = 0; i < num_vals; ++i) {
|
for (int i = 0; i < num_vals; ++i) {
|
||||||
Value val = get_val(i);
|
Value val = get_val(i);
|
||||||
Type expected = block->getArgument(i)->getType();
|
Type expected = block->getArgument(i).getType();
|
||||||
if (val->getType() != expected) {
|
if (val.getType() != expected) {
|
||||||
val =
|
val =
|
||||||
builder->create<TF::CastOp>(loc, expected, val,
|
builder->create<TF::CastOp>(loc, expected, val,
|
||||||
/*Truncate=*/builder->getBoolAttr(false));
|
/*Truncate=*/builder->getBoolAttr(false));
|
||||||
@ -137,12 +137,12 @@ static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
|
|||||||
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
|
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
|
||||||
Value arg = block->getArgument(i);
|
Value arg = block->getArgument(i);
|
||||||
Value result = op->getResult(i);
|
Value result = op->getResult(i);
|
||||||
if (arg->getType() != result->getType()) {
|
if (arg.getType() != result.getType()) {
|
||||||
arg =
|
arg =
|
||||||
builder->create<TF::CastOp>(loc, result->getType(), arg,
|
builder->create<TF::CastOp>(loc, result.getType(), arg,
|
||||||
/*Truncate=*/builder->getBoolAttr(false));
|
/*Truncate=*/builder->getBoolAttr(false));
|
||||||
}
|
}
|
||||||
result->replaceAllUsesWith(arg);
|
result.replaceAllUsesWith(arg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -174,7 +174,7 @@ static LogicalResult LowerIfOp(IfOp op) {
|
|||||||
// Add the block arguments to the merge point, and replace all uses of the
|
// Add the block arguments to the merge point, and replace all uses of the
|
||||||
// original operation results with them.
|
// original operation results with them.
|
||||||
for (Value value : op_inst->getResults())
|
for (Value value : op_inst->getResults())
|
||||||
merge_block->addArgument(value->getType());
|
merge_block->addArgument(value.getType());
|
||||||
ReplaceOpResultWithBlockArgs(loc, op_inst, merge_block, &builder);
|
ReplaceOpResultWithBlockArgs(loc, op_inst, merge_block, &builder);
|
||||||
|
|
||||||
// Get arguments to the branches after dropping the condition which is the
|
// Get arguments to the branches after dropping the condition which is the
|
||||||
|
@ -39,7 +39,7 @@ void PruneGraph(GraphOp graph) {
|
|||||||
// Visit an op's operands if it is output of an Operation in same graph.
|
// Visit an op's operands if it is output of an Operation in same graph.
|
||||||
auto visit_op = [&](Operation* op) {
|
auto visit_op = [&](Operation* op) {
|
||||||
for (Value operand : op->getOperands()) {
|
for (Value operand : op->getOperands()) {
|
||||||
Operation* def = operand->getDefiningOp();
|
Operation* def = operand.getDefiningOp();
|
||||||
if (def && def->getParentOp() == graph &&
|
if (def && def->getParentOp() == graph &&
|
||||||
reachable_ops.insert(def).second) {
|
reachable_ops.insert(def).second) {
|
||||||
// Op has not been visited, add to queue to visit later.
|
// Op has not been visited, add to queue to visit later.
|
||||||
|
@ -55,7 +55,7 @@ void InlineGlobalTensorsPass::runOnModule() {
|
|||||||
// Replace the arg with a tf.Const op in the function body.
|
// Replace the arg with a tf.Const op in the function body.
|
||||||
auto const_op = builder.create<TF::ConstOp>(global_tensor.getLoc(),
|
auto const_op = builder.create<TF::ConstOp>(global_tensor.getLoc(),
|
||||||
global_tensor.value());
|
global_tensor.value());
|
||||||
func.getArgument(i)->replaceAllUsesWith(const_op.getResult());
|
func.getArgument(i).replaceAllUsesWith(const_op.getResult());
|
||||||
args_to_erase.push_back(i);
|
args_to_erase.push_back(i);
|
||||||
}
|
}
|
||||||
func.eraseArguments(args_to_erase);
|
func.eraseArguments(args_to_erase);
|
||||||
|
@ -196,7 +196,7 @@ class LowerDynamicStitchOp : public OpRewritePattern<TF::DynamicStitchOp> {
|
|||||||
if (!matchPattern(index, m_Constant(&index_attr))) return matchFailure();
|
if (!matchPattern(index, m_Constant(&index_attr))) return matchFailure();
|
||||||
indices.push_back(index_attr);
|
indices.push_back(index_attr);
|
||||||
|
|
||||||
RankedTensorType data_ty = data->getType().dyn_cast<RankedTensorType>();
|
RankedTensorType data_ty = data.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!data_ty || !data_ty.hasStaticShape()) return matchFailure();
|
if (!data_ty || !data_ty.hasStaticShape()) return matchFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -270,7 +270,7 @@ class LowerPackOp : public OpRewritePattern<TF::PackOp> {
|
|||||||
// If input type is different than the previous input type, infer the
|
// If input type is different than the previous input type, infer the
|
||||||
// output type. Otherwise, use the already inferred output type from the
|
// output type. Otherwise, use the already inferred output type from the
|
||||||
// previous iteration.
|
// previous iteration.
|
||||||
Type input_ty = input->getType();
|
Type input_ty = input.getType();
|
||||||
if (input_ty != prev_input_ty) {
|
if (input_ty != prev_input_ty) {
|
||||||
inferred_ty = InferExpandDimsType(input_ty, axis, &rewriter);
|
inferred_ty = InferExpandDimsType(input_ty, axis, &rewriter);
|
||||||
prev_input_ty = input_ty;
|
prev_input_ty = input_ty;
|
||||||
|
@ -37,7 +37,7 @@ class GetI64ScalarElementsAttr<int value> :
|
|||||||
|
|
||||||
def GetBiasAddGradReductionIndices : NativeCodeCall<
|
def GetBiasAddGradReductionIndices : NativeCodeCall<
|
||||||
"GetBiasAddGradReductionIndices("
|
"GetBiasAddGradReductionIndices("
|
||||||
"$0->getType().cast<RankedTensorType>().getRank(), $1, &$_builder)">;
|
"$0.getType().cast<RankedTensorType>().getRank(), $1, &$_builder)">;
|
||||||
|
|
||||||
def LowerBiasAddGradOp :
|
def LowerBiasAddGradOp :
|
||||||
Pat<(TF_BiasAddGradOp AnyRankedTensor:$out_backprop, $data_format),
|
Pat<(TF_BiasAddGradOp AnyRankedTensor:$out_backprop, $data_format),
|
||||||
@ -82,12 +82,12 @@ def LowerSoftmaxCrossEntropyWithLogitsOp : Pattern<
|
|||||||
// dimension should be known.
|
// dimension should be known.
|
||||||
class GetDimSizeOfType<int dim> : NativeCodeCall<
|
class GetDimSizeOfType<int dim> : NativeCodeCall<
|
||||||
"GetScalarOfType(getElementTypeOrSelf($1), "
|
"GetScalarOfType(getElementTypeOrSelf($1), "
|
||||||
"$0->getType().cast<RankedTensorType>().getDimSize(" # dim # "))">;
|
"$0.getType().cast<RankedTensorType>().getDimSize(" # dim # "))">;
|
||||||
|
|
||||||
// Same as the above with i32 element type.
|
// Same as the above with i32 element type.
|
||||||
class GetDimSizeAsI32<int dim> : NativeCodeCall<
|
class GetDimSizeAsI32<int dim> : NativeCodeCall<
|
||||||
"GetScalarOfType($_builder.getIntegerType(32), "
|
"GetScalarOfType($_builder.getIntegerType(32), "
|
||||||
"$0->getType().cast<RankedTensorType>().getDimSize(" # dim # "))">;
|
"$0.getType().cast<RankedTensorType>().getDimSize(" # dim # "))">;
|
||||||
|
|
||||||
// Sparse version of SoftmaxCrossEntropyWithLogits is lowered to dense by
|
// Sparse version of SoftmaxCrossEntropyWithLogits is lowered to dense by
|
||||||
// expanding the sparse labels using:
|
// expanding the sparse labels using:
|
||||||
@ -160,7 +160,7 @@ def LowerFillOp : Pat<(TF_FillOp $dims, $value),
|
|||||||
|
|
||||||
def GetAllAxes : NativeCodeCall<
|
def GetAllAxes : NativeCodeCall<
|
||||||
"GetI64ElementsAttrForSeq("
|
"GetI64ElementsAttrForSeq("
|
||||||
"0, $0->getType().cast<RankedTensorType>().getRank(), &$_builder)">;
|
"0, $0.getType().cast<RankedTensorType>().getRank(), &$_builder)">;
|
||||||
|
|
||||||
// L2Loss is lowered using the formula,
|
// L2Loss is lowered using the formula,
|
||||||
// L2Loss(input) = Sum(input * input) / 2
|
// L2Loss(input) = Sum(input * input) / 2
|
||||||
@ -220,7 +220,7 @@ def LowerTanhGradOp :
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def CreateTFShapeOp : NativeCodeCall<
|
def CreateTFShapeOp : NativeCodeCall<
|
||||||
"$_builder.create<TF::ShapeOp>($0->getLoc(), $1, $2)">;
|
"$_builder.create<TF::ShapeOp>($0.getLoc(), $1, $2)">;
|
||||||
|
|
||||||
// TODO(hinsu): Support inputs of TensorList types.
|
// TODO(hinsu): Support inputs of TensorList types.
|
||||||
def LowerZerosLikeOp :
|
def LowerZerosLikeOp :
|
||||||
|
@ -79,7 +79,7 @@ void MaterializePassthroughOpPass::runOnFunction() {
|
|||||||
Block &block = body.front();
|
Block &block = body.front();
|
||||||
for (const auto &arg_mapping :
|
for (const auto &arg_mapping :
|
||||||
llvm::zip(block.getArguments(), op->getOperands())) {
|
llvm::zip(block.getArguments(), op->getOperands())) {
|
||||||
std::get<0>(arg_mapping)->replaceAllUsesWith(std::get<1>(arg_mapping));
|
std::get<0>(arg_mapping).replaceAllUsesWith(std::get<1>(arg_mapping));
|
||||||
}
|
}
|
||||||
op->getBlock()->getOperations().splice(op->getIterator(),
|
op->getBlock()->getOperations().splice(op->getIterator(),
|
||||||
block.getOperations(), block.begin(),
|
block.getOperations(), block.begin(),
|
||||||
@ -87,7 +87,7 @@ void MaterializePassthroughOpPass::runOnFunction() {
|
|||||||
Operation &return_op = block.front();
|
Operation &return_op = block.front();
|
||||||
for (auto ret_mapping :
|
for (auto ret_mapping :
|
||||||
llvm::zip(op->getResults(), return_op.getOperands())) {
|
llvm::zip(op->getResults(), return_op.getOperands())) {
|
||||||
std::get<0>(ret_mapping)->replaceAllUsesWith(std::get<1>(ret_mapping));
|
std::get<0>(ret_mapping).replaceAllUsesWith(std::get<1>(ret_mapping));
|
||||||
}
|
}
|
||||||
op->erase();
|
op->erase();
|
||||||
});
|
});
|
||||||
|
@ -21,7 +21,7 @@ def BroadcastableElements :
|
|||||||
Constraint<CPred<"TFL::IsBroadcastableElementsAttrs($0, $1)">>;
|
Constraint<CPred<"TFL::IsBroadcastableElementsAttrs($0, $1)">>;
|
||||||
def F32ElementsAttr : ElementsAttrBase<
|
def F32ElementsAttr : ElementsAttrBase<
|
||||||
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
|
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
|
||||||
def DefinedByConv2D : Constraint<CPred<"llvm::isa_and_nonnull<mlir::TF::Conv2DOp>($0->getDefiningOp())">>;
|
def DefinedByConv2D : Constraint<CPred<"llvm::isa_and_nonnull<mlir::TF::Conv2DOp>($0.getDefiningOp())">>;
|
||||||
|
|
||||||
// If we see a Conv2D op followed by Mul, then multiply the filter
|
// If we see a Conv2D op followed by Mul, then multiply the filter
|
||||||
// with the value in Mul.
|
// with the value in Mul.
|
||||||
|
@ -54,7 +54,7 @@ bool IsReadOnlyVariableOp(Operation* op) { return isa<TF::ReadVariableOp>(op); }
|
|||||||
|
|
||||||
void RewriteReadOnlyVariableOpToTensorOp(Operation* op, Value tensor_value) {
|
void RewriteReadOnlyVariableOpToTensorOp(Operation* op, Value tensor_value) {
|
||||||
auto read_variable = cast<TF::ReadVariableOp>(op);
|
auto read_variable = cast<TF::ReadVariableOp>(op);
|
||||||
read_variable.value()->replaceAllUsesWith(tensor_value);
|
read_variable.value().replaceAllUsesWith(tensor_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsFreezable(GlobalTensorOp global_tensor,
|
bool IsFreezable(GlobalTensorOp global_tensor,
|
||||||
@ -74,7 +74,7 @@ bool IsFreezable(GlobalTensorOp global_tensor,
|
|||||||
// or control flow, we fail to prove it is freezable even though we could.
|
// or control flow, we fail to prove it is freezable even though we could.
|
||||||
for (auto& global_tensor_use : global_tensor_uses) {
|
for (auto& global_tensor_use : global_tensor_uses) {
|
||||||
auto arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index);
|
auto arg = global_tensor_use.func.getArgument(global_tensor_use.arg_index);
|
||||||
for (auto user : arg->getUsers()) {
|
for (auto user : arg.getUsers()) {
|
||||||
if (!IsReadOnlyVariableOp(user)) {
|
if (!IsReadOnlyVariableOp(user)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -130,12 +130,12 @@ void FreezeGlobalTensors(ModuleOp module,
|
|||||||
auto func = global_tensor_use.func;
|
auto func = global_tensor_use.func;
|
||||||
auto arg_index = global_tensor_use.arg_index;
|
auto arg_index = global_tensor_use.arg_index;
|
||||||
Value arg = func.getArgument(arg_index);
|
Value arg = func.getArgument(arg_index);
|
||||||
for (Operation* user : llvm::make_early_inc_range(arg->getUsers())) {
|
for (Operation* user : llvm::make_early_inc_range(arg.getUsers())) {
|
||||||
RewriteReadOnlyVariableOpToTensorOp(user, arg);
|
RewriteReadOnlyVariableOpToTensorOp(user, arg);
|
||||||
user->erase();
|
user->erase();
|
||||||
}
|
}
|
||||||
Type new_type = global_tensor.value().Attribute::getType();
|
Type new_type = global_tensor.value().Attribute::getType();
|
||||||
arg->setType(new_type);
|
arg.setType(new_type);
|
||||||
auto old_ftype = func.getType();
|
auto old_ftype = func.getType();
|
||||||
auto input_types = old_ftype.getInputs().vec();
|
auto input_types = old_ftype.getInputs().vec();
|
||||||
input_types[arg_index] = new_type;
|
input_types[arg_index] = new_type;
|
||||||
@ -168,7 +168,7 @@ void EraseUnusedBoundInputs(ModuleOp module) {
|
|||||||
SmallVector<unsigned, 4> args_to_erase;
|
SmallVector<unsigned, 4> args_to_erase;
|
||||||
for (int i = 0, e = func.getNumArguments(); i < e; i++) {
|
for (int i = 0, e = func.getNumArguments(); i < e; i++) {
|
||||||
if (func.getArgAttr(i, "tf_saved_model.bound_input") &&
|
if (func.getArgAttr(i, "tf_saved_model.bound_input") &&
|
||||||
func.getArgument(i)->use_empty()) {
|
func.getArgument(i).use_empty()) {
|
||||||
args_to_erase.push_back(i);
|
args_to_erase.push_back(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -100,7 +100,7 @@ void RaiseTFControlFlow::rewriteOps() {
|
|||||||
// aren't necessary any more since the order within a block encodes the
|
// aren't necessary any more since the order within a block encodes the
|
||||||
// same information.
|
// same information.
|
||||||
for (auto &operand : op.getOpOperands()) {
|
for (auto &operand : op.getOpOperands()) {
|
||||||
if (!operand.get()->getType().isa<TFControlType>())
|
if (!operand.get().getType().isa<TFControlType>())
|
||||||
result.operands.push_back(operand.get());
|
result.operands.push_back(operand.get());
|
||||||
|
|
||||||
// Drop all operands from the old operation, eliminating any
|
// Drop all operands from the old operation, eliminating any
|
||||||
@ -111,13 +111,13 @@ void RaiseTFControlFlow::rewriteOps() {
|
|||||||
// Add a result type for each non-control result we find.
|
// Add a result type for each non-control result we find.
|
||||||
bool sawControlResult = false;
|
bool sawControlResult = false;
|
||||||
for (auto opResult : op.getResults()) {
|
for (auto opResult : op.getResults()) {
|
||||||
if (opResult->getType().isa<TFControlType>()) {
|
if (opResult.getType().isa<TFControlType>()) {
|
||||||
sawControlResult = true;
|
sawControlResult = true;
|
||||||
} else {
|
} else {
|
||||||
// We assume all control inputs are at the end of the result list.
|
// We assume all control inputs are at the end of the result list.
|
||||||
assert(!sawControlResult && "all control results must be last");
|
assert(!sawControlResult && "all control results must be last");
|
||||||
(void)sawControlResult;
|
(void)sawControlResult;
|
||||||
result.types.push_back(opResult->getType());
|
result.types.push_back(opResult.getType());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,7 +129,7 @@ void RaiseTFControlFlow::rewriteOps() {
|
|||||||
// We know that all the control results are last, so we can just rewrite
|
// We know that all the control results are last, so we can just rewrite
|
||||||
// the first results.
|
// the first results.
|
||||||
for (unsigned i = 0, e = result.types.size(); i != e; ++i)
|
for (unsigned i = 0, e = result.types.size(); i != e; ++i)
|
||||||
op.getResult(i)->replaceAllUsesWith(replacement->getResult(i));
|
op.getResult(i).replaceAllUsesWith(replacement->getResult(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,16 +74,16 @@ void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas,
|
|||||||
Value input = shape_op.input();
|
Value input = shape_op.input();
|
||||||
// If ShapeOp operand is replicate tensor block argument, replace with the
|
// If ShapeOp operand is replicate tensor block argument, replace with the
|
||||||
// associated first replica operand.
|
// associated first replica operand.
|
||||||
if (auto block_arg = input->dyn_cast<BlockArgument>()) {
|
if (auto block_arg = input.dyn_cast<BlockArgument>()) {
|
||||||
if (block_arg->getOwner() != replicate_block) return;
|
if (block_arg.getOwner() != replicate_block) return;
|
||||||
|
|
||||||
shape_op.setOperand(
|
shape_op.setOperand(
|
||||||
replicate_op.getOperand(num_replicas * block_arg->getArgNumber()));
|
replicate_op.getOperand(num_replicas * block_arg.getArgNumber()));
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operation* input_def = input->getDefiningOp();
|
Operation* input_def = input.getDefiningOp();
|
||||||
|
|
||||||
// If ShapeOp operand is a ReadVariableOp result where the ReadVariableOp
|
// If ShapeOp operand is a ReadVariableOp result where the ReadVariableOp
|
||||||
// operand is a replicate resource block argument, replace ShapeOp with
|
// operand is a replicate resource block argument, replace ShapeOp with
|
||||||
@ -96,13 +96,13 @@ void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas,
|
|||||||
// shape has not changed in replicate prior to read. Currently after both
|
// shape has not changed in replicate prior to read. Currently after both
|
||||||
// ResourceOpLiftingPass and TPURewritePass, there should not be any updates
|
// ResourceOpLiftingPass and TPURewritePass, there should not be any updates
|
||||||
// to resources prior to their respective ReadVariableOp.
|
// to resources prior to their respective ReadVariableOp.
|
||||||
if (auto block_arg = read_var_op.resource()->dyn_cast<BlockArgument>()) {
|
if (auto block_arg = read_var_op.resource().dyn_cast<BlockArgument>()) {
|
||||||
if (block_arg->getOwner() != replicate_block) return;
|
if (block_arg.getOwner() != replicate_block) return;
|
||||||
|
|
||||||
OpBuilder builder(shape_op);
|
OpBuilder builder(shape_op);
|
||||||
auto new_shape_op = builder.create<TF::VariableShapeOp>(
|
auto new_shape_op = builder.create<TF::VariableShapeOp>(
|
||||||
shape_op.getLoc(), shape_op.getType(),
|
shape_op.getLoc(), shape_op.getType(),
|
||||||
replicate_op.getOperand(num_replicas * block_arg->getArgNumber()));
|
replicate_op.getOperand(num_replicas * block_arg.getArgNumber()));
|
||||||
shape_op.replaceAllUsesWith(new_shape_op.getOperation());
|
shape_op.replaceAllUsesWith(new_shape_op.getOperation());
|
||||||
shape_op.erase();
|
shape_op.erase();
|
||||||
}
|
}
|
||||||
@ -112,7 +112,7 @@ void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas,
|
|||||||
bool IsOpReplicateInvariant(Region* replicate_region, Operation* op) {
|
bool IsOpReplicateInvariant(Region* replicate_region, Operation* op) {
|
||||||
auto result = op->walk([&](Operation* inner_op) {
|
auto result = op->walk([&](Operation* inner_op) {
|
||||||
for (Value operand : inner_op->getOperands()) {
|
for (Value operand : inner_op->getOperands()) {
|
||||||
Region* parent_region = operand->getParentRegion();
|
Region* parent_region = operand.getParentRegion();
|
||||||
if (!parent_region || !parent_region->isProperAncestor(replicate_region))
|
if (!parent_region || !parent_region->isProperAncestor(replicate_region))
|
||||||
return WalkResult::interrupt();
|
return WalkResult::interrupt();
|
||||||
}
|
}
|
||||||
|
@ -83,7 +83,7 @@ llvm::SmallVector<tf_executor::IslandOp, 8> ExpandReplicateIntoReplicas(
|
|||||||
mapping.clear();
|
mapping.clear();
|
||||||
for (auto& block_arg : replicate_op.GetBody().getArguments())
|
for (auto& block_arg : replicate_op.GetBody().getArguments())
|
||||||
mapping.map(block_arg, replicate_op.getOperand(
|
mapping.map(block_arg, replicate_op.getOperand(
|
||||||
block_arg->getArgNumber() * num_replicas + i));
|
block_arg.getArgNumber() * num_replicas + i));
|
||||||
|
|
||||||
// Copy over replicate region into replica island.
|
// Copy over replicate region into replica island.
|
||||||
replicate_op.body().cloneInto(&replica.body(), mapping);
|
replicate_op.body().cloneInto(&replica.body(), mapping);
|
||||||
|
@ -127,16 +127,16 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
|
|||||||
OpBuilder builder(func_op);
|
OpBuilder builder(func_op);
|
||||||
// Function arguments.
|
// Function arguments.
|
||||||
for (auto arg : func_op.getArguments()) {
|
for (auto arg : func_op.getArguments()) {
|
||||||
if (!mlir::getElementTypeOrSelf(arg->getType()).isa<TF::ResourceType>()) {
|
if (!mlir::getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto device_attr = func_op.getArgAttrOfType<mlir::StringAttr>(
|
auto device_attr = func_op.getArgAttrOfType<mlir::StringAttr>(
|
||||||
arg->getArgNumber(), kFuncDeviceAttr);
|
arg.getArgNumber(), kFuncDeviceAttr);
|
||||||
if (!device_attr || device_attr.getValue() == "") {
|
if (!device_attr || device_attr.getValue() == "") {
|
||||||
// If device_attr does not exist, try to construct it from any recorded
|
// If device_attr does not exist, try to construct it from any recorded
|
||||||
// assignment.
|
// assignment.
|
||||||
if (auto device = result->DeviceForResource(arg)) {
|
if (auto device = result->DeviceForResource(arg)) {
|
||||||
func_op.setArgAttr(arg->getArgNumber(), kFuncDeviceAttr,
|
func_op.setArgAttr(arg.getArgNumber(), kFuncDeviceAttr,
|
||||||
builder.getStringAttr(*device));
|
builder.getStringAttr(*device));
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
@ -160,7 +160,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
|
|||||||
}
|
}
|
||||||
if (auto identity = llvm::dyn_cast<TF::IdentityOp>(op)) {
|
if (auto identity = llvm::dyn_cast<TF::IdentityOp>(op)) {
|
||||||
// Try to construct IdentityOp's attribute from recorded assignment.
|
// Try to construct IdentityOp's attribute from recorded assignment.
|
||||||
if (!mlir::getElementTypeOrSelf(identity.output()->getType())
|
if (!mlir::getElementTypeOrSelf(identity.output().getType())
|
||||||
.isa<TF::ResourceType>()) {
|
.isa<TF::ResourceType>()) {
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
}
|
}
|
||||||
@ -176,7 +176,7 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op,
|
|||||||
// Propagate and record output device assignment for other ops based on
|
// Propagate and record output device assignment for other ops based on
|
||||||
// existing recording. E.g., IdentityN.
|
// existing recording. E.g., IdentityN.
|
||||||
for (auto output : op->getResults()) {
|
for (auto output : op->getResults()) {
|
||||||
if (!mlir::getElementTypeOrSelf(output->getType())
|
if (!mlir::getElementTypeOrSelf(output.getType())
|
||||||
.isa<TF::ResourceType>()) {
|
.isa<TF::ResourceType>()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -212,7 +212,7 @@ void ResourceDeviceInference::runOnModule() {
|
|||||||
for (auto operand_and_argument :
|
for (auto operand_and_argument :
|
||||||
llvm::zip(caller_operands, callee.getArguments())) {
|
llvm::zip(caller_operands, callee.getArguments())) {
|
||||||
if (!mlir::getElementTypeOrSelf(
|
if (!mlir::getElementTypeOrSelf(
|
||||||
std::get<0>(operand_and_argument)->getType())
|
std::get<0>(operand_and_argument).getType())
|
||||||
.isa<TF::ResourceType>()) {
|
.isa<TF::ResourceType>()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -100,7 +100,7 @@ void ForwardStoreToLoad(tf_device::LaunchOp launch_op) {
|
|||||||
|
|
||||||
// Use stored value in last_store to replace all uses of current resource
|
// Use stored value in last_store to replace all uses of current resource
|
||||||
// load's result, then erase this resource load.
|
// load's result, then erase this resource load.
|
||||||
read_variable_op.value()->replaceAllUsesWith(last_store.value());
|
read_variable_op.value().replaceAllUsesWith(last_store.value());
|
||||||
read_variable_op.erase();
|
read_variable_op.erase();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -130,7 +130,7 @@ void HoistResourceLoads(tf_device::LaunchOp launch_op) {
|
|||||||
Value resource = read_variable_op.resource();
|
Value resource = read_variable_op.resource();
|
||||||
|
|
||||||
// Skip resources created inside of launch_op.
|
// Skip resources created inside of launch_op.
|
||||||
if (resource->getParentRegion() == &launch_op.body()) continue;
|
if (resource.getParentRegion() == &launch_op.body()) continue;
|
||||||
|
|
||||||
auto p = resource_to_read_ops.insert({resource, read_variable_op});
|
auto p = resource_to_read_ops.insert({resource, read_variable_op});
|
||||||
if (p.second) {
|
if (p.second) {
|
||||||
@ -167,7 +167,7 @@ bool AppendResourceStoreValueToReturn(tf_device::LaunchOp launch_op) {
|
|||||||
if (!resource) continue;
|
if (!resource) continue;
|
||||||
|
|
||||||
// Skip resources created inside of launch_op.
|
// Skip resources created inside of launch_op.
|
||||||
if (resource->getParentRegion() == &launch_op.body()) continue;
|
if (resource.getParentRegion() == &launch_op.body()) continue;
|
||||||
|
|
||||||
// TODO(ycao): Prevent same value from being returned multiple times.
|
// TODO(ycao): Prevent same value from being returned multiple times.
|
||||||
// TODO(ycao): Do not return resource store value if it is defined outside
|
// TODO(ycao): Do not return resource store value if it is defined outside
|
||||||
@ -207,7 +207,7 @@ void SinkResourceStores(tf_device::LaunchOp launch_op, OpBuilder* builder) {
|
|||||||
|
|
||||||
// Replace uses of old launch_op results with those of new_launch_op.
|
// Replace uses of old launch_op results with those of new_launch_op.
|
||||||
for (auto p : llvm::zip(launch_op.getResults(), new_launch_op.getResults())) {
|
for (auto p : llvm::zip(launch_op.getResults(), new_launch_op.getResults())) {
|
||||||
std::get<0>(p)->replaceAllUsesWith(std::get<1>(p));
|
std::get<0>(p).replaceAllUsesWith(std::get<1>(p));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a mapping from operands of new_return_op operands to new_launch_op
|
// Create a mapping from operands of new_return_op operands to new_launch_op
|
||||||
|
@ -68,24 +68,24 @@ Optional<llvm::SmallVector<mlir::Type, 4>> InferShapeForFunctionReturnType(
|
|||||||
// Manually fold tf.Cast that precedes the return instruction and only differs
|
// Manually fold tf.Cast that precedes the return instruction and only differs
|
||||||
// in shape refinement level.
|
// in shape refinement level.
|
||||||
for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) {
|
for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) {
|
||||||
Operation* arg_defining_op = arg_op.get()->getDefiningOp();
|
Operation* arg_defining_op = arg_op.get().getDefiningOp();
|
||||||
if (auto cast_op = dyn_cast_or_null<CastOp>(arg_defining_op)) {
|
if (auto cast_op = dyn_cast_or_null<CastOp>(arg_defining_op)) {
|
||||||
// Shape inference should not change the element type.
|
// Shape inference should not change the element type.
|
||||||
if (cast_op.SrcT() != cast_op.DstT()) continue;
|
if (cast_op.SrcT() != cast_op.DstT()) continue;
|
||||||
// We only refine the result shape if the result a dynamic shape, the
|
// We only refine the result shape if the result a dynamic shape, the
|
||||||
// input has static shape, and the two shapes are compatible.
|
// input has static shape, and the two shapes are compatible.
|
||||||
auto has_static_shape = [](const Value value) {
|
auto has_static_shape = [](const Value value) {
|
||||||
auto shaped_type = value->getType().dyn_cast<ShapedType>();
|
auto shaped_type = value.getType().dyn_cast<ShapedType>();
|
||||||
return shaped_type && shaped_type.hasStaticShape();
|
return shaped_type && shaped_type.hasStaticShape();
|
||||||
};
|
};
|
||||||
Value input = cast_op.x();
|
Value input = cast_op.x();
|
||||||
Value result = cast_op.y();
|
Value result = cast_op.y();
|
||||||
if (!has_static_shape(input) || has_static_shape(result) ||
|
if (!has_static_shape(input) || has_static_shape(result) ||
|
||||||
failed(verifyCompatibleShape(input->getType(), result->getType())))
|
failed(verifyCompatibleShape(input.getType(), result.getType())))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
arg_op.set(cast_op.x());
|
arg_op.set(cast_op.x());
|
||||||
if (cast_op.y()->use_empty()) cast_op.erase();
|
if (cast_op.y().use_empty()) cast_op.erase();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,7 +111,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
|||||||
// This is necessary to avoid reprocessing the tf.Cast that are inserted at
|
// This is necessary to avoid reprocessing the tf.Cast that are inserted at
|
||||||
// the end of this function.
|
// the end of this function.
|
||||||
if (isa<CastOp>(op) &&
|
if (isa<CastOp>(op) &&
|
||||||
llvm::all_of(op->getResult(0)->getUsers(), [&](Operation* user) {
|
llvm::all_of(op->getResult(0).getUsers(), [&](Operation* user) {
|
||||||
return user->getDialect() != tf_dialect;
|
return user->getDialect() != tf_dialect;
|
||||||
})) {
|
})) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for tf.Cast with no TF "
|
LLVM_DEBUG(llvm::dbgs() << "Skipping inference for tf.Cast with no TF "
|
||||||
@ -178,7 +178,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Type operand_type = operand->getType();
|
Type operand_type = operand.getType();
|
||||||
if (auto ranked_type = operand_type.dyn_cast<RankedTensorType>()) {
|
if (auto ranked_type = operand_type.dyn_cast<RankedTensorType>()) {
|
||||||
// Convert the MLIR shape indices (int64_t) to TensorFlow indices (int64).
|
// Convert the MLIR shape indices (int64_t) to TensorFlow indices (int64).
|
||||||
ArrayRef<int64_t> shape = ranked_type.getShape();
|
ArrayRef<int64_t> shape = ranked_type.getShape();
|
||||||
@ -215,7 +215,7 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
|||||||
for (int output : llvm::seq<int>(0, c.num_outputs())) {
|
for (int output : llvm::seq<int>(0, c.num_outputs())) {
|
||||||
// Skip already statically shaped results.
|
// Skip already statically shaped results.
|
||||||
Value result = op->getResult(output);
|
Value result = op->getResult(output);
|
||||||
auto shaped_type = result->getType().dyn_cast<ShapedType>();
|
auto shaped_type = result.getType().dyn_cast<ShapedType>();
|
||||||
if (!shaped_type || shaped_type.hasStaticShape()) continue;
|
if (!shaped_type || shaped_type.hasStaticShape()) continue;
|
||||||
|
|
||||||
tensorflow::shape_inference::ShapeHandle shape_handle = c.output(output);
|
tensorflow::shape_inference::ShapeHandle shape_handle = c.output(output);
|
||||||
@ -235,18 +235,18 @@ bool InferShapeForSingleOperation(Operation* op, Dialect* tf_dialect,
|
|||||||
auto get_cast_op = [&]() {
|
auto get_cast_op = [&]() {
|
||||||
if (!cast_op)
|
if (!cast_op)
|
||||||
cast_op =
|
cast_op =
|
||||||
builder.create<TF::CastOp>(op->getLoc(), result->getType(), result,
|
builder.create<TF::CastOp>(op->getLoc(), result.getType(), result,
|
||||||
/*truncate=*/builder.getBoolAttr(false));
|
/*truncate=*/builder.getBoolAttr(false));
|
||||||
return cast_op;
|
return cast_op;
|
||||||
};
|
};
|
||||||
for (OpOperand& use : llvm::make_early_inc_range(result->getUses())) {
|
for (OpOperand& use : llvm::make_early_inc_range(result.getUses())) {
|
||||||
if (use.getOwner()->getDialect() != tf_dialect) use.set(get_cast_op());
|
if (use.getOwner()->getDialect() != tf_dialect) use.set(get_cast_op());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (result->getType() == new_type) continue;
|
if (result.getType() == new_type) continue;
|
||||||
|
|
||||||
// Finally we inferred the shape and replace the type for this result.
|
// Finally we inferred the shape and replace the type for this result.
|
||||||
result->setType(new_type);
|
result.setType(new_type);
|
||||||
changed = true;
|
changed = true;
|
||||||
}
|
}
|
||||||
if (changed)
|
if (changed)
|
||||||
@ -284,7 +284,7 @@ LogicalResult RefineShapeForControlFlowFunc(FuncOp func,
|
|||||||
func.getContext()));
|
func.getContext()));
|
||||||
|
|
||||||
for (auto arg_and_idx : llvm::enumerate(func.getArguments())) {
|
for (auto arg_and_idx : llvm::enumerate(func.getArguments())) {
|
||||||
arg_and_idx.value()->setType(input_types[arg_and_idx.index()]);
|
arg_and_idx.value().setType(input_types[arg_and_idx.index()]);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto res =
|
auto res =
|
||||||
@ -307,7 +307,7 @@ LogicalResult PropagateShapeToIfWhileOpFunctions(
|
|||||||
llvm::SmallVector<Type, 4> input_types;
|
llvm::SmallVector<Type, 4> input_types;
|
||||||
input_types.reserve(std::distance(op.input().begin(), op.input().end()));
|
input_types.reserve(std::distance(op.input().begin(), op.input().end()));
|
||||||
for (Value v : op.input()) {
|
for (Value v : op.input()) {
|
||||||
input_types.push_back(v->getType());
|
input_types.push_back(v.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
ModuleOp module = op.template getParentOfType<ModuleOp>();
|
ModuleOp module = op.template getParentOfType<ModuleOp>();
|
||||||
@ -414,7 +414,7 @@ LogicalResult InferShapeForFunction(FuncOp func,
|
|||||||
auto new_arg_type = mlir::RankedTensorType::get(shape, element_type);
|
auto new_arg_type = mlir::RankedTensorType::get(shape, element_type);
|
||||||
if (new_arg_type != func_type.getInput(i)) {
|
if (new_arg_type != func_type.getInput(i)) {
|
||||||
// If the new type is more detailed, trigger shape inference.
|
// If the new type is more detailed, trigger shape inference.
|
||||||
func.getArgument(i)->setType(new_arg_type);
|
func.getArgument(i).setType(new_arg_type);
|
||||||
needs_refinement = true;
|
needs_refinement = true;
|
||||||
}
|
}
|
||||||
new_arg_types.push_back(new_arg_type);
|
new_arg_types.push_back(new_arg_type);
|
||||||
|
@ -52,8 +52,7 @@ class ExecutorConstantSinking
|
|||||||
Region &body = launch.body();
|
Region &body = launch.body();
|
||||||
visitUsedValuesDefinedAbove(body, [&](OpOperand *use) {
|
visitUsedValuesDefinedAbove(body, [&](OpOperand *use) {
|
||||||
Value constant = use->get();
|
Value constant = use->get();
|
||||||
auto const_op =
|
auto const_op = dyn_cast_or_null<TF::ConstOp>(constant.getDefiningOp());
|
||||||
dyn_cast_or_null<TF::ConstOp>(constant->getDefiningOp());
|
|
||||||
if (!const_op) return;
|
if (!const_op) return;
|
||||||
|
|
||||||
// We found a constant, try to insert it in the map and re-use its
|
// We found a constant, try to insert it in the map and re-use its
|
||||||
@ -62,13 +61,13 @@ class ExecutorConstantSinking
|
|||||||
if (!map_entry.second) {
|
if (!map_entry.second) {
|
||||||
// This constant has already been cloned into the region, reuse it.
|
// This constant has already been cloned into the region, reuse it.
|
||||||
use->set(map_entry.first->getSecond().getResult());
|
use->set(map_entry.first->getSecond().getResult());
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Re-use sunk constant " << *use->get()
|
LLVM_DEBUG(llvm::dbgs() << "Re-use sunk constant " << use->get()
|
||||||
<< "\n in " << *use->get() << "\n");
|
<< "\n in " << use->get() << "\n");
|
||||||
if (constant->use_empty()) const_op.erase();
|
if (constant.use_empty()) const_op.erase();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (constant->hasOneUse()) {
|
if (constant.hasOneUse()) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Moved constant " << *constant << "\n");
|
LLVM_DEBUG(llvm::dbgs() << "Moved constant " << constant << "\n");
|
||||||
const_op.getOperation()->moveBefore(&body.begin()->front());
|
const_op.getOperation()->moveBefore(&body.begin()->front());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -76,8 +75,8 @@ class ExecutorConstantSinking
|
|||||||
body.begin()->getOperations().insert(body.begin()->begin(),
|
body.begin()->getOperations().insert(body.begin()->begin(),
|
||||||
map_entry.first->getSecond());
|
map_entry.first->getSecond());
|
||||||
use->set(map_entry.first->getSecond().getResult());
|
use->set(map_entry.first->getSecond().getResult());
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Sunk cloned constant " << *use->get()
|
LLVM_DEBUG(llvm::dbgs() << "Sunk cloned constant " << use->get()
|
||||||
<< "\n in " << *use->get() << "\n");
|
<< "\n in " << use->get() << "\n");
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -141,7 +141,7 @@ bool ShouldMoveOpAfterCluster(
|
|||||||
const llvm::SmallSetVector<Operation*, 8>& preceding_users) {
|
const llvm::SmallSetVector<Operation*, 8>& preceding_users) {
|
||||||
auto result = op->walk([&](Operation* op) {
|
auto result = op->walk([&](Operation* op) {
|
||||||
for (Value operand : op->getOperands()) {
|
for (Value operand : op->getOperands()) {
|
||||||
Operation* def = operand->getDefiningOp();
|
Operation* def = operand.getDefiningOp();
|
||||||
// Operands may not have a defining op (BlockArgument) or is from a
|
// Operands may not have a defining op (BlockArgument) or is from a
|
||||||
// different block.
|
// different block.
|
||||||
if (!def || def->getBlock() != block) continue;
|
if (!def || def->getBlock() != block) continue;
|
||||||
@ -185,7 +185,7 @@ llvm::SmallVector<Value, 8> CollectClusterResults(
|
|||||||
|
|
||||||
for (Operation* op : cluster_ops) {
|
for (Operation* op : cluster_ops) {
|
||||||
for (Value result : op->getResults()) {
|
for (Value result : op->getResults()) {
|
||||||
for (Operation* user : result->getUsers()) {
|
for (Operation* user : result.getUsers()) {
|
||||||
// Check if user is not an op in the cluster.
|
// Check if user is not an op in the cluster.
|
||||||
if (cluster_ops.count(block->findAncestorOpInBlock(*user)) == 0) {
|
if (cluster_ops.count(block->findAncestorOpInBlock(*user)) == 0) {
|
||||||
results.push_back(result);
|
results.push_back(result);
|
||||||
@ -206,7 +206,7 @@ tf_device::LaunchOp CreateLaunchOpForCluster(Operation* last_cluster_op,
|
|||||||
OpBuilder builder(last_cluster_op);
|
OpBuilder builder(last_cluster_op);
|
||||||
|
|
||||||
llvm::SmallVector<Type, 8> result_types;
|
llvm::SmallVector<Type, 8> result_types;
|
||||||
for (Value result : results) result_types.push_back(result->getType());
|
for (Value result : results) result_types.push_back(result.getType());
|
||||||
|
|
||||||
// An empty string placeholder is used for the device as that will be later
|
// An empty string placeholder is used for the device as that will be later
|
||||||
// populated with the device of the associated TPUReplicateMetadata op.
|
// populated with the device of the associated TPUReplicateMetadata op.
|
||||||
@ -246,7 +246,7 @@ void UpdateLaunchOpResultExternalUses(tf_device::LaunchOp launch_op,
|
|||||||
for (auto ret_vals : llvm::zip(results, launch_op.getResults())) {
|
for (auto ret_vals : llvm::zip(results, launch_op.getResults())) {
|
||||||
Value old_ret = std::get<0>(ret_vals);
|
Value old_ret = std::get<0>(ret_vals);
|
||||||
Value new_ret = std::get<1>(ret_vals);
|
Value new_ret = std::get<1>(ret_vals);
|
||||||
for (auto& use : old_ret->getUses())
|
for (auto& use : old_ret.getUses())
|
||||||
if (!launch_op_block.findAncestorOpInBlock(*use.getOwner()))
|
if (!launch_op_block.findAncestorOpInBlock(*use.getOwner()))
|
||||||
use.set(new_ret);
|
use.set(new_ret);
|
||||||
}
|
}
|
||||||
@ -307,7 +307,7 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op,
|
|||||||
llvm::SmallSetVector<Operation*, 8> unique_replicated_input_ops;
|
llvm::SmallSetVector<Operation*, 8> unique_replicated_input_ops;
|
||||||
mlir::visitUsedValuesDefinedAbove(
|
mlir::visitUsedValuesDefinedAbove(
|
||||||
launch_op.body(), launch_op.body(), [&](mlir::OpOperand* operand) {
|
launch_op.body(), launch_op.body(), [&](mlir::OpOperand* operand) {
|
||||||
Operation* def = operand->get()->getDefiningOp();
|
Operation* def = operand->get().getDefiningOp();
|
||||||
if (def && llvm::isa<TF::TPUReplicatedInputOp>(def))
|
if (def && llvm::isa<TF::TPUReplicatedInputOp>(def))
|
||||||
unique_replicated_input_ops.insert(def);
|
unique_replicated_input_ops.insert(def);
|
||||||
});
|
});
|
||||||
@ -339,7 +339,7 @@ LogicalResult ReplicateCluster(tf_device::LaunchOp launch_op,
|
|||||||
for (auto result_and_idx : llvm::enumerate(launch_op.getResults())) {
|
for (auto result_and_idx : llvm::enumerate(launch_op.getResults())) {
|
||||||
Value result = result_and_idx.value();
|
Value result = result_and_idx.value();
|
||||||
int idx = result_and_idx.index();
|
int idx = result_and_idx.index();
|
||||||
for (auto& use : result->getUses()) {
|
for (auto& use : result.getUses()) {
|
||||||
Operation* def = use.getOwner();
|
Operation* def = use.getOwner();
|
||||||
if (!def || !llvm::isa<TF::TPUReplicatedOutputOp>(def))
|
if (!def || !llvm::isa<TF::TPUReplicatedOutputOp>(def))
|
||||||
return launch_op.emitError()
|
return launch_op.emitError()
|
||||||
@ -470,7 +470,7 @@ void TPUClusterFormation::runOnFunction() {
|
|||||||
// `tf_device.replicate` is created and replicated (1) operands/results are
|
// `tf_device.replicate` is created and replicated (1) operands/results are
|
||||||
// untouched.
|
// untouched.
|
||||||
if (op->getNumOperands() == 1 && op->getNumResults() == 1)
|
if (op->getNumOperands() == 1 && op->getNumResults() == 1)
|
||||||
op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
|
op->getResult(0).replaceAllUsesWith(op->getOperand(0));
|
||||||
|
|
||||||
// Leftover TPUReplicatedInput/TPUReplicatedOutput that are not of
|
// Leftover TPUReplicatedInput/TPUReplicatedOutput that are not of
|
||||||
// `num_replicas` to 1.
|
// `num_replicas` to 1.
|
||||||
|
@ -60,9 +60,9 @@ llvm::SmallDenseMap<int32_t, int32_t> GetRemappedReplicatedInputIndices(
|
|||||||
|
|
||||||
llvm::SmallDenseMap<int32_t, int32_t> remapped_indices;
|
llvm::SmallDenseMap<int32_t, int32_t> remapped_indices;
|
||||||
for (auto operand_and_idx : llvm::enumerate(launch_func.getOperands()))
|
for (auto operand_and_idx : llvm::enumerate(launch_func.getOperands()))
|
||||||
if (auto block_arg = operand_and_idx.value()->dyn_cast<BlockArgument>())
|
if (auto block_arg = operand_and_idx.value().dyn_cast<BlockArgument>())
|
||||||
if (block_arg->getOwner() == replicate_block)
|
if (block_arg.getOwner() == replicate_block)
|
||||||
remapped_indices[block_arg->getArgNumber()] = operand_and_idx.index();
|
remapped_indices[block_arg.getArgNumber()] = operand_and_idx.index();
|
||||||
|
|
||||||
return remapped_indices;
|
return remapped_indices;
|
||||||
}
|
}
|
||||||
|
@ -135,23 +135,23 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
|
|||||||
// Find inputs that are variable reads.
|
// Find inputs that are variable reads.
|
||||||
for (auto operand : llvm::enumerate(execute->getOpOperands())) {
|
for (auto operand : llvm::enumerate(execute->getOpOperands())) {
|
||||||
infos.new_operand_values.push_back(operand.value().get());
|
infos.new_operand_values.push_back(operand.value().get());
|
||||||
if (!operand.value().get()->getDefiningOp()) continue;
|
if (!operand.value().get().getDefiningOp()) continue;
|
||||||
auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(
|
auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(
|
||||||
operand.value().get()->getDefiningOp());
|
operand.value().get().getDefiningOp());
|
||||||
if (!read_op) continue;
|
if (!read_op) continue;
|
||||||
auto resource = read_op.resource();
|
auto resource = read_op.resource();
|
||||||
|
|
||||||
if (check_device) {
|
if (check_device) {
|
||||||
if (auto resource_op = resource->getDefiningOp()) {
|
if (auto resource_op = resource.getDefiningOp()) {
|
||||||
auto resource_attr = resource_op->getAttr(kDeviceAttr);
|
auto resource_attr = resource_op->getAttr(kDeviceAttr);
|
||||||
// Check device matching for the node defining the resource.
|
// Check device matching for the node defining the resource.
|
||||||
if (!resource_attr || resource_attr != device_attr) continue;
|
if (!resource_attr || resource_attr != device_attr) continue;
|
||||||
} else {
|
} else {
|
||||||
auto resource_arg = resource->dyn_cast<BlockArgument>();
|
auto resource_arg = resource.dyn_cast<BlockArgument>();
|
||||||
assert(resource_arg);
|
assert(resource_arg);
|
||||||
// Check device matching for the argument defining the resource.
|
// Check device matching for the argument defining the resource.
|
||||||
auto resource_attr = func.getArgAttrOfType<mlir::StringAttr>(
|
auto resource_attr = func.getArgAttrOfType<mlir::StringAttr>(
|
||||||
resource_arg->getArgNumber(), kFuncDeviceAttr);
|
resource_arg.getArgNumber(), kFuncDeviceAttr);
|
||||||
if (!resource_attr || resource_attr != device_attr) continue;
|
if (!resource_attr || resource_attr != device_attr) continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -222,9 +222,8 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
|
|||||||
llvm::SmallVector<bool, 8> output_fused(execute->getNumResults(), false);
|
llvm::SmallVector<bool, 8> output_fused(execute->getNumResults(), false);
|
||||||
for (int i = 0; i < execute->getNumResults(); ++i) {
|
for (int i = 0; i < execute->getNumResults(); ++i) {
|
||||||
auto result = execute->getResult(i);
|
auto result = execute->getResult(i);
|
||||||
if (!result->hasOneUse()) continue;
|
if (!result.hasOneUse()) continue;
|
||||||
auto assign_op =
|
auto assign_op = llvm::dyn_cast<TF::AssignVariableOp>(*result.user_begin());
|
||||||
llvm::dyn_cast<TF::AssignVariableOp>(*result->user_begin());
|
|
||||||
if (!assign_op) continue;
|
if (!assign_op) continue;
|
||||||
auto resource = assign_op.resource();
|
auto resource = assign_op.resource();
|
||||||
auto it = infos.per_resource_info.find(resource);
|
auto it = infos.per_resource_info.find(resource);
|
||||||
@ -330,7 +329,7 @@ void MergeForOneTPUExecute(Operation* execute, bool check_device,
|
|||||||
// Replace the uses.
|
// Replace the uses.
|
||||||
for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) {
|
for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) {
|
||||||
if (infos.old_to_new_output_mapping[i] < 0) continue;
|
if (infos.old_to_new_output_mapping[i] < 0) continue;
|
||||||
execute->getResult(i)->replaceAllUsesWith(
|
execute->getResult(i).replaceAllUsesWith(
|
||||||
merged_execute.getResult(infos.old_to_new_output_mapping[i]));
|
merged_execute.getResult(infos.old_to_new_output_mapping[i]));
|
||||||
}
|
}
|
||||||
// Remove the assign ops.
|
// Remove the assign ops.
|
||||||
|
@ -457,7 +457,7 @@ LogicalResult Rewrite(
|
|||||||
// the other ops that are intended to consume the compile result.
|
// the other ops that are intended to consume the compile result.
|
||||||
Block* block = launch_func.getOperation()->getBlock();
|
Block* block = launch_func.getOperation()->getBlock();
|
||||||
for (auto compile_result_op : block->getOps<TF::TPUCompilationResultOp>())
|
for (auto compile_result_op : block->getOps<TF::TPUCompilationResultOp>())
|
||||||
compile_result_op.output()->replaceAllUsesWith(compile_op->getResult(0));
|
compile_result_op.output().replaceAllUsesWith(compile_op->getResult(0));
|
||||||
|
|
||||||
BuildTPUCompileSucceededAssertOp(compile_op, builder);
|
BuildTPUCompileSucceededAssertOp(compile_op, builder);
|
||||||
|
|
||||||
|
@ -97,11 +97,11 @@ void BreakUpIslands::runOnOperation() {
|
|||||||
dups.clear();
|
dups.clear();
|
||||||
|
|
||||||
for (Value input : edges) {
|
for (Value input : edges) {
|
||||||
dups.insert(input->getDefiningOp());
|
dups.insert(input.getDefiningOp());
|
||||||
}
|
}
|
||||||
// Insert new control edges removing duplicates.
|
// Insert new control edges removing duplicates.
|
||||||
for (Value value : llvm::reverse(edge.second)) {
|
for (Value value : llvm::reverse(edge.second)) {
|
||||||
if (dups.insert(value->getDefiningOp()).second) edges.push_back(value);
|
if (dups.insert(value.getDefiningOp()).second) edges.push_back(value);
|
||||||
}
|
}
|
||||||
state.addOperands(edges);
|
state.addOperands(edges);
|
||||||
Operation* new_op = builder.createOperation(state);
|
Operation* new_op = builder.createOperation(state);
|
||||||
@ -160,7 +160,7 @@ IslandSourcesAndSinks FindSourcesAndSinksInIsland(
|
|||||||
for (auto predecessor : predecessors) result.sinks.erase(predecessor);
|
for (auto predecessor : predecessors) result.sinks.erase(predecessor);
|
||||||
bool has_in_island_operands = false;
|
bool has_in_island_operands = false;
|
||||||
for (auto operand : sub_op.getOperands()) {
|
for (auto operand : sub_op.getOperands()) {
|
||||||
auto defining_op = operand->getDefiningOp();
|
auto defining_op = operand.getDefiningOp();
|
||||||
if (!defining_op || defining_op->getParentOp() != island) continue;
|
if (!defining_op || defining_op->getParentOp() != island) continue;
|
||||||
// Remove operands from sinks.
|
// Remove operands from sinks.
|
||||||
result.sinks.erase(defining_op);
|
result.sinks.erase(defining_op);
|
||||||
@ -190,16 +190,16 @@ void BreakUpIslands::BreakUpIsland(
|
|||||||
// the island that defines that fetched value.
|
// the island that defines that fetched value.
|
||||||
for (auto fetch : op.GetYield().fetches()) {
|
for (auto fetch : op.GetYield().fetches()) {
|
||||||
// Ok, because there is no op to add control to (eg: function args).
|
// Ok, because there is no op to add control to (eg: function args).
|
||||||
if (!fetch->getDefiningOp()) continue;
|
if (!fetch.getDefiningOp()) continue;
|
||||||
if (fetch->getDefiningOp()->getParentOp() == op) {
|
if (fetch.getDefiningOp()->getParentOp() == op) {
|
||||||
// OK, because it is the same island.
|
// OK, because it is the same island.
|
||||||
} else if (auto island_op = llvm::dyn_cast<tf_executor::IslandOp>(
|
} else if (auto island_op = llvm::dyn_cast<tf_executor::IslandOp>(
|
||||||
fetch->getDefiningOp())) {
|
fetch.getDefiningOp())) {
|
||||||
island_control_inputs.push_back(island_op.control());
|
island_control_inputs.push_back(island_op.control());
|
||||||
} else {
|
} else {
|
||||||
// TODO(parkers): Any defining op that has a control output can be handled
|
// TODO(parkers): Any defining op that has a control output can be handled
|
||||||
// just like an island.
|
// just like an island.
|
||||||
fetch->getDefiningOp()->emitError("Fetching non-island as dependency.");
|
fetch.getDefiningOp()->emitError("Fetching non-island as dependency.");
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -255,11 +255,11 @@ void BreakUpIslands::BreakUpIsland(
|
|||||||
sink_island_controls.push_back(island.control());
|
sink_island_controls.push_back(island.control());
|
||||||
}
|
}
|
||||||
assert(sink_island_controls.size() == 1);
|
assert(sink_island_controls.size() == 1);
|
||||||
op.control()->replaceAllUsesWith(sink_island_controls[0]);
|
op.control().replaceAllUsesWith(sink_island_controls[0]);
|
||||||
// All existing outputs need to add a control flow edge from
|
// All existing outputs need to add a control flow edge from
|
||||||
// sink_island_controls[0].
|
// sink_island_controls[0].
|
||||||
for (Value out : op.outputs()) {
|
for (Value out : op.outputs()) {
|
||||||
for (auto& use : out->getUses()) {
|
for (auto& use : out.getUses()) {
|
||||||
Operation* owner = use.getOwner();
|
Operation* owner = use.getOwner();
|
||||||
if (auto island_op =
|
if (auto island_op =
|
||||||
llvm::dyn_cast<tf_executor::IslandOp>(owner->getParentOp())) {
|
llvm::dyn_cast<tf_executor::IslandOp>(owner->getParentOp())) {
|
||||||
@ -275,7 +275,7 @@ void BreakUpIslands::BreakUpIsland(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto item : llvm::zip(op.outputs(), op.GetYield().fetches()))
|
for (auto item : llvm::zip(op.outputs(), op.GetYield().fetches()))
|
||||||
std::get<0>(item)->replaceAllUsesWith(std::get<1>(item));
|
std::get<0>(item).replaceAllUsesWith(std::get<1>(item));
|
||||||
op.erase();
|
op.erase();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ tf_executor::IslandOp ControlToExecutorDialectConversion::CreateIslandForOp(
|
|||||||
// Create a new region for the tf_executor.island body
|
// Create a new region for the tf_executor.island body
|
||||||
SmallVector<Value, 8> operands;
|
SmallVector<Value, 8> operands;
|
||||||
for (Value operand : op->getOperands())
|
for (Value operand : op->getOperands())
|
||||||
if (operand->getType().isa<tf_executor::ControlType>())
|
if (operand.getType().isa<tf_executor::ControlType>())
|
||||||
operands.push_back(operand);
|
operands.push_back(operand);
|
||||||
SmallVector<Type, 8> types;
|
SmallVector<Type, 8> types;
|
||||||
for (Type result_type : op->getResultTypes())
|
for (Type result_type : op->getResultTypes())
|
||||||
@ -155,7 +155,7 @@ void ControlToExecutorDialectConversion::runOnFunction() {
|
|||||||
loc, types, operands, ArrayRef<NamedAttribute>{});
|
loc, types, operands, ArrayRef<NamedAttribute>{});
|
||||||
} else if (op.getName().getStringRef() == "_tf.NextIteration.source") {
|
} else if (op.getName().getStringRef() == "_tf.NextIteration.source") {
|
||||||
replacement = builder.create<tf_executor::NextIterationSourceOp>(
|
replacement = builder.create<tf_executor::NextIterationSourceOp>(
|
||||||
loc, op.getResult(0)->getType());
|
loc, op.getResult(0).getType());
|
||||||
// Record a mapping of the name to the nextiteration.source so that when
|
// Record a mapping of the name to the nextiteration.source so that when
|
||||||
// we convert the sink we can get the token.
|
// we convert the sink we can get the token.
|
||||||
StringAttr frame = op.getAttrOfType<StringAttr>("name");
|
StringAttr frame = op.getAttrOfType<StringAttr>("name");
|
||||||
@ -164,9 +164,9 @@ void ControlToExecutorDialectConversion::runOnFunction() {
|
|||||||
cast<tf_executor::NextIterationSourceOp>(replacement);
|
cast<tf_executor::NextIterationSourceOp>(replacement);
|
||||||
// Replace the results here since the _tf source does not produce a token
|
// Replace the results here since the _tf source does not produce a token
|
||||||
// there isn't a mapping for the new result #1.
|
// there isn't a mapping for the new result #1.
|
||||||
op.getResult(0)->replaceAllUsesWith(replacement->getResult(0));
|
op.getResult(0).replaceAllUsesWith(replacement->getResult(0));
|
||||||
for (int i : llvm::seq<int>(1, op.getNumResults()))
|
for (int i : llvm::seq<int>(1, op.getNumResults()))
|
||||||
op.getResult(i)->replaceAllUsesWith(replacement->getResult(i + 1));
|
op.getResult(i).replaceAllUsesWith(replacement->getResult(i + 1));
|
||||||
replacement->setAttrs(op.getAttrList());
|
replacement->setAttrs(op.getAttrList());
|
||||||
op.erase();
|
op.erase();
|
||||||
continue;
|
continue;
|
||||||
@ -202,7 +202,7 @@ void ControlToExecutorDialectConversion::runOnFunction() {
|
|||||||
// Only the non-control operands are carried over, the island is handling
|
// Only the non-control operands are carried over, the island is handling
|
||||||
// the control input.
|
// the control input.
|
||||||
for (Value operand : op.getOperands())
|
for (Value operand : op.getOperands())
|
||||||
if (!operand->getType().isa<tf_executor::ControlType>())
|
if (!operand.getType().isa<tf_executor::ControlType>())
|
||||||
result.operands.push_back(operand);
|
result.operands.push_back(operand);
|
||||||
|
|
||||||
// Add a result type for each non-control result we find
|
// Add a result type for each non-control result we find
|
||||||
@ -232,7 +232,7 @@ void ControlToExecutorDialectConversion::runOnFunction() {
|
|||||||
if (!isa<tf_executor::IslandOp>(replacement))
|
if (!isa<tf_executor::IslandOp>(replacement))
|
||||||
replacement->setAttrs(op.getAttrList());
|
replacement->setAttrs(op.getAttrList());
|
||||||
for (int i : llvm::seq<int>(0, op.getNumResults()))
|
for (int i : llvm::seq<int>(0, op.getNumResults()))
|
||||||
op.getResult(i)->replaceAllUsesWith(replacement->getResult(i));
|
op.getResult(i).replaceAllUsesWith(replacement->getResult(i));
|
||||||
op.erase();
|
op.erase();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -79,7 +79,7 @@ void ExecutorToControlDialectConversion::runOnFunction() {
|
|||||||
for (auto ops_and_ret_vals :
|
for (auto ops_and_ret_vals :
|
||||||
llvm::zip(graph.getResults(), fetch.getOperands()))
|
llvm::zip(graph.getResults(), fetch.getOperands()))
|
||||||
std::get<0>(ops_and_ret_vals)
|
std::get<0>(ops_and_ret_vals)
|
||||||
->replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
|
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
|
||||||
op.erase();
|
op.erase();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -106,7 +106,7 @@ void ExecutorToControlDialectConversion::runOnFunction() {
|
|||||||
for (auto ops_and_ret_vals :
|
for (auto ops_and_ret_vals :
|
||||||
llvm::zip(island.getResults(), wrapped_op.getOperands()))
|
llvm::zip(island.getResults(), wrapped_op.getOperands()))
|
||||||
std::get<0>(ops_and_ret_vals)
|
std::get<0>(ops_and_ret_vals)
|
||||||
->replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
|
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
// Add a leading _ off the name.
|
// Add a leading _ off the name.
|
||||||
@ -141,7 +141,7 @@ void ExecutorToControlDialectConversion::runOnFunction() {
|
|||||||
for (auto ops_and_ret_vals :
|
for (auto ops_and_ret_vals :
|
||||||
llvm::zip(wrapped_op.getResults(), replacement->getResults()))
|
llvm::zip(wrapped_op.getResults(), replacement->getResults()))
|
||||||
std::get<0>(ops_and_ret_vals)
|
std::get<0>(ops_and_ret_vals)
|
||||||
->replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
|
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
|
||||||
|
|
||||||
ctl_sequence = replacement->getResult(replacement->getNumResults() - 1);
|
ctl_sequence = replacement->getResult(replacement->getNumResults() - 1);
|
||||||
}
|
}
|
||||||
@ -151,13 +151,13 @@ void ExecutorToControlDialectConversion::runOnFunction() {
|
|||||||
// been rewritten from ops in island. Last op rewritten must logically
|
// been rewritten from ops in island. Last op rewritten must logically
|
||||||
// carry // all the island control inputs, we can simply use it to
|
// carry // all the island control inputs, we can simply use it to
|
||||||
// replace all uses of island's control output.
|
// replace all uses of island's control output.
|
||||||
island.control()->replaceAllUsesWith(ctl_sequence);
|
island.control().replaceAllUsesWith(ctl_sequence);
|
||||||
} else if (island.getNumOperands() > 0) {
|
} else if (island.getNumOperands() > 0) {
|
||||||
// Getting here means island had an effectively empty body and there is
|
// Getting here means island had an effectively empty body and there is
|
||||||
// just one control input. In this case, island's control output should
|
// just one control input. In this case, island's control output should
|
||||||
// be replaced with the control input.
|
// be replaced with the control input.
|
||||||
assert(island.getNumOperands() == 1);
|
assert(island.getNumOperands() == 1);
|
||||||
island.control()->replaceAllUsesWith(island.getOperand(0));
|
island.control().replaceAllUsesWith(island.getOperand(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
op.erase();
|
op.erase();
|
||||||
@ -192,7 +192,7 @@ void ExecutorToControlDialectConversion::runOnFunction() {
|
|||||||
// dialect.
|
// dialect.
|
||||||
auto non_null_operands = llvm::make_filter_range(
|
auto non_null_operands = llvm::make_filter_range(
|
||||||
op.getOperands(),
|
op.getOperands(),
|
||||||
[](Value v) { return !v->getType().isa<tf_executor::TokenType>(); });
|
[](Value v) { return !v.getType().isa<tf_executor::TokenType>(); });
|
||||||
state.operands.append(non_null_operands.begin(), non_null_operands.end());
|
state.operands.append(non_null_operands.begin(), non_null_operands.end());
|
||||||
for (Type result_type : op.getResultTypes()) {
|
for (Type result_type : op.getResultTypes()) {
|
||||||
// Filter out TokenType, they don't exist in the control dialect.
|
// Filter out TokenType, they don't exist in the control dialect.
|
||||||
@ -212,14 +212,14 @@ void ExecutorToControlDialectConversion::runOnFunction() {
|
|||||||
|
|
||||||
if (auto next_iteration =
|
if (auto next_iteration =
|
||||||
dyn_cast<tf_executor::NextIterationSourceOp>(op)) {
|
dyn_cast<tf_executor::NextIterationSourceOp>(op)) {
|
||||||
next_iteration.output()->replaceAllUsesWith(replacement->getResult(0));
|
next_iteration.output().replaceAllUsesWith(replacement->getResult(0));
|
||||||
next_iteration.token()->dropAllUses();
|
next_iteration.token().dropAllUses();
|
||||||
next_iteration.control()->replaceAllUsesWith(replacement->getResult(1));
|
next_iteration.control().replaceAllUsesWith(replacement->getResult(1));
|
||||||
} else {
|
} else {
|
||||||
for (auto ops_and_ret_vals :
|
for (auto ops_and_ret_vals :
|
||||||
llvm::zip(op.getResults(), replacement->getResults()))
|
llvm::zip(op.getResults(), replacement->getResults()))
|
||||||
std::get<0>(ops_and_ret_vals)
|
std::get<0>(ops_and_ret_vals)
|
||||||
->replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
|
.replaceAllUsesWith(std::get<1>(ops_and_ret_vals));
|
||||||
}
|
}
|
||||||
op.erase();
|
op.erase();
|
||||||
}
|
}
|
||||||
|
@ -236,7 +236,7 @@ std::string Exporter::UniqueName(Operation* op) {
|
|||||||
|
|
||||||
StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
|
StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
|
||||||
BlockArgument arg, unsigned index, llvm::StringRef name) {
|
BlockArgument arg, unsigned index, llvm::StringRef name) {
|
||||||
auto func = arg->getParentRegion()->getParentOfType<mlir::FuncOp>();
|
auto func = arg.getParentRegion()->getParentOfType<mlir::FuncOp>();
|
||||||
|
|
||||||
auto node_def = absl::make_unique<NodeDef>();
|
auto node_def = absl::make_unique<NodeDef>();
|
||||||
if (!name.empty())
|
if (!name.empty())
|
||||||
@ -248,7 +248,7 @@ StatusOr<std::unique_ptr<NodeDef>> Exporter::GetArgumentNode(
|
|||||||
|
|
||||||
DataType dtype;
|
DataType dtype;
|
||||||
TF_RETURN_IF_ERROR(ConvertToDataType(
|
TF_RETURN_IF_ERROR(ConvertToDataType(
|
||||||
arg->getType().cast<mlir::TensorType>().getElementType(), &dtype));
|
arg.getType().cast<mlir::TensorType>().getElementType(), &dtype));
|
||||||
AttrValue type_attr;
|
AttrValue type_attr;
|
||||||
type_attr.set_type(dtype);
|
type_attr.set_type(dtype);
|
||||||
(*node_def->mutable_attr())["T"] = type_attr;
|
(*node_def->mutable_attr())["T"] = type_attr;
|
||||||
@ -286,7 +286,7 @@ StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode(
|
|||||||
auto inst_op = inst->getOperand(index);
|
auto inst_op = inst->getOperand(index);
|
||||||
DataType dtype;
|
DataType dtype;
|
||||||
TF_RETURN_IF_ERROR(ConvertToDataType(
|
TF_RETURN_IF_ERROR(ConvertToDataType(
|
||||||
inst_op->getType().cast<mlir::TensorType>().getElementType(), &dtype));
|
inst_op.getType().cast<mlir::TensorType>().getElementType(), &dtype));
|
||||||
AttrValue type_attr;
|
AttrValue type_attr;
|
||||||
type_attr.set_type(dtype);
|
type_attr.set_type(dtype);
|
||||||
(*node_def->mutable_attr())["T"] = type_attr;
|
(*node_def->mutable_attr())["T"] = type_attr;
|
||||||
@ -298,8 +298,8 @@ StatusOr<std::unique_ptr<NodeDef>> Exporter::GetReturnNode(
|
|||||||
|
|
||||||
Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node,
|
Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node,
|
||||||
unsigned dst_index) {
|
unsigned dst_index) {
|
||||||
if (auto input_result = src->dyn_cast<mlir::OpResult>()) {
|
if (auto input_result = src.dyn_cast<mlir::OpResult>()) {
|
||||||
auto* input_inst = input_result->getOwner();
|
auto* input_inst = input_result.getOwner();
|
||||||
// replaces the input node by the sink one if it is an NextIteration source:
|
// replaces the input node by the sink one if it is an NextIteration source:
|
||||||
auto it = source_to_sink_.find(input_inst);
|
auto it = source_to_sink_.find(input_inst);
|
||||||
if (it != source_to_sink_.end()) {
|
if (it != source_to_sink_.end()) {
|
||||||
@ -308,16 +308,16 @@ Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node,
|
|||||||
auto node_it = nodes_.find(input_inst);
|
auto node_it = nodes_.find(input_inst);
|
||||||
TF_RET_CHECK(node_it != nodes_.end())
|
TF_RET_CHECK(node_it != nodes_.end())
|
||||||
<< "Use of OpResult encountered before def!";
|
<< "Use of OpResult encountered before def!";
|
||||||
if (input_result->getType().isa<mlir::TFControlFlow::TFControlType>()) {
|
if (input_result.getType().isa<mlir::TFControlFlow::TFControlType>()) {
|
||||||
graph_->AddControlEdge(node_it->second, dst_node);
|
graph_->AddControlEdge(node_it->second, dst_node);
|
||||||
} else {
|
} else {
|
||||||
graph_->AddEdge(node_it->second, input_result->getResultNumber(),
|
graph_->AddEdge(node_it->second, input_result.getResultNumber(), dst_node,
|
||||||
dst_node, dst_index);
|
dst_index);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto input_arg = src->cast<BlockArgument>();
|
auto input_arg = src.cast<BlockArgument>();
|
||||||
auto input_node_it = args_.find(input_arg);
|
auto input_node_it = args_.find(input_arg);
|
||||||
TF_RET_CHECK(input_node_it != args_.end())
|
TF_RET_CHECK(input_node_it != args_.end())
|
||||||
<< "Use of BlockArgument encounted before def!";
|
<< "Use of BlockArgument encounted before def!";
|
||||||
@ -366,7 +366,7 @@ Status Exporter::AddInstructionNode(Operation* inst) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool IsEntryFunctionArg(BlockArgument arg) {
|
bool IsEntryFunctionArg(BlockArgument arg) {
|
||||||
return arg->getParentRegion()->getParentOfType<mlir::FuncOp>().getName() ==
|
return arg.getParentRegion()->getParentOfType<mlir::FuncOp>().getName() ==
|
||||||
"main";
|
"main";
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -387,21 +387,21 @@ Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index,
|
|||||||
// is an input node. We recover the original input node and skip adding the
|
// is an input node. We recover the original input node and skip adding the
|
||||||
// argument node. The new input node will be handled as normal in the
|
// argument node. The new input node will be handled as normal in the
|
||||||
// following steps.
|
// following steps.
|
||||||
if (!arg->hasOneUse()) {
|
if (!arg.hasOneUse()) {
|
||||||
return errors::FailedPrecondition(
|
return errors::FailedPrecondition(
|
||||||
"Arg in 'main' should only have one user.");
|
"Arg in 'main' should only have one user.");
|
||||||
}
|
}
|
||||||
auto* input = *arg->user_begin();
|
auto* input = *arg.user_begin();
|
||||||
auto input_name = input->getName().getStringRef();
|
auto input_name = input->getName().getStringRef();
|
||||||
input_name.consume_back(".input");
|
input_name.consume_back(".input");
|
||||||
mlir::OpBuilder builder(arg->getOwner());
|
mlir::OpBuilder builder(arg.getOwner());
|
||||||
auto loc = mlir::NameLoc::get(builder.getIdentifier(UniqueName(input)),
|
auto loc = mlir::NameLoc::get(builder.getIdentifier(UniqueName(input)),
|
||||||
builder.getContext());
|
builder.getContext());
|
||||||
OperationState state(loc, input_name.str());
|
OperationState state(loc, input_name.str());
|
||||||
state.attributes.append(input->getAttrs().begin(), input->getAttrs().end());
|
state.attributes.append(input->getAttrs().begin(), input->getAttrs().end());
|
||||||
for (auto op : input->getOperands()) {
|
for (auto op : input->getOperands()) {
|
||||||
// Skip the argument in the new operation.
|
// Skip the argument in the new operation.
|
||||||
if (op->isa<BlockArgument>()) continue;
|
if (op.isa<BlockArgument>()) continue;
|
||||||
state.operands.push_back(op);
|
state.operands.push_back(op);
|
||||||
}
|
}
|
||||||
state.types.append(input->getResultTypes().begin(),
|
state.types.append(input->getResultTypes().begin(),
|
||||||
@ -419,7 +419,7 @@ Status Exporter::AddArgumentNode(BlockArgument arg, unsigned index,
|
|||||||
<< inst << " to an empty string.";
|
<< inst << " to an empty string.";
|
||||||
mapped_name.assign(input_mapped_name);
|
mapped_name.assign(input_mapped_name);
|
||||||
for (int index : llvm::seq<int>(0, input->getNumResults())) {
|
for (int index : llvm::seq<int>(0, input->getNumResults())) {
|
||||||
input->getResult(index)->replaceAllUsesWith(inst->getResult(index));
|
input->getResult(index).replaceAllUsesWith(inst->getResult(index));
|
||||||
}
|
}
|
||||||
input->dropAllReferences();
|
input->dropAllReferences();
|
||||||
input->erase();
|
input->erase();
|
||||||
@ -524,7 +524,7 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
|||||||
// the main graph did not have its _Retval nodes lifted into the functions
|
// the main graph did not have its _Retval nodes lifted into the functions
|
||||||
// returns.
|
// returns.
|
||||||
if (!graph_as_function) {
|
if (!graph_as_function) {
|
||||||
auto defining_op = it.value()->getDefiningOp();
|
auto defining_op = it.value().getDefiningOp();
|
||||||
auto& mapped_name = exporter.op_to_name_[defining_op];
|
auto& mapped_name = exporter.op_to_name_[defining_op];
|
||||||
DCHECK(mapped_name.empty())
|
DCHECK(mapped_name.empty())
|
||||||
<< "Convert() attempted to change the op_to_name_ mapping for "
|
<< "Convert() attempted to change the op_to_name_ mapping for "
|
||||||
@ -541,7 +541,7 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
|||||||
// Only assign user of argument the input name if the main graph did not
|
// Only assign user of argument the input name if the main graph did not
|
||||||
// have its _Arg nodes lifted into the functions arguments.
|
// have its _Arg nodes lifted into the functions arguments.
|
||||||
if (!graph_as_function) {
|
if (!graph_as_function) {
|
||||||
auto first_user = *it.value()->user_begin();
|
auto first_user = *it.value().user_begin();
|
||||||
auto& mapped_name = exporter.op_to_name_[first_user];
|
auto& mapped_name = exporter.op_to_name_[first_user];
|
||||||
DCHECK(mapped_name.empty())
|
DCHECK(mapped_name.empty())
|
||||||
<< "Convert() attempted to change the op_to_name_ mapping for "
|
<< "Convert() attempted to change the op_to_name_ mapping for "
|
||||||
@ -556,7 +556,7 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
|||||||
for (auto it : llvm::enumerate(block.getArguments())) {
|
for (auto it : llvm::enumerate(block.getArguments())) {
|
||||||
int index = it.index();
|
int index = it.index();
|
||||||
auto arg = it.value();
|
auto arg = it.value();
|
||||||
mlir::Type type = arg->getType();
|
mlir::Type type = arg.getType();
|
||||||
if (!type.isa<mlir::TensorType>()) {
|
if (!type.isa<mlir::TensorType>()) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"FuncOps arguments must have tensor types. Found ",
|
"FuncOps arguments must have tensor types. Found ",
|
||||||
|
@ -132,7 +132,7 @@ StatusOr<std::unique_ptr<NodeDef>> ConvertTFDialectOpToNodeDef(
|
|||||||
mlir::OperationState result(inst->getLoc(),
|
mlir::OperationState result(inst->getLoc(),
|
||||||
inst->getName().getStringRef().drop_front());
|
inst->getName().getStringRef().drop_front());
|
||||||
for (mlir::Value operand : inst->getOperands())
|
for (mlir::Value operand : inst->getOperands())
|
||||||
if (!operand->getType().isa<mlir::TFControlFlow::TFControlType>())
|
if (!operand.getType().isa<mlir::TFControlFlow::TFControlType>())
|
||||||
result.operands.push_back(operand);
|
result.operands.push_back(operand);
|
||||||
|
|
||||||
// Add a result type for each non-control result we find
|
// Add a result type for each non-control result we find
|
||||||
|
@ -1192,9 +1192,9 @@ Status ImporterBase::ConvertFunctionArgAndRets(
|
|||||||
|
|
||||||
// Collect mapping of OutputTensor to associated block arg.
|
// Collect mapping of OutputTensor to associated block arg.
|
||||||
arg_nodes_to_values.try_emplace({arg_node.node, arg_node.index}, arg_def);
|
arg_nodes_to_values.try_emplace({arg_node.node, arg_node.index}, arg_def);
|
||||||
island->getResult(0)->replaceAllUsesWith(arg_def);
|
island->getResult(0).replaceAllUsesWith(arg_def);
|
||||||
// Erase control outputs from feed.
|
// Erase control outputs from feed.
|
||||||
auto control_uses = island->getResult(1)->getUses();
|
auto control_uses = island->getResult(1).getUses();
|
||||||
for (auto& control_use : llvm::make_early_inc_range(control_uses))
|
for (auto& control_use : llvm::make_early_inc_range(control_uses))
|
||||||
control_use.getOwner()->eraseOperand(control_use.getOperandNumber());
|
control_use.getOwner()->eraseOperand(control_use.getOperandNumber());
|
||||||
|
|
||||||
@ -1389,7 +1389,7 @@ mlir::Operation* ImporterBase::createOperation(
|
|||||||
builder_.getBlock()->begin());
|
builder_.getBlock()->begin());
|
||||||
auto source_op =
|
auto source_op =
|
||||||
builder_at_begin.create<mlir::tf_executor::NextIterationSourceOp>(
|
builder_at_begin.create<mlir::tf_executor::NextIterationSourceOp>(
|
||||||
loc, operands[0]->getType(), result.attributes);
|
loc, operands[0].getType(), result.attributes);
|
||||||
return builder_.create<mlir::tf_executor::NextIterationSinkOp>(
|
return builder_.create<mlir::tf_executor::NextIterationSinkOp>(
|
||||||
loc, source_op.token(), operands, result.attributes);
|
loc, source_op.token(), operands, result.attributes);
|
||||||
}
|
}
|
||||||
@ -1654,7 +1654,7 @@ Status ImporterBase::AddBackedges() {
|
|||||||
Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
|
Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
|
||||||
int dst_input) {
|
int dst_input) {
|
||||||
// Get the NextIteration.Source operation from the token operand of the sink.
|
// Get the NextIteration.Source operation from the token operand of the sink.
|
||||||
mlir::Operation* source = sink->getOperand(0)->getDefiningOp();
|
mlir::Operation* source = sink->getOperand(0).getDefiningOp();
|
||||||
|
|
||||||
// Adds the "source" to the operands of the dst by creating a new dst
|
// Adds the "source" to the operands of the dst by creating a new dst
|
||||||
// operation.
|
// operation.
|
||||||
@ -1680,7 +1680,7 @@ Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst,
|
|||||||
// result of the new operation, and deletes the old operation.
|
// result of the new operation, and deletes the old operation.
|
||||||
for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) {
|
for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) {
|
||||||
auto new_output = new_dst->getResult(i);
|
auto new_output = new_dst->getResult(i);
|
||||||
dst->getResult(i)->replaceAllUsesWith(new_output);
|
dst->getResult(i).replaceAllUsesWith(new_output);
|
||||||
}
|
}
|
||||||
dst->dropAllReferences();
|
dst->dropAllReferences();
|
||||||
dst->erase();
|
dst->erase();
|
||||||
|
@ -222,12 +222,12 @@ static bool IsRefTypeControlOp(mlir::Operation* op) {
|
|||||||
|
|
||||||
auto op_name = op_name_or_status.ConsumeValueOrDie();
|
auto op_name = op_name_or_status.ConsumeValueOrDie();
|
||||||
if (op_name.equals("NextIteration"))
|
if (op_name.equals("NextIteration"))
|
||||||
return mlir::getElementTypeOrSelf(op->getOperand(0)->getType())
|
return mlir::getElementTypeOrSelf(op->getOperand(0).getType())
|
||||||
.isa<mlir::TF::TensorFlowRefType>();
|
.isa<mlir::TF::TensorFlowRefType>();
|
||||||
|
|
||||||
if (op_name.equals("Enter") || op_name.equals("Exit") ||
|
if (op_name.equals("Enter") || op_name.equals("Exit") ||
|
||||||
op_name.equals("Switch") || op_name.equals("Merge")) {
|
op_name.equals("Switch") || op_name.equals("Merge")) {
|
||||||
return getElementTypeOrSelf(op->getResult(0)->getType())
|
return getElementTypeOrSelf(op->getResult(0).getType())
|
||||||
.isa<mlir::TF::TensorFlowRefType>();
|
.isa<mlir::TF::TensorFlowRefType>();
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -407,7 +407,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
|||||||
}
|
}
|
||||||
case HloOpcode::kWhile: {
|
case HloOpcode::kWhile: {
|
||||||
auto op = func_builder->create<mlir::xla_hlo::WhileOp>(
|
auto op = func_builder->create<mlir::xla_hlo::WhileOp>(
|
||||||
loc, operands[0]->getType(), operands[0]);
|
loc, operands[0].getType(), operands[0]);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
ImportComputation(instruction->while_condition(), &op.cond()));
|
ImportComputation(instruction->while_condition(), &op.cond()));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
|
@ -175,7 +175,7 @@ void ConstOp::build(Builder* builder, OperationState& result, Attribute value) {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
|
||||||
const auto output_type = getResult()->getType().cast<ShapedType>();
|
const auto output_type = getResult().getType().cast<ShapedType>();
|
||||||
const auto output_size = output_type.getNumElements();
|
const auto output_size = output_type.getNumElements();
|
||||||
const auto dimension = iota_dimension().getSExtValue();
|
const auto dimension = iota_dimension().getSExtValue();
|
||||||
const auto max_dim_size = output_type.getDimSize(dimension);
|
const auto max_dim_size = output_type.getDimSize(dimension);
|
||||||
@ -204,15 +204,14 @@ OpFoldResult IotaOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void AbsOp::build(Builder* builder, OperationState& result, Value operand) {
|
void AbsOp::build(Builder* builder, OperationState& result, Value operand) {
|
||||||
auto shaped_type = operand->getType().cast<ShapedType>();
|
auto shaped_type = operand.getType().cast<ShapedType>();
|
||||||
Type new_type;
|
Type new_type;
|
||||||
if (!shaped_type.getElementType().isa<ComplexType>()) {
|
if (!shaped_type.getElementType().isa<ComplexType>()) {
|
||||||
new_type = operand->getType();
|
new_type = operand.getType();
|
||||||
} else if (shaped_type.hasRank()) {
|
} else if (shaped_type.hasRank()) {
|
||||||
new_type =
|
new_type = RankedTensorType::get(shaped_type.getShape(), operand.getType());
|
||||||
RankedTensorType::get(shaped_type.getShape(), operand->getType());
|
|
||||||
} else {
|
} else {
|
||||||
new_type = UnrankedTensorType::get(operand->getType());
|
new_type = UnrankedTensorType::get(operand.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
return AbsOp::build(builder, result, new_type, operand);
|
return AbsOp::build(builder, result, new_type, operand);
|
||||||
@ -225,7 +224,7 @@ void AbsOp::build(Builder* builder, OperationState& result, Value operand) {
|
|||||||
void ConvertOp::build(Builder* builder, OperationState& result, Value operand,
|
void ConvertOp::build(Builder* builder, OperationState& result, Value operand,
|
||||||
Type result_element_ty) {
|
Type result_element_ty) {
|
||||||
Type result_ty;
|
Type result_ty;
|
||||||
Type operand_ty = operand->getType();
|
Type operand_ty = operand.getType();
|
||||||
if (auto ranked_ty = operand_ty.dyn_cast<RankedTensorType>()) {
|
if (auto ranked_ty = operand_ty.dyn_cast<RankedTensorType>()) {
|
||||||
result_ty = RankedTensorType::get(ranked_ty.getShape(), result_element_ty);
|
result_ty = RankedTensorType::get(ranked_ty.getShape(), result_element_ty);
|
||||||
} else {
|
} else {
|
||||||
@ -235,7 +234,7 @@ void ConvertOp::build(Builder* builder, OperationState& result, Value operand,
|
|||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
|
||||||
if (getOperand()->getType() == getResult()->getType()) return getOperand();
|
if (getOperand().getType() == getResult().getType()) return getOperand();
|
||||||
|
|
||||||
// If the operand is constant, we can do the conversion now.
|
// If the operand is constant, we can do the conversion now.
|
||||||
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
|
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
|
||||||
@ -252,7 +251,7 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
|
|
||||||
static LogicalResult Verify(GetTupleElementOp op) {
|
static LogicalResult Verify(GetTupleElementOp op) {
|
||||||
auto indexVal = op.index().getZExtValue();
|
auto indexVal = op.index().getZExtValue();
|
||||||
auto operandType = op.getOperand()->getType().cast<TupleType>();
|
auto operandType = op.getOperand().getType().cast<TupleType>();
|
||||||
if (indexVal >= operandType.size()) {
|
if (indexVal >= operandType.size()) {
|
||||||
return op.emitOpError(
|
return op.emitOpError(
|
||||||
llvm::formatv("index {0} is out of bounds of operand with size {1}",
|
llvm::formatv("index {0} is out of bounds of operand with size {1}",
|
||||||
@ -269,7 +268,7 @@ static LogicalResult Verify(GetTupleElementOp op) {
|
|||||||
|
|
||||||
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
|
||||||
if (auto tupleOp =
|
if (auto tupleOp =
|
||||||
dyn_cast_or_null<xla_hlo::TupleOp>(getOperand()->getDefiningOp())) {
|
dyn_cast_or_null<xla_hlo::TupleOp>(getOperand().getDefiningOp())) {
|
||||||
return tupleOp.getOperand(index().getLimitedValue());
|
return tupleOp.getOperand(index().getLimitedValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -305,9 +304,9 @@ static LogicalResult Verify(BroadcastOp op) {
|
|||||||
"broadcast_sizes has rank {0} instead of rank 1", sizesRank));
|
"broadcast_sizes has rank {0} instead of rank 1", sizesRank));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
|
auto resultType = op.getResult().getType().cast<RankedTensorType>();
|
||||||
auto resultRank = resultType.getRank();
|
auto resultRank = resultType.getRank();
|
||||||
auto operandType = op.operand()->getType().cast<RankedTensorType>();
|
auto operandType = op.operand().getType().cast<RankedTensorType>();
|
||||||
auto operandRank = operandType.getRank();
|
auto operandRank = operandType.getRank();
|
||||||
auto sizesSize = sizesType.getNumElements();
|
auto sizesSize = sizesType.getNumElements();
|
||||||
auto expectedRank = operandRank + sizesSize;
|
auto expectedRank = operandRank + sizesSize;
|
||||||
@ -341,7 +340,7 @@ static LogicalResult Verify(BroadcastOp op) {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static LogicalResult Verify(BroadcastInDimOp op) {
|
static LogicalResult Verify(BroadcastInDimOp op) {
|
||||||
auto operandType = op.operand()->getType().cast<RankedTensorType>();
|
auto operandType = op.operand().getType().cast<RankedTensorType>();
|
||||||
auto operandRank = operandType.getRank();
|
auto operandRank = operandType.getRank();
|
||||||
if (!op.broadcast_dimensions()) {
|
if (!op.broadcast_dimensions()) {
|
||||||
if (operandRank == 0) {
|
if (operandRank == 0) {
|
||||||
@ -368,7 +367,7 @@ static LogicalResult Verify(BroadcastInDimOp op) {
|
|||||||
dimensionsSize, operandRank));
|
dimensionsSize, operandRank));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
|
auto resultType = op.getResult().getType().cast<RankedTensorType>();
|
||||||
auto resultRank = resultType.getRank();
|
auto resultRank = resultType.getRank();
|
||||||
if (resultRank < operandRank) {
|
if (resultRank < operandRank) {
|
||||||
return op.emitOpError(
|
return op.emitOpError(
|
||||||
@ -403,9 +402,9 @@ static LogicalResult Verify(BroadcastInDimOp op) {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static LogicalResult Verify(ClampOp op) {
|
static LogicalResult Verify(ClampOp op) {
|
||||||
auto operandType = op.operand()->getType().cast<RankedTensorType>();
|
auto operandType = op.operand().getType().cast<RankedTensorType>();
|
||||||
auto operandShape = operandType.getShape();
|
auto operandShape = operandType.getShape();
|
||||||
auto minType = op.min()->getType().cast<RankedTensorType>();
|
auto minType = op.min().getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
auto minShape = minType.getShape();
|
auto minShape = minType.getShape();
|
||||||
if (minShape != operandShape && minType.getRank() != 0) {
|
if (minShape != operandShape && minType.getRank() != 0) {
|
||||||
@ -415,7 +414,7 @@ static LogicalResult Verify(ClampOp op) {
|
|||||||
llvm::make_range(operandShape.begin(), operandShape.end())));
|
llvm::make_range(operandShape.begin(), operandShape.end())));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto maxType = op.max()->getType().cast<RankedTensorType>();
|
auto maxType = op.max().getType().cast<RankedTensorType>();
|
||||||
auto maxShape = maxType.getShape();
|
auto maxShape = maxType.getShape();
|
||||||
if (maxShape != operandShape && maxType.getRank() != 0) {
|
if (maxShape != operandShape && maxType.getRank() != 0) {
|
||||||
return op.emitOpError(llvm::formatv(
|
return op.emitOpError(llvm::formatv(
|
||||||
@ -433,7 +432,7 @@ static LogicalResult Verify(ClampOp op) {
|
|||||||
|
|
||||||
void ComplexOp::build(Builder* builder, OperationState& state, Value lhs,
|
void ComplexOp::build(Builder* builder, OperationState& state, Value lhs,
|
||||||
Value rhs) {
|
Value rhs) {
|
||||||
auto type = lhs->getType();
|
auto type = lhs.getType();
|
||||||
auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
|
auto element_ty = ComplexType::get(getElementTypeOrSelf(type));
|
||||||
Type result_ty;
|
Type result_ty;
|
||||||
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
|
if (auto ranked_type = type.dyn_cast<RankedTensorType>()) {
|
||||||
@ -449,9 +448,9 @@ void ComplexOp::build(Builder* builder, OperationState& state, Value lhs,
|
|||||||
|
|
||||||
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
|
||||||
auto real_op =
|
auto real_op =
|
||||||
dyn_cast_or_null<xla_hlo::RealOp>(getOperand(0)->getDefiningOp());
|
dyn_cast_or_null<xla_hlo::RealOp>(getOperand(0).getDefiningOp());
|
||||||
auto imag_op =
|
auto imag_op =
|
||||||
dyn_cast_or_null<xla_hlo::ImagOp>(getOperand(1)->getDefiningOp());
|
dyn_cast_or_null<xla_hlo::ImagOp>(getOperand(1).getDefiningOp());
|
||||||
if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
|
if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
|
||||||
return real_op.getOperand();
|
return real_op.getOperand();
|
||||||
}
|
}
|
||||||
@ -477,12 +476,12 @@ Type CreateRealType(Type type) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void ImagOp::build(Builder* builder, OperationState& state, Value val) {
|
void ImagOp::build(Builder* builder, OperationState& state, Value val) {
|
||||||
build(builder, state, CreateRealType(val->getType()), val);
|
build(builder, state, CreateRealType(val.getType()), val);
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
||||||
if (auto complex_op =
|
if (auto complex_op =
|
||||||
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand()->getDefiningOp())) {
|
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand().getDefiningOp())) {
|
||||||
return complex_op.getOperand(1);
|
return complex_op.getOperand(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -490,12 +489,12 @@ OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void RealOp::build(Builder* builder, OperationState& state, Value val) {
|
void RealOp::build(Builder* builder, OperationState& state, Value val) {
|
||||||
build(builder, state, CreateRealType(val->getType()), val);
|
build(builder, state, CreateRealType(val.getType()), val);
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
|
||||||
if (auto complex_op =
|
if (auto complex_op =
|
||||||
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand()->getDefiningOp())) {
|
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand().getDefiningOp())) {
|
||||||
return complex_op.getOperand(0);
|
return complex_op.getOperand(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -512,12 +511,12 @@ OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult Verify(ConcatenateOp op) {
|
static LogicalResult Verify(ConcatenateOp op) {
|
||||||
auto firstType = op.getOperand(0)->getType().cast<RankedTensorType>();
|
auto firstType = op.getOperand(0).getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
auto firstShape = firstType.getShape();
|
auto firstShape = firstType.getShape();
|
||||||
int numOperands = op.getNumOperands();
|
int numOperands = op.getNumOperands();
|
||||||
for (int i = 1; i < numOperands; i++) {
|
for (int i = 1; i < numOperands; i++) {
|
||||||
auto secondType = op.getOperand(i)->getType().cast<RankedTensorType>();
|
auto secondType = op.getOperand(i).getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
if (firstType.getRank() != secondType.getRank()) {
|
if (firstType.getRank() != secondType.getRank()) {
|
||||||
return op.emitOpError(
|
return op.emitOpError(
|
||||||
@ -552,18 +551,18 @@ void DynamicSliceOp::getCanonicalizationPatterns(
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||||
if (getOperand()->getType() == getType()) {
|
if (getOperand().getType() == getType()) {
|
||||||
return getOperand();
|
return getOperand();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto prev_op =
|
if (auto prev_op =
|
||||||
dyn_cast_or_null<ReshapeOp>(getOperand()->getDefiningOp())) {
|
dyn_cast_or_null<ReshapeOp>(getOperand().getDefiningOp())) {
|
||||||
setOperand(prev_op.getOperand());
|
setOperand(prev_op.getOperand());
|
||||||
return getResult();
|
return getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
|
if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
|
||||||
return elements.reshape(getResult()->getType().cast<ShapedType>());
|
return elements.reshape(getResult().getType().cast<ShapedType>());
|
||||||
}
|
}
|
||||||
|
|
||||||
return {};
|
return {};
|
||||||
@ -613,7 +612,7 @@ void ReduceOp::build(Builder* builder, OperationState& state,
|
|||||||
|
|
||||||
for (Value operand : operands) {
|
for (Value operand : operands) {
|
||||||
result_ty.push_back(
|
result_ty.push_back(
|
||||||
GetReduceResultType(operand->getType(), dimensions, builder));
|
GetReduceResultType(operand.getType(), dimensions, builder));
|
||||||
}
|
}
|
||||||
build(builder, state, result_ty, operands, init_values, dimensions);
|
build(builder, state, result_ty, operands, init_values, dimensions);
|
||||||
}
|
}
|
||||||
@ -645,8 +644,8 @@ static LogicalResult Verify(SelectOp op) {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static LogicalResult Verify(PadOp op) {
|
static LogicalResult Verify(PadOp op) {
|
||||||
auto input_type = op.operand()->getType().cast<RankedTensorType>();
|
auto input_type = op.operand().getType().cast<RankedTensorType>();
|
||||||
auto pad_type = op.padding_value()->getType().cast<RankedTensorType>();
|
auto pad_type = op.padding_value().getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
if (pad_type.getRank() != 0) {
|
if (pad_type.getRank() != 0) {
|
||||||
return op.emitOpError(
|
return op.emitOpError(
|
||||||
@ -678,7 +677,7 @@ static LogicalResult Verify(PadOp op) {
|
|||||||
|
|
||||||
auto input_shape = input_type.getShape();
|
auto input_shape = input_type.getShape();
|
||||||
auto output_shape =
|
auto output_shape =
|
||||||
op.getResult()->getType().cast<RankedTensorType>().getShape();
|
op.getResult().getType().cast<RankedTensorType>().getShape();
|
||||||
if (input_shape.size() != output_shape.size()) {
|
if (input_shape.size() != output_shape.size()) {
|
||||||
return op.emitOpError(
|
return op.emitOpError(
|
||||||
llvm::formatv("operand rank ({0}) and result rank({0}) should match",
|
llvm::formatv("operand rank ({0}) and result rank({0}) should match",
|
||||||
@ -757,15 +756,15 @@ static Type GetBroadcastType(Builder* builder, Type x, Type y,
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
#define BINARY_BUILDER(Op) \
|
#define BINARY_BUILDER(Op) \
|
||||||
void Op::build(Builder* builder, OperationState& result, Value left, \
|
void Op::build(Builder* builder, OperationState& result, Value left, \
|
||||||
Value right, DenseIntElementsAttr broadcast_dimensions) { \
|
Value right, DenseIntElementsAttr broadcast_dimensions) { \
|
||||||
auto type = GetBroadcastType(builder, left->getType().cast<ShapedType>(), \
|
auto type = GetBroadcastType(builder, left.getType().cast<ShapedType>(), \
|
||||||
right->getType().cast<ShapedType>(), \
|
right.getType().cast<ShapedType>(), \
|
||||||
getElementTypeOrSelf(right->getType()), \
|
getElementTypeOrSelf(right.getType()), \
|
||||||
broadcast_dimensions); \
|
broadcast_dimensions); \
|
||||||
return Op::build(builder, result, type, left, right, \
|
return Op::build(builder, result, type, left, right, \
|
||||||
broadcast_dimensions); \
|
broadcast_dimensions); \
|
||||||
}
|
}
|
||||||
|
|
||||||
BINARY_BUILDER(AddOp);
|
BINARY_BUILDER(AddOp);
|
||||||
@ -815,7 +814,7 @@ Type SliceOp::InferOutputTypes(Builder* builder, Value operand,
|
|||||||
DenseIntElementsAttr start_indices,
|
DenseIntElementsAttr start_indices,
|
||||||
DenseIntElementsAttr limit_indices,
|
DenseIntElementsAttr limit_indices,
|
||||||
DenseIntElementsAttr strides) {
|
DenseIntElementsAttr strides) {
|
||||||
Type ty = operand->getType();
|
Type ty = operand.getType();
|
||||||
RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
|
RankedTensorType ranked_ty = ty.dyn_cast<RankedTensorType>();
|
||||||
if (!ranked_ty) return ty;
|
if (!ranked_ty) return ty;
|
||||||
int64_t rank = ranked_ty.getRank();
|
int64_t rank = ranked_ty.getRank();
|
||||||
@ -852,7 +851,7 @@ void SortOp::build(Builder* builder, OperationState& state, ValueRange operands,
|
|||||||
|
|
||||||
SmallVector<Type, 2> element_types;
|
SmallVector<Type, 2> element_types;
|
||||||
element_types.reserve(operands.size());
|
element_types.reserve(operands.size());
|
||||||
for (Value operand : operands) element_types.push_back(operand->getType());
|
for (Value operand : operands) element_types.push_back(operand.getType());
|
||||||
state.addTypes(builder->getTupleType(element_types));
|
state.addTypes(builder->getTupleType(element_types));
|
||||||
|
|
||||||
state.addRegion();
|
state.addRegion();
|
||||||
@ -864,14 +863,13 @@ static LogicalResult Verify(SortOp op) {
|
|||||||
|
|
||||||
// TODO(antiagainst): verify partionally dynamic shapes
|
// TODO(antiagainst): verify partionally dynamic shapes
|
||||||
if (llvm::all_of(operands, [](Value operand) {
|
if (llvm::all_of(operands, [](Value operand) {
|
||||||
return operand->getType().cast<ShapedType>().hasRank();
|
return operand.getType().cast<ShapedType>().hasRank();
|
||||||
})) {
|
})) {
|
||||||
ArrayRef<int64_t> input_shape =
|
ArrayRef<int64_t> input_shape =
|
||||||
(*operands.begin())->getType().cast<ShapedType>().getShape();
|
(*operands.begin()).getType().cast<ShapedType>().getShape();
|
||||||
|
|
||||||
if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) {
|
if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) {
|
||||||
return operand->getType().cast<ShapedType>().getShape() !=
|
return operand.getType().cast<ShapedType>().getShape() != input_shape;
|
||||||
input_shape;
|
|
||||||
}))
|
}))
|
||||||
return op.emitOpError("requires all inputs to have the same dimensions");
|
return op.emitOpError("requires all inputs to have the same dimensions");
|
||||||
|
|
||||||
@ -889,10 +887,10 @@ static LogicalResult Verify(SortOp op) {
|
|||||||
for (auto indexed_operand : llvm::enumerate(operands)) {
|
for (auto indexed_operand : llvm::enumerate(operands)) {
|
||||||
int index = indexed_operand.index();
|
int index = indexed_operand.index();
|
||||||
Type element_type =
|
Type element_type =
|
||||||
indexed_operand.value()->getType().cast<ShapedType>().getElementType();
|
indexed_operand.value().getType().cast<ShapedType>().getElementType();
|
||||||
Type tensor_type = RankedTensorType::get({}, element_type);
|
Type tensor_type = RankedTensorType::get({}, element_type);
|
||||||
for (int i : {2 * index, 2 * index + 1}) {
|
for (int i : {2 * index, 2 * index + 1}) {
|
||||||
Type arg_type = block.getArgument(i)->getType();
|
Type arg_type = block.getArgument(i).getType();
|
||||||
if (arg_type != tensor_type)
|
if (arg_type != tensor_type)
|
||||||
return op.emitOpError("comparator block argument #")
|
return op.emitOpError("comparator block argument #")
|
||||||
<< i << " should be of type " << tensor_type << " but got "
|
<< i << " should be of type " << tensor_type << " but got "
|
||||||
@ -926,7 +924,7 @@ static LogicalResult Verify(TransposeOp op) {
|
|||||||
}
|
}
|
||||||
auto permutationSize = permutationType.getNumElements();
|
auto permutationSize = permutationType.getNumElements();
|
||||||
|
|
||||||
auto operandType = op.operand()->getType().dyn_cast<RankedTensorType>();
|
auto operandType = op.operand().getType().dyn_cast<RankedTensorType>();
|
||||||
if (operandType) {
|
if (operandType) {
|
||||||
auto operandRank = operandType.getRank();
|
auto operandRank = operandType.getRank();
|
||||||
if (operandRank != permutationSize) {
|
if (operandRank != permutationSize) {
|
||||||
@ -936,7 +934,7 @@ static LogicalResult Verify(TransposeOp op) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto resultType = op.getResult()->getType().dyn_cast<RankedTensorType>();
|
auto resultType = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||||
if (resultType) {
|
if (resultType) {
|
||||||
auto resultRank = resultType.getRank();
|
auto resultRank = resultType.getRank();
|
||||||
if (resultRank != permutationSize) {
|
if (resultRank != permutationSize) {
|
||||||
@ -972,14 +970,14 @@ static LogicalResult Verify(TransposeOp op) {
|
|||||||
|
|
||||||
void GetTupleElementOp::build(Builder* builder, OperationState& result,
|
void GetTupleElementOp::build(Builder* builder, OperationState& result,
|
||||||
Value tuple, int32_t index) {
|
Value tuple, int32_t index) {
|
||||||
if (auto tuple_type = tuple->getType().dyn_cast<TupleType>()) {
|
if (auto tuple_type = tuple.getType().dyn_cast<TupleType>()) {
|
||||||
auto element_type = tuple_type.getType(index);
|
auto element_type = tuple_type.getType(index);
|
||||||
build(builder, result, element_type, tuple,
|
build(builder, result, element_type, tuple,
|
||||||
builder->getI32IntegerAttr(index));
|
builder->getI32IntegerAttr(index));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
build(builder, result, tuple->getType(), tuple,
|
build(builder, result, tuple.getType(), tuple,
|
||||||
builder->getI32IntegerAttr(index));
|
builder->getI32IntegerAttr(index));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -992,7 +990,7 @@ void TupleOp::build(Builder* builder, OperationState& result,
|
|||||||
SmallVector<Type, 4> types;
|
SmallVector<Type, 4> types;
|
||||||
types.reserve(values.size());
|
types.reserve(values.size());
|
||||||
for (auto val : values) {
|
for (auto val : values) {
|
||||||
types.push_back(val->getType());
|
types.push_back(val.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
build(builder, result, builder->getTupleType(types), values);
|
build(builder, result, builder->getTupleType(types), values);
|
||||||
@ -1014,7 +1012,7 @@ void UnaryEinsumOp::getCanonicalizationPatterns(
|
|||||||
void CompareOp::build(Builder* builder, OperationState& result, Value lhs,
|
void CompareOp::build(Builder* builder, OperationState& result, Value lhs,
|
||||||
Value rhs, DenseIntElementsAttr broadcast_dimensions,
|
Value rhs, DenseIntElementsAttr broadcast_dimensions,
|
||||||
StringAttr comparison_direction) {
|
StringAttr comparison_direction) {
|
||||||
auto new_type = GetBroadcastType(builder, lhs->getType(), rhs->getType(),
|
auto new_type = GetBroadcastType(builder, lhs.getType(), rhs.getType(),
|
||||||
builder->getI1Type(), broadcast_dimensions);
|
builder->getI1Type(), broadcast_dimensions);
|
||||||
build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
|
build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
|
||||||
comparison_direction);
|
comparison_direction);
|
||||||
|
@ -23,8 +23,8 @@ namespace mlir {
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y) {
|
DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y) {
|
||||||
TensorType xType = x->getType().dyn_cast<RankedTensorType>();
|
TensorType xType = x.getType().dyn_cast<RankedTensorType>();
|
||||||
TensorType yType = y->getType().dyn_cast<RankedTensorType>();
|
TensorType yType = y.getType().dyn_cast<RankedTensorType>();
|
||||||
if (xType == yType || !xType || !yType) return {};
|
if (xType == yType || !xType || !yType) return {};
|
||||||
|
|
||||||
// If the shapes have the same rank, then there is nothing to do.
|
// If the shapes have the same rank, then there is nothing to do.
|
||||||
|
@ -35,8 +35,8 @@ mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b,
|
|||||||
/// Get a constant splat for the given value type.
|
/// Get a constant splat for the given value type.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static ElementsAttr getSplat(Builder* b, Value val, T constant) {
|
static ElementsAttr getSplat(Builder* b, Value val, T constant) {
|
||||||
auto valType = val->getType().cast<TensorType>();
|
auto valType = val.getType().cast<TensorType>();
|
||||||
auto valElementType = getElementTypeOrSelf(val->getType());
|
auto valElementType = getElementTypeOrSelf(val.getType());
|
||||||
|
|
||||||
// Handle integer elements.
|
// Handle integer elements.
|
||||||
Attribute elementAttr;
|
Attribute elementAttr;
|
||||||
|
@ -542,7 +542,7 @@ LogicalResult ExportXlaOp(OutfeedOp op, OpLoweringContext ctx) {
|
|||||||
auto& value_map = *ctx.values;
|
auto& value_map = *ctx.values;
|
||||||
value_map[op] = xla::OutfeedWithToken(
|
value_map[op] = xla::OutfeedWithToken(
|
||||||
value_map[op.operand()], value_map[op.token()],
|
value_map[op.operand()], value_map[op.token()],
|
||||||
xla::TypeToShape(op.operand()->getType()), op.outfeed_config());
|
xla::TypeToShape(op.operand().getType()), op.outfeed_config());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -883,7 +883,7 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
|
|||||||
std::vector<xla::Shape> arg_shapes;
|
std::vector<xla::Shape> arg_shapes;
|
||||||
arg_shapes.reserve(bb.getNumArguments());
|
arg_shapes.reserve(bb.getNumArguments());
|
||||||
for (auto& arg : bb.getArguments())
|
for (auto& arg : bb.getArguments())
|
||||||
arg_shapes.push_back(xla::TypeToShape(arg->getType()));
|
arg_shapes.push_back(xla::TypeToShape(arg.getType()));
|
||||||
xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
|
xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
|
||||||
auto tuple = xla::Parameter(builder, 0, input_shape, "arg_tuple");
|
auto tuple = xla::Parameter(builder, 0, input_shape, "arg_tuple");
|
||||||
for (auto& it : llvm::enumerate(bb.getArguments())) {
|
for (auto& it : llvm::enumerate(bb.getArguments())) {
|
||||||
@ -893,7 +893,7 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
|
|||||||
for (auto& it : llvm::enumerate(bb.getArguments())) {
|
for (auto& it : llvm::enumerate(bb.getArguments())) {
|
||||||
auto arg = it.value();
|
auto arg = it.value();
|
||||||
auto num = it.index();
|
auto num = it.index();
|
||||||
xla::Shape shape = xla::TypeToShape(arg->getType());
|
xla::Shape shape = xla::TypeToShape(arg.getType());
|
||||||
lowering[arg] =
|
lowering[arg] =
|
||||||
xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num));
|
xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num));
|
||||||
}
|
}
|
||||||
@ -1024,7 +1024,7 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
|
|||||||
|
|
||||||
llvm::SmallDenseSet<int32_t, 4> used_shape_indices;
|
llvm::SmallDenseSet<int32_t, 4> used_shape_indices;
|
||||||
auto arg_type =
|
auto arg_type =
|
||||||
entry_func.getArgument(i)->getType().dyn_cast<RankedTensorType>();
|
entry_func.getArgument(i).getType().dyn_cast<RankedTensorType>();
|
||||||
for (auto shape_and_padding : llvm::enumerate(llvm::zip(
|
for (auto shape_and_padding : llvm::enumerate(llvm::zip(
|
||||||
shape_indices.getValue(), padding_arg_indices.getValue()))) {
|
shape_indices.getValue(), padding_arg_indices.getValue()))) {
|
||||||
const int element_index = shape_and_padding.index();
|
const int element_index = shape_and_padding.index();
|
||||||
@ -1059,7 +1059,7 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
|
|||||||
kPaddingArgIndicesAttr, i, element_index, e, padding_arg_index));
|
kPaddingArgIndicesAttr, i, element_index, e, padding_arg_index));
|
||||||
|
|
||||||
Type padding_arg_type =
|
Type padding_arg_type =
|
||||||
entry_func.getArgument(padding_arg_index)->getType();
|
entry_func.getArgument(padding_arg_index).getType();
|
||||||
if (auto tensor_type = padding_arg_type.dyn_cast<RankedTensorType>())
|
if (auto tensor_type = padding_arg_type.dyn_cast<RankedTensorType>())
|
||||||
if (tensor_type.getRank() != 0)
|
if (tensor_type.getRank() != 0)
|
||||||
return entry_func.emitError()
|
return entry_func.emitError()
|
||||||
|
@ -29,7 +29,7 @@ def BuildSliceLimits : NativeCodeCall<
|
|||||||
|
|
||||||
def BuildSliceStrides : NativeCodeCall<
|
def BuildSliceStrides : NativeCodeCall<
|
||||||
"GetI64ElementsAttr(SmallVector<int64_t, 4>("
|
"GetI64ElementsAttr(SmallVector<int64_t, 4>("
|
||||||
"$0->getType().cast<RankedTensorType>().getRank(), 1), &$_builder)">;
|
"$0.getType().cast<RankedTensorType>().getRank(), 1), &$_builder)">;
|
||||||
|
|
||||||
def DynamicSliceToSlice: Pat<(HLO_DynamicSliceOp HLO_Tensor:$input,
|
def DynamicSliceToSlice: Pat<(HLO_DynamicSliceOp HLO_Tensor:$input,
|
||||||
(HLO_ConstOp I64ElementsAttr:$starting_indices),
|
(HLO_ConstOp I64ElementsAttr:$starting_indices),
|
||||||
|
@ -40,7 +40,7 @@ namespace {
|
|||||||
constexpr StringRef kTempBufferAttr = "temp";
|
constexpr StringRef kTempBufferAttr = "temp";
|
||||||
|
|
||||||
Value GetTensorStoreOrReturnMemRef(Value value) {
|
Value GetTensorStoreOrReturnMemRef(Value value) {
|
||||||
for (const auto& user : value->getUsers()) {
|
for (const auto& user : value.getUsers()) {
|
||||||
if (auto tensor_store = dyn_cast<TensorStoreOp>(user)) {
|
if (auto tensor_store = dyn_cast<TensorStoreOp>(user)) {
|
||||||
if (tensor_store.getOperand(0) == value) {
|
if (tensor_store.getOperand(0) == value) {
|
||||||
return tensor_store.getOperand(1);
|
return tensor_store.getOperand(1);
|
||||||
@ -57,8 +57,8 @@ Value GetTensorStoreOrReturnMemRef(Value value) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Operation* GetLastUse(Value value) {
|
Operation* GetLastUse(Value value) {
|
||||||
Operation* last = value->getDefiningOp();
|
Operation* last = value.getDefiningOp();
|
||||||
for (auto& user : value->getUses()) {
|
for (auto& user : value.getUses()) {
|
||||||
Operation* user_op = user.getOwner();
|
Operation* user_op = user.getOwner();
|
||||||
if (!user_op->isBeforeInBlock(last)) {
|
if (!user_op->isBeforeInBlock(last)) {
|
||||||
last = user_op;
|
last = user_op;
|
||||||
@ -69,7 +69,7 @@ Operation* GetLastUse(Value value) {
|
|||||||
|
|
||||||
Value InsertAllocAndDealloc(Location loc, Value result,
|
Value InsertAllocAndDealloc(Location loc, Value result,
|
||||||
ConversionPatternRewriter* rewriter) {
|
ConversionPatternRewriter* rewriter) {
|
||||||
auto result_type = result->getType().dyn_cast<ShapedType>();
|
auto result_type = result.getType().dyn_cast<ShapedType>();
|
||||||
if (!result_type || !result_type.hasStaticShape()) {
|
if (!result_type || !result_type.hasStaticShape()) {
|
||||||
emitError(loc,
|
emitError(loc,
|
||||||
"tensor to buffer conversion expects statically shaped results");
|
"tensor to buffer conversion expects statically shaped results");
|
||||||
@ -79,7 +79,7 @@ Value InsertAllocAndDealloc(Location loc, Value result,
|
|||||||
|
|
||||||
Operation* last = GetLastUse(result);
|
Operation* last = GetLastUse(result);
|
||||||
|
|
||||||
Operation* op = result->getDefiningOp();
|
Operation* op = result.getDefiningOp();
|
||||||
OpBuilder allocBuilder(op);
|
OpBuilder allocBuilder(op);
|
||||||
auto alloc = allocBuilder.create<AllocOp>(loc, memref_type);
|
auto alloc = allocBuilder.create<AllocOp>(loc, memref_type);
|
||||||
alloc.setAttr(kTempBufferAttr, rewriter->getBoolAttr(true));
|
alloc.setAttr(kTempBufferAttr, rewriter->getBoolAttr(true));
|
||||||
@ -161,7 +161,7 @@ struct HloToLHloReduceConverter
|
|||||||
int original_arg_count = entry_block.getNumArguments();
|
int original_arg_count = entry_block.getNumArguments();
|
||||||
for (int i = 0; i < original_arg_count; ++i) {
|
for (int i = 0; i < original_arg_count; ++i) {
|
||||||
auto old_arg = entry_block.getArgument(i);
|
auto old_arg = entry_block.getArgument(i);
|
||||||
auto old_type = old_arg->getType().cast<TensorType>();
|
auto old_type = old_arg.getType().cast<TensorType>();
|
||||||
auto new_type =
|
auto new_type =
|
||||||
MemRefType::get(old_type.getShape(), old_type.getElementType());
|
MemRefType::get(old_type.getShape(), old_type.getElementType());
|
||||||
auto new_arg = entry_block.addArgument(new_type);
|
auto new_arg = entry_block.addArgument(new_type);
|
||||||
@ -169,7 +169,7 @@ struct HloToLHloReduceConverter
|
|||||||
}
|
}
|
||||||
// Add an argument for the result.
|
// Add an argument for the result.
|
||||||
entry_block.addArgument(
|
entry_block.addArgument(
|
||||||
entry_block.getArgument(original_arg_count)->getType());
|
entry_block.getArgument(original_arg_count).getType());
|
||||||
// Remove the old arguments.
|
// Remove the old arguments.
|
||||||
for (int i = original_arg_count - 1; i >= 0; --i) {
|
for (int i = original_arg_count - 1; i >= 0; --i) {
|
||||||
entry_block.eraseArgument(i);
|
entry_block.eraseArgument(i);
|
||||||
|
@ -99,8 +99,8 @@ LogicalResult LowerConditionalOp(mlir::xla_hlo::ConditionalOp conditional_op) {
|
|||||||
mapper, &builder)))
|
mapper, &builder)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
tail_block->addArguments(conditional_op.getResult()->getType());
|
tail_block->addArguments(conditional_op.getResult().getType());
|
||||||
conditional_op.getResult()->replaceAllUsesWith(tail_block->getArgument(0));
|
conditional_op.getResult().replaceAllUsesWith(tail_block->getArgument(0));
|
||||||
|
|
||||||
op_inst->erase();
|
op_inst->erase();
|
||||||
return success();
|
return success();
|
||||||
@ -201,7 +201,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
|||||||
|
|
||||||
// Erase the original while loop.
|
// Erase the original while loop.
|
||||||
tail_block->addArgument(while_op.getType());
|
tail_block->addArgument(while_op.getType());
|
||||||
while_op.getResult()->replaceAllUsesWith(tail_block->getArgument(0));
|
while_op.getResult().replaceAllUsesWith(tail_block->getArgument(0));
|
||||||
op_inst->erase();
|
op_inst->erase();
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -242,7 +242,7 @@ static Value ApplyReduction(Location loc, Value input,
|
|||||||
static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format,
|
static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format,
|
||||||
Value input) {
|
Value input) {
|
||||||
return b.getI64IntegerAttr(
|
return b.getI64IntegerAttr(
|
||||||
getFeatureDimension(format, input->getType().cast<RankedTensorType>()));
|
getFeatureDimension(format, input.getType().cast<RankedTensorType>()));
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -254,7 +254,7 @@ static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format,
|
|||||||
static DenseIntElementsAttr getBiasFeatureDimension(Builder &b,
|
static DenseIntElementsAttr getBiasFeatureDimension(Builder &b,
|
||||||
StringAttr format,
|
StringAttr format,
|
||||||
Value input) {
|
Value input) {
|
||||||
auto inputType = input->getType().cast<RankedTensorType>();
|
auto inputType = input.getType().cast<RankedTensorType>();
|
||||||
size_t featureDim = getFeatureDimension(format, inputType);
|
size_t featureDim = getFeatureDimension(format, inputType);
|
||||||
RankedTensorType type = RankedTensorType::get(1, b.getIntegerType(64));
|
RankedTensorType type = RankedTensorType::get(1, b.getIntegerType(64));
|
||||||
return DenseIntElementsAttr::get(type, featureDim);
|
return DenseIntElementsAttr::get(type, featureDim);
|
||||||
@ -319,8 +319,8 @@ static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) {
|
|||||||
// must be broadcasted with a size 1 tensor or another dynamic dimension.
|
// must be broadcasted with a size 1 tensor or another dynamic dimension.
|
||||||
// Returns false on rankless.
|
// Returns false on rankless.
|
||||||
static bool AreBroadcastCompatible(Value x, Value y) {
|
static bool AreBroadcastCompatible(Value x, Value y) {
|
||||||
auto x_rankless = x->getType().dyn_cast<RankedTensorType>();
|
auto x_rankless = x.getType().dyn_cast<RankedTensorType>();
|
||||||
auto y_rankless = y->getType().dyn_cast<RankedTensorType>();
|
auto y_rankless = y.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!x_rankless || !y_rankless) {
|
if (!x_rankless || !y_rankless) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -418,7 +418,7 @@ static void BuildArgMinMaxReductionBody(Type input_element_type,
|
|||||||
|
|
||||||
static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices,
|
static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices,
|
||||||
DenseIntElementsAttr slice_sizes) {
|
DenseIntElementsAttr slice_sizes) {
|
||||||
auto input_ty = input->getType().dyn_cast<RankedTensorType>();
|
auto input_ty = input.getType().dyn_cast<RankedTensorType>();
|
||||||
int64_t input_rank = input_ty.getRank();
|
int64_t input_rank = input_ty.getRank();
|
||||||
ArrayRef<int64_t> input_shape = input_ty.getShape();
|
ArrayRef<int64_t> input_shape = input_ty.getShape();
|
||||||
DenseIntElementsAttr constant_start_indices;
|
DenseIntElementsAttr constant_start_indices;
|
||||||
@ -465,7 +465,7 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
|
|||||||
.cast<DenseIntElementsAttr>();
|
.cast<DenseIntElementsAttr>();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto input_ty = input->getType().dyn_cast<RankedTensorType>();
|
auto input_ty = input.getType().dyn_cast<RankedTensorType>();
|
||||||
int64_t input_rank = input_ty.getRank();
|
int64_t input_rank = input_ty.getRank();
|
||||||
ArrayRef<int64_t> input_shape = input_ty.getShape();
|
ArrayRef<int64_t> input_shape = input_ty.getShape();
|
||||||
SmallVector<int64_t, 4> normalized_sizes;
|
SmallVector<int64_t, 4> normalized_sizes;
|
||||||
@ -574,9 +574,9 @@ class ConvertConv : public OpRewritePattern<OpT> {
|
|||||||
std::string data_format = op.data_format().str();
|
std::string data_format = op.data_format().str();
|
||||||
if (!FormatFromString(data_format, &format)) return Pattern::matchFailure();
|
if (!FormatFromString(data_format, &format)) return Pattern::matchFailure();
|
||||||
|
|
||||||
auto input_ty = op.input()->getType().template dyn_cast<RankedTensorType>();
|
auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
|
||||||
auto filter_ty =
|
auto filter_ty =
|
||||||
op.filter()->getType().template dyn_cast<RankedTensorType>();
|
op.filter().getType().template dyn_cast<RankedTensorType>();
|
||||||
auto result_ty = op.getType().template dyn_cast<RankedTensorType>();
|
auto result_ty = op.getType().template dyn_cast<RankedTensorType>();
|
||||||
|
|
||||||
// Input, filter and the result needs to have static shape for calculation
|
// Input, filter and the result needs to have static shape for calculation
|
||||||
@ -698,10 +698,10 @@ class ConvertBF16FloorDivOp : public OpRewritePattern<TF::FloorDivOp> {
|
|||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto l = op.x();
|
auto l = op.x();
|
||||||
auto r = op.y();
|
auto r = op.y();
|
||||||
auto element_type = getElementTypeOrSelf(l->getType());
|
auto element_type = getElementTypeOrSelf(l.getType());
|
||||||
if (!element_type.isBF16()) return matchFailure();
|
if (!element_type.isBF16()) return matchFailure();
|
||||||
|
|
||||||
auto out_type = op.z()->getType().cast<TensorType>();
|
auto out_type = op.z().getType().cast<TensorType>();
|
||||||
|
|
||||||
l = rewriter.create<ConvertOp>(op.getLoc(), l, rewriter.getF32Type());
|
l = rewriter.create<ConvertOp>(op.getLoc(), l, rewriter.getF32Type());
|
||||||
r = rewriter.create<ConvertOp>(op.getLoc(), r, rewriter.getF32Type());
|
r = rewriter.create<ConvertOp>(op.getLoc(), r, rewriter.getF32Type());
|
||||||
@ -765,13 +765,13 @@ class ConvertFusedBatchNormGradBase
|
|||||||
// activation shape needs to be static to convert negative indices in
|
// activation shape needs to be static to convert negative indices in
|
||||||
// TensorFlow to absolute indices required by HLO.
|
// TensorFlow to absolute indices required by HLO.
|
||||||
RankedTensorType act_type =
|
RankedTensorType act_type =
|
||||||
act->getType().template dyn_cast<RankedTensorType>();
|
act.getType().template dyn_cast<RankedTensorType>();
|
||||||
if (!act_type) return Pattern::matchFailure();
|
if (!act_type) return Pattern::matchFailure();
|
||||||
Type act_ele_type = act_type.getElementType();
|
Type act_ele_type = act_type.getElementType();
|
||||||
// To support mixed precision, the statistics type, which maybe more
|
// To support mixed precision, the statistics type, which maybe more
|
||||||
// precise than the input types, are used for this op.
|
// precise than the input types, are used for this op.
|
||||||
Type kernel_type =
|
Type kernel_type =
|
||||||
scale->getType().template cast<TensorType>().getElementType();
|
scale.getType().template cast<TensorType>().getElementType();
|
||||||
grad = rewriter.create<ConvertOp>(loc, grad, kernel_type);
|
grad = rewriter.create<ConvertOp>(loc, grad, kernel_type);
|
||||||
act = rewriter.create<ConvertOp>(loc, act, kernel_type);
|
act = rewriter.create<ConvertOp>(loc, act, kernel_type);
|
||||||
|
|
||||||
@ -787,7 +787,7 @@ class ConvertFusedBatchNormGradBase
|
|||||||
Type feature_type = RankedTensorType::get(
|
Type feature_type = RankedTensorType::get(
|
||||||
{GetDimSize(act_type, feature_dim)}, kernel_type);
|
{GetDimSize(act_type, feature_dim)}, kernel_type);
|
||||||
Type result_type = TupleType::get(
|
Type result_type = TupleType::get(
|
||||||
{act->getType(), feature_type, feature_type}, rewriter.getContext());
|
{act.getType(), feature_type, feature_type}, rewriter.getContext());
|
||||||
|
|
||||||
auto training_op = rewriter.create<BatchNormGradOp>(
|
auto training_op = rewriter.create<BatchNormGradOp>(
|
||||||
loc, result_type, act, scale, mean, var, grad, op.epsilon(),
|
loc, result_type, act, scale, mean, var, grad, op.epsilon(),
|
||||||
@ -870,10 +870,10 @@ class ConvertFusedBatchNormV3Op
|
|||||||
auto feature_dim =
|
auto feature_dim =
|
||||||
getFeatureDimensionAttr(rewriter, op.data_formatAttr(), op.x());
|
getFeatureDimensionAttr(rewriter, op.data_formatAttr(), op.x());
|
||||||
|
|
||||||
auto input_type_tensor = op.x()->getType().dyn_cast<TensorType>();
|
auto input_type_tensor = op.x().getType().dyn_cast<TensorType>();
|
||||||
auto input_element_type = input_type_tensor.getElementType();
|
auto input_element_type = input_type_tensor.getElementType();
|
||||||
|
|
||||||
auto scale_type_tensor = op.scale()->getType().dyn_cast<TensorType>();
|
auto scale_type_tensor = op.scale().getType().dyn_cast<TensorType>();
|
||||||
auto scale_element_type = scale_type_tensor.getElementType();
|
auto scale_element_type = scale_type_tensor.getElementType();
|
||||||
|
|
||||||
// TODO(b/69928690): Support mixed precision in the XLA batch
|
// TODO(b/69928690): Support mixed precision in the XLA batch
|
||||||
@ -922,7 +922,7 @@ class ConvertFusedBatchNormV3Op
|
|||||||
op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor));
|
op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor));
|
||||||
|
|
||||||
auto corrected_variance = rewriter.create<xla_hlo::MulOp>(
|
auto corrected_variance = rewriter.create<xla_hlo::MulOp>(
|
||||||
op.getLoc(), batch_variance->getType(), batch_variance,
|
op.getLoc(), batch_variance.getType(), batch_variance,
|
||||||
factor_const_op, /*DenseIntElementsAttr=*/DenseIntElementsAttr());
|
factor_const_op, /*DenseIntElementsAttr=*/DenseIntElementsAttr());
|
||||||
|
|
||||||
// Convert back to input type to stay aligned with expected output type
|
// Convert back to input type to stay aligned with expected output type
|
||||||
@ -1016,12 +1016,12 @@ class ConvertMaxPoolOp : public OpRewritePattern<TF::MaxPoolOp> {
|
|||||||
PatternMatchResult matchAndRewrite(TF::MaxPoolOp op,
|
PatternMatchResult matchAndRewrite(TF::MaxPoolOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Type element_type =
|
Type element_type =
|
||||||
op.input()->getType().cast<TensorType>().getElementType();
|
op.input().getType().cast<TensorType>().getElementType();
|
||||||
if (!element_type.isIntOrFloat()) return matchFailure();
|
if (!element_type.isIntOrFloat()) return matchFailure();
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
ConstOp init = GetMinValueForType(element_type, loc, &rewriter);
|
ConstOp init = GetMinValueForType(element_type, loc, &rewriter);
|
||||||
|
|
||||||
auto input_ty = op.input()->getType().dyn_cast<RankedTensorType>();
|
auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!input_ty) return matchFailure();
|
if (!input_ty) return matchFailure();
|
||||||
DenseIntElementsAttr paddings_attr = GetReduceWindowPadding(
|
DenseIntElementsAttr paddings_attr = GetReduceWindowPadding(
|
||||||
input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
|
input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
|
||||||
@ -1067,9 +1067,9 @@ class ConvertSigmoidOp : public OpRewritePattern<TF::SigmoidOp> {
|
|||||||
|
|
||||||
auto scalar_one = rewriter.create<ConstOp>(
|
auto scalar_one = rewriter.create<ConstOp>(
|
||||||
op.getLoc(),
|
op.getLoc(),
|
||||||
rewriter.getFloatAttr(getElementTypeOrSelf(operand->getType()), 0.5));
|
rewriter.getFloatAttr(getElementTypeOrSelf(operand.getType()), 0.5));
|
||||||
|
|
||||||
auto shaped_type = operand->getType().cast<ShapedType>();
|
auto shaped_type = operand.getType().cast<ShapedType>();
|
||||||
auto constant_ones = rewriter.create<BroadcastOp>(
|
auto constant_ones = rewriter.create<BroadcastOp>(
|
||||||
op.getLoc(), shaped_type, scalar_one,
|
op.getLoc(), shaped_type, scalar_one,
|
||||||
DenseIntElementsAttr::get(
|
DenseIntElementsAttr::get(
|
||||||
@ -1080,7 +1080,7 @@ class ConvertSigmoidOp : public OpRewritePattern<TF::SigmoidOp> {
|
|||||||
auto scaled_input = rewriter.create<MulOp>(
|
auto scaled_input = rewriter.create<MulOp>(
|
||||||
op.getLoc(), operand, constant_ones, DenseIntElementsAttr());
|
op.getLoc(), operand, constant_ones, DenseIntElementsAttr());
|
||||||
auto tanh_op =
|
auto tanh_op =
|
||||||
rewriter.create<TanhOp>(op.getLoc(), operand->getType(), scaled_input);
|
rewriter.create<TanhOp>(op.getLoc(), operand.getType(), scaled_input);
|
||||||
auto mul_op =
|
auto mul_op =
|
||||||
rewriter.create<MulOp>(op.getLoc(), tanh_op, constant_ones,
|
rewriter.create<MulOp>(op.getLoc(), tanh_op, constant_ones,
|
||||||
/*DenseIntElementsAttr=*/DenseIntElementsAttr());
|
/*DenseIntElementsAttr=*/DenseIntElementsAttr());
|
||||||
@ -1129,7 +1129,7 @@ class ConvertSoftmaxOp : public OpRewritePattern<OpTy> {
|
|||||||
|
|
||||||
// Softmax converter requires ranked type because the XLA reduce ops used
|
// Softmax converter requires ranked type because the XLA reduce ops used
|
||||||
// while lowering requires dimensions attribute to reduce along.
|
// while lowering requires dimensions attribute to reduce along.
|
||||||
RankedTensorType type = logits->getType().dyn_cast<RankedTensorType>();
|
RankedTensorType type = logits.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!type) return Pattern::matchFailure();
|
if (!type) return Pattern::matchFailure();
|
||||||
|
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
@ -1202,11 +1202,11 @@ class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> {
|
|||||||
PatternMatchResult matchAndRewrite(TF::SizeOp op,
|
PatternMatchResult matchAndRewrite(TF::SizeOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value input = op.input();
|
Value input = op.input();
|
||||||
auto input_ty = input->getType().dyn_cast<RankedTensorType>();
|
auto input_ty = input.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!input_ty) return Pattern::matchFailure();
|
if (!input_ty) return Pattern::matchFailure();
|
||||||
|
|
||||||
const int64_t rank = input_ty.getRank();
|
const int64_t rank = input_ty.getRank();
|
||||||
auto result_type = op.getResult()->getType();
|
auto result_type = op.getResult().getType();
|
||||||
Operation *size =
|
Operation *size =
|
||||||
GetScalarConstOfType(result_type.cast<TensorType>().getElementType(),
|
GetScalarConstOfType(result_type.cast<TensorType>().getElementType(),
|
||||||
op.getLoc(), 1, &rewriter);
|
op.getLoc(), 1, &rewriter);
|
||||||
@ -1264,7 +1264,7 @@ class ConvertSplitOp : public OpRewritePattern<TF::SplitOp> {
|
|||||||
PatternMatchResult matchAndRewrite(TF::SplitOp op,
|
PatternMatchResult matchAndRewrite(TF::SplitOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
// We can only split along static dimensions.
|
// We can only split along static dimensions.
|
||||||
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
|
auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!input_type) return matchFailure();
|
if (!input_type) return matchFailure();
|
||||||
|
|
||||||
// We can only match when the split dimension is a constant scalar.
|
// We can only match when the split dimension is a constant scalar.
|
||||||
@ -1356,7 +1356,7 @@ class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
|
|||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
// We can only split along static dimensions.
|
// We can only split along static dimensions.
|
||||||
// TODO(b/145731001): enhance to support dynamic-shaped inputs.
|
// TODO(b/145731001): enhance to support dynamic-shaped inputs.
|
||||||
auto input_type = op.value()->getType().dyn_cast<RankedTensorType>();
|
auto input_type = op.value().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!input_type) return matchFailure();
|
if (!input_type) return matchFailure();
|
||||||
|
|
||||||
// We can only match when the split dimension is a constant scalar.
|
// We can only match when the split dimension is a constant scalar.
|
||||||
@ -1453,7 +1453,7 @@ class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
|
|||||||
//
|
//
|
||||||
// TODO(hinsu): Relax this constraint for ops without negative indices and
|
// TODO(hinsu): Relax this constraint for ops without negative indices and
|
||||||
// strides.
|
// strides.
|
||||||
auto input_ty = op.input()->getType().dyn_cast<RankedTensorType>();
|
auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!input_ty || !input_ty.hasStaticShape()) return matchFailure();
|
if (!input_ty || !input_ty.hasStaticShape()) return matchFailure();
|
||||||
ArrayRef<int64_t> input_shape = input_ty.getShape();
|
ArrayRef<int64_t> input_shape = input_ty.getShape();
|
||||||
|
|
||||||
@ -1553,7 +1553,7 @@ class ConvertStridedSliceGradOp
|
|||||||
return matchFailure();
|
return matchFailure();
|
||||||
|
|
||||||
Value grad = op.dy();
|
Value grad = op.dy();
|
||||||
Type element_type = grad->getType().cast<ShapedType>().getElementType();
|
Type element_type = grad.getType().cast<ShapedType>().getElementType();
|
||||||
|
|
||||||
// Perform reshape to undo any new/shrink axies done by strided slice.
|
// Perform reshape to undo any new/shrink axies done by strided slice.
|
||||||
grad = rewriter.create<xla_hlo::ReshapeOp>(
|
grad = rewriter.create<xla_hlo::ReshapeOp>(
|
||||||
@ -1593,7 +1593,7 @@ class ConvertStridedSliceGradOp
|
|||||||
|
|
||||||
if (!dims_to_reverse.empty()) {
|
if (!dims_to_reverse.empty()) {
|
||||||
grad = rewriter.create<xla_hlo::ReverseOp>(
|
grad = rewriter.create<xla_hlo::ReverseOp>(
|
||||||
op.getLoc(), grad->getType(), grad,
|
op.getLoc(), grad.getType(), grad,
|
||||||
GetI64ElementsAttr(dims_to_reverse, &rewriter));
|
GetI64ElementsAttr(dims_to_reverse, &rewriter));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1631,7 +1631,7 @@ class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> {
|
|||||||
PatternMatchResult matchAndRewrite(TF::RangeOp op,
|
PatternMatchResult matchAndRewrite(TF::RangeOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto result = op.getResult();
|
auto result = op.getResult();
|
||||||
auto result_type = result->getType();
|
auto result_type = result.getType();
|
||||||
if (!result_type.cast<ShapedType>().hasStaticShape()) {
|
if (!result_type.cast<ShapedType>().hasStaticShape()) {
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
}
|
}
|
||||||
@ -1663,7 +1663,7 @@ class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
|
|||||||
// TODO(b/141785544): Update this to not require static shapes.
|
// TODO(b/141785544): Update this to not require static shapes.
|
||||||
// Input shape needs to be static to convert negative indices in TensorFlow
|
// Input shape needs to be static to convert negative indices in TensorFlow
|
||||||
// to absolute indices required by HLO.
|
// to absolute indices required by HLO.
|
||||||
auto input_ty = op.input()->getType().template dyn_cast<RankedTensorType>();
|
auto input_ty = op.input().getType().template dyn_cast<RankedTensorType>();
|
||||||
if (!input_ty) return this->matchFailure();
|
if (!input_ty) return this->matchFailure();
|
||||||
ArrayRef<int64_t> input_shape = input_ty.getShape();
|
ArrayRef<int64_t> input_shape = input_ty.getShape();
|
||||||
|
|
||||||
@ -1826,7 +1826,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> {
|
|||||||
PatternMatchResult matchAndRewrite(OpTy op,
|
PatternMatchResult matchAndRewrite(OpTy op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
RankedTensorType input_type =
|
RankedTensorType input_type =
|
||||||
op.input()->getType().template dyn_cast<RankedTensorType>();
|
op.input().getType().template dyn_cast<RankedTensorType>();
|
||||||
if (!input_type) {
|
if (!input_type) {
|
||||||
return this->matchFailure();
|
return this->matchFailure();
|
||||||
}
|
}
|
||||||
@ -1841,7 +1841,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> {
|
|||||||
Derived::GetInitialValue(input_element_type, loc, rewriter);
|
Derived::GetInitialValue(input_element_type, loc, rewriter);
|
||||||
|
|
||||||
RankedTensorType output_type =
|
RankedTensorType output_type =
|
||||||
op.output()->getType().template dyn_cast<RankedTensorType>();
|
op.output().getType().template dyn_cast<RankedTensorType>();
|
||||||
if (!output_type) {
|
if (!output_type) {
|
||||||
return this->matchFailure();
|
return this->matchFailure();
|
||||||
}
|
}
|
||||||
@ -1918,9 +1918,9 @@ class ConvertTensorScatterUpdateOp
|
|||||||
|
|
||||||
PatternMatchResult matchAndRewrite(TF::TensorScatterUpdateOp op,
|
PatternMatchResult matchAndRewrite(TF::TensorScatterUpdateOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto tensor_ty = op.tensor()->getType().dyn_cast<RankedTensorType>();
|
auto tensor_ty = op.tensor().getType().dyn_cast<RankedTensorType>();
|
||||||
auto indices_ty = op.indices()->getType().dyn_cast<RankedTensorType>();
|
auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
|
||||||
auto updates_ty = op.updates()->getType().dyn_cast<RankedTensorType>();
|
auto updates_ty = op.updates().getType().dyn_cast<RankedTensorType>();
|
||||||
|
|
||||||
if (!tensor_ty || !indices_ty || !updates_ty) return matchFailure();
|
if (!tensor_ty || !indices_ty || !updates_ty) return matchFailure();
|
||||||
// Last dimension of the indices needs to known at compile time for
|
// Last dimension of the indices needs to known at compile time for
|
||||||
@ -1977,7 +1977,7 @@ class ConvertTileOp : public OpRewritePattern<TF::TileOp> {
|
|||||||
|
|
||||||
PatternMatchResult matchAndRewrite(TF::TileOp op,
|
PatternMatchResult matchAndRewrite(TF::TileOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto input_ty = op.input()->getType().dyn_cast<RankedTensorType>();
|
auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!input_ty || !input_ty.hasStaticShape()) return matchFailure();
|
if (!input_ty || !input_ty.hasStaticShape()) return matchFailure();
|
||||||
ArrayRef<int64_t> input_shape = input_ty.getShape();
|
ArrayRef<int64_t> input_shape = input_ty.getShape();
|
||||||
Type element_type = input_ty.getElementType();
|
Type element_type = input_ty.getElementType();
|
||||||
@ -2041,12 +2041,12 @@ class ConvertMaxPoolGradOp : public OpRewritePattern<TF::MaxPoolGradOp> {
|
|||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
Type element_type =
|
Type element_type =
|
||||||
op.orig_input()->getType().cast<TensorType>().getElementType();
|
op.orig_input().getType().cast<TensorType>().getElementType();
|
||||||
|
|
||||||
// Compute paddings using the original input and kernel shape and strides.
|
// Compute paddings using the original input and kernel shape and strides.
|
||||||
// Here, ReduceWindow op as used as the MaxPool op is lowered to the
|
// Here, ReduceWindow op as used as the MaxPool op is lowered to the
|
||||||
// ReduceWindow op.
|
// ReduceWindow op.
|
||||||
auto input_ty = op.orig_input()->getType().dyn_cast<RankedTensorType>();
|
auto input_ty = op.orig_input().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!input_ty) return matchFailure();
|
if (!input_ty) return matchFailure();
|
||||||
DenseIntElementsAttr paddings_attr = GetReduceWindowPadding(
|
DenseIntElementsAttr paddings_attr = GetReduceWindowPadding(
|
||||||
input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
|
input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter);
|
||||||
@ -2099,11 +2099,11 @@ class ConvertConv2DBackpropInputOp
|
|||||||
return Pattern::matchFailure();
|
return Pattern::matchFailure();
|
||||||
|
|
||||||
auto out_backprop_ty =
|
auto out_backprop_ty =
|
||||||
op.out_backprop()->getType().dyn_cast<RankedTensorType>();
|
op.out_backprop().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!out_backprop_ty || !out_backprop_ty.hasStaticShape())
|
if (!out_backprop_ty || !out_backprop_ty.hasStaticShape())
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape();
|
ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape();
|
||||||
auto filter_ty = op.filter()->getType().dyn_cast<RankedTensorType>();
|
auto filter_ty = op.filter().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!filter_ty || !filter_ty.hasStaticShape()) return matchFailure();
|
if (!filter_ty || !filter_ty.hasStaticShape()) return matchFailure();
|
||||||
ArrayRef<int64_t> filter_shape = filter_ty.getShape();
|
ArrayRef<int64_t> filter_shape = filter_ty.getShape();
|
||||||
int num_spatial_dims = 2;
|
int num_spatial_dims = 2;
|
||||||
@ -2243,11 +2243,11 @@ class ConvertConv2DBackpropFilterOp
|
|||||||
return Pattern::matchFailure();
|
return Pattern::matchFailure();
|
||||||
|
|
||||||
auto out_backprop_ty =
|
auto out_backprop_ty =
|
||||||
op.out_backprop()->getType().dyn_cast<RankedTensorType>();
|
op.out_backprop().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!out_backprop_ty || !out_backprop_ty.hasStaticShape())
|
if (!out_backprop_ty || !out_backprop_ty.hasStaticShape())
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape();
|
ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape();
|
||||||
auto input_ty = op.input()->getType().dyn_cast<RankedTensorType>();
|
auto input_ty = op.input().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!input_ty || !input_ty.hasStaticShape()) return matchFailure();
|
if (!input_ty || !input_ty.hasStaticShape()) return matchFailure();
|
||||||
ArrayRef<int64_t> input_shape = input_ty.getShape();
|
ArrayRef<int64_t> input_shape = input_ty.getShape();
|
||||||
|
|
||||||
@ -2432,7 +2432,7 @@ class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
|
|||||||
|
|
||||||
PatternMatchResult matchAndRewrite(TF::OneHotOp op,
|
PatternMatchResult matchAndRewrite(TF::OneHotOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto indices_ty = op.indices()->getType().dyn_cast<RankedTensorType>();
|
auto indices_ty = op.indices().getType().dyn_cast<RankedTensorType>();
|
||||||
if (!indices_ty || !indices_ty.hasStaticShape()) return matchFailure();
|
if (!indices_ty || !indices_ty.hasStaticShape()) return matchFailure();
|
||||||
ArrayRef<int64_t> indices_shape = indices_ty.getShape();
|
ArrayRef<int64_t> indices_shape = indices_ty.getShape();
|
||||||
Type element_type = indices_ty.getElementType();
|
Type element_type = indices_ty.getElementType();
|
||||||
@ -2522,7 +2522,7 @@ class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
|
|||||||
|
|
||||||
// The last dimension of the input tensor's shape should be known so we can
|
// The last dimension of the input tensor's shape should be known so we can
|
||||||
// have clamped end_indices for slices.
|
// have clamped end_indices for slices.
|
||||||
TensorType input_type = op.input()->getType().cast<TensorType>();
|
TensorType input_type = op.input().getType().cast<TensorType>();
|
||||||
if (!input_type.hasRank()) return matchFailure();
|
if (!input_type.hasRank()) return matchFailure();
|
||||||
int64_t input_rank = input_type.getRank();
|
int64_t input_rank = input_type.getRank();
|
||||||
int64_t last_dim_index = input_rank - 1;
|
int64_t last_dim_index = input_rank - 1;
|
||||||
@ -2587,7 +2587,7 @@ class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
|
|||||||
|
|
||||||
PatternMatchResult matchAndRewrite(TF::UnpackOp op,
|
PatternMatchResult matchAndRewrite(TF::UnpackOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto value_type = op.value()->getType().cast<RankedTensorType>();
|
auto value_type = op.value().getType().cast<RankedTensorType>();
|
||||||
if (!value_type) return matchFailure();
|
if (!value_type) return matchFailure();
|
||||||
|
|
||||||
int64_t value_rank = value_type.getRank();
|
int64_t value_rank = value_type.getRank();
|
||||||
@ -2645,12 +2645,12 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern<OpTy> {
|
|||||||
|
|
||||||
PatternMatchResult matchAndRewrite(OpTy op,
|
PatternMatchResult matchAndRewrite(OpTy op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto data_type = op.data()->getType().template dyn_cast<RankedTensorType>();
|
auto data_type = op.data().getType().template dyn_cast<RankedTensorType>();
|
||||||
if (!data_type) return this->matchFailure();
|
if (!data_type) return this->matchFailure();
|
||||||
int64_t data_rank = data_type.getRank();
|
int64_t data_rank = data_type.getRank();
|
||||||
|
|
||||||
auto segment_ids_type =
|
auto segment_ids_type =
|
||||||
op.segment_ids()->getType().template dyn_cast<RankedTensorType>();
|
op.segment_ids().getType().template dyn_cast<RankedTensorType>();
|
||||||
if (!segment_ids_type) return this->matchFailure();
|
if (!segment_ids_type) return this->matchFailure();
|
||||||
int64_t segment_ids_rank = segment_ids_type.getRank();
|
int64_t segment_ids_rank = segment_ids_type.getRank();
|
||||||
|
|
||||||
|
@ -68,8 +68,8 @@ void Detuple(Value tuple, Operation::result_range replace, OpBuilder* builder) {
|
|||||||
// De-tuple the results of the xla hlo conditional result.
|
// De-tuple the results of the xla hlo conditional result.
|
||||||
for (auto result_it : llvm::enumerate(replace)) {
|
for (auto result_it : llvm::enumerate(replace)) {
|
||||||
auto get_tuple_value = builder->create<xla_hlo::GetTupleElementOp>(
|
auto get_tuple_value = builder->create<xla_hlo::GetTupleElementOp>(
|
||||||
result_it.value()->getLoc(), tuple, result_it.index());
|
result_it.value().getLoc(), tuple, result_it.index());
|
||||||
result_it.value()->replaceAllUsesWith(get_tuple_value);
|
result_it.value().replaceAllUsesWith(get_tuple_value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,19 +35,19 @@ def FalseBoolAttr : AttrConstraint<CPred<"!$_self.getValue()">>;
|
|||||||
def TrueBoolAttr : AttrConstraint<CPred<"$_self.getValue()">>;
|
def TrueBoolAttr : AttrConstraint<CPred<"$_self.getValue()">>;
|
||||||
|
|
||||||
def CastValueToI64: NativeCodeCall<
|
def CastValueToI64: NativeCodeCall<
|
||||||
"CastValueToI64($0->getLoc(), $1, &$_builder)">;
|
"CastValueToI64($0.getLoc(), $1, &$_builder)">;
|
||||||
|
|
||||||
// Here, $0 is an ElementsAttr with exactly one element of type integer. $1 is
|
// Here, $0 is an ElementsAttr with exactly one element of type integer. $1 is
|
||||||
// the corresponding value of ranked tensor type whose axis is referred in $0.
|
// the corresponding value of ranked tensor type whose axis is referred in $0.
|
||||||
def GetHLOAxisFromTFAxis : NativeCodeCall<
|
def GetHLOAxisFromTFAxis : NativeCodeCall<
|
||||||
"GetHLOAxisFromTFAxis("
|
"GetHLOAxisFromTFAxis("
|
||||||
"$0, $1->getType().cast<RankedTensorType>().getRank(), &$_builder)">;
|
"$0, $1.getType().cast<RankedTensorType>().getRank(), &$_builder)">;
|
||||||
|
|
||||||
// Same as the above but with $1 of type operand_range from variadic TensorFlow
|
// Same as the above but with $1 of type operand_range from variadic TensorFlow
|
||||||
// input.
|
// input.
|
||||||
def GetHLOAxisFromTFAxisVariadic : NativeCodeCall<
|
def GetHLOAxisFromTFAxisVariadic : NativeCodeCall<
|
||||||
"GetHLOAxisFromTFAxis("
|
"GetHLOAxisFromTFAxis("
|
||||||
"$0, (*$1.begin())->getType().cast<RankedTensorType>().getRank(), "
|
"$0, (*$1.begin()).getType().cast<RankedTensorType>().getRank(), "
|
||||||
"&$_builder)">;
|
"&$_builder)">;
|
||||||
|
|
||||||
def : Pattern<
|
def : Pattern<
|
||||||
@ -251,10 +251,10 @@ def OneElementAttr
|
|||||||
"Scalar ElementsAttr">;
|
"Scalar ElementsAttr">;
|
||||||
|
|
||||||
def HasRankedFirstOperand
|
def HasRankedFirstOperand
|
||||||
: Constraint<CPred<"(*$0.begin())->getType().isa<RankedTensorType>()">>;
|
: Constraint<CPred<"(*$0.begin()).getType().isa<RankedTensorType>()">>;
|
||||||
|
|
||||||
def IsShapedTensor
|
def IsShapedTensor
|
||||||
: Constraint<CPred<"$0->getType().isa<RankedTensorType>()">>;
|
: Constraint<CPred<"$0.getType().isa<RankedTensorType>()">>;
|
||||||
|
|
||||||
// This pattern converts TensorFlow axis format to HLO axis format which
|
// This pattern converts TensorFlow axis format to HLO axis format which
|
||||||
// doesn't wrap around like TensorFlow and is always positive. For this
|
// doesn't wrap around like TensorFlow and is always positive. For this
|
||||||
@ -405,7 +405,7 @@ def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices,
|
|||||||
// Ternary op patterns.
|
// Ternary op patterns.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def BothTypesMatch : Constraint<CPred<"$0->getType() == $1->getType()">,
|
def BothTypesMatch : Constraint<CPred<"$0.getType() == $1.getType()">,
|
||||||
"types must be equal">;
|
"types must be equal">;
|
||||||
|
|
||||||
foreach src = [TF_SelectOp, TF_SelectV2Op] in
|
foreach src = [TF_SelectOp, TF_SelectV2Op] in
|
||||||
|
@ -47,8 +47,8 @@ struct CompareIConvert : public RewritePattern {
|
|||||||
|
|
||||||
auto lhs = compare_op.lhs();
|
auto lhs = compare_op.lhs();
|
||||||
auto rhs = compare_op.rhs();
|
auto rhs = compare_op.rhs();
|
||||||
auto lhs_type = lhs->getType().cast<TensorType>();
|
auto lhs_type = lhs.getType().cast<TensorType>();
|
||||||
auto rhs_type = rhs->getType().cast<TensorType>();
|
auto rhs_type = rhs.getType().cast<TensorType>();
|
||||||
|
|
||||||
// Broadcasting not supported by this rewrite.
|
// Broadcasting not supported by this rewrite.
|
||||||
if (lhs_type.getShape() != rhs_type.getShape()) return matchFailure();
|
if (lhs_type.getShape() != rhs_type.getShape()) return matchFailure();
|
||||||
@ -86,8 +86,8 @@ struct CompareFConvert : public RewritePattern {
|
|||||||
|
|
||||||
auto lhs = compare_op.lhs();
|
auto lhs = compare_op.lhs();
|
||||||
auto rhs = compare_op.rhs();
|
auto rhs = compare_op.rhs();
|
||||||
auto lhs_type = lhs->getType().cast<TensorType>();
|
auto lhs_type = lhs.getType().cast<TensorType>();
|
||||||
auto rhs_type = rhs->getType().cast<TensorType>();
|
auto rhs_type = rhs.getType().cast<TensorType>();
|
||||||
|
|
||||||
// Broadcasting not supported by this rewrite.
|
// Broadcasting not supported by this rewrite.
|
||||||
if (lhs_type.getShape() != rhs_type.getShape()) return matchFailure();
|
if (lhs_type.getShape() != rhs_type.getShape()) return matchFailure();
|
||||||
|
@ -31,8 +31,8 @@ def : Pat<(HLO_ConstOp ElementsAttr:$value),
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def IsSameSizePred : CPred<
|
def IsSameSizePred : CPred<
|
||||||
"$0->getType().cast<ShapedType>().getShape() "
|
"$0.getType().cast<ShapedType>().getShape() "
|
||||||
"== $1->getType().cast<ShapedType>().getShape()">;
|
"== $1.getType().cast<ShapedType>().getShape()">;
|
||||||
def IsSameSizeConstraint : Constraint<IsSameSizePred, "inputs are same size">;
|
def IsSameSizeConstraint : Constraint<IsSameSizePred, "inputs are same size">;
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,8 +39,8 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOp> {
|
|||||||
PatternRewriter& rewriter) const override {
|
PatternRewriter& rewriter) const override {
|
||||||
const auto& lhs = op.lhs();
|
const auto& lhs = op.lhs();
|
||||||
const auto& rhs = op.rhs();
|
const auto& rhs = op.rhs();
|
||||||
const auto& lhs_type = lhs->getType().template cast<MemRefType>();
|
const auto& lhs_type = lhs.getType().template cast<MemRefType>();
|
||||||
const auto& rhs_type = rhs->getType().template cast<MemRefType>();
|
const auto& rhs_type = rhs.getType().template cast<MemRefType>();
|
||||||
const auto& element_type = lhs_type.getElementType();
|
const auto& element_type = lhs_type.getElementType();
|
||||||
|
|
||||||
if (lhs_type.getShape() != rhs_type.getShape()) {
|
if (lhs_type.getShape() != rhs_type.getShape()) {
|
||||||
|
@ -55,7 +55,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
|
|||||||
// Only support 1d reductions for now.
|
// Only support 1d reductions for now.
|
||||||
int64_t size = 0;
|
int64_t size = 0;
|
||||||
for (auto result : reduce_op.out()) {
|
for (auto result : reduce_op.out()) {
|
||||||
auto shaped_type = result->getType().dyn_cast<ShapedType>();
|
auto shaped_type = result.getType().dyn_cast<ShapedType>();
|
||||||
if (!shaped_type || shaped_type.getRank() != 1) {
|
if (!shaped_type || shaped_type.getRank() != 1) {
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
}
|
}
|
||||||
@ -71,7 +71,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
|
|||||||
// Require all inputs to have the same shape.
|
// Require all inputs to have the same shape.
|
||||||
int64_t reduce_dim_size = 0;
|
int64_t reduce_dim_size = 0;
|
||||||
for (auto input : reduce_op.operands()) {
|
for (auto input : reduce_op.operands()) {
|
||||||
auto shaped_type = input->getType().dyn_cast<ShapedType>();
|
auto shaped_type = input.getType().dyn_cast<ShapedType>();
|
||||||
if (!shaped_type || !shaped_type.hasStaticShape()) {
|
if (!shaped_type || !shaped_type.hasStaticShape()) {
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
}
|
}
|
||||||
@ -128,7 +128,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
|
|||||||
auto output = mapping.lookup(*reduce_op.out().begin());
|
auto output = mapping.lookup(*reduce_op.out().begin());
|
||||||
// TODO(herhut) Move this to the SliceOp builder.
|
// TODO(herhut) Move this to the SliceOp builder.
|
||||||
auto resType = MemRefType::get(
|
auto resType = MemRefType::get(
|
||||||
llvm::None, output->getType().cast<MemRefType>().getElementType(),
|
llvm::None, output.getType().cast<MemRefType>().getElementType(),
|
||||||
makeStridedLinearLayoutMap(llvm::None,
|
makeStridedLinearLayoutMap(llvm::None,
|
||||||
MemRefType::getDynamicStrideOrOffset(),
|
MemRefType::getDynamicStrideOrOffset(),
|
||||||
rewriter.getContext()));
|
rewriter.getContext()));
|
||||||
@ -136,7 +136,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
|
|||||||
loc, resType, output, ArrayRef<Value>{launch_op.getThreadIds().x});
|
loc, resType, output, ArrayRef<Value>{launch_op.getThreadIds().x});
|
||||||
llvm::SmallVector<Value, 4> indexings;
|
llvm::SmallVector<Value, 4> indexings;
|
||||||
auto input_buffer = *reduce_op.operands().begin();
|
auto input_buffer = *reduce_op.operands().begin();
|
||||||
auto input_type = input_buffer->getType().cast<MemRefType>();
|
auto input_type = input_buffer.getType().cast<MemRefType>();
|
||||||
for (int64_t dim = 0; dim < input_type.getRank(); ++dim) {
|
for (int64_t dim = 0; dim < input_type.getRank(); ++dim) {
|
||||||
indexings.push_back(dim == reducing_dimension
|
indexings.push_back(dim == reducing_dimension
|
||||||
? loop.getInductionVar()
|
? loop.getInductionVar()
|
||||||
|
@ -57,7 +57,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<LhloOp> {
|
|||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
auto loc = lhlo_op.getLoc();
|
auto loc = lhlo_op.getLoc();
|
||||||
auto argType =
|
auto argType =
|
||||||
lhlo_op.getOperand(0)->getType().template dyn_cast<ShapedType>();
|
lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
|
||||||
if (!argType || !argType.hasStaticShape()) {
|
if (!argType || !argType.hasStaticShape()) {
|
||||||
emitError(loc,
|
emitError(loc,
|
||||||
"lhlo to linalg conversion expects statically shaped args");
|
"lhlo to linalg conversion expects statically shaped args");
|
||||||
@ -73,7 +73,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<LhloOp> {
|
|||||||
unsigned nloops = 0;
|
unsigned nloops = 0;
|
||||||
int operandCount = args.size() - 1;
|
int operandCount = args.size() - 1;
|
||||||
for (const auto& arg : llvm::enumerate(args)) {
|
for (const auto& arg : llvm::enumerate(args)) {
|
||||||
auto memrefType = arg.value()->getType().dyn_cast<MemRefType>();
|
auto memrefType = arg.value().getType().dyn_cast<MemRefType>();
|
||||||
if (!memrefType) return ConversionPattern::matchFailure();
|
if (!memrefType) return ConversionPattern::matchFailure();
|
||||||
unsigned rank = memrefType.getRank();
|
unsigned rank = memrefType.getRank();
|
||||||
if (!rank || (nloops && nloops != rank)) {
|
if (!rank || (nloops && nloops != rank)) {
|
||||||
@ -125,7 +125,7 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
|||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
auto loc = lhlo_op.getLoc();
|
auto loc = lhlo_op.getLoc();
|
||||||
auto argType =
|
auto argType =
|
||||||
lhlo_op.getOperand(0)->getType().template dyn_cast<ShapedType>();
|
lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
|
||||||
if (!argType || !argType.getElementType().isIntOrFloat() ||
|
if (!argType || !argType.getElementType().isIntOrFloat() ||
|
||||||
(argType.getRank() != 0)) {
|
(argType.getRank() != 0)) {
|
||||||
return ConversionPattern::matchFailure();
|
return ConversionPattern::matchFailure();
|
||||||
@ -151,9 +151,9 @@ class BroadcastInDimConverter : public OpConversionPattern<BroadcastInDimOp> {
|
|||||||
BroadcastInDimOp broadcastOp, ArrayRef<Value> args,
|
BroadcastInDimOp broadcastOp, ArrayRef<Value> args,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
auto operandMemrefType =
|
auto operandMemrefType =
|
||||||
broadcastOp.operand()->getType().dyn_cast<MemRefType>();
|
broadcastOp.operand().getType().dyn_cast<MemRefType>();
|
||||||
auto resultMemrefType =
|
auto resultMemrefType =
|
||||||
broadcastOp.output()->getType().dyn_cast<MemRefType>();
|
broadcastOp.output().getType().dyn_cast<MemRefType>();
|
||||||
if (!operandMemrefType || !resultMemrefType) return matchFailure();
|
if (!operandMemrefType || !resultMemrefType) return matchFailure();
|
||||||
auto broadcastDims = broadcastOp.broadcast_dimensions();
|
auto broadcastDims = broadcastOp.broadcast_dimensions();
|
||||||
if (!broadcastDims.hasValue()) return matchFailure();
|
if (!broadcastDims.hasValue()) return matchFailure();
|
||||||
@ -253,7 +253,7 @@ class IotaConverter : public OpConversionPattern<IotaOp> {
|
|||||||
IotaOp iotaOp, ArrayRef<Value> args,
|
IotaOp iotaOp, ArrayRef<Value> args,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
auto resultMemrefType =
|
auto resultMemrefType =
|
||||||
iotaOp.getOperand()->getType().dyn_cast<MemRefType>();
|
iotaOp.getOperand().getType().dyn_cast<MemRefType>();
|
||||||
if (!resultMemrefType) return matchFailure();
|
if (!resultMemrefType) return matchFailure();
|
||||||
|
|
||||||
auto resultElementType = resultMemrefType.getElementType();
|
auto resultElementType = resultMemrefType.getElementType();
|
||||||
|
@ -49,7 +49,7 @@ Value TransposeReshape(Value arg, mlir::Location loc,
|
|||||||
llvm::ArrayRef<int64_t> right_dims,
|
llvm::ArrayRef<int64_t> right_dims,
|
||||||
llvm::ArrayRef<int64_t> arg_shape,
|
llvm::ArrayRef<int64_t> arg_shape,
|
||||||
PatternRewriter *rewriter) {
|
PatternRewriter *rewriter) {
|
||||||
auto element_type = mlir::getElementTypeOrSelf(arg->getType());
|
auto element_type = mlir::getElementTypeOrSelf(arg.getType());
|
||||||
|
|
||||||
int64_t left_size = 1;
|
int64_t left_size = 1;
|
||||||
for (auto dim : left_dims) {
|
for (auto dim : left_dims) {
|
||||||
@ -94,7 +94,7 @@ Value TransposeReshape(Value arg, mlir::Location loc,
|
|||||||
Value ProcessDotArg(Value arg, mlir::Location loc,
|
Value ProcessDotArg(Value arg, mlir::Location loc,
|
||||||
ElementsAttr contract_dims_attr, bool outer_dims_first,
|
ElementsAttr contract_dims_attr, bool outer_dims_first,
|
||||||
PatternRewriter *rewriter) {
|
PatternRewriter *rewriter) {
|
||||||
auto shape = arg->getType().cast<mlir::ShapedType>().getShape();
|
auto shape = arg.getType().cast<mlir::ShapedType>().getShape();
|
||||||
|
|
||||||
llvm::SmallVector<bool, 5> is_outer_dim;
|
llvm::SmallVector<bool, 5> is_outer_dim;
|
||||||
is_outer_dim.resize(shape.size(), true);
|
is_outer_dim.resize(shape.size(), true);
|
||||||
@ -154,8 +154,8 @@ struct GeneralDotConvert
|
|||||||
/*outer_dims_first=*/false, &rewriter);
|
/*outer_dims_first=*/false, &rewriter);
|
||||||
|
|
||||||
// Dot resulting shape.
|
// Dot resulting shape.
|
||||||
auto lhs_shape = lhs->getType().cast<mlir::ShapedType>().getShape();
|
auto lhs_shape = lhs.getType().cast<mlir::ShapedType>().getShape();
|
||||||
auto rhs_shape = rhs->getType().cast<mlir::ShapedType>().getShape();
|
auto rhs_shape = rhs.getType().cast<mlir::ShapedType>().getShape();
|
||||||
auto new_dot_type =
|
auto new_dot_type =
|
||||||
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
|
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ using ScalarIOp = typename ScalarOp<LHLO_BinaryOp>::IOp;
|
|||||||
template <typename LhloOp>
|
template <typename LhloOp>
|
||||||
Operation* MapLhloOpToStdScalarOp(LhloOp lhlo_op, ArrayRef<Type> result_types,
|
Operation* MapLhloOpToStdScalarOp(LhloOp lhlo_op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> block_args, OpBuilder b) {
|
ArrayRef<Value> block_args, OpBuilder b) {
|
||||||
Type element_type = block_args.front()->getType();
|
Type element_type = block_args.front().getType();
|
||||||
if (element_type.isa<IntegerType>()) {
|
if (element_type.isa<IntegerType>()) {
|
||||||
return b.template create<ScalarIOp<LhloOp>>(lhlo_op.getLoc(), result_types,
|
return b.template create<ScalarIOp<LhloOp>>(lhlo_op.getLoc(), result_types,
|
||||||
block_args, mlir::None);
|
block_args, mlir::None);
|
||||||
@ -79,7 +79,7 @@ inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
|
|||||||
ArrayRef<Value> block_args, OpBuilder b) {
|
ArrayRef<Value> block_args, OpBuilder b) {
|
||||||
const auto& lhs = block_args[0];
|
const auto& lhs = block_args[0];
|
||||||
const auto& rhs = block_args[1];
|
const auto& rhs = block_args[1];
|
||||||
Type element_type = lhs->getType();
|
Type element_type = lhs.getType();
|
||||||
if (element_type.isa<IntegerType>()) {
|
if (element_type.isa<IntegerType>()) {
|
||||||
auto lhs_gt_rhs = b.create<ScalarIOp<CompareOp>>(
|
auto lhs_gt_rhs = b.create<ScalarIOp<CompareOp>>(
|
||||||
lhlo_op.getLoc(), CmpIPredicate::sgt, lhs, rhs);
|
lhlo_op.getLoc(), CmpIPredicate::sgt, lhs, rhs);
|
||||||
@ -99,7 +99,7 @@ inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
|
|||||||
ArrayRef<Value> block_args, OpBuilder b) {
|
ArrayRef<Value> block_args, OpBuilder b) {
|
||||||
const auto& lhs = block_args[0];
|
const auto& lhs = block_args[0];
|
||||||
const auto& rhs = block_args[1];
|
const auto& rhs = block_args[1];
|
||||||
Type element_type = lhs->getType();
|
Type element_type = lhs.getType();
|
||||||
if (element_type.isa<IntegerType>()) {
|
if (element_type.isa<IntegerType>()) {
|
||||||
auto lhs_lt_rhs = b.create<ScalarIOp<CompareOp>>(
|
auto lhs_lt_rhs = b.create<ScalarIOp<CompareOp>>(
|
||||||
lhlo_op.getLoc(), CmpIPredicate::slt, lhs, rhs);
|
lhlo_op.getLoc(), CmpIPredicate::slt, lhs, rhs);
|
||||||
@ -117,7 +117,7 @@ template <>
|
|||||||
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
|
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
|
||||||
xla_lhlo::AndOp lhlo_op, ArrayRef<Type> result_types,
|
xla_lhlo::AndOp lhlo_op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> block_args, OpBuilder b) {
|
ArrayRef<Value> block_args, OpBuilder b) {
|
||||||
Type element_type = block_args.front()->getType();
|
Type element_type = block_args.front().getType();
|
||||||
return element_type.isa<IntegerType>()
|
return element_type.isa<IntegerType>()
|
||||||
? b.create<::mlir::AndOp>(lhlo_op.getLoc(), result_types,
|
? b.create<::mlir::AndOp>(lhlo_op.getLoc(), result_types,
|
||||||
block_args, mlir::None)
|
block_args, mlir::None)
|
||||||
@ -153,7 +153,7 @@ inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::CompareOp>(
|
|||||||
ArrayRef<Value> block_args, OpBuilder b) {
|
ArrayRef<Value> block_args, OpBuilder b) {
|
||||||
const auto& lhs = block_args[0];
|
const auto& lhs = block_args[0];
|
||||||
const auto& rhs = block_args[1];
|
const auto& rhs = block_args[1];
|
||||||
Type element_type = lhs->getType();
|
Type element_type = lhs.getType();
|
||||||
if (element_type.isa<IntegerType>()) {
|
if (element_type.isa<IntegerType>()) {
|
||||||
Optional<CmpIPredicate> predicate =
|
Optional<CmpIPredicate> predicate =
|
||||||
getIntCmpPredicate(lhlo_op.comparison_direction());
|
getIntCmpPredicate(lhlo_op.comparison_direction());
|
||||||
@ -181,7 +181,7 @@ template <>
|
|||||||
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
|
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
|
||||||
xla_lhlo::ExpOp lhlo_op, ArrayRef<Type> result_types,
|
xla_lhlo::ExpOp lhlo_op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> block_args, OpBuilder b) {
|
ArrayRef<Value> block_args, OpBuilder b) {
|
||||||
Type element_type = block_args.front()->getType();
|
Type element_type = block_args.front().getType();
|
||||||
return element_type.isa<FloatType>()
|
return element_type.isa<FloatType>()
|
||||||
? b.create<::mlir::ExpOp>(lhlo_op.getLoc(), result_types,
|
? b.create<::mlir::ExpOp>(lhlo_op.getLoc(), result_types,
|
||||||
block_args, mlir::None)
|
block_args, mlir::None)
|
||||||
|
@ -257,7 +257,7 @@ mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size,
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (mlir::IROperand& use :
|
for (mlir::IROperand& use :
|
||||||
llvm::make_early_inc_range(loop.getInductionVar()->getUses())) {
|
llvm::make_early_inc_range(loop.getInductionVar().getUses())) {
|
||||||
mlir::Operation* owner = use.getOwner();
|
mlir::Operation* owner = use.getOwner();
|
||||||
BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
|
BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
|
||||||
unsigned new_dim = affine_map.operands.size();
|
unsigned new_dim = affine_map.operands.size();
|
||||||
@ -330,7 +330,7 @@ mlir::Operation* HoistAndFix(llvm::iplist<mlir::Operation>::iterator begin_op,
|
|||||||
indvars.push_back(ancestor.getInductionVar());
|
indvars.push_back(ancestor.getInductionVar());
|
||||||
}
|
}
|
||||||
for (mlir::IROperand& use :
|
for (mlir::IROperand& use :
|
||||||
llvm::make_early_inc_range(alloc.getResult()->getUses())) {
|
llvm::make_early_inc_range(alloc.getResult().getUses())) {
|
||||||
mlir::Operation* owner = use.getOwner();
|
mlir::Operation* owner = use.getOwner();
|
||||||
BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
|
BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
|
||||||
affine_map.operands.insert(affine_map.operands.begin(), indvars.begin(),
|
affine_map.operands.insert(affine_map.operands.begin(), indvars.begin(),
|
||||||
|
@ -108,7 +108,7 @@ struct SingleTripLoopRemoval
|
|||||||
: public mlir::FunctionPass<SingleTripLoopRemoval> {
|
: public mlir::FunctionPass<SingleTripLoopRemoval> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
auto getConstantValue = [](mlir::Value value) -> llvm::Optional<int64_t> {
|
auto getConstantValue = [](mlir::Value value) -> llvm::Optional<int64_t> {
|
||||||
auto definingOp = value->getDefiningOp();
|
auto definingOp = value.getDefiningOp();
|
||||||
if (!definingOp) return llvm::None;
|
if (!definingOp) return llvm::None;
|
||||||
auto constantOp = llvm::dyn_cast<mlir::ConstantOp>(definingOp);
|
auto constantOp = llvm::dyn_cast<mlir::ConstantOp>(definingOp);
|
||||||
if (!constantOp) return llvm::None;
|
if (!constantOp) return llvm::None;
|
||||||
@ -180,9 +180,9 @@ struct StoreForwardingPass : mlir::FunctionPass<StoreForwardingPass> {
|
|||||||
// Recursively checks defining ops until finds AllocOp. Return either AllocOp
|
// Recursively checks defining ops until finds AllocOp. Return either AllocOp
|
||||||
// if it is found or nullptr.
|
// if it is found or nullptr.
|
||||||
mlir::Operation* SearchAllocOp(mlir::Value memref) {
|
mlir::Operation* SearchAllocOp(mlir::Value memref) {
|
||||||
mlir::Operation* defOp = memref->getDefiningOp();
|
mlir::Operation* defOp = memref.getDefiningOp();
|
||||||
while (auto subviewOp = mlir::dyn_cast_or_null<mlir::SubViewOp>(defOp)) {
|
while (auto subviewOp = mlir::dyn_cast_or_null<mlir::SubViewOp>(defOp)) {
|
||||||
defOp = subviewOp.source()->getDefiningOp();
|
defOp = subviewOp.source().getDefiningOp();
|
||||||
}
|
}
|
||||||
if (auto allocOp = mlir::dyn_cast_or_null<mlir::AllocOp>(defOp)) {
|
if (auto allocOp = mlir::dyn_cast_or_null<mlir::AllocOp>(defOp)) {
|
||||||
return allocOp.getOperation();
|
return allocOp.getOperation();
|
||||||
@ -211,7 +211,7 @@ struct StoreForwardingPass : mlir::FunctionPass<StoreForwardingPass> {
|
|||||||
struct DeadTempBufferRemoval : mlir::FunctionPass<DeadTempBufferRemoval> {
|
struct DeadTempBufferRemoval : mlir::FunctionPass<DeadTempBufferRemoval> {
|
||||||
bool operationConsideredDead(mlir::Operation* op) {
|
bool operationConsideredDead(mlir::Operation* op) {
|
||||||
for (auto result : op->getResults()) {
|
for (auto result : op->getResults()) {
|
||||||
if (!llvm::all_of(result->getUsers(), [&](mlir::Operation* op) {
|
if (!llvm::all_of(result.getUsers(), [&](mlir::Operation* op) {
|
||||||
// Store and Dealloc is OK.
|
// Store and Dealloc is OK.
|
||||||
if (llvm::isa<mlir::StoreOp>(op) ||
|
if (llvm::isa<mlir::StoreOp>(op) ||
|
||||||
llvm::isa<mlir::DeallocOp>(op)) {
|
llvm::isa<mlir::DeallocOp>(op)) {
|
||||||
@ -235,7 +235,7 @@ struct DeadTempBufferRemoval : mlir::FunctionPass<DeadTempBufferRemoval> {
|
|||||||
|
|
||||||
void recursiveErase(mlir::Operation* op) {
|
void recursiveErase(mlir::Operation* op) {
|
||||||
for (auto result : op->getResults()) {
|
for (auto result : op->getResults()) {
|
||||||
for (auto user : llvm::make_early_inc_range(result->getUsers())) {
|
for (auto user : llvm::make_early_inc_range(result.getUsers())) {
|
||||||
recursiveErase(user);
|
recursiveErase(user);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -197,19 +197,19 @@ static absl::optional<int64> getLaunchBound(const mlir::gpu::KernelDim3& dim) {
|
|||||||
op->emitError() << "bound " << name << " is not constant";
|
op->emitError() << "bound " << name << " is not constant";
|
||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
};
|
};
|
||||||
auto y_op = dim.y->getDefiningOp();
|
auto y_op = dim.y.getDefiningOp();
|
||||||
auto dim_y = get_constant(y_op, "y");
|
auto dim_y = get_constant(y_op, "y");
|
||||||
if (!dim_y.has_value() || dim_y.value() != 1) {
|
if (!dim_y.has_value() || dim_y.value() != 1) {
|
||||||
y_op->emitError() << "bound 'y' is not constant 1";
|
y_op->emitError() << "bound 'y' is not constant 1";
|
||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
auto z_op = dim.z->getDefiningOp();
|
auto z_op = dim.z.getDefiningOp();
|
||||||
auto dim_z = get_constant(z_op, "z");
|
auto dim_z = get_constant(z_op, "z");
|
||||||
if (!dim_z.has_value() || dim_z.value() != 1) {
|
if (!dim_z.has_value() || dim_z.value() != 1) {
|
||||||
z_op->emitError() << "bound 'z' is not constant 1";
|
z_op->emitError() << "bound 'z' is not constant 1";
|
||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
return get_constant(dim.x->getDefiningOp(), "x");
|
return get_constant(dim.x.getDefiningOp(), "x");
|
||||||
}
|
}
|
||||||
|
|
||||||
using OperandToValueMap =
|
using OperandToValueMap =
|
||||||
@ -224,7 +224,7 @@ static StatusOr<std::vector<const HloInstruction*>> ComputeOperandToValueMap(
|
|||||||
for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands();
|
for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands();
|
||||||
++kernel_index) {
|
++kernel_index) {
|
||||||
auto launchop_operand =
|
auto launchop_operand =
|
||||||
launchOp.getKernelOperand(kernel_index)->dyn_cast<BlockArgument>();
|
launchOp.getKernelOperand(kernel_index).dyn_cast<BlockArgument>();
|
||||||
if (!launchop_operand) {
|
if (!launchop_operand) {
|
||||||
launchOp.emitError("argument to kernel is not a function input");
|
launchOp.emitError("argument to kernel is not a function input");
|
||||||
has_failed = true;
|
has_failed = true;
|
||||||
@ -233,7 +233,7 @@ static StatusOr<std::vector<const HloInstruction*>> ComputeOperandToValueMap(
|
|||||||
// host_index is the argument position to the surrounding function that
|
// host_index is the argument position to the surrounding function that
|
||||||
// contains the launch. This index corresponds to HLO operand indices
|
// contains the launch. This index corresponds to HLO operand indices
|
||||||
// by construction.
|
// by construction.
|
||||||
auto host_index = launchop_operand->getArgNumber();
|
auto host_index = launchop_operand.getArgNumber();
|
||||||
// The trailing argument to the outer function are the results.
|
// The trailing argument to the outer function are the results.
|
||||||
auto operand =
|
auto operand =
|
||||||
(host_index < operands.size()) ? operands[host_index] : instr;
|
(host_index < operands.size()) ? operands[host_index] : instr;
|
||||||
@ -304,7 +304,7 @@ Status InsertBufferLoadPreduleIntoKernel(
|
|||||||
// { baseptr, dataptr, offset, shape_vect, stride_vect }
|
// { baseptr, dataptr, offset, shape_vect, stride_vect }
|
||||||
// where shape_vect and stride_vect are integer vectors with length
|
// where shape_vect and stride_vect are integer vectors with length
|
||||||
// matching the rank of the tensor.
|
// matching the rank of the tensor.
|
||||||
auto target_type = value->getType().cast<LLVMType>();
|
auto target_type = value.getType().cast<LLVMType>();
|
||||||
auto struct_type = target_type.getPointerElementTy();
|
auto struct_type = target_type.getPointerElementTy();
|
||||||
auto descPtr =
|
auto descPtr =
|
||||||
builder.create<mlir::LLVM::AllocaOp>(loc, target_type, one, 0);
|
builder.create<mlir::LLVM::AllocaOp>(loc, target_type, one, 0);
|
||||||
@ -367,7 +367,7 @@ Status InsertBufferLoadPreduleIntoKernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Now we can use the descriptor instead of the original argument.
|
// Now we can use the descriptor instead of the original argument.
|
||||||
value->replaceAllUsesWith(descPtr);
|
value.replaceAllUsesWith(descPtr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user