Rename xla_hlo dialect to mhlo

This is part of the current refactoring of the HLO related dialect.
`xla_hlo` will be reintroduced in a new form later.

PiperOrigin-RevId: 319916753
Change-Id: I2c1b426b8a293927af5569bd35990a54b6b0743e
This commit is contained in:
Mehdi Amini 2020-07-06 21:51:24 -07:00 committed by TensorFlower Gardener
parent f6ab4daebc
commit bafd347479
112 changed files with 3243 additions and 3257 deletions

View File

@ -482,8 +482,8 @@ tf_cc_test(
)
cc_library(
name = "xla_hlo_fusion",
srcs = ["lib/Dialect/mhlo/transforms/xla_hlo_fusion.cc"],
name = "mhlo_fusion",
srcs = ["lib/Dialect/mhlo/transforms/mhlo_fusion.cc"],
deps = [
":cycle_detector",
":hlo",
@ -696,7 +696,7 @@ cc_library(
":lhlo_legalize_to_affine",
":lhlo_legalize_to_gpu",
":lhlo_legalize_to_parallel_loops",
":xla_hlo_fusion",
":mhlo_fusion",
":xla_legalize_control_flow",
":xla_legalize_tanh_to_approximation",
":xla_legalize_to_linalg",

View File

@ -17,12 +17,12 @@ limitations under the License.
// These ops are not necessarily orthogonal or optimized for transformation but
// for ease of expression in certain cases deemed important for client
// libraries (i.e. implicit broadcasting, helper ops, etc).
// This dialect is considered to exist in addition to augment the xla_hlo
// This dialect is considered to exist in addition to augment the mhlo
// dialect for ergonomic needs, not duplicate/replace it.
//
// The typical use of this dialect is for client libraries to be able to emit
// less constrained ops and rely on the conversion framework to lower any
// xla_chlo ops to canonical xla_hlo ops.
// xla_chlo ops to canonical mhlo ops.
//
// See: https://www.tensorflow.org/xla/operation_semantics
@ -44,7 +44,7 @@ def HLOClient_Dialect : Dialect {
let description = [{
This dialect contains ops that align closely with the API surface area
of the XlaBuilder C++ API, where such ops have semantics that go beyond
what exists in the lower level dialects (such as `xla_hlo`). Essentially,
what exists in the lower level dialects (such as `mhlo`). Essentially,
whenever the client library uses syntactic sugar or composition
of multiple ops for an API call, this dialect tries to model the API call
and provide conversion patterns to fully materialize into lower level
@ -65,7 +65,7 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
// shape broadcasting.
//
// These correspond to operations in the xla_hlo dialect without the
// These correspond to operations in the mhlo dialect without the
// "broadcast_" prefix, except that those ops require same-shaped operands and
// results.
//

View File

@ -37,12 +37,12 @@ class OpBuilder;
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc"
namespace xla_hlo {
namespace mhlo {
class XlaHloDialect : public Dialect {
public:
explicit XlaHloDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "xla_hlo"; }
static StringRef getDialectNamespace() { return "mhlo"; }
// Registered hook to materialize a constant operation from a given attribute
// value with the desired resultant type.
@ -82,7 +82,7 @@ class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
// %1 = index_cast %0 : index to i64
// %2 = dim %arg0, 1 : memref<?x?xf32>
// %3 = index_cast %2 : index to i64
// %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3)
// %4 = "mhlo.scalars_to_dimension_tensor"(%1, %3)
// : (i64, i64) -> tensor<2xi64>
//
// and returns %4 as the shape value.
@ -93,7 +93,7 @@ LogicalResult deriveShapeFromFirstOperand(
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
} // end namespace xla_hlo
} // end namespace mhlo
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_

View File

@ -29,8 +29,8 @@ include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
def HLO_Dialect : Dialect {
let name = "xla_hlo";
let cppNamespace = "xla_hlo";
let name = "mhlo";
let cppNamespace = "mhlo";
}
class HLO_Op<string mnemonic, list<OpTrait> traits> :

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
template <typename HloOpTy>
struct HloToLhloOpImpl {
@ -33,7 +33,7 @@ using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
#define MAP_HLO_TO_LHLO(OpName) \
template <> \
struct HloToLhloOpImpl<xla_hlo::OpName> { \
struct HloToLhloOpImpl<mhlo::OpName> { \
using Type = xla_lhlo::OpName; \
}
@ -74,7 +74,7 @@ MAP_HLO_TO_LHLO(TanhOp);
#undef MAP_HLO_TO_LHLO
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_

View File

@ -464,7 +464,7 @@ struct XlaOpToStdScalarOp {
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
std::is_same<typename xla_hlo::HloToLhloOp<LhloOpTy>,
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
@ -472,8 +472,8 @@ struct XlaOpToStdScalarOp {
args, b);
}
// Implementation for HLO ops except xla_hlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = xla_hlo::HloToLhloOp<XlaOpTy>,
// Implementation for HLO ops except mhlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = mhlo::HloToLhloOp<XlaOpTy>,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
!std::is_same<LhloOpTy, std::false_type>::value>>
@ -493,10 +493,11 @@ struct XlaOpToStdScalarOp {
op.getLoc(), comparison_direction, result_types, args, b);
}
// Implementation for xla_hlo::CompareOp.
template <typename HloOpTy, typename = std::enable_if_t<std::is_same<
HloOpTy, xla_hlo::CompareOp>::value>>
static Value map(xla_hlo::CompareOp op, ArrayRef<Type> result_types,
// Implementation for mhlo::CompareOp.
template <typename HloOpTy,
typename =
std::enable_if_t<std::is_same<HloOpTy, mhlo::CompareOp>::value>>
static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(

View File

@ -29,7 +29,7 @@ template <typename T>
class OperationPass;
class Pass;
namespace xla_hlo {
namespace mhlo {
/// Lowers HLO control flow ops to the Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
@ -55,10 +55,10 @@ std::unique_ptr<OperationPass<FuncOp>> createTransformUnrankedHloPass();
// necessary to export to XLA.
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
// fuse xla_hlo ops to kLoop/kInput fusion patterns
// fuse mhlo ops to kLoop/kInput fusion patterns
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
} // namespace xla_hlo
} // namespace mhlo
namespace xla_lhlo {

View File

@ -27,7 +27,7 @@ class LLVMTypeConverter;
class LowerToLLVMOptions;
class OwningRewritePatternList;
class BufferAssignmentPlacer;
namespace xla_hlo {
namespace mhlo {
// Collection of rewrite patterns for lowering a general dot product.
void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
@ -73,7 +73,7 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
} // namespace xla_hlo
} // namespace mhlo
namespace xla_lhlo {

View File

@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
// Static initialization for XLA dialect registration.
static mlir::DialectRegistration<mlir::xla_hlo::XlaHloDialect> xla_hlo_ops;
static mlir::DialectRegistration<mlir::mhlo::XlaHloDialect> mhlo_ops;
static mlir::DialectRegistration<mlir::xla_chlo::XlaHloClientDialect>
xla_chlo_ops;
static mlir::DialectRegistration<mlir::xla_lhlo::XlaLhloDialect> xla_lhlo_ops;

View File

@ -60,7 +60,7 @@ limitations under the License.
namespace mlir {
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc"
namespace xla_hlo {
namespace mhlo {
Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
Attribute value, Type type,
@ -68,8 +68,7 @@ Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
// HLO dialect constants only support ElementsAttr unlike standard dialect
// constant which supports all attributes.
if (value.isa<ElementsAttr>())
return builder.create<xla_hlo::ConstOp>(loc, type,
value.cast<ElementsAttr>());
return builder.create<mhlo::ConstOp>(loc, type, value.cast<ElementsAttr>());
return nullptr;
}
@ -167,7 +166,7 @@ void ConstOp::build(OpBuilder& builder, OperationState& result,
}
// TODO: support other XLA specific types.
assert(type && "unsupported attribute type for building xla_hlo.constant");
assert(type && "unsupported attribute type for building mhlo.constant");
result.types.push_back(type);
result.addAttribute("value", value);
}
@ -387,7 +386,7 @@ static LogicalResult Verify(GetTupleElementOp op) {
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
if (auto tupleOp =
dyn_cast_or_null<xla_hlo::TupleOp>(getOperand().getDefiningOp())) {
dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) {
return tupleOp.getOperand(index().getLimitedValue());
}
@ -693,10 +692,8 @@ void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs,
}
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
auto real_op =
dyn_cast_or_null<xla_hlo::RealOp>(getOperand(0).getDefiningOp());
auto imag_op =
dyn_cast_or_null<xla_hlo::ImagOp>(getOperand(1).getDefiningOp());
auto real_op = dyn_cast_or_null<mhlo::RealOp>(getOperand(0).getDefiningOp());
auto imag_op = dyn_cast_or_null<mhlo::ImagOp>(getOperand(1).getDefiningOp());
if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
return real_op.getOperand();
}
@ -727,7 +724,7 @@ void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) {
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
if (auto complex_op =
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand().getDefiningOp())) {
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
return complex_op.getOperand(1);
}
@ -740,7 +737,7 @@ void RealOp::build(OpBuilder& builder, OperationState& state, Value val) {
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
if (auto complex_op =
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand().getDefiningOp())) {
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
return complex_op.getOperand(0);
}
@ -1148,7 +1145,7 @@ static LogicalResult Verify(MapOp op) {
// RecvOp
//===----------------------------------------------------------------------===//
// Checks that the result type is of the form `tuple<any_type, xla_hlo::token>`
// Checks that the result type is of the form `tuple<any_type, mhlo::token>`
static LogicalResult Verify(RecvOp op) {
auto result_ty = op.getResult().getType().cast<TupleType>();
auto subtypes = result_ty.getTypes();
@ -2020,7 +2017,7 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
//===----------------------------------------------------------------------===//
// xla_hlo Dialect Interfaces
// mhlo Dialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
@ -2032,7 +2029,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface {
BlockAndValueMapping& valueMapping) const final {
return true;
}
// Operations in xla_hlo dialect are always legal to inline since they are
// Operations in mhlo dialect are always legal to inline since they are
// pure.
bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final {
return true;
@ -2041,7 +2038,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface {
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// xla_hlo Dialect Constructor
// mhlo Dialect Constructor
//===----------------------------------------------------------------------===//
XlaHloDialect::XlaHloDialect(MLIRContext* context)
@ -2061,8 +2058,7 @@ Type XlaHloDialect::parseType(DialectAsmParser& parser) const {
if (parser.parseKeyword(&data_type)) return Type();
if (data_type == "token") return TokenType::get(getContext());
parser.emitError(parser.getNameLoc())
<< "unknown xla_hlo type: " << data_type;
parser.emitError(parser.getNameLoc()) << "unknown mhlo type: " << data_type;
return nullptr;
}
@ -2071,7 +2067,7 @@ void XlaHloDialect::printType(Type type, DialectAsmPrinter& os) const {
os << "token";
return;
}
os << "<unknown xla_hlo type>";
os << "<unknown mhlo type>";
}
//===----------------------------------------------------------------------===//
@ -2106,5 +2102,5 @@ LogicalResult deriveShapeFromFirstOperand(
return success();
}
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir

View File

@ -30,7 +30,7 @@ namespace xla_chlo {
namespace {
// Converts binary ops that statically are determined to not broadcast directly
// to the corresponding xla_hlo non-broadcasting op.
// to the corresponding mhlo non-broadcasting op.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
@ -63,7 +63,7 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
};
// Converts a binary op with ranked broadcasting operands to explicitly
// broadcast and invoke the corresponding xla_hlo non-broadcasting op.
// broadcast and invoke the corresponding mhlo non-broadcasting op.
// Note that dynamic broadcasting supported by this pattern is only valid for
// "numpy" broadcasting semantics as defined here:
// https://docs.scipy.org/doc/numpy/reference/ufuncs.html
@ -136,7 +136,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
// properly.
auto lhs_broadcast_dimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(result_rank - lhs_type.getRank(), result_rank));
Value broadcasted_lhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>(
Value broadcasted_lhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(result_type.getShape(),
lhs_type.getElementType()),
@ -144,7 +144,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
rewriter.getI64TensorAttr(lhs_broadcast_dimensions));
auto rhs_broadcast_dimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(result_rank - rhs_type.getRank(), result_rank));
Value broadcasted_rhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>(
Value broadcasted_rhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(result_type.getShape(),
rhs_type.getElementType()),
@ -182,21 +182,19 @@ struct HloBinaryElementwiseAdaptor {
};
struct HloComplexAdaptor {
static xla_hlo::ComplexOp CreateOp(BroadcastComplexOp from_op,
Type result_type, Value broadcasted_lhs,
Value broadcasted_rhs,
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<xla_hlo::ComplexOp>(from_op.getLoc(), result_type,
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs);
}
};
struct HloCompareAdaptor {
static xla_hlo::CompareOp CreateOp(BroadcastCompareOp from_op,
Type result_type, Value broadcasted_lhs,
Value broadcasted_rhs,
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<xla_hlo::CompareOp>(from_op.getLoc(), result_type,
return builder.create<mhlo::CompareOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs,
from_op.comparison_direction());
}
@ -214,28 +212,27 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \
patterns);
POPULATE_BCAST(BroadcastAddOp, xla_hlo::AddOp);
POPULATE_BCAST(BroadcastAndOp, xla_hlo::AndOp);
POPULATE_BCAST(BroadcastAtan2Op, xla_hlo::Atan2Op);
POPULATE_BCAST(BroadcastDivOp, xla_hlo::DivOp);
POPULATE_BCAST(BroadcastMaxOp, xla_hlo::MaxOp);
POPULATE_BCAST(BroadcastMinOp, xla_hlo::MinOp);
POPULATE_BCAST(BroadcastMulOp, xla_hlo::MulOp);
POPULATE_BCAST(BroadcastOrOp, xla_hlo::OrOp);
POPULATE_BCAST(BroadcastPowOp, xla_hlo::PowOp);
POPULATE_BCAST(BroadcastRemOp, xla_hlo::RemOp);
POPULATE_BCAST(BroadcastShiftLeftOp, xla_hlo::ShiftLeftOp);
POPULATE_BCAST(BroadcastShiftRightArithmeticOp,
xla_hlo::ShiftRightArithmeticOp);
POPULATE_BCAST(BroadcastShiftRightLogicalOp, xla_hlo::ShiftRightLogicalOp);
POPULATE_BCAST(BroadcastSubOp, xla_hlo::SubOp);
POPULATE_BCAST(BroadcastXorOp, xla_hlo::XorOp);
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
// Broadcasting ops requiring special construction.
PopulateForBinaryOp<BroadcastComplexOp, xla_hlo::ComplexOp,
HloComplexAdaptor>(context, patterns);
PopulateForBinaryOp<BroadcastCompareOp, xla_hlo::CompareOp,
HloCompareAdaptor>(context, patterns);
PopulateForBinaryOp<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>(
context, patterns);
PopulateForBinaryOp<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>(
context, patterns);
}
} // namespace xla_chlo

View File

@ -32,8 +32,8 @@ struct TestChloLegalizeToHloPass
OwningRewritePatternList conversionPatterns;
conversionTarget.addIllegalDialect<XlaHloClientDialect>();
// Consider the xla_hlo dialect legal for tests.
conversionTarget.addLegalDialect<xla_hlo::XlaHloDialect>();
// Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<mhlo::XlaHloDialect>();
// The conversion uses helpers from the Standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();

View File

@ -37,7 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
template <typename T>
@ -128,7 +128,7 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
op->getLoc(), result.value(), results_shape.front(), &rewriter));
}
}
rewriter.create<xla_hlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
buffer_args, op->getAttrs());
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
return success();
@ -136,12 +136,12 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
};
struct HloToLhloDynamicBroadcastInDimOpConverter
: public BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp> {
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
public:
using BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp>::BaseOpConversion;
using BaseOpConversion<mhlo::DynamicBroadcastInDimOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
xla_hlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
Value resultBuffer = InsertDynamicAllocAndDealloc(
@ -162,7 +162,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
// and size of the target dimension if size-1 dimension expansion is
// necessary.
xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
xla_hlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>();
auto operand_shape = operand_type.getShape();
@ -220,12 +220,12 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
}
};
struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {
struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
public:
using BaseOpConversion<xla_hlo::ReduceOp>::BaseOpConversion;
using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
xla_hlo::ReduceOp op, ArrayRef<Value> operands,
mhlo::ReduceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
// TODO(b/137624192) Implement variadic reduce.
@ -314,10 +314,10 @@ class HloToLhloTensorStoreOpConverter
// "xla_lhlo.fusion"() ({
// %0 = tensor_load %arg1 : memref<2x2xf32>
// %1 = tensor_load %arg2 : memref<2x2xf32>
// %2 = "xla_hlo.add"(%0, %1) :
// %2 = "mhlo.add"(%0, %1) :
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// %3 = tensor_load %arg0 : memref<2x2xf32>
// %4 = "xla_hlo.multiply"(%2, %3) :
// %4 = "mhlo.multiply"(%2, %3) :
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// tensor_store %4, %arg3 : memref<2x2xf32>
// "xla_lhlo.terminator"() : () -> ()
@ -344,8 +344,8 @@ class HloToLhloTensorStoreOpConverter
// FuncOp signature conversion example:
//
// func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// %0 = "xla_hlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
// tensor<4xf32> %1 = "xla_hlo.add"(%arg0, %0) : (tensor<4xf32>,
// %0 = "mhlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
// tensor<4xf32> %1 = "mhlo.add"(%arg0, %0) : (tensor<4xf32>,
// tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32>
// }
//
@ -388,7 +388,7 @@ struct HloLegalizeToLhlo
target.addIllegalOp<mlir::TensorStoreOp>();
target.addLegalOp<ModuleTerminatorOp>();
target.addLegalOp<TensorFromElementsOp>();
target.addIllegalDialect<xla_hlo::XlaHloDialect>();
target.addIllegalDialect<mhlo::XlaHloDialect>();
BufferAssignmentTypeConverter converter;
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
@ -442,38 +442,38 @@ void populateHLOToLHLOConversionPattern(
// clang-format off
patterns->insert<
HloToLhloDynamicBroadcastInDimOpConverter,
HloToLhloOpConverter<xla_hlo::AbsOp>,
HloToLhloOpConverter<xla_hlo::AddOp>,
HloToLhloOpConverter<xla_hlo::AndOp>,
HloToLhloOpConverter<xla_hlo::BroadcastInDimOp>,
HloToLhloOpConverter<xla_hlo::CeilOp>,
HloToLhloOpConverter<xla_hlo::CompareOp>,
HloToLhloOpConverter<xla_hlo::ComplexOp>,
HloToLhloOpConverter<xla_hlo::ConstOp>,
HloToLhloOpConverter<xla_hlo::ConvOp>,
HloToLhloOpConverter<xla_hlo::ConvertOp>,
HloToLhloOpConverter<xla_hlo::CopyOp>,
HloToLhloOpConverter<xla_hlo::CosOp>,
HloToLhloOpConverter<xla_hlo::DivOp>,
HloToLhloOpConverter<xla_hlo::DotOp>,
HloToLhloOpConverter<xla_hlo::ExpOp>,
HloToLhloOpConverter<xla_hlo::GatherOp>,
HloToLhloOpConverter<xla_hlo::ImagOp>,
HloToLhloOpConverter<xla_hlo::IotaOp>,
HloToLhloOpConverter<xla_hlo::LogOp>,
HloToLhloOpConverter<xla_hlo::MaxOp>,
HloToLhloOpConverter<xla_hlo::MinOp>,
HloToLhloOpConverter<xla_hlo::MulOp>,
HloToLhloOpConverter<xla_hlo::NegOp>,
HloToLhloOpConverter<xla_hlo::RealOp>,
HloToLhloOpConverter<xla_hlo::RemOp>,
HloToLhloOpConverter<xla_hlo::RsqrtOp>,
HloToLhloOpConverter<xla_hlo::ReshapeOp>,
HloToLhloOpConverter<xla_hlo::SelectOp>,
HloToLhloOpConverter<xla_hlo::SignOp>,
HloToLhloOpConverter<xla_hlo::SqrtOp>,
HloToLhloOpConverter<xla_hlo::SubOp>,
HloToLhloOpConverter<xla_hlo::TanhOp>,
HloToLhloOpConverter<mhlo::AbsOp>,
HloToLhloOpConverter<mhlo::AddOp>,
HloToLhloOpConverter<mhlo::AndOp>,
HloToLhloOpConverter<mhlo::BroadcastInDimOp>,
HloToLhloOpConverter<mhlo::CeilOp>,
HloToLhloOpConverter<mhlo::CompareOp>,
HloToLhloOpConverter<mhlo::ComplexOp>,
HloToLhloOpConverter<mhlo::ConstOp>,
HloToLhloOpConverter<mhlo::ConvOp>,
HloToLhloOpConverter<mhlo::ConvertOp>,
HloToLhloOpConverter<mhlo::CopyOp>,
HloToLhloOpConverter<mhlo::CosOp>,
HloToLhloOpConverter<mhlo::DivOp>,
HloToLhloOpConverter<mhlo::DotOp>,
HloToLhloOpConverter<mhlo::ExpOp>,
HloToLhloOpConverter<mhlo::GatherOp>,
HloToLhloOpConverter<mhlo::ImagOp>,
HloToLhloOpConverter<mhlo::IotaOp>,
HloToLhloOpConverter<mhlo::LogOp>,
HloToLhloOpConverter<mhlo::MaxOp>,
HloToLhloOpConverter<mhlo::MinOp>,
HloToLhloOpConverter<mhlo::MulOp>,
HloToLhloOpConverter<mhlo::NegOp>,
HloToLhloOpConverter<mhlo::RealOp>,
HloToLhloOpConverter<mhlo::RemOp>,
HloToLhloOpConverter<mhlo::RsqrtOp>,
HloToLhloOpConverter<mhlo::ReshapeOp>,
HloToLhloOpConverter<mhlo::SelectOp>,
HloToLhloOpConverter<mhlo::SignOp>,
HloToLhloOpConverter<mhlo::SqrtOp>,
HloToLhloOpConverter<mhlo::SubOp>,
HloToLhloOpConverter<mhlo::TanhOp>,
HloToLhloReduceOpConverter,
HloToLhloTensorLoadOpConverter,
HloToLhloTensorStoreOpConverter
@ -489,5 +489,5 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
static PassRegistration<HloLegalizeToLhlo> legalize_pass(
"hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir

View File

@ -35,7 +35,7 @@ limitations under the License.
using mlir::PassRegistration;
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
struct LegalizeControlFlow
: public mlir::PassWrapper<LegalizeControlFlow, FunctionPass> {
@ -51,7 +51,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block,
OpBuilder* builder) {
for (auto& old_block : region->getBlocks()) {
Block* block = mapper.lookup(&old_block);
auto return_op = dyn_cast<xla_hlo::ReturnOp>(block->getTerminator());
auto return_op = dyn_cast<mhlo::ReturnOp>(block->getTerminator());
if (!return_op) continue;
builder->setInsertionPointToEnd(block);
builder->create<mlir::BranchOp>(loc, target_block, return_op.getOperands());
@ -61,7 +61,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block,
return success();
}
LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) {
LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) {
Operation* op_inst = if_op.getOperation();
mlir::OpBuilder builder(if_op);
auto orig_block = op_inst->getBlock();
@ -106,13 +106,13 @@ LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) {
return success();
}
LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
// Converts an XLA while loop into control flow. This generates a set of MLIR
// blocks and branches, along with inlining the regions provided by the XLA
// while loop. The structure should be similar to below:
//
// <prior operations>
// %0 = "xla_hlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
// %0 = "mhlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
// <post operations>
auto* op_inst = while_op.getOperation();
mlir::OpBuilder builder(while_op);
@ -147,7 +147,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
// extract_element and conditional branch. This changes the block below:
// ^cond(%0):
// <inlined conditional region>
// "xla_hlo".return(%1)
// "mhlo".return(%1)
//
// Into:
// ^cond(%0):
@ -156,14 +156,14 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
// cond_br %2, ^body(%0), ^tail(%0) // Branch.
builder.setInsertionPointToStart(cond_block);
// Replace the xla_hlo::ReturnOp with a branch back to the condition block.
// This is required as the xla_hlo::ReturnOp is used to mark the end of a
// Replace the mhlo::ReturnOp with a branch back to the condition block.
// This is required as the mhlo::ReturnOp is used to mark the end of a
// block for regions nested inside of a operations (MLIR ReturnOp cannot be
// nested within an non-function region).
for (auto& block : while_op.cond()) {
auto new_block = mapper.lookup(&block);
auto return_op = dyn_cast<xla_hlo::ReturnOp>(new_block->getTerminator());
auto return_op = dyn_cast<mhlo::ReturnOp>(new_block->getTerminator());
if (!return_op) continue;
builder.setInsertionPointToEnd(new_block);
@ -183,7 +183,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
// conditional block. This changes the block below:
// ^body(%0):
// <inlined body block>
// "xla_hlo".return(%1)
// "mhlo".return(%1)
//
// Into:
// ^body(%0):
@ -191,8 +191,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
// br ^cond(%0) // Branch.
for (auto& block : while_op.body()) {
auto new_block = mapper.lookup(&block);
auto return_op =
dyn_cast<mlir::xla_hlo::ReturnOp>(new_block->getTerminator());
auto return_op = dyn_cast<mlir::mhlo::ReturnOp>(new_block->getTerminator());
if (!return_op) continue;
builder.setInsertionPointToEnd(new_block);
builder.create<mlir::BranchOp>(loc, cond_block, return_op.getOperands());
@ -224,14 +223,14 @@ void LegalizeControlFlow::runOnFunction() {
}
}
} // namespace
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
mlir::xla_hlo::createLegalizeControlFlowPass() {
mlir::mhlo::createLegalizeControlFlowPass() {
return std::make_unique<LegalizeControlFlow>();
}
static PassRegistration<mlir::xla_hlo::LegalizeControlFlow> legalize_cf_pass(
static PassRegistration<mlir::mhlo::LegalizeControlFlow> legalize_cf_pass(
"xla-legalize-control-flow",
"Legalize from XLA control flow to MLIR control flow");

View File

@ -28,14 +28,14 @@ namespace mlir {
namespace {
#include "tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc"
} // end anonymous namespace
namespace xla_hlo {
namespace mhlo {
namespace {
class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> {
class CompareIConvert : public OpRewritePattern<mhlo::CompareOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(xla_hlo::CompareOp op,
LogicalResult matchAndRewrite(mhlo::CompareOp op,
PatternRewriter &rewriter) const override {
auto lhs = op.lhs();
auto rhs = op.rhs();
@ -68,11 +68,11 @@ class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> {
}
};
class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> {
class CompareFConvert : public OpRewritePattern<mhlo::CompareOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(xla_hlo::CompareOp op,
LogicalResult matchAndRewrite(mhlo::CompareOp op,
PatternRewriter &rewriter) const override {
auto lhs = op.lhs();
auto rhs = op.rhs();
@ -109,11 +109,11 @@ class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> {
// convert the integer constant to iota result type. For complex types, the real
// part is replaced with the generated constant and the imaginary part is
// replaced with zero tensor.
class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(xla_hlo::IotaOp op,
LogicalResult matchAndRewrite(mhlo::IotaOp op,
PatternRewriter &rewriter) const override {
auto output_type = op.getType().cast<ShapedType>();
auto output_size = output_type.getNumElements();
@ -168,8 +168,7 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0)));
auto imag_zeroes =
rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, zeroes);
rewriter.replaceOpWithNewOp<xla_hlo::ComplexOp>(op, iota_const,
imag_zeroes);
rewriter.replaceOpWithNewOp<mhlo::ComplexOp>(op, iota_const, imag_zeroes);
return success();
}
};
@ -197,12 +196,12 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
/// Perform the lowering to standard dialect.
void LegalizeToStandard::runOnFunction() {
OwningRewritePatternList patterns;
mlir::xla_hlo::PopulateXlaToStdPatterns(&patterns, &getContext());
mlir::mhlo::PopulateXlaToStdPatterns(&patterns, &getContext());
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
static PassRegistration<LegalizeToStandard> legalize_pass(
"xla-legalize-to-std", "Legalize from XLA dialect to standard dialect");
} // end namespace xla_hlo
} // end namespace mhlo
} // end namespace mlir

View File

@ -84,13 +84,13 @@ Value TransposeReshape(Value arg, mlir::Location loc,
transposed_shape.push_back(arg_shape[val]);
}
auto transpose_type = RankedTensorType::get(transposed_shape, element_type);
auto transpose_result = rewriter->create<mlir::xla_hlo::TransposeOp>(
auto transpose_result = rewriter->create<mlir::mhlo::TransposeOp>(
loc, transpose_type, arg, transpose_permutation_attr);
// Return the final result.
auto reshaped_type =
RankedTensorType::get({left_size, right_size}, element_type);
return rewriter->create<mlir::xla_hlo::ReshapeOp>(loc, reshaped_type,
return rewriter->create<mlir::mhlo::ReshapeOp>(loc, reshaped_type,
transpose_result);
}
@ -125,8 +125,7 @@ Value ProcessDotArg(Value arg, mlir::Location loc,
return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter);
}
struct GeneralDotConvert
: public OpRewritePattern<mlir::xla_hlo::DotGeneralOp> {
struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
// Attempts to lower a General Dot operator to a standard Dot operator.
// General dots include batching dimensions and can have collapsing
// dimensions along any axis. Inserting correctly arrange transpose and
@ -138,7 +137,7 @@ struct GeneralDotConvert
explicit GeneralDotConvert(MLIRContext *context)
: OpRewritePattern(context) {}
LogicalResult matchAndRewrite(mlir::xla_hlo::DotGeneralOp op,
LogicalResult matchAndRewrite(mlir::mhlo::DotGeneralOp op,
PatternRewriter &rewriter) const override {
auto dot_element_type = mlir::getElementTypeOrSelf(op);
@ -162,10 +161,10 @@ struct GeneralDotConvert
auto new_dot_type =
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
auto new_dot_op = rewriter.create<mlir::xla_hlo::DotOp>(
auto new_dot_op = rewriter.create<mlir::mhlo::DotOp>(
op.getLoc(), new_dot_type, lhs, rhs, *(op.precision_config()));
rewriter.replaceOpWithNewOp<mlir::xla_hlo::ReshapeOp>(op, op.getType(),
rewriter.replaceOpWithNewOp<mlir::mhlo::ReshapeOp>(op, op.getType(),
new_dot_op);
return success();
}
@ -176,15 +175,14 @@ struct LegalizeGeneralDot
/// Lower all general dots that can be represented as a non-batched matmul.
void runOnFunction() override {
OwningRewritePatternList patterns;
mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns,
&getContext());
mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext());
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
} // namespace
void mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(
void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(
OwningRewritePatternList *patterns, MLIRContext *ctx) {
patterns->insert<GeneralDotConvert>(ctx);
}

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
@ -86,5 +86,5 @@ void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
patterns->insert<ClampWithBroadcastConvert>(context);
}
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
@ -33,7 +33,7 @@ struct TestMaterializeBroadcastsPass
ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns;
// Consider the xla_hlo dialect legal for tests.
// Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<XlaHloDialect>();
// The conversion uses helpers from the Standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
@ -50,9 +50,9 @@ struct TestMaterializeBroadcastsPass
} // namespace
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir
static mlir::PassRegistration<mlir::xla_hlo::TestMaterializeBroadcastsPass>
pass("test-xla-materialize-broadcasts",
static mlir::PassRegistration<mlir::mhlo::TestMaterializeBroadcastsPass> pass(
"test-xla-materialize-broadcasts",
"Test pass for materializing 'broadcast_dimensions' attributes");

View File

@ -60,7 +60,7 @@ limitations under the License.
// shape dialect once it is ready.
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
using llvm::EquivalenceClasses;
@ -544,7 +544,7 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
}
FusionOp fusion =
b.create<xla_hlo::FusionOp>(fused_loc, output_types, inputs);
b.create<mhlo::FusionOp>(fused_loc, output_types, inputs);
Region& region = fusion.fused_computation();
region.push_back(new Block);
Block& block = region.front();
@ -552,7 +552,7 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
op->moveBefore(&block, block.end());
}
b.setInsertionPoint(&block, block.end());
b.create<xla_hlo::ReturnOp>(fused_loc, outputs);
b.create<mhlo::ReturnOp>(fused_loc, outputs);
for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) {
Value output = std::get<0>(output_and_result);
@ -572,8 +572,8 @@ std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusion() {
return std::make_unique<XlaHloFusion>();
}
static PassRegistration<XlaHloFusion> xla_hlo_fusion_pass(
"xla-hlo-fusion", "fuse xla_hlo ops to kLoop/kInput fusion patterns.");
static PassRegistration<XlaHloFusion> mhlo_fusion_pass(
"xla-hlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns.");
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
@ -81,5 +81,5 @@ std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
return std::make_unique<SinkConstantsToControlFlow>();
}
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
@ -40,11 +40,11 @@ Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type,
auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64));
auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim});
if (shape_value) {
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
loc, result_type, value_1d, shape_value, dims);
}
assert(result_type.hasStaticShape());
return rewriter.create<xla_hlo::BroadcastInDimOp>(loc, result_type, value_1d,
return rewriter.create<mhlo::BroadcastInDimOp>(loc, result_type, value_1d,
dims);
}
@ -89,25 +89,25 @@ Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
auto epsilon_tensor_attr =
DenseElementsAttr::get(scalar_type, {epsilon_attr.cast<Attribute>()});
Value epsilon =
rewriter.create<xla_hlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
rewriter.create<mhlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64));
auto dims = DenseIntElementsAttr::get(dims_type, SmallVector<int64_t, 1>{});
if (broadcast_to_type.hasStaticShape()) {
return rewriter.create<xla_hlo::BroadcastInDimOp>(
return rewriter.create<mhlo::BroadcastInDimOp>(
op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims);
}
Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter);
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
op->getLoc(), broadcast_to_type, epsilon, shape_value,
/*broadcast_dims=*/dims);
}
class UnfuseBatchNormInferencePattern
: public OpRewritePattern<xla_hlo::BatchNormInferenceOp> {
: public OpRewritePattern<mhlo::BatchNormInferenceOp> {
public:
using OpRewritePattern<xla_hlo::BatchNormInferenceOp>::OpRewritePattern;
using OpRewritePattern<mhlo::BatchNormInferenceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(xla_hlo::BatchNormInferenceOp bn_op,
LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bn_op,
PatternRewriter& rewriter) const override {
// Enforce type invariants.
// Note that we deduce the actual element type from the variance,
@ -132,9 +132,9 @@ class UnfuseBatchNormInferencePattern
if (!epsilon) {
return failure();
}
Value stddev = rewriter.create<xla_hlo::AddOp>(bn_op.getLoc(),
bn_op.variance(), epsilon);
stddev = rewriter.create<xla_hlo::SqrtOp>(bn_op.getLoc(), stddev);
Value stddev =
rewriter.create<mhlo::AddOp>(bn_op.getLoc(), bn_op.variance(), epsilon);
stddev = rewriter.create<mhlo::SqrtOp>(bn_op.getLoc(), stddev);
// Broadcast all terms.
Value shape_value;
@ -156,14 +156,13 @@ class UnfuseBatchNormInferencePattern
// Compute:
// scale * (input - mean) / stddev + offset
Value result = rewriter.create<xla_hlo::SubOp>(
bn_op.getLoc(), bn_op.operand(), broadcast_mean);
result = rewriter.create<xla_hlo::MulOp>(bn_op.getLoc(), result,
broadcast_scale);
result = rewriter.create<xla_hlo::DivOp>(bn_op.getLoc(), result,
broadcast_stddev);
rewriter.replaceOpWithNewOp<xla_hlo::AddOp>(bn_op, result,
broadcast_offset);
Value result = rewriter.create<mhlo::SubOp>(bn_op.getLoc(), bn_op.operand(),
broadcast_mean);
result =
rewriter.create<mhlo::MulOp>(bn_op.getLoc(), result, broadcast_scale);
result =
rewriter.create<mhlo::DivOp>(bn_op.getLoc(), result, broadcast_stddev);
rewriter.replaceOpWithNewOp<mhlo::AddOp>(bn_op, result, broadcast_offset);
return success();
}
@ -180,5 +179,5 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext* context,
patterns->insert<UnfuseBatchNormInferencePattern>(context);
}
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
@ -38,9 +38,9 @@ struct TestUnfuseBatchNormPass
} // namespace
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir
static mlir::PassRegistration<mlir::xla_hlo::TestUnfuseBatchNormPass> pass(
static mlir::PassRegistration<mlir::mhlo::TestUnfuseBatchNormPass> pass(
"test-xla-unfuse-batch-norm",
"Test pass for materializing 'broadcast_dimensions' attributes");

View File

@ -182,7 +182,7 @@ struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern;
// This code has been adapted from IREE's
// (https://github.com/google/iree/) xla_hlo -> linalg conversion.
// (https://github.com/google/iree/) mhlo -> linalg conversion.
LogicalResult matchAndRewrite(
xla_lhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
@ -348,14 +348,14 @@ class BroadcastConverter
class HloBroadcastInDimConverter
: public DataMovementOpConverter<HloBroadcastInDimConverter,
xla_hlo::BroadcastInDimOp, false> {
mhlo::BroadcastInDimOp, false> {
public:
using DataMovementOpConverter<HloBroadcastInDimConverter,
xla_hlo::BroadcastInDimOp,
mhlo::BroadcastInDimOp,
false>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(
xla_hlo::BroadcastInDimOp broadcastOp, Builder* b) {
mhlo::BroadcastInDimOp broadcastOp, Builder* b) {
auto resultType = getXLAOpResultType<false>(broadcastOp);
auto operandType =
broadcastOp.operand().getType().template cast<ShapedType>();
@ -845,7 +845,7 @@ struct HloLegalizeToLinalg
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
auto func = getFunction();
xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
signalPassFailure();
}
@ -863,40 +863,40 @@ static PassRegistration<LhloLegalizeToLinalg> legalize_lhlo_pass(
"lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect");
} // namespace xla_lhlo
namespace xla_hlo {
namespace mhlo {
void populateHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
patterns->insert<BroadcastConverter<xla_hlo::BroadcastOp, false>,
patterns->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
HloBroadcastInDimConverter,
PointwiseToLinalgConverter<xla_hlo::AbsOp, false>,
PointwiseToLinalgConverter<xla_hlo::AddOp, false>,
PointwiseToLinalgConverter<xla_hlo::AndOp, false>,
PointwiseToLinalgConverter<xla_hlo::CeilOp, false>,
PointwiseToLinalgConverter<xla_hlo::CompareOp, false>,
PointwiseToLinalgConverter<xla_hlo::ComplexOp, false>,
PointwiseToLinalgConverter<xla_hlo::ConvertOp, false>,
PointwiseToLinalgConverter<xla_hlo::CopyOp, false>,
PointwiseToLinalgConverter<xla_hlo::CosOp, false>,
PointwiseToLinalgConverter<xla_hlo::DivOp, false>,
PointwiseToLinalgConverter<xla_hlo::ExpOp, false>,
PointwiseToLinalgConverter<xla_hlo::ImagOp, false>,
PointwiseToLinalgConverter<xla_hlo::LogOp, false>,
PointwiseToLinalgConverter<xla_hlo::MaxOp, false>,
PointwiseToLinalgConverter<xla_hlo::MinOp, false>,
PointwiseToLinalgConverter<xla_hlo::MulOp, false>,
PointwiseToLinalgConverter<xla_hlo::NegOp, false>,
PointwiseToLinalgConverter<xla_hlo::RealOp, false>,
PointwiseToLinalgConverter<xla_hlo::RemOp, false>,
PointwiseToLinalgConverter<xla_hlo::RsqrtOp, false>,
PointwiseToLinalgConverter<xla_hlo::SelectOp, false>,
PointwiseToLinalgConverter<xla_hlo::SinOp, false>,
PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>,
PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
ReshapeOpConverter<xla_hlo::ReshapeOp, false>,
ReverseConverter<xla_hlo::ReverseOp, false>,
TransposeConverter<xla_hlo::TransposeOp, false>>(context);
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
PointwiseToLinalgConverter<mhlo::AddOp, false>,
PointwiseToLinalgConverter<mhlo::AndOp, false>,
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
PointwiseToLinalgConverter<mhlo::CosOp, false>,
PointwiseToLinalgConverter<mhlo::DivOp, false>,
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<mhlo::LogOp, false>,
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
PointwiseToLinalgConverter<mhlo::MinOp, false>,
PointwiseToLinalgConverter<mhlo::MulOp, false>,
PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
PointwiseToLinalgConverter<mhlo::SinOp, false>,
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SubOp, false>,
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<mhlo::TransposeOp, false>>(context);
}
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
@ -905,5 +905,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
static PassRegistration<HloLegalizeToLinalg> legalize_hlo_pass(
"hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect");
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
// TODO(frgossen): Make it variadic.
@ -69,7 +69,7 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
operandTy.getElementType());
Value flatOperand = rewriter.create<xla_hlo::DynamicReshapeOp>(
Value flatOperand = rewriter.create<mhlo::DynamicReshapeOp>(
loc, flatTensorTy, operand, flatShapeAsDimTensor);
// Generate IR for the actual operation.
@ -80,7 +80,7 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
rewriter.getIndexType());
Value shapeAsExtentTensor =
rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
Value result = rewriter.create<xla_hlo::DynamicReshapeOp>(
Value result = rewriter.create<mhlo::DynamicReshapeOp>(
loc, operandTy, flatResult, shapeAsExtentTensor);
rewriter.replaceOp(op, result);
@ -184,5 +184,5 @@ static PassRegistration<TransformUnrankedHloPass> transform_unranked_hlo_pass(
"transform-unranked-hlo",
"Realize element-wise operations on ranked tensors where possible");
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir

View File

@ -2,107 +2,107 @@
// CHECK-LABEL: add_fold
func @add_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
%1 = xla_hlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<[6, 8, 10, 12]>
%2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
%0 = mhlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
%1 = mhlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
// CHECK: mhlo.constant dense<[6, 8, 10, 12]>
%2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: add_scalar_fold
func @add_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<1> : tensor<4xi64>
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<6>
%2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
%0 = mhlo.constant dense<1> : tensor<4xi64>
%1 = mhlo.constant dense<5> : tensor<4xi64>
// CHECK: mhlo.constant dense<6>
%2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: add_fold_float
func @add_fold_float() -> tensor<4xf64> {
%0 = xla_hlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64>
%1 = xla_hlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64>
// CHECK: xla_hlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
%2 = "xla_hlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
%0 = mhlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64>
%1 = mhlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64>
// CHECK: mhlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
%2 = "mhlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: sub_scalar_fold
func @sub_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<5> : tensor<4xi64>
%1 = xla_hlo.constant dense<1> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<4>
%2 = "xla_hlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
%0 = mhlo.constant dense<5> : tensor<4xi64>
%1 = mhlo.constant dense<1> : tensor<4xi64>
// CHECK: mhlo.constant dense<4>
%2 = "mhlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: multiply_scalar_fold
func @multiply_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<5> : tensor<4xi64>
%1 = xla_hlo.constant dense<3> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<15>
%2 = "xla_hlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
%0 = mhlo.constant dense<5> : tensor<4xi64>
%1 = mhlo.constant dense<3> : tensor<4xi64>
// CHECK: mhlo.constant dense<15>
%2 = "mhlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: divide_scalar_fold
func @divide_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<1>
%2 = "xla_hlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
%0 = mhlo.constant dense<7> : tensor<4xi64>
%1 = mhlo.constant dense<5> : tensor<4xi64>
// CHECK: mhlo.constant dense<1>
%2 = "mhlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: divide_fold_float
func @divide_fold_float() -> tensor<4xf64> {
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: xla_hlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]>
%2 = "xla_hlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: mhlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]>
%2 = "mhlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: max_scalar_fold
func @max_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<7>
%2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
%0 = mhlo.constant dense<7> : tensor<4xi64>
%1 = mhlo.constant dense<5> : tensor<4xi64>
// CHECK: mhlo.constant dense<7>
%2 = "mhlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: max_fold_float
func @max_fold_float() -> tensor<4xf64> {
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: xla_hlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]>
%2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: mhlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]>
%2 = "mhlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: min_scalar_fold
func @min_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
%1 = xla_hlo.constant dense<-5> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<-5>
%2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
%0 = mhlo.constant dense<7> : tensor<4xi64>
%1 = mhlo.constant dense<-5> : tensor<4xi64>
// CHECK: mhlo.constant dense<-5>
%2 = "mhlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: min_fold_float
func @min_fold_float() -> tensor<4xf64> {
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: xla_hlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]>
%2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: mhlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]>
%2 = "mhlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: concatenate_noop
func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-SAME: [[ARG:%.+]]: tensor<4xi32>
%0 = "xla_hlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32>
%0 = "mhlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32>
// CHECK: return [[ARG]]
return %0 : tensor<4xi32>
@ -112,7 +112,7 @@ func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> {
// CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32>
// CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32>
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32>
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32>
// CHECK: return [[ARG0]]
return %0 : tensor<4xi32>
@ -120,34 +120,34 @@ func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) ->
// CHECK-LABEL: concatenate_empty_bool
func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> {
// CHECK: xla_hlo.constant
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
// CHECK: mhlo.constant
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
return %0 : tensor<0xi1>
}
// CHECK-LABEL: concatenate_empty_int
func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> {
// CHECK: xla_hlo.constant
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32>
// CHECK: mhlo.constant
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32>
return %0 : tensor<0xi32>
}
// CHECK-LABEL: concatenate_empty_float
func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
// CHECK: xla_hlo.constant
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
// CHECK: mhlo.constant
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
return %0 : tensor<0xf32>
}
// CHECK-LABEL: concatenate_const_1D
func @concatenate_const_1D() -> tensor<4xi32> {
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[0, 1, 2, 3]>
%0 = xla_hlo.constant dense<[0, 1]> : tensor<2xi32>
%1 = xla_hlo.constant dense<[2, 3]> : tensor<2xi32>
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32>
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[0, 1, 2, 3]>
%0 = mhlo.constant dense<[0, 1]> : tensor<2xi32>
%1 = mhlo.constant dense<[2, 3]> : tensor<2xi32>
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32>
// CHECK: return [[VAL]]
return %2 : tensor<4xi32>
@ -155,11 +155,11 @@ func @concatenate_const_1D() -> tensor<4xi32> {
// CHECK-LABEL: concatenate_const_1D_float
func @concatenate_const_1D_float() -> tensor<4xf32> {
// CHECK: [[VAL:%.+]] = xla_hlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
// CHECK: [[VAL:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
%0 = xla_hlo.constant dense<[0.0, 1.0]> : tensor<2xf32>
%1 = xla_hlo.constant dense<[2.0, 3.0]> : tensor<2xf32>
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32>
%0 = mhlo.constant dense<[0.0, 1.0]> : tensor<2xf32>
%1 = mhlo.constant dense<[2.0, 3.0]> : tensor<2xf32>
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32>
// CHECK: return [[VAL]]
return %2 : tensor<4xf32>
@ -167,12 +167,12 @@ func @concatenate_const_1D_float() -> tensor<4xf32> {
// CHECK-LABEL: concatenate_const_2D_vertical
func @concatenate_const_2D_vertical() -> tensor<2x2xi32> {
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[
// CHECK-SAME: [0, 1], [2, 3]
// CHECK-SAME: ]>
%0 = xla_hlo.constant dense<[[0, 1]]> : tensor<1x2xi32>
%1 = xla_hlo.constant dense<[[2, 3]]> : tensor<1x2xi32>
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
%0 = mhlo.constant dense<[[0, 1]]> : tensor<1x2xi32>
%1 = mhlo.constant dense<[[2, 3]]> : tensor<1x2xi32>
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
// CHECK: return [[VAL]]
return %2 : tensor<2x2xi32>
@ -180,12 +180,12 @@ func @concatenate_const_2D_vertical() -> tensor<2x2xi32> {
// CHECK-LABEL: concatenate_const_2D_horizontal
func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[
// CHECK-SAME: [0, 2], [1, 3]
// CHECK-SAME: ]>
%0 = xla_hlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
%1 = xla_hlo.constant dense<[[2], [3]]> : tensor<2x1xi32>
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
%0 = mhlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
%1 = mhlo.constant dense<[[2], [3]]> : tensor<2x1xi32>
%2 = "mhlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
// CHECK: return [[VAL]]
return %2 : tensor<2x2xi32>
@ -193,40 +193,40 @@ func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
// CHECK-LABEL: dynamic_slice_variable_start
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
// CHECK: "xla_hlo.dynamic-slice"
%1 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
// CHECK: "mhlo.dynamic-slice"
%1 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
return %1 : tensor<1x4xi32>
}
// CHECK-LABEL: dynamic_slice_constant_start
func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
// CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0)
// CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0)
// CHECK-DAG-SAME: limit_indices = dense<3> : tensor<1xi64>
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>}
// CHECK: return %[[RESULT]] : tensor<2xi32>
%0 = xla_hlo.constant dense<1> : tensor<i64>
%1 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
%0 = mhlo.constant dense<1> : tensor<i64>
%1 = "mhlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
return %1 : tensor<2xi32>
}
// CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape
func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<?x4xi32> {
// CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0)
// CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0)
// CHECK-DAG-SAME: limit_indices = dense<[2, 4]> : tensor<2xi64>
// CHECK-DAG-SAME: start_indices = dense<[1, 0]> : tensor<2xi64>
// CHECK-DAG-SAME: strides = dense<1> : tensor<2xi64>
// CHECK: return %[[RESULT]] : tensor<?x4xi32>
%0 = xla_hlo.constant dense<1> : tensor<i64>
%1 = xla_hlo.constant dense<0> : tensor<i64>
%2 = "xla_hlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<?x4xi32>
%0 = mhlo.constant dense<1> : tensor<i64>
%1 = mhlo.constant dense<0> : tensor<i64>
%2 = "mhlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<?x4xi32>
return %2 : tensor<?x4xi32>
}
// CHECK-LABEL: slice_2D_noop
// CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64>
func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> {
%0 = "xla_hlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>)
%0 = "mhlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>)
// CHECK-NEXT: return [[ARG]]
return %0 : tensor<2x2xi64>
@ -234,80 +234,80 @@ func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> {
// CHECK-LABEL: slice_1D_fold
func @slice_1D_fold() -> tensor<2xi64> {
%0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<[7, 9]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
%0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
// CHECK: mhlo.constant dense<[7, 9]>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
return %1 : tensor<2xi64>
}
// CHECK-LABEL: slice_1D_fp
func @slice_1D_fp() -> tensor<2xf32> {
%0 = xla_hlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32>
// CHECK: xla_hlo.constant dense<[7.000000e+00, 9.000000e+00]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>)
%0 = mhlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32>
// CHECK: mhlo.constant dense<[7.000000e+00, 9.000000e+00]>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>)
return %1 : tensor<2xf32>
}
// CHECK-LABEL: slice_1D_strided_fold
func @slice_1D_strided_fold() -> tensor<2xi64> {
%0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<[7, 10]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
%0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
// CHECK: mhlo.constant dense<[7, 10]>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
return %1 : tensor<2xi64>
}
// CHECK-LABEL: slice_2D_fold
func @slice_2D_fold() -> tensor<2x2xi64> {
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: xla_hlo.constant dense<[
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: mhlo.constant dense<[
// CHECK-SAME: [6, 7],
// CHECK-SAME: [10, 11]
// CHECK-SAME: ]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>)
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>)
return %1 : tensor<2x2xi64>
}
// CHECK-LABEL: slice_2D_fold_horizontal
func @slice_2D_fold_horizontal() -> tensor<1x4xi64> {
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: xla_hlo.constant dense<[
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: mhlo.constant dense<[
// CHECK-SAME: [0, 1, 2, 3]
// CHECK-SAME: ]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>)
%1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>)
return %1 : tensor<1x4xi64>
}
// CHECK-LABEL: slice_2D_fold_vertical
func @slice_2D_fold_vertical() -> tensor<4x1xi64> {
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: xla_hlo.constant dense<[
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: mhlo.constant dense<[
// CHECK-SAME: [2], [6], [10], [14]
// CHECK-SAME: ]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>)
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>)
return %1 : tensor<4x1xi64>
}
// CHECK-LABEL: slice_concat_fold_first
func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> {
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
// CHECK: return %arg0
return %1 : tensor<1x5xf32>
}
// CHECK-LABEL: slice_concat_fold_second
func @slice_concat_fold_second(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> {
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
// CHECK: return %arg1
return %1 : tensor<1x5xf32>
}
// CHECK-LABEL: slice_concat_fold_second_with_slice
func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x4xf32> {
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
// CHECK: [[SLICE:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>)
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
// CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>)
// CHECK: return [[SLICE]]
return %1 : tensor<1x4xf32>
@ -315,9 +315,9 @@ func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<
// CHECK-LABEL: slice_concat_fold_middle
func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> {
%0 = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
// CHECK: [[SLICE:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>)
%0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
// CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>)
// CHECK: return [[SLICE]]
return %1 : tensor<1x5xf32>
@ -325,11 +325,11 @@ func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %
// CHECK-LABEL: slice_concat_fold_two
func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<2x5xf32> {
// CHECK: [[CONCAT:%.+]] = "xla_hlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64}
%0 = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
// CHECK: [[CONCAT:%.+]] = "mhlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64}
%0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
// CHECK: [[SLICE:%.+]] = "xla_hlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>)
// CHECK: [[SLICE:%.+]] = "mhlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>)
// CHECK: return [[SLICE]]
return %1 : tensor<2x5xf32>
@ -338,72 +338,72 @@ func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg
// CHECK-LABEL: func @broadcast_in_dim_identity
func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
// CHECK: return %arg0
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
return %0 : tensor<2x3x4xf32>
}
// CHECK-LABEL: func @broadcast_in_dim_not_identity_because_it_actually_broadcasts
func @broadcast_in_dim_not_identity_because_it_actually_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> {
// CHECK: xla_hlo.broadcast_in_dim
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32>
// CHECK: mhlo.broadcast_in_dim
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// CHECK-LABEL: func @broadcast_in_dim_not_identity_permutation
func @broadcast_in_dim_not_identity_permutation(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: xla_hlo.broadcast_in_dim
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: mhlo.broadcast_in_dim
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic
func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> {
// CHECK: %[[RESULT:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32>
%0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32>
// CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32>
%0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32>
// CHECK: return %[[RESULT]] : tensor<5x4xf32>
return %0 : tensor<5x4xf32>
}
// CHECK-LABEL: @complex_expand_fold
func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>)
%1 = "xla_hlo.real"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%2 = "xla_hlo.imag"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>)
%1 = "mhlo.real"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%2 = "mhlo.imag"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
// CHECK: return %arg0, %arg1
return %1, %2 : tensor<4xf32>, tensor<4xf32>
}
// CHECK-LABEL: @complex_collapse_fold
func @complex_collapse_fold(%arg0: tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>> {
%0 = "xla_hlo.real"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%1 = "xla_hlo.imag"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%2 = "xla_hlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
%0 = "mhlo.real"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%1 = "mhlo.imag"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%2 = "mhlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
// CHECK: return %arg0
return %2 : tensor<4xcomplex<f32>>
}
// CHECK-LABEL: @dynamic_iota_is_static
func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> {
// CHECK: [[RESULT:%.*]] = "xla_hlo.iota"
// CHECK: [[RESULT:%.*]] = "mhlo.iota"
// CHECK: return [[RESULT]]
%0 = "xla_hlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32>
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// CHECK-LABEL: @iota_not_lowered_to_constant
func @iota_not_lowered_to_constant() -> tensor<4xi32> {
// CHECK: [[RESULT:%.*]] = "xla_hlo.iota"
// CHECK: [[RESULT:%.*]] = "mhlo.iota"
// CHECK: return [[RESULT]]
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// CHECK-LABEL: @unary_einsum
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
// CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: "xla_hlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"}
%0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
// CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: "mhlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"}
%0 = "mhlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -411,30 +411,30 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> {
// CHECK: return [[ARG]]
%0 = "xla_hlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32>
%0 = "mhlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic
func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> {
// CHECK: xla_hlo.reshape
%0 = "xla_hlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32>
// CHECK: mhlo.reshape
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32>
return %0 : tensor<4x1xf32>
}
// CHECK-LABEL: do_not_dce_while_with_outfeed
func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: xla_hlo.while
%0 = "xla_hlo.while"(%arg0) ( {
// CHECK: mhlo.while
%0 = "mhlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i64>):
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"mhlo.return"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i64>):
%1 = "xla_hlo.create_token"() : () -> !xla_hlo.token
%1 = "mhlo.create_token"() : () -> !mhlo.token
// Side-effecting op outfeed present inside while.
%2 = "xla_hlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor<i64>, !xla_hlo.token) -> !xla_hlo.token
"xla_hlo.return"(%arg1) : (tensor<i64>) -> ()
%2 = "mhlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor<i64>, !mhlo.token) -> !mhlo.token
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %arg0 : tensor<i64>
@ -442,15 +442,15 @@ func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-LABEL: dce_while_without_side_effect
func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-NOT: xla_hlo.while
%0 = "xla_hlo.while"(%arg0) ( {
// CHECK-NOT: mhlo.while
%0 = "mhlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i64>):
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"mhlo.return"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i64>):
%1 = "xla_hlo.create_token"() : () -> !xla_hlo.token
"xla_hlo.return"(%arg1) : (tensor<i64>) -> ()
%1 = "mhlo.create_token"() : () -> !mhlo.token
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %arg0 : tensor<i64>

View File

@ -4,7 +4,7 @@
// representative op for detailed broadcast semantics.
// CHECK-LABEL: @addWithoutBroadcast
func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.add %arg0, %arg1
// CHECK: mhlo.add %arg0, %arg1
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -20,9 +20,9 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK-NEXT: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]]
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]]
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
// CHECK-NEXT: }
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xf32>
@ -41,9 +41,9 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK-NEXT: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT:.+]] = "mhlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
// CHECK-NEXT: }
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xcomplex<f32>>
@ -62,9 +62,9 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.+]] = "mhlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: shape.assuming_yield %[[RESULT]]
// CHECK-NEXT: }
// CHECK: return %[[FINAL_RESULT]] : tensor<?x?xi1>
@ -76,7 +76,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// Verifies that broadcast_dimensions validity checks are valid.
// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions
func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK: xla_hlo.add
// CHECK: mhlo.add
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
@ -85,7 +85,7 @@ func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<
// Verifies that broadcast_dimensions validity checks are valid.
// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions
func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<f32>) -> tensor<1x4xf32> {
// CHECK: xla_hlo.add
// CHECK: mhlo.add
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
@ -113,7 +113,7 @@ func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1:
// expansions. Tests below merely verify that the op has an expansion.
// CHECK-LABEL: @andWithoutBroadcast
func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: xla_hlo.and %arg0, %arg1
// CHECK: mhlo.and %arg0, %arg1
%0 = xla_chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}
@ -121,7 +121,7 @@ func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
// -----
// CHECK-LABEL: @atan2WithoutBroadcast
func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.atan2 %arg0, %arg1
// CHECK: mhlo.atan2 %arg0, %arg1
%0 = xla_chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -129,7 +129,7 @@ func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
// -----
// CHECK-LABEL: @compareWithoutBroadcast
func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> {
// CHECK: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}
@ -137,7 +137,7 @@ func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
// -----
// CHECK-LABEL: @complexWithoutBroadcast
func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex<f32>> {
// CHECK: "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
// CHECK: "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
return %0 : tensor<4xcomplex<f32>>
}
@ -145,7 +145,7 @@ func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
// -----
// CHECK-LABEL: @divideWithoutBroadcast
func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.divide %arg0, %arg1
// CHECK: mhlo.divide %arg0, %arg1
%0 = xla_chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -153,7 +153,7 @@ func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tens
// -----
// CHECK-LABEL: @maximumWithoutBroadcast
func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.maximum %arg0, %arg1
// CHECK: mhlo.maximum %arg0, %arg1
%0 = xla_chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -161,7 +161,7 @@ func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
// -----
// CHECK-LABEL: @minimumWithoutBroadcast
func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.minimum %arg0, %arg1
// CHECK: mhlo.minimum %arg0, %arg1
%0 = xla_chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -169,7 +169,7 @@ func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
// -----
// CHECK-LABEL: @multiplyWithoutBroadcast
func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.multiply %arg0, %arg1
// CHECK: mhlo.multiply %arg0, %arg1
%0 = xla_chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -177,7 +177,7 @@ func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te
// -----
// CHECK-LABEL: @orWithoutBroadcast
func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: xla_hlo.or %arg0, %arg1
// CHECK: mhlo.or %arg0, %arg1
%0 = xla_chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}
@ -185,7 +185,7 @@ func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi
// -----
// CHECK-LABEL: @powerWithoutBroadcast
func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.power %arg0, %arg1
// CHECK: mhlo.power %arg0, %arg1
%0 = xla_chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -193,7 +193,7 @@ func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
// -----
// CHECK-LABEL: @remainderWithoutBroadcast
func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.remainder %arg0, %arg1
// CHECK: mhlo.remainder %arg0, %arg1
%0 = xla_chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -201,7 +201,7 @@ func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> t
// -----
// CHECK-LABEL: @shift_leftWithoutBroadcast
func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.shift_left %arg0, %arg1
// CHECK: mhlo.shift_left %arg0, %arg1
%0 = xla_chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -209,7 +209,7 @@ func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) ->
// -----
// CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast
func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1
// CHECK: mhlo.shift_right_arithmetic %arg0, %arg1
%0 = xla_chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -217,7 +217,7 @@ func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor
// -----
// CHECK-LABEL: @shift_right_logicalWithoutBroadcast
func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.shift_right_logical %arg0, %arg1
// CHECK: mhlo.shift_right_logical %arg0, %arg1
%0 = xla_chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -225,7 +225,7 @@ func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4x
// -----
// CHECK-LABEL: @subWithoutBroadcast
func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.subtract %arg0, %arg1
// CHECK: mhlo.subtract %arg0, %arg1
%0 = xla_chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -233,7 +233,7 @@ func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<
// -----
// CHECK-LABEL: @xorWithoutBroadcast
func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: xla_hlo.xor %arg0, %arg1
// CHECK: mhlo.xor %arg0, %arg1
%0 = xla_chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}

View File

@ -3,7 +3,7 @@
// CHECK-LABEL: func @single_operand
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @single_operand(%arg: tensor<1x2xf32>) -> tensor<1x2xf32> {
%0 = "xla_hlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32>
%0 = "mhlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32>
// CHECK-NEXT: return [[ARG]]
return %0 : tensor<1x2xf32>
}

View File

@ -5,7 +5,7 @@
// CHECK-LABEL: func @same_type
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @same_type(%arg: tensor<f32>) -> tensor<f32> {
%0 = "xla_hlo.convert"(%arg) : (tensor<f32>) -> tensor<f32>
%0 = "mhlo.convert"(%arg) : (tensor<f32>) -> tensor<f32>
// CHECK-NEXT: return [[ARG]]
return %0 : tensor<f32>
}
@ -15,8 +15,8 @@ func @same_type(%arg: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: func @int_widening
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @int_widening(%arg: tensor<i32>) -> tensor<i64> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i64>
%0 = "xla_hlo.convert"(%arg) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i64>
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<i64>
}
@ -26,8 +26,8 @@ func @int_widening(%arg: tensor<i32>) -> tensor<i64> {
// CHECK-LABEL: func @int_narrowing
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @int_narrowing(%arg: tensor<i32>) -> tensor<i16> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i16>
%0 = "xla_hlo.convert"(%arg) : (tensor<i32>) -> tensor<i16>
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i16>
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<i16>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<i16>
}
@ -37,8 +37,8 @@ func @int_narrowing(%arg: tensor<i32>) -> tensor<i16> {
// CHECK-LABEL: func @float_int
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @float_int(%arg: tensor<f32>) -> tensor<i32> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<f32>) -> tensor<i32>
%0 = "xla_hlo.convert"(%arg) : (tensor<f32>) -> tensor<i32>
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<f32>) -> tensor<i32>
%0 = "mhlo.convert"(%arg) : (tensor<f32>) -> tensor<i32>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<i32>
}
@ -48,8 +48,8 @@ func @float_int(%arg: tensor<f32>) -> tensor<i32> {
// CHECK-LABEL: func @int_float
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @int_float(%arg: tensor<i32>) -> tensor<f32> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<f32>
%0 = "xla_hlo.convert"(%arg) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<f32>
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<f32>
}
@ -59,8 +59,8 @@ func @int_float(%arg: tensor<i32>) -> tensor<f32> {
// CHECK-LABEL: func @high_rank_tensor
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32>
%0 = "xla_hlo.convert"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32>
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32>
%0 = "mhlo.convert"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<2x3xf32>
}
@ -70,9 +70,9 @@ func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> {
// CHECK-LABEL: func @const_same_type
func @const_same_type() -> tensor<i32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
%cst = xla_hlo.constant dense<42> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<i32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
%cst = mhlo.constant dense<42> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i32>
}
@ -81,9 +81,9 @@ func @const_same_type() -> tensor<i32> {
// CHECK-LABEL: func @const_float_int
func @const_float_int() -> tensor<i32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
%cst = xla_hlo.constant dense<42.0> : tensor<f32>
%0 = "xla_hlo.convert"(%cst) : (tensor<f32>) -> tensor<i32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
%cst = mhlo.constant dense<42.0> : tensor<f32>
%0 = "mhlo.convert"(%cst) : (tensor<f32>) -> tensor<i32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i32>
}
@ -92,9 +92,9 @@ func @const_float_int() -> tensor<i32> {
// CHECK-LABEL: func @const_int_float
func @const_int_float() -> tensor<f32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.{{0*}}e+00> : tensor<f32>
%cst = xla_hlo.constant dense<4> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor<f32>
%cst = mhlo.constant dense<4> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<f32>
}
@ -103,9 +103,9 @@ func @const_int_float() -> tensor<f32> {
// CHECK-LABEL: func @const_negative_int_float
func @const_negative_int_float() -> tensor<f32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<-4.{{0*}}e+00> : tensor<f32>
%cst = xla_hlo.constant dense<-4> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-4.{{0*}}e+00> : tensor<f32>
%cst = mhlo.constant dense<-4> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<f32>
}
@ -114,9 +114,9 @@ func @const_negative_int_float() -> tensor<f32> {
// CHECK-LABEL: func @const_int_bf16
func @const_int_bf16() -> tensor<bf16> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.{{0*}}e+00> : tensor<bf16>
%cst = xla_hlo.constant dense<4> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<bf16>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor<bf16>
%cst = mhlo.constant dense<4> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<bf16>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<bf16>
}
@ -125,9 +125,9 @@ func @const_int_bf16() -> tensor<bf16> {
// CHECK-LABEL: func @const_bf16_int
func @const_bf16_int() -> tensor<i16> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i16>
%cst = xla_hlo.constant dense<42.0> : tensor<bf16>
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<i16>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16>
%cst = mhlo.constant dense<42.0> : tensor<bf16>
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<i16>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i16>
}
@ -136,9 +136,9 @@ func @const_bf16_int() -> tensor<i16> {
// CHECK-LABEL: func @const_int_narrowing
func @const_int_narrowing() -> tensor<i32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
%cst = xla_hlo.constant dense<42> : tensor<i64>
%0 = "xla_hlo.convert"(%cst) : (tensor<i64>) -> tensor<i32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
%cst = mhlo.constant dense<42> : tensor<i64>
%0 = "mhlo.convert"(%cst) : (tensor<i64>) -> tensor<i32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i32>
}
@ -147,9 +147,9 @@ func @const_int_narrowing() -> tensor<i32> {
// CHECK-LABEL: func @const_int_widening
func @const_int_widening() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i64>
%cst = xla_hlo.constant dense<42> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
%cst = mhlo.constant dense<42> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i64>
}
@ -158,9 +158,9 @@ func @const_int_widening() -> tensor<i64> {
// CHECK-LABEL: func @const_negative_int_widening
func @const_negative_int_widening() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<-42> : tensor<i64>
%cst = xla_hlo.constant dense<-42> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64>
%cst = mhlo.constant dense<-42> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i64>
}
@ -169,9 +169,9 @@ func @const_negative_int_widening() -> tensor<i64> {
// CHECK-LABEL: func @const_float_narrowing
func @const_float_narrowing() -> tensor<f32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<f32>
%cst = xla_hlo.constant dense<4.2> : tensor<f64>
%0 = "xla_hlo.convert"(%cst) : (tensor<f64>) -> tensor<f32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor<f32>
%cst = mhlo.constant dense<4.2> : tensor<f64>
%0 = "mhlo.convert"(%cst) : (tensor<f64>) -> tensor<f32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<f32>
}
@ -180,9 +180,9 @@ func @const_float_narrowing() -> tensor<f32> {
// CHECK-LABEL: func @const_f32_bf16
func @const_f32_bf16() -> tensor<bf16> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+01> : tensor<bf16>
%cst = xla_hlo.constant dense<42.0> : tensor<f32>
%0 = "xla_hlo.convert"(%cst) : (tensor<f32>) -> tensor<bf16>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+01> : tensor<bf16>
%cst = mhlo.constant dense<42.0> : tensor<f32>
%0 = "mhlo.convert"(%cst) : (tensor<f32>) -> tensor<bf16>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<bf16>
}
@ -191,9 +191,9 @@ func @const_f32_bf16() -> tensor<bf16> {
// CHECK-LABEL: func @const_bf16_f64
func @const_bf16_f64() -> tensor<f64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.187500e+00> : tensor<f64>
%cst = xla_hlo.constant dense<4.2> : tensor<bf16>
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.187500e+00> : tensor<f64>
%cst = mhlo.constant dense<4.2> : tensor<bf16>
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<f64>
}
@ -202,9 +202,9 @@ func @const_bf16_f64() -> tensor<f64> {
// CHECK-LABEL: func @const_bf16_int
func @const_bf16_int() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i64>
%cst = xla_hlo.constant dense<42.0> : tensor<bf16>
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<i64>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
%cst = mhlo.constant dense<42.0> : tensor<bf16>
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<i64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i64>
}
@ -214,11 +214,11 @@ func @const_bf16_int() -> tensor<i64> {
// CHECK-LABEL: func @const_high_rank_tensor
func @const_high_rank_tensor() -> tensor<2x3xi32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[
// CHECK-SAME: [1, 2, 3], [4, 5, 6]
// CHECK-SAME: ]> : tensor<2x3xi32>
%cst = xla_hlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
%0 = "xla_hlo.convert"(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32>
%cst = mhlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
%0 = "mhlo.convert"(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2x3xi32>
}

View File

@ -4,7 +4,7 @@
// BOTH-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.exponential"(%tensor_operand)
%tensor_result = "mhlo.exponential"(%tensor_operand)
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
@ -28,11 +28,11 @@ func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// BOTH-LABEL: func @func_op_long
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
%2 = xla_hlo.add %arg0, %1 : tensor<4xf32>
%3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
%4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32>
%5 = xla_hlo.multiply %2, %4 : tensor<4xf32>
%1 = mhlo.maximum %arg0, %arg1 : tensor<4xf32>
%2 = mhlo.add %arg0, %1 : tensor<4xf32>
%3 = mhlo.minimum %arg0, %arg1 : tensor<4xf32>
%4 = mhlo.subtract %arg1, %3 : tensor<4xf32>
%5 = mhlo.multiply %2, %4 : tensor<4xf32>
return %5 : tensor<4xf32>
}
// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
@ -65,12 +65,12 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
%sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2)
%sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
%tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier)
%tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
@ -86,7 +86,7 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
// BOTH-LABEL: func @copy
func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.copy"(%tensor_operand)
%tensor_result = "mhlo.copy"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -98,7 +98,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @exp
func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.exponential"(%tensor_operand)
%tensor_result = "mhlo.exponential"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -110,7 +110,7 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @log
func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.log"(%tensor_operand)
%tensor_result = "mhlo.log"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -125,7 +125,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%tensor_pred = tensor_load %pred : memref<2x2xi1>
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
%tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -138,7 +138,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs)
%tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs)
{comparison_direction = "EQ"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
// BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
@ -151,7 +151,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x
// BOTH-LABEL: func @broadcast
func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
%tensor_operand = tensor_load %operand : memref<5xf32>
%tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand)
%tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand)
{broadcast_dimensions = dense<1> : tensor<1xi64>}
: (tensor<5xf32>) -> tensor<10x5xf32>
// BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
@ -170,7 +170,7 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
%tensor_operand = tensor_load %operand : memref<?x?xf32>
%shape = call @external_func() : () -> tensor<3xi64>
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
// BOTH: %[[SHAPE:.*]] = call @external_func()
@ -226,7 +226,7 @@ func @complex(%real: memref<2x2xf32>,
%result: memref<2x2xcomplex<f32>>) {
%tensor_real = tensor_load %real : memref<2x2xf32>
%tensor_imag = tensor_load %imag : memref<2x2xf32>
%tensor_result = "xla_hlo.complex"(%tensor_real, %tensor_imag)
%tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
// BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
@ -238,7 +238,7 @@ func @complex(%real: memref<2x2xf32>,
// BOTH-LABEL: func @real
func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "xla_hlo.real"(%tensor_operand)
%tensor_result = "mhlo.real"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -250,7 +250,7 @@ func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @imag
func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "xla_hlo.imag"(%tensor_operand)
%tensor_result = "mhlo.imag"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -261,7 +261,7 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @iota
func @iota(%result: memref<10xi32>) {
%tensor_result = "xla_hlo.iota"()
%tensor_result = "mhlo.iota"()
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
// BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
tensor_store %tensor_result, %result : memref<10xi32>
@ -273,7 +273,7 @@ func @iota(%result: memref<10xi32>) {
// BOTH-LABEL: func @abs
func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.abs"(%tensor_operand)
%tensor_result = "mhlo.abs"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -285,7 +285,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @ceil
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.ceil"(%tensor_operand)
%tensor_result = "mhlo.ceil"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -297,7 +297,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @convert
func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.convert"(%tensor_operand)
%tensor_result = "mhlo.convert"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
// BOTH-NOT: tensor_store
@ -310,7 +310,7 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @cos
func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.cosine"(%tensor_operand)
%tensor_result = "mhlo.cosine"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -322,7 +322,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @neg
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.negate"(%tensor_operand)
%tensor_result = "mhlo.negate"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -334,7 +334,7 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @rsqrt
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.rsqrt"(%tensor_operand)
%tensor_result = "mhlo.rsqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -346,7 +346,7 @@ func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @sign
func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.sign"(%tensor_operand)
%tensor_result = "mhlo.sign"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -358,7 +358,7 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @sqrt
func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.sqrt"(%tensor_operand)
%tensor_result = "mhlo.sqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -370,7 +370,7 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @tanh
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.tanh"(%tensor_operand)
%tensor_result = "mhlo.tanh"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -383,7 +383,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs)
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
@ -395,7 +395,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
// Dynamic shape binary element-wise operation.
// BOTH-LABEL: func @add_dyn
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
%result = "xla_hlo.add"(%lhs, %rhs)
%result = "mhlo.add"(%lhs, %rhs)
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
@ -420,7 +420,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
// Dynamic shape unary element-wise operation.
// BOTH-LABEL: func @tanh_dyn
func @tanh_dyn(%arg0: tensor<?x?xf32>) {
%result = "xla_hlo.tanh"(%arg0)
%result = "mhlo.tanh"(%arg0)
: (tensor<?x?xf32>) -> tensor<?x?xf32>
// BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
@ -448,7 +448,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
%dot = "xla_hlo.dot"(%arg0, %arg0)
%dot = "mhlo.dot"(%arg0, %arg0)
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]])
// ESC: return %[[ALLOC]]
@ -466,7 +466,7 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor
// BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
// BOTH-SAME: rhs_dilation = dense<[1, 2]>
// BOTH-SAME: window_strides = dense<[2, 1]>
%out = "xla_hlo.convolution"(%filter, %input) {
%out = "mhlo.convolution"(%filter, %input) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,

View File

@ -10,7 +10,7 @@ func @float_add(%lhs: tensor<2x2xf32>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32
// CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = addf %[[ARG0]], %[[ARG1]]
// CHECK: linalg.yield %[[RESULT]]
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xf32>,
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -22,7 +22,7 @@ func @integer_add(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: addi
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xi32>,
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -34,7 +34,7 @@ func @float_mul(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: mulf
%0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>,
%0 = "mhlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -46,7 +46,7 @@ func @integer_mul(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: muli
%0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>,
%0 = "mhlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -58,7 +58,7 @@ func @float_remainder(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: remf
%0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>,
%0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -70,7 +70,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: remi_signed
%0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>,
%0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -79,7 +79,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>,
// CHECK-LABEL: func @float_rsqrt
func @float_rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
%tensor_result = "xla_hlo.rsqrt"(%operand)
%tensor_result = "mhlo.rsqrt"(%operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: linalg.generic
// CHECK: rsqrt
@ -93,7 +93,7 @@ func @float_sub(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: subf
%0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>,
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -105,7 +105,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: subi
%0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>,
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -116,7 +116,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: absf
%0 = "xla_hlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%0 = "mhlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -126,7 +126,7 @@ func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: exp
%0 = "xla_hlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%0 = "mhlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -136,7 +136,7 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: log
%0 = "xla_hlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%0 = "mhlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -146,7 +146,7 @@ func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: ceilf
%0 = "xla_hlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%0 = "mhlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -156,7 +156,7 @@ func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: negf
%0 = "xla_hlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%0 = "mhlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -166,7 +166,7 @@ func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: tanh
%0 = "xla_hlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%0 = "mhlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -177,7 +177,7 @@ func @integer_and(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: and
%0 = "xla_hlo.and"(%lhs, %rhs) : (tensor<2x2xi32>,
%0 = "mhlo.and"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -187,7 +187,7 @@ func @integer_and(%lhs: tensor<2x2xi32>,
// CHECK-LABEL: func @float_cmp
func @float_cmp(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"}
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
return %0 : tensor<2x2xi1>
}
@ -201,7 +201,7 @@ func @float_cmp(%lhs: tensor<2x2xf32>,
// CHECK-LABEL: func @int_cmp
func @int_cmp(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi1> {
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "LT"}
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "LT"}
: (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>)
return %0 : tensor<2x2xi1>
}
@ -216,7 +216,7 @@ func @int_cmp(%lhs: tensor<2x2xi32>,
func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: cos
%0 = "xla_hlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%0 = "mhlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -226,7 +226,7 @@ func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: sin
%0 = "xla_hlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%0 = "mhlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -235,7 +235,7 @@ func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK-LABEL: func @copy
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
%0 = "xla_hlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>)
%0 = "mhlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>)
return %0 : tensor<2x4x8xf32>
}
// CHECK: return [[ARG]] : tensor<2x4x8xf32>
@ -245,7 +245,7 @@ func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
// CHECK-LABEL: func @select
func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = "xla_hlo.select"(%pred, %lhs, %rhs)
%0 = "mhlo.select"(%pred, %lhs, %rhs)
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
return %0 : tensor<2x2xf32>
}
@ -260,7 +260,7 @@ func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @broadcast_scalar
func @broadcast_scalar(%arg: tensor<f32>) -> tensor<4x2x1xf32> {
%0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<f32>) -> tensor<4x2x1xf32>
%0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<f32>) -> tensor<4x2x1xf32>
return %0: tensor<4x2x1xf32>
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
@ -273,7 +273,7 @@ func @broadcast_scalar(%arg: tensor<f32>) -> tensor<4x2x1xf32> {
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
// CHECK-LABEL: func @broadcast
func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
%0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32>
%0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32>
return %0: tensor<4x2x1x4x?x16xf32>
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
@ -286,7 +286,7 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
// CHECK-LABEL: func @broadcast_in_dim
func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%operand)
%0 = "mhlo.broadcast_in_dim"(%operand)
{broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>}
: (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32>
return %0 : tensor<7x10x6x4x5xf32>
@ -302,7 +302,7 @@ func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> {
// CHECK-LABEL: func @broadcast_in_dim_with_one_to_one
func @broadcast_in_dim_with_one_to_one(
%operand: tensor<1xf32>) -> tensor<1x5xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%operand)
%0 = "mhlo.broadcast_in_dim"(%operand)
{broadcast_dimensions = dense<[0]> : tensor<1xi64>}
: (tensor<1xf32>) -> tensor<1x5xf32>
return %0 : tensor<1x5xf32>
@ -317,7 +317,7 @@ func @broadcast_in_dim_with_one_to_one(
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @broadcast_scalar
func @broadcast_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%operand)
%0 = "mhlo.broadcast_in_dim"(%operand)
{broadcast_dimensions = dense<[]> : tensor<0xi64>}
: (tensor<f32>) -> tensor<7x10x6xf32>
return %0 : tensor<7x10x6xf32>
@ -332,7 +332,7 @@ func @broadcast_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> {
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func @transpose
func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}
%0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}
: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
return %0 : tensor<3x2x5x9xi32>
}
@ -344,7 +344,7 @@ func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-LABEL: func @reshape_3D_2D
func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
%0 = "mhlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
return %0 : tensor<12x42xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
@ -355,7 +355,7 @@ func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
// CHECK-LABEL: func @reshape_4D_2D
func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32>
%0 = "mhlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32>
return %0 : tensor<12x42xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
@ -366,7 +366,7 @@ func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> {
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// CHECK-LABEL: func @reshape_2D_4D
func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32>
%0 = "mhlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32>
return %0 : tensor<12x1x42x1xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
@ -375,7 +375,7 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
// CHECK-LABEL: func @minf
func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = "xla_hlo.minimum"(%lhs, %rhs)
%0 = "mhlo.minimum"(%lhs, %rhs)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
@ -389,7 +389,7 @@ func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK-LABEL: func @maxi
func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
%0 = "xla_hlo.maximum"(%lhs, %rhs)
%0 = "mhlo.maximum"(%lhs, %rhs)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
@ -404,7 +404,7 @@ func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK-DAG: #[[MAP:.*]] = affine_map<() -> ()>
// CHECK-LABEL: func @add_scalar
func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
// CHECK: linalg.generic
@ -417,7 +417,7 @@ func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
func @reshape_collapse_single_dim
(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x784xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>
%0 = "mhlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>
return %0 : tensor<1x784xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
@ -428,7 +428,7 @@ func @reshape_collapse_single_dim
// -----
func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32>
%0 = "mhlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32>
return %0 : tensor<2x4x3xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
@ -440,7 +440,7 @@ func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> {
// -----
func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32>
%0 = "mhlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32>
return %0 : tensor<2x4x2xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0)>
@ -451,7 +451,7 @@ func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> {
// -----
func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32>
%0 = "mhlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32>
return %0 : tensor<1x4x2xf32>
}
// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
@ -462,7 +462,7 @@ func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> {
func @reshape_multiple_collapse
(%arg0 : tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32>
%0 = "mhlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32>
return %0 : tensor<1x4x5x6xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)>
@ -476,7 +476,7 @@ func @reshape_multiple_collapse
// CHECK-LABEL: func @convert_i32_to_f32
func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
%result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
return %result : tensor<2x2xf32>
}
// CHECK: linalg.generic
@ -488,7 +488,7 @@ func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
// CHECK-LABEL: func @convert_i16_to_i32
func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32>
%result = "mhlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32>
return %result : tensor<2x2xi32>
}
// CHECK: linalg.generic
@ -500,7 +500,7 @@ func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> {
// CHECK-LABEL: func @convert_i32_to_i16
func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16>
%result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16>
return %result : tensor<2x2xi16>
}
// CHECK: linalg.generic
@ -512,7 +512,7 @@ func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> {
// CHECK-LABEL: func @convert_f32_to_f64
func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64>
%result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64>
return %result : tensor<2x2xf64>
}
// CHECK: linalg.generic
@ -524,7 +524,7 @@ func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> {
// CHECK-LABEL: func @convert_f64_to_f32
func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32>
%result = "mhlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32>
return %result : tensor<2x2xf32>
}
// CHECK: linalg.generic
@ -536,7 +536,7 @@ func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
// CHECK-LABEL: func @convert_f32_to_i32
func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
%result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
return %result : tensor<2x2xi32>
}
// CHECK: linalg.generic
@ -550,7 +550,7 @@ func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @reverse
func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
%result = "xla_hlo.reverse"(%input) {
%result = "mhlo.reverse"(%input) {
dimensions = dense<1> : tensor<1xi64>
} : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %result : tensor<2x3xf32>

View File

@ -1,28 +1,28 @@
// RUN: mlir-hlo-opt %s -inline | FileCheck %s
// Test case: Basic test of inlining into xla_hlo.while.
// Test case: Basic test of inlining into mhlo.while.
// CHECK-LABEL: func @caller
// CHECK: "xla_hlo.while"{{.*}}( {
// CHECK: "mhlo.while"{{.*}}( {
// CHECK: }, {
// CHECK: "xla_hlo.exponential"
// CHECK: "mhlo.exponential"
// CHECK: })
// CHECK-LABEL: func @callee
func @caller(%arg0: tensor<f32>, %pred: tensor<i1>) -> tensor<f32> {
%0 = "xla_hlo.while"(%arg0) ( {
%0 = "mhlo.while"(%arg0) ( {
^entry(%unused: tensor<f32>):
"xla_hlo.return"(%pred) : (tensor<i1>) -> ()
"mhlo.return"(%pred) : (tensor<i1>) -> ()
}, {
^entry(%0: tensor<f32>):
%1 = call @callee(%0) : (tensor<f32>) -> (tensor<f32>)
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
"mhlo.return"(%1) : (tensor<f32>) -> ()
} ) : (tensor<f32>) -> (tensor<f32>)
return %0 : tensor<f32>
}
func @callee(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "xla_hlo.exponential"(%arg0) : (tensor<f32>) -> tensor<f32>
%0 = "mhlo.exponential"(%arg0) : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}

View File

@ -4,21 +4,21 @@
func @while(%arg0: tensor<i64>) -> tensor<i64> {
//CHECK: br ^bb1(%arg0 : tensor<i64>)
//CHECK: ^bb1([[VAL0:%.+]]: tensor<i64>):
//CHECK: [[VAL1:%.+]] = "xla_hlo.compare"([[VAL0]], [[VAL0]])
//CHECK: [[VAL1:%.+]] = "mhlo.compare"([[VAL0]], [[VAL0]])
//CHECK: [[VAL2:%.+]] = extract_element [[VAL1]][] : tensor<i1>
//CHECK: cond_br [[VAL2]], ^bb2([[VAL0]] : tensor<i64>), ^bb3([[VAL0]] : tensor<i64>)
//CHECK: ^bb2([[VAL3:%.+]]: tensor<i64>):
//CHECK: [[VAL4:%.+]] = xla_hlo.add [[VAL3]], [[VAL3]]
//CHECK: [[VAL4:%.+]] = mhlo.add [[VAL3]], [[VAL3]]
//CHECK: br ^bb1([[VAL4]] : tensor<i64>)
//CHECK: ^bb3([[VAL5:%.+]]: tensor<i64>):
%0 = "xla_hlo.while"(%arg0) ( {
%0 = "mhlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i64>):
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"mhlo.return"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i64>):
%1 = xla_hlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64>
"xla_hlo.return"(%1) : (tensor<i64>) -> ()
%1 = mhlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64>
"mhlo.return"(%1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
// CHECK-NEXT: return [[VAL5]]
@ -30,27 +30,27 @@ func @conditional(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: [[C0:%.+]] = constant dense<1.000000e+01> : tensor<f32>
%cst = constant dense<1.000000e+01> : tensor<f32>
// CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
%0 = "xla_hlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: [[VAL0:%.+]] = "mhlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
%0 = "mhlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor<i1>
// CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor<f32>), ^bb2(%arg0 : tensor<f32>)
%1 = "xla_hlo.if"(%0, %arg0, %arg0) ( {
%1 = "mhlo.if"(%0, %arg0, %arg0) ( {
^bb0(%arg1: tensor<f32>):
// CHECK: ^bb1([[VAL2:%.+]]: tensor<f32>):
// CHECK: [[VAL3:%.+]] = "xla_hlo.log"([[VAL2]]) : (tensor<f32>) -> tensor<f32>
// CHECK: [[VAL3:%.+]] = "mhlo.log"([[VAL2]]) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^bb3([[VAL3]] : tensor<f32>)
%2 = "xla_hlo.log"(%arg1) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
%2 = "mhlo.log"(%arg1) : (tensor<f32>) -> tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}, {
^bb0(%arg1: tensor<f32>):
// CHECK: ^bb2([[VAL4:%.+]]: tensor<f32>):
// CHECK: [[VAL5:%.+]] = "xla_hlo.exponential"([[VAL4]]) : (tensor<f32>) -> tensor<f32>
// CHECK: [[VAL5:%.+]] = "mhlo.exponential"([[VAL4]]) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^bb3([[VAL5]] : tensor<f32>)
%2 = "xla_hlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
%2 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: ^bb3([[VAL6:%.+]]: tensor<f32>):
@ -62,27 +62,27 @@ func @conditional(%arg0: tensor<f32>) -> tensor<f32> {
func @while_with_multiple_blocks_in_body(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor<i64>)
// CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
// CHECK: %1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %2 = extract_element %1[] : tensor<i1>
// CHECK: cond_br %2, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
// CHECK: ^[[BODY_ENTRY]](%3: tensor<i64>):
// CHECK: br ^[[BODY_SUCC:.+]](%3 : tensor<i64>)
// CHECK: ^[[BODY_SUCC]](%4: tensor<i64>):
// CHECK: %5 = xla_hlo.add %4, %4 : tensor<i64>
// CHECK: %5 = mhlo.add %4, %4 : tensor<i64>
// CHECK: br ^[[COND_ENTRY]](%5 : tensor<i64>)
// CHECK: ^[[EXIT]](%6: tensor<i64>):
// CHECK: return %6 : tensor<i64>
// CHECK: }
%0 = "xla_hlo.while"(%arg0) ( {
%0 = "mhlo.while"(%arg0) ( {
^cond_entry(%arg1: tensor<i64>):
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"mhlo.return"(%1) : (tensor<i1>) -> ()
}, {
^body_entry(%arg1: tensor<i64>):
br ^body_succ(%arg1: tensor<i64>)
^body_succ(%0: tensor<i64>):
%1 = xla_hlo.add %0, %0 : tensor<i64>
"xla_hlo.return"(%1) : (tensor<i64>) -> ()
%1 = mhlo.add %0, %0 : tensor<i64>
"mhlo.return"(%1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %0 : tensor<i64>
@ -94,7 +94,7 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
// CHECK: br ^[[COND_SUCC:.+]](%0 : tensor<i64>)
// CHECK: ^[[COND_SUCC]](%1: tensor<i64>):
// CHECK: %2 = "xla_hlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %2 = "mhlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %3 = extract_element %2[] : tensor<i1>
// CHECK: cond_br %3, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
// CHECK: ^[[BODY_ENTRY]](%4: tensor<i64>):
@ -102,15 +102,15 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: ^[[EXIT]](%5: tensor<i64>):
// CHECK: return %5 : tensor<i64>
// CHECK: }
%0 = "xla_hlo.while"(%arg0) ( {
%0 = "mhlo.while"(%arg0) ( {
^cond_entry(%arg1: tensor<i64>):
br ^cond_succ(%arg1: tensor<i64>)
^cond_succ(%0: tensor<i64>):
%1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
%1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"mhlo.return"(%1) : (tensor<i1>) -> ()
}, {
^body_entry(%arg1: tensor<i64>):
"xla_hlo.return"(%arg1) : (tensor<i64>) -> ()
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %0 : tensor<i64>
@ -123,24 +123,24 @@ func @conditional_with_multiple_blocks(%arg0: tensor<f32>, %arg1: tensor<f32>, %
// CHECK: ^[[THEN_ENTRY]](%1: tensor<f32>):
// CHECK: br ^[[THEN_SUCC:.+]](%1 : tensor<f32>)
// CHECK: ^[[THEN_SUCC]](%2: tensor<f32>):
// CHECK: %3 = "xla_hlo.log"(%2) : (tensor<f32>) -> tensor<f32>
// CHECK: %3 = "mhlo.log"(%2) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^[[EXIT:.+]](%3 : tensor<f32>)
// CHECK: ^[[ELSE_ENTRY]](%4: tensor<f32>):
// CHECK: %5 = "xla_hlo.exponential"(%4) : (tensor<f32>) -> tensor<f32>
// CHECK: %5 = "mhlo.exponential"(%4) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^[[EXIT]](%5 : tensor<f32>)
// CHECK: ^[[EXIT]](%6: tensor<f32>):
// CHECK: return %6 : tensor<f32>
// CHECK: }
%1 = "xla_hlo.if"(%pred, %arg0, %arg1) ( {
%1 = "mhlo.if"(%pred, %arg0, %arg1) ( {
^then_entry(%arg2: tensor<f32>):
br ^then_succ(%arg2: tensor<f32>)
^then_succ(%0: tensor<f32>):
%2 = "xla_hlo.log"(%0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
%2 = "mhlo.log"(%0) : (tensor<f32>) -> tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}, {
^else_entry(%arg2: tensor<f32>):
%2 = "xla_hlo.exponential"(%arg2) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
%2 = "mhlo.exponential"(%arg2) : (tensor<f32>) -> tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %1 : tensor<f32>
}

View File

@ -3,19 +3,19 @@
// CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: %0 = addf %arg0, %arg1 : tensor<4xf32>
%0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = "mhlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %1 = mulf %0, %arg1 : tensor<4xf32>
%1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%1 = "mhlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %2 = subf %1, %arg1 : tensor<4xf32>
%2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%2 = "mhlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %3 = divf %2, %arg1 : tensor<4xf32>
%3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%3 = "mhlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %4 = remf %3, %arg1 : tensor<4xf32>
%4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%4 = "mhlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: return %4 : tensor<4xf32>
return %4 : tensor<4xf32>
@ -24,19 +24,19 @@ func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf
// CHECK-LABEL: func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-NEXT: %0 = addi %arg0, %arg1 : tensor<4xi32>
%0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%0 = "mhlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %1 = muli %0, %arg1 : tensor<4xi32>
%1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%1 = "mhlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %2 = subi %1, %arg1 : tensor<4xi32>
%2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%2 = "mhlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %3 = divi_signed %2, %arg1 : tensor<4xi32>
%3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%3 = "mhlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %4 = remi_signed %3, %arg1 : tensor<4xi32>
%4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%4 = "mhlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: return %4 : tensor<4xi32>
return %4 : tensor<4xi32>
@ -45,17 +45,17 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32
// CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
// CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32>
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg0 : tensor<4xi32>
%1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
%1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg0 : tensor<4xi32>
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg0 : tensor<4xi32>
%3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
%3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg0 : tensor<4xi32>
%4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
%4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg0 : tensor<4xi32>
%5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
%5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
}
@ -63,28 +63,28 @@ func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi
// CHECK-LABEL: func @compare_float
func @compare_float(%arg0: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
// CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg0 : tensor<4xf32>
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %1 = cmpf "une", %arg0, %arg0 : tensor<4xf32>
%1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg0 : tensor<4xf32>
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg0 : tensor<4xf32>
%3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg0 : tensor<4xf32>
%4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg0 : tensor<4xf32>
%5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
}
// CHECK-LABEL: func @int_constant
func @int_constant() -> (tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>) {
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<i32>
%0 = "xla_hlo.constant"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>)
%0 = "mhlo.constant"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>)
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xi32>
%1 = "xla_hlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
%1 = "mhlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xi32>
%2 = "xla_hlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
%2 = "mhlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
return %0, %1, %2: tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
}
@ -92,11 +92,11 @@ func @int_constant() -> (tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>) {
// CHECK-LABEL: func @float_constant
func @float_constant() -> (tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>) {
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<f32>
%0 = "xla_hlo.constant"() {value = dense<0.0> : tensor<f32>} : () -> (tensor<f32>)
%0 = "mhlo.constant"() {value = dense<0.0> : tensor<f32>} : () -> (tensor<f32>)
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xf32>
%1 = "xla_hlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
%1 = "mhlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xf32>
%2 = "xla_hlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
%2 = "mhlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
return %0, %1, %2: tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
}
@ -105,7 +105,7 @@ func @float_constant() -> (tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>) {
// CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> {
func @iota.const.1() -> tensor<4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<4xi32>
return %0 : tensor<4xi32>
}
@ -113,7 +113,7 @@ func @iota.const.1() -> tensor<4xi32> {
// CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> {
func @iota.const.2() -> tensor<2x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
return %0 : tensor<2x4xi32>
}
@ -121,7 +121,7 @@ func @iota.const.2() -> tensor<2x4xi32> {
// CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> {
func @iota.const.3() -> tensor<2x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
%0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
return %0 : tensor<2x4xi32>
}
@ -129,7 +129,7 @@ func @iota.const.3() -> tensor<2x4xi32> {
// CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> {
func @iota.const.4() -> tensor<2x3x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
return %0 : tensor<2x3x4xi32>
}
@ -137,7 +137,7 @@ func @iota.const.4() -> tensor<2x3x4xi32> {
// CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> {
func @iota.const.5() -> tensor<2x3x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
%0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
return %0 : tensor<2x3x4xi32>
}
@ -145,7 +145,7 @@ func @iota.const.5() -> tensor<2x3x4xi32> {
// CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> {
func @iota.const.6() -> tensor<2x3x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
%0 = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
return %0 : tensor<2x3x4xi32>
}
@ -153,7 +153,7 @@ func @iota.const.6() -> tensor<2x3x4xi32> {
// CHECK-LABEL: func @iota.const.f32
func @iota.const.f32() -> tensor<4xf32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
// CHECK-NEXT: return %[[CST]] : tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -161,7 +161,7 @@ func @iota.const.f32() -> tensor<4xf32> {
// CHECK-LABEL: func @iota.const.f64
func @iota.const.f64() -> tensor<4xf64> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64>
// CHECK-NEXT: return %[[CST]] : tensor<4xf64>
return %0 : tensor<4xf64>
}
@ -169,7 +169,7 @@ func @iota.const.f64() -> tensor<4xf64> {
// CHECK-LABEL: func @iota.const.bf16
func @iota.const.bf16() -> tensor<4xbf16> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xbf16>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16>
// CHECK-NEXT: return %[[CST]] : tensor<4xbf16>
return %0 : tensor<4xbf16>
}
@ -178,8 +178,8 @@ func @iota.const.bf16() -> tensor<4xbf16> {
func @iota.const.complex.f32() -> tensor<4xcomplex<f32>> {
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf32>
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f32>>
// CHECK-NEXT: [[COMPLEX:%.*]] = "mhlo.complex"([[REAL]], [[IMAG]])
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f32>>
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f32>>
return %0 : tensor<4xcomplex<f32>>
}
@ -188,8 +188,8 @@ func @iota.const.complex.f32() -> tensor<4xcomplex<f32>> {
func @iota.const.complex.f64() -> tensor<4xcomplex<f64>> {
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf64>
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f64>>
// CHECK-NEXT: [[COMPLEX:%.*]] = "mhlo.complex"([[REAL]], [[IMAG]])
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f64>>
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f64>>
return %0 : tensor<4xcomplex<f64>>
}

View File

@ -396,9 +396,9 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m
"xla_lhlo.fusion"() ( {
%0 = tensor_load %input1 : memref<10xf32>
%1 = tensor_load %input2 : memref<10xf32>
%2 = "xla_hlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%3 = tensor_load %input3 : memref<10xf32>
%4 = "xla_hlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
tensor_store %4, %out : memref<10xf32>
"xla_lhlo.terminator"() : () -> ()
} ) : () -> ()
@ -803,15 +803,15 @@ func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %a
func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () {
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
"mhlo.return"(%max) : (tensor<f32>) -> ()
})
{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> ()
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
"mhlo.return"(%max) : (tensor<f32>) -> ()
})
{
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
@ -958,8 +958,8 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32
%updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () {
"xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
%add = xla_hlo.add %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%add) : (tensor<f32>) -> ()
%add = mhlo.add %lhs, %rhs : tensor<f32>
"mhlo.return"(%add) : (tensor<f32>) -> ()
}) {
scatter_dimension_numbers = {
update_window_dims = dense<[1]> : tensor<1xi64>,
@ -979,8 +979,8 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
^bb0(%a: tensor<f32>, %b: tensor<f32>):
%c = xla_hlo.add %a, %b : tensor<f32>
"xla_hlo.return"(%c) : (tensor<f32>) -> ()
%c = mhlo.add %a, %b : tensor<f32>
"mhlo.return"(%c) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
return
}
@ -991,8 +991,8 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref
// expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
^bb0(%a: tensor<f32>, %b: tensor<f32>):
%c = xla_hlo.add %a, %b : tensor<f32>
"xla_hlo.return"(%c) : (tensor<f32>) -> ()
%c = mhlo.add %a, %b : tensor<f32>
"mhlo.return"(%c) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<10xf32>) -> ()
return
}
@ -1012,8 +1012,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
return
}
@ -1025,8 +1025,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
return
}
@ -1038,8 +1038,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
return
}

View File

@ -2,14 +2,14 @@
// CHECK-LABEL: @add
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3
%4 = "xla_hlo.add"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3
%4 = "mhlo.add"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
@ -17,14 +17,14 @@ func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
// CHECK-LABEL: @add_unranked
func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3
%4 = "xla_hlo.add"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3
%4 = "mhlo.add"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<*xf32>, tensor<*xf32>
@ -32,14 +32,14 @@ func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
// CHECK-LABEL: @sub
func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3
%4 = "xla_hlo.subtract"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3
%4 = "mhlo.subtract"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
@ -47,14 +47,14 @@ func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
// CHECK-LABEL: @sub_unranked
func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3
%4 = "xla_hlo.subtract"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3
%4 = "mhlo.subtract"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<*xf32>, tensor<*xf32>
@ -62,18 +62,18 @@ func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
// CHECK-LABEL: @mul
func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
%4 = "xla_hlo.multiply"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg3
// CHECK-DAG: [[VAL2:%.+]] = mhlo.subtract [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply %arg0, %arg3
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL5:%.+]] = mhlo.add [[VAL3]], [[VAL4]]
%4 = "mhlo.multiply"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return %2, %5 : tensor<2xf32>, tensor<2xf32>
return %5, %6 : tensor<2xf32>, tensor<2xf32>
@ -81,18 +81,18 @@ func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
// CHECK-LABEL: @mul_unranked
func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
%4 = "xla_hlo.multiply"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg3
// CHECK-DAG: [[VAL2:%.+]] = mhlo.subtract [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply %arg0, %arg3
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL5:%.+]] = mhlo.add [[VAL3]], [[VAL4]]
%4 = "mhlo.multiply"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return %2, %5 : tensor<*xf32>, tensor<*xf32>
return %5, %6 : tensor<*xf32>, tensor<*xf32>
@ -100,36 +100,36 @@ func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
// CHECK-LABEL: @div
func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3)
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.negate"(%arg3)
// Compute the numerator's real component:
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]]
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]]
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL2:%.+]] = mhlo.multiply %arg1, [[VAL0]]
// CHECK-DAG: [[VAL3:%.+]] = mhlo.subtract [[VAL1]], [[VAL2]]
// Compute the real valued denominator as rhs * con(rhs):
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]]
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]]
// Compute the numerator's imaginary component:
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
// CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]]
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
// CHECK-DAG: [[VAL7:%.+]] = mhlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL8:%.+]] = mhlo.multiply %arg0, [[VAL0]]
// CHECK-DAG: [[VAL9:%.+]] = mhlo.add [[VAL8]], [[VAL7]]
// Divide the numerator by the real valued denominator.
// CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]]
// CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]]
%4 = "xla_hlo.divide"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]]
// CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]]
%4 = "mhlo.divide"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL10]], [[VAL11]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
@ -139,36 +139,36 @@ func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
// CHECK-LABEL: @div_unranked
func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3)
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.negate"(%arg3)
// Compute the numerator's real component:
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]]
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]]
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL2:%.+]] = mhlo.multiply %arg1, [[VAL0]]
// CHECK-DAG: [[VAL3:%.+]] = mhlo.subtract [[VAL1]], [[VAL2]]
// Compute the real valued denominator as rhs * con(rhs):
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]]
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]]
// Compute the numerator's imaginary component:
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
// CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]]
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
// CHECK-DAG: [[VAL7:%.+]] = mhlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL8:%.+]] = mhlo.multiply %arg0, [[VAL0]]
// CHECK-DAG: [[VAL9:%.+]] = mhlo.add [[VAL8]], [[VAL7]]
// Divide the numerator by the real valued denominator.
// CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]]
// CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]]
%4 = "xla_hlo.divide"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]]
// CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]]
%4 = "mhlo.divide"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL10]], [[VAL11]]
return %5, %6 : tensor<*xf32>, tensor<*xf32>
@ -176,14 +176,14 @@ func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
// CHECK-LABEL: @abs
func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg0
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg1
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.add [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.sqrt"([[VAL2]])
%1 = "xla_hlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%2 = "xla_hlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg0
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg1
// CHECK-DAG: [[VAL2:%.+]] = mhlo.add [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = "mhlo.sqrt"([[VAL2]])
%1 = "mhlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL3]]
return %2 : tensor<2xf32>
@ -191,16 +191,16 @@ func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
// CHECK-LABEL: @exp
func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0)
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1)
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1)
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]]
%1 = "xla_hlo.exponential"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%2 = "xla_hlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%3 = "xla_hlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0)
// CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1)
// CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1)
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]]
%1 = "mhlo.exponential"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%3 = "mhlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL3]], [[VAL4]]
return %2, %3 : tensor<2xf32>, tensor<2xf32>
@ -208,16 +208,16 @@ func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tenso
// CHECK-LABEL: @exp_unranked
func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0)
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1)
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1)
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]]
%1 = "xla_hlo.exponential"(%0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%2 = "xla_hlo.real"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%3 = "xla_hlo.imag"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0)
// CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1)
// CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1)
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]]
%1 = "mhlo.exponential"(%0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%2 = "mhlo.real"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%3 = "mhlo.imag"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL3]], [[VAL4]]
return %2, %3 : tensor<*xf32>, tensor<*xf32>

View File

@ -2,10 +2,10 @@
// CHECK-LABEL: @testDebatch1
func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> {
// CHECK-DAG: [[R0:%.+]] = "xla_hlo.reshape"(%arg0) : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
// CHECK-DAG: [[R1:%.+]] = "xla_hlo.dot"([[R0]], %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
// CHECK: [[R2:%.+]] = "xla_hlo.reshape"([[R1]]) : (tensor<1x3xf32>) -> tensor<1x1x3xf32>
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32>
// CHECK-DAG: [[R0:%.+]] = "mhlo.reshape"(%arg0) : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
// CHECK-DAG: [[R1:%.+]] = "mhlo.dot"([[R0]], %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
// CHECK: [[R2:%.+]] = "mhlo.reshape"([[R1]]) : (tensor<1x3xf32>) -> tensor<1x1x3xf32>
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32>
return %0 : tensor<1x1x3xf32>
}
@ -14,13 +14,13 @@ func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1
// CHECK-LABEL: @testDebatch2
func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3x1x1xf32> {
// CHECK-DAG: [[R0:%.+]] = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
// CHECK-DAG: [[R1:%.+]] = "xla_hlo.transpose"(%arg1) {permutation = dense<[2, 0, 1]> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32>
// CHECK-DAG: [[R2:%.+]] = "xla_hlo.reshape"([[R1]]) : (tensor<2x1x1xf32>) -> tensor<2x1xf32>
// CHECK-DAG: [[R3:%.+]] = "xla_hlo.dot"([[R0]], [[R2]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32>
// CHECK: [[R4:%.+]] = "xla_hlo.reshape"([[R3]]) : (tensor<3x1xf32>) -> tensor<3x1x1xf32>
// CHECK-DAG: [[R0:%.+]] = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
// CHECK-DAG: [[R1:%.+]] = "mhlo.transpose"(%arg1) {permutation = dense<[2, 0, 1]> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32>
// CHECK-DAG: [[R2:%.+]] = "mhlo.reshape"([[R1]]) : (tensor<2x1x1xf32>) -> tensor<2x1xf32>
// CHECK-DAG: [[R3:%.+]] = "mhlo.dot"([[R0]], [[R2]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32>
// CHECK: [[R4:%.+]] = "mhlo.reshape"([[R3]]) : (tensor<3x1xf32>) -> tensor<3x1x1xf32>
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32>
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32>
return %0 : tensor<3x1x1xf32>
}
@ -28,8 +28,8 @@ func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3
// CHECK-LABEL: @testBatchPassthrough
func @testBatchPassthrough(%arg0: tensor<2x2x3xf32>, %arg1: tensor<2x1x2xf32>) -> tensor<3x2x1xf32> {
// CHECK-NEXT: "xla_hlo.dot_general"(%arg0, %arg1)
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0]> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<[0]> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<3x2x1xf32>
// CHECK-NEXT: "mhlo.dot_general"(%arg0, %arg1)
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0]> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<[0]> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<3x2x1xf32>
return %0 : tensor<3x2x1xf32>
}

View File

@ -3,9 +3,9 @@
// CHECK-LABEL: @clampBroadcast
// CHECK-SAME: (%[[MIN:.+]]: tensor<f32>, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor<f32>)
func @clampBroadcast(%min: tensor<f32>, %value: tensor<4xf32>, %max: tensor<f32>) -> tensor<4xf32> {
// CHECK-DAG: %[[MIN_BC:.+]] = "xla_hlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
// CHECK-DAG: %[[MAX_BC:.+]] = "xla_hlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
// CHECK: "xla_hlo.clamp"(%[[MIN_BC]], %[[VAL]], %[[MAX_BC]]) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = "xla_hlo.clamp"(%min, %value, %max) : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
// CHECK-DAG: %[[MIN_BC:.+]] = "mhlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
// CHECK-DAG: %[[MAX_BC:.+]] = "mhlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
// CHECK: "mhlo.clamp"(%[[MIN_BC]], %[[VAL]], %[[MAX_BC]]) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = "mhlo.clamp"(%min, %value, %max) : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}

File diff suppressed because it is too large Load Diff

View File

@ -4,11 +4,11 @@
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>)
// CHECK: return %[[ARG0]]
func @noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%2 = "xla_hlo.reduce"(%arg0, %0) ( {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%2 = "mhlo.reduce"(%arg0, %0) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%4 = xla_hlo.add %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
%4 = mhlo.add %arg1, %arg2 : tensor<f32>
"mhlo.return"(%4) : (tensor<f32>) -> ()
}) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
return %2 : tensor<4x8xf32>
}

View File

@ -2,9 +2,9 @@
// CHECK-LABEL: func @const_fold_collapse_to_scalar
func @const_fold_collapse_to_scalar() -> tensor<i32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
%cst = xla_hlo.constant dense<42> : tensor<1x1xi32>
%0 = "xla_hlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor<i32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
%cst = mhlo.constant dense<42> : tensor<1x1xi32>
%0 = "mhlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor<i32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i32>
}
@ -13,9 +13,9 @@ func @const_fold_collapse_to_scalar() -> tensor<i32> {
// CHECK-LABEL: func @const_fold_collapse_to_tensor
func @const_fold_collapse_to_tensor() -> tensor<2xi32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<2xi32>
%cst = xla_hlo.constant dense<42> : tensor<1x2xi32>
%0 = "xla_hlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<2xi32>
%cst = mhlo.constant dense<42> : tensor<1x2xi32>
%0 = "mhlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2xi32>
}
@ -24,9 +24,9 @@ func @const_fold_collapse_to_tensor() -> tensor<2xi32> {
// CHECK-LABEL: func @const_fold_expand
func @const_fold_expand() -> tensor<1xi32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<1xi32>
%cst = xla_hlo.constant dense<42> : tensor<i32>
%0 = "xla_hlo.reshape"(%cst) : (tensor<i32>) -> tensor<1xi32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<1xi32>
%cst = mhlo.constant dense<42> : tensor<i32>
%0 = "mhlo.reshape"(%cst) : (tensor<i32>) -> tensor<1xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<1xi32>
}
@ -35,9 +35,9 @@ func @const_fold_expand() -> tensor<1xi32> {
// CHECK-LABEL: func @const_fold_nontrivial
func @const_fold_nontrivial() -> tensor<16xi64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<16xi64>
%cst = xla_hlo.constant dense<42> : tensor<4x4xi64>
%0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64>
%cst = mhlo.constant dense<42> : tensor<4x4xi64>
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<16xi64>
}
@ -46,9 +46,9 @@ func @const_fold_nontrivial() -> tensor<16xi64> {
// CHECK-LABEL: func @const_fold_flatten
func @const_fold_flatten() -> tensor<16xi64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<16xi64>
%cst = xla_hlo.constant dense<42> : tensor<4x4xi64>
%0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64>
%cst = mhlo.constant dense<42> : tensor<4x4xi64>
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<16xi64>
}
@ -57,9 +57,9 @@ func @const_fold_flatten() -> tensor<16xi64> {
// CHECK-LABEL: func @const_fold_6
func @const_fold_6() -> tensor<6xi32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
%cst = xla_hlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
%0 = "xla_hlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
%cst = mhlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
%0 = "mhlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<6xi32>
}
@ -68,11 +68,11 @@ func @const_fold_6() -> tensor<6xi32> {
// CHECK-LABEL: func @const_fold_same_shape
func @const_fold_same_shape() -> tensor<2x3xi32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[
// CHECK-SAME: [1, 2, 3], [4, 5, 6]
// CHECK-SAME: ]> : tensor<2x3xi32>
%cst = xla_hlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
%0 = "xla_hlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32>
%cst = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
%0 = "mhlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2x3xi32>
}
@ -81,9 +81,9 @@ func @const_fold_same_shape() -> tensor<2x3xi32> {
// CHECK-LABEL: func @const_fold_float
func @const_fold_float() -> tensor<16xf64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64>
%cst = xla_hlo.constant dense<4.2> : tensor<4x4xf64>
%0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64>
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64>
%cst = mhlo.constant dense<4.2> : tensor<4x4xf64>
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<16xf64>
}
@ -94,7 +94,7 @@ func @const_fold_float() -> tensor<16xf64> {
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK-NEXT: return [[ARG]]
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32>
return %0 : tensor<2x3xi32>
}
@ -103,10 +103,10 @@ func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK-LABEL: func @non_const_chained_reshape
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, tensor<6xi32>) {
// CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
// CHECK-NEXT: "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// CHECK-NEXT: "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
return %0, %1 : tensor<3x2xi32>, tensor<6xi32> // return both so nothing is removed
}
@ -115,9 +115,9 @@ func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, ten
// CHECK-LABEL: func @non_const_chained_reshape_unused_parent
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<6xi32> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
// CHECK-NEXT: [[RES:%.+]] = "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
// CHECK-NEXT: return [[RES]]
return %1 : tensor<6xi32>
}
@ -127,8 +127,8 @@ func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<
// CHECK-LABEL: func @non_const_chained_reshape_becomes_noop
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32>
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32>
// CHECK-NEXT: return [[ARG]]
return %1 : tensor<2x3xi32>
}
@ -138,12 +138,12 @@ func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2
// CHECK-LABEL: func @non_const_many_chained_reshapes
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_many_chained_reshapes(%arg : tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32>
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32>
%1 = "xla_hlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32>
%2 = "xla_hlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32>
%3 = "xla_hlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32>
%4 = "xla_hlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32>
// CHECK-NEXT: [[RES:%.+]] = "mhlo.reshape"([[ARG]]) : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32>
%0 = "mhlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32>
%1 = "mhlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32>
%2 = "mhlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32>
%3 = "mhlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32>
%4 = "mhlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32>
// CHECK-NEXT: return [[RES]]
return %4 : tensor<1x2x4x3xi32>
}

View File

@ -3,7 +3,7 @@
// CHECK-LABEL: func @noop
// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>)
func @noop(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> {
%0 = "xla_hlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
%0 = "mhlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
// CHECK: return %[[ARG0]]
return %0 : tensor<1x2xf32>
}

View File

@ -4,27 +4,27 @@
// CHECK-LABEL: func @sink_const_to_while
func @sink_const_to_while(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-NEXT: xla_hlo.while
%c0 = xla_hlo.constant dense<1> : tensor<i64>
%c1 = xla_hlo.constant dense<2> : tensor<i64>
%0 = "xla_hlo.while"(%arg0) ( {
// CHECK-NEXT: mhlo.while
%c0 = mhlo.constant dense<1> : tensor<i64>
%c1 = mhlo.constant dense<2> : tensor<i64>
%0 = "mhlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i64>):
// CHECK: %[[ARG1A:.+]]: tensor<i64>
// CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor<i64>
// CHECK: "xla_hlo.compare"(%[[C0]], %[[ARG1A]])
%1 = "xla_hlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
// CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor<i64>
// CHECK: "mhlo.compare"(%[[C0]], %[[ARG1A]])
%1 = "mhlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"mhlo.return"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i64>):
// CHECK: %[[ARG1B:.+]]: tensor<i64>
// CHECK-DAG: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor<i64>
// CHECK-DAG: %[[ADD0:.+]] = xla_hlo.add %[[ARG1B]], %[[ARG1B]]
%2 = xla_hlo.add %arg1, %arg1 : tensor<i64>
// CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], %[[ADD0]]
%3 = xla_hlo.add %c1, %2 : tensor<i64>
// CHECK: %[[ADD2:.+]] = xla_hlo.add %[[C1]], %[[ADD1]]
%4 = xla_hlo.add %c1, %3 : tensor<i64>
"xla_hlo.return"(%4) : (tensor<i64>) -> ()
// CHECK-DAG: %[[C1:.+]] = mhlo.constant dense<2> : tensor<i64>
// CHECK-DAG: %[[ADD0:.+]] = mhlo.add %[[ARG1B]], %[[ARG1B]]
%2 = mhlo.add %arg1, %arg1 : tensor<i64>
// CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]], %[[ADD0]]
%3 = mhlo.add %c1, %2 : tensor<i64>
// CHECK: %[[ADD2:.+]] = mhlo.add %[[C1]], %[[ADD1]]
%4 = mhlo.add %c1, %3 : tensor<i64>
"mhlo.return"(%4) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %0 : tensor<i64>
}
@ -33,28 +33,28 @@ func @sink_const_to_while(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-LABEL: func @sink_const_to_conditional
func @sink_const_to_conditional(%arg0: tensor<i64>) -> tensor<i64> {
%c0 = xla_hlo.constant dense<1> : tensor<i64>
%c1 = xla_hlo.constant dense<2> : tensor<i64>
%0 = "xla_hlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
%1 = "xla_hlo.tuple"(%arg0) : (tensor<i64>) -> tuple<tensor<i64>>
// CHECK: xla_hlo.if
%2 = "xla_hlo.if"(%0, %1, %1) ( {
%c0 = mhlo.constant dense<1> : tensor<i64>
%c1 = mhlo.constant dense<2> : tensor<i64>
%0 = "mhlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
%1 = "mhlo.tuple"(%arg0) : (tensor<i64>) -> tuple<tensor<i64>>
// CHECK: mhlo.if
%2 = "mhlo.if"(%0, %1, %1) ( {
^bb0(%arg1: tuple<tensor<i64>>):
// CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor<i64>
%3 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
// CHECK: %[[ADD0:.+]] = xla_hlo.add %[[C0]],
%4 = xla_hlo.add %c0, %3 : tensor<i64>
%5 = "xla_hlo.tuple"(%4) : (tensor<i64>) -> tuple<tensor<i64>>
"xla_hlo.return"(%5) : (tuple<tensor<i64>>) -> ()
// CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor<i64>
%3 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
// CHECK: %[[ADD0:.+]] = mhlo.add %[[C0]],
%4 = mhlo.add %c0, %3 : tensor<i64>
%5 = "mhlo.tuple"(%4) : (tensor<i64>) -> tuple<tensor<i64>>
"mhlo.return"(%5) : (tuple<tensor<i64>>) -> ()
}, {
^bb0(%arg1: tuple<tensor<i64>>):
// CHECK: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor<i64>
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
// CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]],
%7 = xla_hlo.add %c1, %6 : tensor<i64>
%8 = "xla_hlo.tuple"(%7) : (tensor<i64>) -> tuple<tensor<i64>>
"xla_hlo.return"(%8) : (tuple<tensor<i64>>) -> ()
// CHECK: %[[C1:.+]] = mhlo.constant dense<2> : tensor<i64>
%6 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
// CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]],
%7 = mhlo.add %c1, %6 : tensor<i64>
%8 = "mhlo.tuple"(%7) : (tensor<i64>) -> tuple<tensor<i64>>
"mhlo.return"(%8) : (tuple<tensor<i64>>) -> ()
}) : (tensor<i1>, tuple<tensor<i64>>, tuple<tensor<i64>>) -> tuple<tensor<i64>>
%9 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
%9 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
return %9 : tensor<i64>
}

View File

@ -3,7 +3,7 @@
// CHECK-LABEL: func @remove_noop
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> {
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32>
%0 = "mhlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32>
// CHECK-NEXT: return [[ARG]]
return %0 : tensor<2x3x9x5xi32>
}
@ -13,8 +13,8 @@ func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> {
// CHECK-LABEL: func @keep_real_transpose
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
// CHECK-NEXT: "xla_hlo.transpose"([[ARG]])
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
// CHECK-NEXT: "mhlo.transpose"([[ARG]])
%0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
return %0 : tensor<3x2x5x9xi32>
}
@ -23,7 +23,7 @@ func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
// CHECK-LABEL: func @keep_same_shape_real_transpose
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @keep_same_shape_real_transpose(%arg : tensor<4x4xi32>) -> tensor<4x4xi32> {
// CHECK-NEXT: "xla_hlo.transpose"([[ARG]])
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32>
// CHECK-NEXT: "mhlo.transpose"([[ARG]])
%0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32>
return %0 : tensor<4x4xi32>
}

View File

@ -4,7 +4,7 @@
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @fold_access(%arg : tensor<i32>) -> tensor<i32> {
// CHECK-NEXT: return [[ARG]]
%tuple = "xla_hlo.tuple"(%arg) : (tensor<i32>) -> tuple<tensor<i32>>
%element = "xla_hlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
%tuple = "mhlo.tuple"(%arg) : (tensor<i32>) -> tuple<tensor<i32>>
%element = "mhlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
return %element : tensor<i32>
}

View File

@ -10,19 +10,19 @@ func @batchNormInference_2D_inner_features(
%x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
-> (tensor<4x256xf32>) {
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.001000e-05> : tensor<f32>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32>
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.001000e-05> : tensor<f32>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32>
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} :
(tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>,
tensor<256xf32>) -> tensor<4x256xf32>
@ -36,12 +36,12 @@ func @batchNormInference_2D_inner_features(
// the verifier to enforce the rest.
// CHECK-SAME: %[[X:[^:]+]]
// CHECK-SAME: %[[SCALE:[^:]+]]
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32>
func @batchNormInference_4D_middle_features(
%x: tensor<3x4x256x6xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
-> (tensor<3x4x256x6xf32>) {
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 1.001000e-05 : f32, feature_index = 2 : i64} :
(tensor<3x4x256x6xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>,
tensor<256xf32>) -> tensor<3x4x256x6xf32>
@ -51,12 +51,12 @@ func @batchNormInference_4D_middle_features(
// -----
// CHECK-LABEL: @batchNormInference_f64
// Validate that epsilon is properly promoted to f64
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<f64>
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f64>
func @batchNormInference_f64(
%x: tensor<4x256xf64>, %scale: tensor<256xf64>, %offset: tensor<256xf64>,
%mean: tensor<256xf64>, %variance: tensor<256xf64>)
-> (tensor<4x256xf64>) {
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 1.0 : f32, feature_index = 1 : i64} :
(tensor<4x256xf64>, tensor<256xf64>, tensor<256xf64>, tensor<256xf64>,
tensor<256xf64>) -> tensor<4x256xf64>
@ -66,12 +66,12 @@ func @batchNormInference_f64(
// -----
// CHECK-LABEL: @batchNormInference_f16
// Validate that epsilon is properly promoted to f64
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<f16>
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f16>
func @batchNormInference_f16(
%x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>,
%mean: tensor<256xf16>, %variance: tensor<256xf16>)
-> (tensor<4x256xf16>) {
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 1.0 : f32, feature_index = 1 : i64} :
(tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>,
tensor<256xf16>) -> tensor<4x256xf16>
@ -85,7 +85,7 @@ func @batchNormInference_f16_overflow(
%mean: tensor<256xf16>, %variance: tensor<256xf16>)
-> (tensor<4x256xf16>) {
// expected-warning @+1 {{Could not convert batch_norm epsilon to target fp type: opStatus = 24}}
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 0.00000001 : f32, feature_index = 1 : i64} :
(tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>,
tensor<256xf16>) -> tensor<4x256xf16>
@ -108,26 +108,26 @@ func @batchNormInference_dynamic_shape(
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32>
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<?x?x?x?xf32>
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<?x?x?x?xf32>
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 0.001 : f32, feature_index = 1 : i64} :
(tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>,
tensor<?xf32>) -> tensor<?x?x?x?xf32>

View File

@ -2,14 +2,14 @@
// CHECK-LABEL: func @multi_outputs_same
func @multi_outputs_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = "xla_hlo.add"(%1, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET:.*]]:2 = "xla_hlo.fusion"
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.subtract
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.return
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = "mhlo.add"(%1, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET:.*]]:2 = "mhlo.fusion"
// CHECK-NEXT: mhlo.add
// CHECK-NEXT: mhlo.subtract
// CHECK-NEXT: mhlo.add
// CHECK-NEXT: mhlo.return
return %1, %2 : tensor<?x?xf32>, tensor<?x?xf32>
}
@ -17,18 +17,18 @@ func @multi_outputs_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (ten
// CHECK-LABEL: func @multi_outputs_same_2
func @multi_outputs_same_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = "xla_hlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "xla_hlo.abs"(%arg1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = "xla_hlo.add"(%0, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%3 = "xla_hlo.abs"(%0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%4 = "xla_hlo.abs"(%1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET:.*]]:3 = "xla_hlo.fusion"
// CHECK-NEXT: xla_hlo.abs
// CHECK-NEXT: xla_hlo.abs
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.abs
// CHECK-NEXT: xla_hlo.abs
// CHECK-NEXT: xla_hlo.return
%0 = "mhlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "mhlo.abs"(%arg1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = "mhlo.add"(%0, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%3 = "mhlo.abs"(%0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%4 = "mhlo.abs"(%1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET:.*]]:3 = "mhlo.fusion"
// CHECK-NEXT: mhlo.abs
// CHECK-NEXT: mhlo.abs
// CHECK-NEXT: mhlo.add
// CHECK-NEXT: mhlo.abs
// CHECK-NEXT: mhlo.abs
// CHECK-NEXT: mhlo.return
return %2, %3, %4 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
}
@ -36,9 +36,9 @@ func @multi_outputs_same_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (t
// CHECK-LABEL: func @multi_outputs_not_sure_same
func @multi_outputs_not_sure_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = "xla_hlo.add"(%arg0, %arg0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: xla_hlo.fusion
%1 = "xla_hlo.subtract"(%arg1, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%0 = "mhlo.add"(%arg0, %arg0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: mhlo.fusion
%1 = "mhlo.subtract"(%arg1, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
}
@ -46,25 +46,25 @@ func @multi_outputs_not_sure_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>
// CHECK-LABEL: func @reduce
func @reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET0:.*]] = "xla_hlo.fusion"
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.subtract
// CHECK-NEXT: xla_hlo.return
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET0:.*]] = "mhlo.fusion"
// CHECK-NEXT: mhlo.add
// CHECK-NEXT: mhlo.subtract
// CHECK-NEXT: mhlo.return
// Currently we do not support fuse arguments and ops without direct producer-consumer
// relationship. Thus Reduce Op should not be fused with above two ops.
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%3 = "xla_hlo.reduce"(%arg0, %2) ( {
%2 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%3 = "mhlo.reduce"(%arg0, %2) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%4 = "xla_hlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
%4 = "mhlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"mhlo.return"(%4) : (tensor<f32>) -> ()
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
%4 = "xla_hlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%4 = "mhlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// Above two ops should not be fused since reduce op can not be
// fused with its consumer.
// CHECK-NOT: xla_hlo.fusion
// CHECK-NOT: mhlo.fusion
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
}
@ -73,25 +73,25 @@ func @reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>
// CHECK-LABEL: func @reduce_2
func @reduce_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%3 = "xla_hlo.reduce"(%1, %2) ( {
%2 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%3 = "mhlo.reduce"(%1, %2) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%4 = "xla_hlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
%4 = "mhlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"mhlo.return"(%4) : (tensor<f32>) -> ()
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
// CHECK: %[[RET0:.*]]:2 = "xla_hlo.fusion"
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.subtract
// CHECK-NEXT: xla_hlo.constant
// CHECK-NEXT: xla_hlo.reduce
// CHECK: xla_hlo.return
// CHECK: %[[RET0:.*]]:2 = "mhlo.fusion"
// CHECK-NEXT: mhlo.add
// CHECK-NEXT: mhlo.subtract
// CHECK-NEXT: mhlo.constant
// CHECK-NEXT: mhlo.reduce
// CHECK: mhlo.return
// Following op should not be fused with the above ops since reduce op can not be
// fused with its consumer.
// CHECK-NOT: xla_hlo.fusion
%4 = "xla_hlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK-NOT: mhlo.fusion
%4 = "mhlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
}

View File

@ -9,15 +9,15 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
%num_elements = shape.num_elements %shape
%num_elements_as_index = shape.size_to_index %num_elements
%flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex>
%flat_a = "xla_hlo.dynamic_reshape"(%a, %flat_shape)
%flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape)
: (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// Apply operation.
%flat_b = "xla_hlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
%flat_b = "mhlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
// Restore original shape.
%shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor<?xindex>
%b = "xla_hlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
%b = "mhlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
return %b : tensor<*xf32>
@ -33,12 +33,12 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
// CHECK-NEXT: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLAT_B:.*]] = "xla_hlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: return %[[B]] : tensor<*xf32>
%b = "xla_hlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
%b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
return %b : tensor<*xf32>
}
@ -48,9 +48,9 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-LABEL: @sqrt_ranked
// CHECK-SAME: (%[[A:.*]]: tensor<3x?xf32>)
func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> {
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32>
// CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32>
// CHECK-NEXT: return %[[B]] : tensor<3x?xf32>
%b = "xla_hlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32>
%b = "mhlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32>
return %b : tensor<3x?xf32>
}
@ -60,9 +60,9 @@ func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> {
// CHECK-LABEL: @sqrt_static
// CHECK-SAME: (%[[A:.*]]: tensor<2x3xf32>)
func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: return %[[B]] : tensor<2x3xf32>
%b = "xla_hlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%b = "mhlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %b : tensor<2x3xf32>
}
@ -77,12 +77,12 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
// CHECK: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_B:.*]] = "xla_hlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_RESULT:.*]] = xla_hlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
// CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESULT]] : tensor<*xf32>
%result = xla_hlo.add %a, %b : tensor<*xf32>
%result = mhlo.add %a, %b : tensor<*xf32>
return %result : tensor<*xf32>
}

View File

@ -19,7 +19,7 @@ module attributes {tf.versions = {producer = 888 : i32}} {
// CHECK-LABEL: func @_func
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xi32>,
// CHECK-SAME: %[[ARG1:.*]]: tensor<?xi32> {xla_hlo.is_same_data_across_replicas}
// CHECK-SAME: %[[ARG1:.*]]: tensor<?xi32> {mhlo.is_same_data_across_replicas}
// CHECK-SAME: %[[ARG2:.*]]: tensor<?xi32>)
func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>, %arg2: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
@ -54,9 +54,9 @@ module attributes {tf.versions = {producer = 888 : i32}} {
}
// CHECK-LABEL: func @_func
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xi32> {xla_hlo.is_same_data_across_replicas},
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xi32> {mhlo.is_same_data_across_replicas},
// CHECK-SAME: %[[ARG1:.*]]: tensor<?xi32>,
// CHECK-SAME: %[[ARG2:.*]]: tensor<!tf.resource<tensor<?xi32>>> {xla_hlo.is_same_data_across_replicas}
// CHECK-SAME: %[[ARG2:.*]]: tensor<!tf.resource<tensor<?xi32>>> {mhlo.is_same_data_across_replicas}
func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>, %arg2: tensor<!tf.resource<tensor<?xi32>>>) -> tensor<?xi32> {
%0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
@ -78,7 +78,7 @@ module attributes {tf.versions = {producer = 888 : i32}} {
}
// CHECK-LABEL: func @_func
// CHECK-NOT: xla_hlo.is_same_data_across_replicas
// CHECK-NOT: mhlo.is_same_data_across_replicas
func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>

View File

@ -17,8 +17,8 @@ func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tenso
}
func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = xla_hlo.add %arg0, %arg0 : tensor<2xi32>
%1 = xla_hlo.add %0, %arg0 : tensor<2xi32>
%0 = mhlo.add %arg0, %arg0 : tensor<2xi32>
%1 = mhlo.add %0, %arg0 : tensor<2xi32>
return %1 : tensor<2xi32>
}
@ -33,7 +33,7 @@ func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi3
}
func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32>
%0 = mhlo.divide %arg0, %arg0 : tensor<2xi32>
return %0 : tensor<2xi32>
}
@ -43,7 +43,7 @@ func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
}
func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
%0 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
%0 = mhlo.shift_left %arg0, %arg1 : tensor<4xi32>
return %0 : tensor<4xi32>
}
@ -53,17 +53,17 @@ func @div_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
}
func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%0 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
%0 = mhlo.maximum %arg0, %arg1 : tensor<4xf32>
return %0 : tensor<4xf32>
}
func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%0 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
%0 = mhlo.minimum %arg0, %arg1 : tensor<4xf32>
return %0 : tensor<4xf32>
}
func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32>
%0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32>
return %0 : tensor<2xi32>
}
@ -73,7 +73,7 @@ func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
}
func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32>
%0 = mhlo.divide %arg0, %arg0 : tensor<2xi32>
return %0 : tensor<2xi32>
}
@ -83,7 +83,7 @@ func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor
}
func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32>
%0 = mhlo.subtract %arg0, %arg0 : tensor<2xi32>
return %0 : tensor<2xi32>
}
@ -93,7 +93,7 @@ func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
}
func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
%0 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
%0 = mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
return %0 : tensor<4xi32>
}
@ -103,7 +103,7 @@ func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> ten
}
func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> {
%0 = xla_hlo.and %arg0, %arg0 : tensor<2xi1>
%0 = mhlo.and %arg0, %arg0 : tensor<2xi1>
return %0 : tensor<2xi1>
}
@ -118,7 +118,7 @@ func @and_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
}
func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> {
%0 = xla_hlo.or %arg0, %arg0 : tensor<2xi1>
%0 = mhlo.or %arg0, %arg0 : tensor<2xi1>
return %0 : tensor<2xi1>
}
@ -133,7 +133,7 @@ func @or_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
}
func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
%0 = xla_hlo.or %arg0, %arg1 : tensor<4xi32>
%0 = mhlo.or %arg0, %arg1 : tensor<4xi32>
return %0 : tensor<4xi32>
}
@ -148,7 +148,7 @@ func @bitwise_or_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?
}
func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
%0 = xla_hlo.and %arg0, %arg1 : tensor<4xi32>
%0 = mhlo.and %arg0, %arg1 : tensor<4xi32>
return %0 : tensor<4xi32>
}
@ -163,69 +163,69 @@ func @bitwise_and_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<
}
func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = xla_hlo.power %arg0, %arg0 : tensor<2xf32>
%0 = mhlo.power %arg0, %arg0 : tensor<2xf32>
return %0 : tensor<2xf32>
}
func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = xla_hlo.power %arg0, %arg0 : tensor<?xf32>
%0 = mhlo.power %arg0, %arg0 : tensor<?xf32>
return %0 : tensor<?xf32>
}
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
%0 = xla_hlo.constant dense<0> : tensor<2x3xi32>
%0 = mhlo.constant dense<0> : tensor<2x3xi32>
%1 = "xla_chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
%2 = xla_hlo.constant dense<0> : tensor<3xi32>
%2 = mhlo.constant dense<0> : tensor<3xi32>
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%6 = "xla_hlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%7 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
%8 = xla_hlo.constant dense<1> : tensor<3xi32>
%9 = xla_hlo.subtract %7, %8 : tensor<3xi32>
%6 = "mhlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%7 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
%8 = mhlo.constant dense<1> : tensor<3xi32>
%9 = mhlo.subtract %7, %8 : tensor<3xi32>
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%12 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
%11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%12 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
%13 = "xla_chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%14 = "mhlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
return %14 : tensor<2x3xi32>
}
func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
%0 = xla_hlo.constant dense<0> : tensor<3xi32>
%1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
%2 = xla_hlo.constant dense<0> : tensor<2x3xi32>
%0 = mhlo.constant dense<0> : tensor<3xi32>
%1 = "mhlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
%2 = mhlo.constant dense<0> : tensor<2x3xi32>
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%6 = "xla_hlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
%7 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%8 = xla_hlo.constant dense<1> : tensor<2x3xi32>
%9 = xla_hlo.subtract %7, %8 : tensor<2x3xi32>
%6 = "mhlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
%7 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%8 = mhlo.constant dense<1> : tensor<2x3xi32>
%9 = mhlo.subtract %7, %8 : tensor<2x3xi32>
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%12 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%13 = xla_hlo.divide %11, %12 : tensor<2x3xi32>
%14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%12 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%13 = mhlo.divide %11, %12 : tensor<2x3xi32>
%14 = "mhlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
return %14 : tensor<2x3xi32>
}
func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = xla_hlo.divide %arg0, %arg0 : tensor<2xf32>
%1 = xla_hlo.divide %arg0, %arg0 : tensor<2xf32>
%2 = "xla_hlo.floor"(%1) : (tensor<2xf32>) -> tensor<2xf32>
%0 = mhlo.divide %arg0, %arg0 : tensor<2xf32>
%1 = mhlo.divide %arg0, %arg0 : tensor<2xf32>
%2 = "mhlo.floor"(%1) : (tensor<2xf32>) -> tensor<2xf32>
return %2 : tensor<2xf32>
}
func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> {
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
%1 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
%2 = "xla_hlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16>
%2 = "mhlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16>
return %2 : tensor<2x3xf16>
}
func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0 : tensor<2xi1>
}
@ -250,7 +250,7 @@ func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor
}
func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0 : tensor<2xi1>
}
@ -270,7 +270,7 @@ func @notequal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: ten
}
func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0 : tensor<2xi1>
}
@ -280,7 +280,7 @@ func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<
}
func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0 : tensor<2xi1>
}
@ -290,7 +290,7 @@ func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> t
}
func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0 : tensor<2xi1>
}
@ -300,7 +300,7 @@ func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2
}
func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0 : tensor<2xi1>
}
@ -310,426 +310,426 @@ func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tens
}
func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> {
%2 = "xla_hlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32>
%2 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32>
return %2 : tensor<6x3xf32>
}
func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> {
%2 = "xla_hlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32>
%2 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32>
return %2 : tensor<3x6xf32>
}
func @const() -> tensor<2xi32> {
%0 = xla_hlo.constant dense<0> : tensor<2xi32>
%0 = mhlo.constant dense<0> : tensor<2xi32>
return %0 : tensor<2xi32>
}
func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
%0 = xla_hlo.constant dense<0> : tensor<i32>
%0 = mhlo.constant dense<0> : tensor<i32>
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
return %1 : tensor<1xi32>
}
func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = xla_hlo.constant dense<0> : tensor<i32>
%0 = mhlo.constant dense<0> : tensor<i32>
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
return %1 : tensor<?xi32>
}
func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
%0 = xla_hlo.constant dense<0> : tensor<i32>
%1 = xla_hlo.constant dense<6> : tensor<i32>
%0 = mhlo.constant dense<0> : tensor<i32>
%1 = mhlo.constant dense<6> : tensor<i32>
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
return %3 : tensor<1xi32>
}
func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = xla_hlo.constant dense<0> : tensor<i32>
%1 = xla_hlo.constant dense<6> : tensor<i32>
%0 = mhlo.constant dense<0> : tensor<i32>
%1 = mhlo.constant dense<6> : tensor<i32>
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
return %3 : tensor<?xi32>
}
func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor<?x?xf32>) -> tensor<4x8xf32> {
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%1 = "xla_chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32>
%3 = "xla_hlo.select"(%1, %arg0, %2) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
%2 = mhlo.constant dense<0.000000e+00> : tensor<4x8xf32>
%3 = "mhlo.select"(%1, %arg0, %2) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
return %3 : tensor<4x8xf32>
}
func @select(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}
func @select_float(%arg0: tensor<2xi1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @select_multidimensional(%arg0: tensor<3x2xi1>, %arg1: tensor<3x2xi32>, %arg2: tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
return %0 : tensor<3x2xi32>
}
func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}
func @selectv2_pred_scalar(%arg0: tensor<i1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}
func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
%0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
%1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
%0 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
%1 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
%2 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
return %2 : tensor<3x2xf32>
}
func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
%0 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi32>
%1 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64>
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
%0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi32>
%1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64>
%2 = "mhlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
return %2 : tensor<3x2x1xf32>
}
func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
%0 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64>
%1 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64>
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
%0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64>
%1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64>
%2 = "mhlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
return %2 : tensor<3x2x1xf32>
}
func @transpose_dynamic_2d(%arg0: tensor<?x4xf32>) -> tensor<4x?xf32> {
%0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
%1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<?x4xf32>) -> tensor<4x?xf32>
%0 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
%1 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
%2 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<?x4xf32>) -> tensor<4x?xf32>
return %2 : tensor<4x?xf32>
}
func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
%1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32>
%0 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
%1 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
%2 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32>
return %2 : tensor<*xf32>
}
func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "mhlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "mhlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "mhlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "mhlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func @ceil_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "mhlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @complex_abs(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
%0 = "xla_hlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
%0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "mhlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.cosine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "mhlo.cosine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "mhlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "mhlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "mhlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "mhlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "mhlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "mhlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "mhlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> {
%0 = "xla_hlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1>
%0 = "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1>
return %0 : tensor<2xi1>
}
func @is_finite_dynamic(%arg0: tensor<?xf32>) -> tensor<?xi1> {
%0 = "xla_hlo.is_finite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1>
%0 = "mhlo.is_finite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1>
return %0 : tensor<?xi1>
}
func @is_finite_unranked(%arg0: tensor<*xf32>) -> tensor<*xi1> {
%0 = "xla_hlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1>
%0 = "mhlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1>
return %0 : tensor<*xi1>
}
func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "mhlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @log_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "mhlo.log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "mhlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "mhlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @log1p_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "mhlo.log_plus_one"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "mhlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.negate"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "mhlo.negate"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "mhlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
%1 = xla_hlo.constant dense<2> : tensor<1xi64>
%2 = xla_hlo.constant dense<5.000000e-01> : tensor<2xf32>
%3 = xla_hlo.multiply %arg0, %2 : tensor<2xf32>
%4 = "xla_hlo.tanh"(%3) : (tensor<2xf32>) -> tensor<2xf32>
%5 = xla_hlo.multiply %4, %2 : tensor<2xf32>
%6 = xla_hlo.add %5, %2 : tensor<2xf32>
%0 = mhlo.constant dense<5.000000e-01> : tensor<f32>
%1 = mhlo.constant dense<2> : tensor<1xi64>
%2 = mhlo.constant dense<5.000000e-01> : tensor<2xf32>
%3 = mhlo.multiply %arg0, %2 : tensor<2xf32>
%4 = "mhlo.tanh"(%3) : (tensor<2xf32>) -> tensor<2xf32>
%5 = mhlo.multiply %4, %2 : tensor<2xf32>
%6 = mhlo.add %5, %2 : tensor<2xf32>
return %6 : tensor<2xf32>
}
func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "mhlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @sin_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.sine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "mhlo.sine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "mhlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "mhlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func @rsqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "mhlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "mhlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @sqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "mhlo.sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func @sqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "mhlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "mhlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "mhlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "mhlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @bitcast_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
%0 = "mhlo.bitcast_convert"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
func @bitcast_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
%0 = "mhlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> {
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32>
%0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}
func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
%1 = xla_hlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
%3 = xla_hlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
%4 = "xla_hlo.sign"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
%5 = "xla_hlo.select"(%2, %3, %4) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
%6 = "xla_hlo.select"(%0, %1, %5) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
%1 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
%3 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
%4 = "mhlo.sign"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
%5 = "mhlo.select"(%2, %3, %4) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
%6 = "mhlo.select"(%0, %1, %5) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
return %6 : tensor<1x2x3x4xf32>
}
func @size_rank_one_i32(%arg0: tensor<f32>) -> tensor<i32> {
%0 = xla_hlo.constant dense<1> : tensor<i32>
%0 = mhlo.constant dense<1> : tensor<i32>
return %0 : tensor<i32>
}
func @size_rank_one_i64(%arg0: tensor<f32>) -> tensor<i64> {
%0 = xla_hlo.constant dense<1> : tensor<i64>
%0 = mhlo.constant dense<1> : tensor<i64>
return %0 : tensor<i64>
}
func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex<f32>> {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
return %0 : tensor<3xcomplex<f32>>
}
func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> {
%0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
%0 = "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
func @convert_slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> {
%0 = "xla_hlo.slice"(%arg0) {limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x4672xf32>) -> tensor<1x519xf32>
%0 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x4672xf32>) -> tensor<1x519xf32>
return %0 : tensor<1x519xf32>
}
func @reshape(%arg0: tensor<4x6xf32>) -> tensor<2x2x6xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<4x6xf32>) -> tensor<2x2x6xf32>
%0 = "mhlo.reshape"(%arg0) : (tensor<4x6xf32>) -> tensor<2x2x6xf32>
return %0 : tensor<2x2x6xf32>
}
func @convert_dot_1d_2d(%arg0: tensor<256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1xf32> {
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256x1xf32>) -> tensor<1xf32>
%0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256x1xf32>) -> tensor<1xf32>
return %0 : tensor<1xf32>
}
func @convert_dot_2d_1d(%arg0: tensor<1x256xf32>, %arg1: tensor<256xf32>) -> tensor<1xf32> {
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256xf32>) -> tensor<1xf32>
%0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256xf32>) -> tensor<1xf32>
return %0 : tensor<1xf32>
}
func @convert_dot_1d_1d(%arg0: tensor<256xf32>, %arg1: tensor<256xf32>) -> tensor<f32> {
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
%0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
return %0 : tensor<f32>
}
func @convert_dot_2d_2d(%arg0: tensor<1x256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1x1xf32> {
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
%0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
return %0 : tensor<1x1xf32>
}
func @broadcast_in_dim_tf_style(%arg0: tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32>
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32>
return %0 : tensor<3x8x8x16xf32>
}
func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32>
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32>
return %0 : tensor<3x8x8x16xf32>
}
func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> {
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32>
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32>
return %0 : tensor<3x5x1x4xf32>
}
func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
%0 = "xla_hlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
%0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
{input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
(tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
@ -737,7 +737,7 @@ func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>
}
func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
%0 = "xla_hlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
%0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
{input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
feature_group_count = 207 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
(tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
@ -745,7 +745,7 @@ func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x2
}
func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
%0 = "xla_hlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
%0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
{input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
(tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
@ -753,22 +753,22 @@ func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3
}
func @convert_reduce_to_sum(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%1 = "xla_hlo.reduce"(%arg0, %0) ( {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%1 = "mhlo.reduce"(%arg0, %0) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%2 = xla_hlo.add %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
%2 = mhlo.add %arg1, %arg2 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
return %1 : tensor<1xf32>
}
func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
// "0xFF800000" represents -INF for f32.
%0 = xla_hlo.constant dense<0xFF800000> : tensor<f32>
%1 = "xla_hlo.reduce"(%arg0, %0) ( {
%0 = mhlo.constant dense<0xFF800000> : tensor<f32>
%1 = "mhlo.reduce"(%arg0, %0) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%2 = xla_hlo.maximum %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
%2 = mhlo.maximum %arg1, %arg2 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
return %1 : tensor<1xf32>
}
@ -776,11 +776,11 @@ func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
func @convert_reduce_to_min(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
// "0x7F800000" represents INF for f32.
%0 = xla_hlo.constant dense<0x7F800000> : tensor<f32>
%1 = "xla_hlo.reduce"(%arg0, %0) ( {
%0 = mhlo.constant dense<0x7F800000> : tensor<f32>
%1 = "mhlo.reduce"(%arg0, %0) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%2 = xla_hlo.minimum %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
%2 = mhlo.minimum %arg1, %arg2 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
return %1 : tensor<1xf32>
}

View File

@ -17,7 +17,7 @@ func @single_arg_single_shape(%arg0: tensor<i1>) {
}
// CHECK-LABEL: func @func0
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [1 : i32], shape_indices = [2 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>)
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1> {mhlo.padding_map = {padding_arg_indices = [1 : i32], shape_indices = [2 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>)
func @func0(%arg0: tensor<i1>, %arg1: tensor<i1>) {
return
}
@ -44,7 +44,7 @@ func @single_arg_multiple_shapes(%arg0: tensor<i1>) {
}
// CHECK-LABEL: func @func1
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [1 : i32, 2 : i32], shape_indices = [2 : i32, 3 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>)
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1> {mhlo.padding_map = {padding_arg_indices = [1 : i32, 2 : i32], shape_indices = [2 : i32, 3 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>)
func @func1(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
return
}
@ -76,7 +76,7 @@ func @multiple_args(%arg0: tensor<i1>) {
}
// CHECK-LABEL: func @func2
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [1 : i32, 2 : i32], shape_indices = [2 : i32, 3 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [3 : i32], shape_indices = [1 : i32]}})
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1> {mhlo.padding_map = {padding_arg_indices = [1 : i32, 2 : i32], shape_indices = [2 : i32, 3 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {mhlo.padding_map = {padding_arg_indices = [3 : i32], shape_indices = [1 : i32]}})
func @func2(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>, %arg3: tensor<i1>, %arg4: tensor<i1>) {
return
}
@ -97,7 +97,7 @@ func @remap_indices(%arg0: tensor<i1>) {
}
// CHECK-LABEL: func @func3
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [0 : i32], shape_indices = [2 : i32]}})
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {mhlo.padding_map = {padding_arg_indices = [0 : i32], shape_indices = [2 : i32]}})
func @func3(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
return
}
@ -196,7 +196,7 @@ func @missing_padding_arg(%arg0: tensor<i1>) {
}
// CHECK-LABEL: func @func8
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [2 : i32], shape_indices = [2 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>)
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {mhlo.padding_map = {padding_arg_indices = [2 : i32], shape_indices = [2 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>)
func @func8(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
return
}
@ -218,7 +218,7 @@ func @missing_replicated_input_indices(%arg0: tensor<i1>) {
}
// CHECK-LABEL: func @func9
// CHECK-NOT: xla_hlo.padding_map
// CHECK-NOT: mhlo.padding_map
func @func9(%arg0: tensor<i1>, %arg1: tensor<i1>) {
return
}
@ -240,7 +240,7 @@ func @non_contigous_indices(%arg0: tensor<i1>) {
}
// CHECK-LABEL: func @func10
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [0 : i32], shape_indices = [6 : i32]}})
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {mhlo.padding_map = {padding_arg_indices = [0 : i32], shape_indices = [6 : i32]}})
func @func10(%arg0: tensor<i1>, %arg1: tensor<i1>) {
return
}

View File

@ -30,8 +30,8 @@ func @check_default_sharding_for_block_arg_inputs_outputs(%arg0: tensor<*xi32>)
}
// CHECK-LABEL: func @func_without_sharding
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"})
// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"})
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"})
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"})
func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> {
return %arg0 : tensor<*xi32>
}
@ -51,8 +51,8 @@ func @check_default_sharding_for_inputs_outputs(%arg0: tensor<*xi32>) {
}
// CHECK-LABEL: func @func_without_sharding
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"})
// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"})
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"})
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"})
func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> {
%0 = "tf.A"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
return %0 : tensor<*xi32>
@ -72,8 +72,8 @@ func @check_sharding_for_input_correctly_identified(%arg0: tensor<*xi32>) {
}
// CHECK-LABEL: func @inputs_with_sharding_func
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"})
// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"})
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"})
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"})
func @inputs_with_sharding_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.A"(%0) : (tensor<*xi32>) -> (tensor<*xi32>)
@ -94,8 +94,8 @@ return
}
// CHECK-LABEL: func @func_with_sharding
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {xla_hlo.sharding = "\04\05\06"})
// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"})
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {mhlo.sharding = "\04\05\06"})
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"})
func @func_with_sharding(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
@ -119,8 +119,8 @@ func @check_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
}
// CHECK-LABEL: func @func_with_sharding_after_identity
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {xla_hlo.sharding = "\04\05\06"})
// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"})
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {mhlo.sharding = "\04\05\06"})
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"})
func @func_with_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
%0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
@ -145,8 +145,8 @@ func @check_sharding_after_read_variable(%arg0: tensor<*xi32>, %arg1: tensor<*xi
}
// CHECK-LABEL: func @func_with_sharding_after_read_variable
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*x!tf.resource<tensor<32xf32>>> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*x!tf.resource<tensor<32xf32>>> {xla_hlo.sharding = "\04\05\06"})
// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"})
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*x!tf.resource<tensor<32xf32>>> {mhlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*x!tf.resource<tensor<32xf32>>> {mhlo.sharding = "\04\05\06"})
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"})
func @func_with_sharding_after_read_variable(%arg0: tensor<*x!tf.resource<tensor<32xf32>>>, %arg1: tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<*xi32>, tensor<*xi1>) {
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<32xf32>) -> tensor<32xf32>
@ -173,8 +173,8 @@ func @check_sharding_after_cast_op(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
}
// CHECK-LABEL: func @func_with_sharding_after_cast
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {xla_hlo.sharding = "\04\05\06"})
// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"})
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {mhlo.sharding = "\04\05\06"})
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"})
func @func_with_sharding_after_cast(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
%0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.Cast"(%0) : (tensor<*xi32>) -> tensor<*xi1>
@ -200,8 +200,8 @@ func @check_sharding_inside_functional_op(%arg0: tensor<*xi32>, %arg1: tensor<*x
}
// CHECK-LABEL: func @func_with_device_training_loop
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {xla_hlo.sharding = "\04\05\06"})
// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"})
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {mhlo.sharding = "\04\05\06"})
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"})
func @func_with_device_training_loop(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
%1:2 = "tf.StatefulPartitionedCall"(%arg0){f= @func_body, config="", config_proto="", executor_type=""}
: (tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)

View File

@ -45,8 +45,8 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0"
return %10 : tensor<i1>
}
// CHECK-LABEL: func @_func
// CHECK-SAME: [[FUNCINPUT0:.*]]: tensor<2x112x112x12xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT1:%.*]]: tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT2:%.*]]: tensor<f32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[VAL_59:%.*]]: tensor<i64> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
func @_func(%arg0: tensor<2x224x224x3xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<f32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<i64> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
// CHECK-SAME: [[FUNCINPUT0:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT1:%.*]]: tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT2:%.*]]: tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[VAL_59:%.*]]: tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
func @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
%0 = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
%1 = "tf.Const"() {value = dense<0> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
%2 = "tf.Const"() {value = dense<[7, 7, 3, 64]> : tensor<4xi32>} : () -> tensor<4xi32>

View File

@ -33,7 +33,7 @@ namespace TFDevice {
namespace {
constexpr char kReplicationAttr[] = "xla_hlo.is_same_data_across_replicas";
constexpr char kReplicationAttr[] = "mhlo.is_same_data_across_replicas";
constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
// Analyzes the inputs to ClusterFuncOps in the module, and annotates their

View File

@ -47,14 +47,14 @@ namespace mlir {
namespace TF {
namespace {
using xla_hlo::DotDimensionNumbers;
using mhlo::DotDimensionNumbers;
class ConvertConvOp : public OpConversionPattern<xla_hlo::ConvOp> {
class ConvertConvOp : public OpConversionPattern<mhlo::ConvOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_hlo::ConvOp conv_op, ArrayRef<Value> args,
mhlo::ConvOp conv_op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
if (!IsSupportedConvOp(conv_op)) {
return failure();
@ -120,7 +120,7 @@ class ConvertConvOp : public OpConversionPattern<xla_hlo::ConvOp> {
};
private:
bool IsSamePadding(xla_hlo::ConvOp conv_op, int num_spatial_dims,
bool IsSamePadding(mhlo::ConvOp conv_op, int num_spatial_dims,
ArrayRef<int64_t> strides, ArrayRef<int64_t> dilation,
ArrayRef<int64_t> padding_array) const {
for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
@ -142,7 +142,7 @@ class ConvertConvOp : public OpConversionPattern<xla_hlo::ConvOp> {
return true;
}
void CreateConvOp(xla_hlo::ConvOp conv_op, ArrayRef<int64_t> strides,
void CreateConvOp(mhlo::ConvOp conv_op, ArrayRef<int64_t> strides,
StringRef padding, ArrayRef<int64_t> dilation,
bool is_depthwise_conv,
ConversionPatternRewriter &rewriter) const {
@ -167,13 +167,13 @@ class ConvertConvOp : public OpConversionPattern<xla_hlo::ConvOp> {
}
}
bool IsSupportedConvOp(xla_hlo::ConvOp conv_op) const {
bool IsSupportedConvOp(mhlo::ConvOp conv_op) const {
if (!conv_op.lhs().getType().cast<ShapedType>().hasStaticShape() ||
!conv_op.rhs().getType().cast<ShapedType>().hasStaticShape() ||
!conv_op.getType().cast<ShapedType>().hasStaticShape())
return false;
// All ones in "lhs_dilation" means this "xla_hlo.conv" op should be
// All ones in "lhs_dilation" means this "mhlo.conv" op should be
// converted to "tf.Conv2D" or "tf.DepthwiseConv2dNativeOp".
if (conv_op.lhs_dilation().hasValue()) {
auto lhs_dilation = conv_op.lhs_dilation().getValue();
@ -236,15 +236,15 @@ class ConvertConvOp : public OpConversionPattern<xla_hlo::ConvOp> {
}
};
class ConvertSliceOp : public OpConversionPattern<xla_hlo::SliceOp> {
class ConvertSliceOp : public OpConversionPattern<mhlo::SliceOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_hlo::SliceOp slice_op, ArrayRef<Value> args,
mhlo::SliceOp slice_op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
DenseIntElementsAttr strides = slice_op.strides();
// Strides must be 1 otherwise we cannot legalize this `xla_hlo.slice` op.
// Strides must be 1 otherwise we cannot legalize this `mhlo.slice` op.
if (!strides.isSplat() ||
strides.getSplatValue().cast<IntegerAttr>().getInt() != 1)
return failure();
@ -374,10 +374,10 @@ class DotDimensionsInfo {
DimensionSetVector out_dimensions_;
};
// Converts xla_hlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be
// Converts mhlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be
// inserted to convert to well-formed matrix multiply.
Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
auto dot_general_op = cast<xla_hlo::DotGeneralOp>(old_op);
auto dot_general_op = cast<mhlo::DotGeneralOp>(old_op);
auto lhs_type = dot_general_op.lhs().getType().cast<ShapedType>();
auto rhs_type = dot_general_op.rhs().getType().cast<ShapedType>();
auto result_type = dot_general_op.getResult().getType().cast<ShapedType>();
@ -405,7 +405,7 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
lhs_dot_dimensions_info.out_dimensions().SizesArray(),
lhs_dot_dimensions_info.contracting_dimensions().SizesArray());
auto lhs_transposed = rewriter.create<xla_hlo::TransposeOp>(
auto lhs_transposed = rewriter.create<mhlo::TransposeOp>(
loc,
RankedTensorType::get(lhs_transposed_shape, lhs_type.getElementType()),
dot_general_op.lhs(),
@ -423,7 +423,7 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
rhs_dot_dimensions_info.contracting_dimensions().SizesArray(),
rhs_dot_dimensions_info.out_dimensions().SizesArray());
auto rhs_transposed = rewriter.create<xla_hlo::TransposeOp>(
auto rhs_transposed = rewriter.create<mhlo::TransposeOp>(
loc,
RankedTensorType::get(rhs_transposed_shape, rhs_type.getElementType()),
dot_general_op.rhs(),
@ -438,7 +438,7 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
llvm::ArrayRef<int64_t>{
lhs_dot_dimensions_info.FlattenedContractingDimensionSize()});
auto lhs_flattend = rewriter.create<xla_hlo::ReshapeOp>(
auto lhs_flattend = rewriter.create<mhlo::ReshapeOp>(
loc,
RankedTensorType::get(lhs_flattened_shape, lhs_type.getElementType()),
lhs_transposed.getResult());
@ -450,7 +450,7 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
rhs_dot_dimensions_info.FlattenedContractingDimensionSize()},
llvm::ArrayRef<int64_t>{
rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
auto rhs_flattend = rewriter.create<xla_hlo::ReshapeOp>(
auto rhs_flattend = rewriter.create<mhlo::ReshapeOp>(
loc,
RankedTensorType::get(rhs_flattened_shape, rhs_type.getElementType()),
rhs_transposed.getResult());
@ -466,14 +466,14 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
loc, RankedTensorType::get(matmul_shape, result_type.getElementType()),
lhs_flattend.getResult(), rhs_flattend.getResult());
auto reshaped =
rewriter.create<xla_hlo::ReshapeOp>(loc, result_type, matmul.getResult());
rewriter.create<mhlo::ReshapeOp>(loc, result_type, matmul.getResult());
return reshaped.getResult();
}
// This function tries to match that the "xla_hlo::ReduceOp" only has one
// input, one init_value and one result. Also "xla_hlo::ReduceOp" has two ops
// This function tries to match that the "mhlo::ReduceOp" only has one
// input, one init_value and one result. Also "mhlo::ReduceOp" has two ops
// in the region, and the last one is return op.
LogicalResult MatchReduceOpInput(xla_hlo::ReduceOp reduce_op) {
LogicalResult MatchReduceOpInput(mhlo::ReduceOp reduce_op) {
if (reduce_op.operands().size() != 1 || reduce_op.init_values().size() != 1 ||
reduce_op.getResults().size() != 1)
return failure();
@ -489,23 +489,23 @@ LogicalResult MatchReduceOpInput(xla_hlo::ReduceOp reduce_op) {
return success();
}
// TODO(jingpu): This "xla_hlo::ReduceOp" can corresponds to many TF ops
// TODO(jingpu): This "mhlo::ReduceOp" can corresponds to many TF ops
// with different ops in reduce_op.body. Now we only match to "tf.Max", "tf.Min"
// and "tf.Sum".
class ConvertReduceOpToTfSum : public OpConversionPattern<xla_hlo::ReduceOp> {
class ConvertReduceOpToTfSum : public OpConversionPattern<mhlo::ReduceOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_hlo::ReduceOp reduce_op, ArrayRef<Value> args,
mhlo::ReduceOp reduce_op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
if (failed(MatchReduceOpInput(reduce_op))) return failure();
Operation *first_op = &reduce_op.body().front().front();
if (!llvm::isa<xla_hlo::AddOp>(first_op)) return failure();
if (!llvm::isa<mhlo::AddOp>(first_op)) return failure();
// In `MatchReduceOpInput` function, we already match that the
// "xla_hlo::ReduceOp" only has one input, one init_value and one result.
// "mhlo::ReduceOp" only has one input, one init_value and one result.
auto input = reduce_op.operands()[0];
// Get reduction dimension.
DenseIntElementsAttr dimension = reduce_op.dimensions();
@ -531,20 +531,20 @@ class ConvertReduceOpToTfSum : public OpConversionPattern<xla_hlo::ReduceOp> {
};
};
class ConvertReduceOpToTfMax : public OpConversionPattern<xla_hlo::ReduceOp> {
class ConvertReduceOpToTfMax : public OpConversionPattern<mhlo::ReduceOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_hlo::ReduceOp reduce_op, ArrayRef<Value> args,
mhlo::ReduceOp reduce_op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
if (failed(MatchReduceOpInput(reduce_op))) return failure();
Operation *first_op = &reduce_op.body().front().front();
if (!llvm::isa<xla_hlo::MaxOp>(first_op)) return failure();
if (!llvm::isa<mhlo::MaxOp>(first_op)) return failure();
// In `MatchReduceOpInput` function, we already match that the
// "xla_hlo::ReduceOp" only has one input, one init_value and one result.
// "mhlo::ReduceOp" only has one input, one init_value and one result.
auto input = reduce_op.operands()[0];
// Get reduction dimension.
DenseIntElementsAttr dimension = reduce_op.dimensions();
@ -572,20 +572,20 @@ class ConvertReduceOpToTfMax : public OpConversionPattern<xla_hlo::ReduceOp> {
};
};
class ConvertReduceOpToTfMin : public OpConversionPattern<xla_hlo::ReduceOp> {
class ConvertReduceOpToTfMin : public OpConversionPattern<mhlo::ReduceOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_hlo::ReduceOp reduce_op, ArrayRef<Value> args,
mhlo::ReduceOp reduce_op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
if (failed(MatchReduceOpInput(reduce_op))) return failure();
Operation *first_op = &reduce_op.body().front().front();
if (!llvm::isa<xla_hlo::MinOp>(first_op)) return failure();
if (!llvm::isa<mhlo::MinOp>(first_op)) return failure();
// In `MatchReduceOpInput` function, we already match that the
// "xla_hlo::ReduceOp" only has one input, one init_value and one result.
// "mhlo::ReduceOp" only has one input, one init_value and one result.
Value input = reduce_op.operands()[0];
// Get reduction dimension.
DenseIntElementsAttr dimension = reduce_op.dimensions();
@ -645,10 +645,10 @@ ConstantOp ShapeToConst(PatternRewriter &rewriter, Value value) {
return rewriter.create<ConstantOp>(value.getLoc(), attr_type, attr);
}
// Converts xla_hlo.dot to tf.MatMul. Reshape ops will be inserted when
// Converts mhlo.dot to tf.MatMul. Reshape ops will be inserted when
// necessary.
Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) {
auto dot_op = cast<xla_hlo::DotOp>(old_op);
auto dot_op = cast<mhlo::DotOp>(old_op);
const mlir::Location loc = dot_op.getLoc();
// Normalizes a ShapedType to 2d if the ShapedType is less than 2d by
// inserting dummy 1-element dimensions in the begining. Does nothing if the
@ -677,7 +677,7 @@ Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) {
return input;
}
auto reshape = rewriter.create<xla_hlo::ReshapeOp>(
auto reshape = rewriter.create<mhlo::ReshapeOp>(
loc, normalize_rank(input_type), input);
return reshape.getResult();
};
@ -694,7 +694,7 @@ Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) {
loc, normalize_rank(output_type), a, b,
/*transpose_a=*/rewriter.getBoolAttr(false), transpose_b);
auto reshape =
rewriter.create<xla_hlo::ReshapeOp>(loc, output_type, matmul.product());
rewriter.create<mhlo::ReshapeOp>(loc, output_type, matmul.product());
return reshape.getResult();
}
@ -752,7 +752,7 @@ void LegalizeHloToTf::runOnFunction() {
target.addLegalDialect<TensorFlowDialect>();
target.addLegalOp<CallOp, ConstantOp>();
if (failed(applyPartialConversion(getFunction(), target, patterns))) {
getFunction().emitError("xla_hlo to TF legalization failed.");
getFunction().emitError("mhlo to TF legalization failed.");
signalPassFailure();
}
}

View File

@ -151,7 +151,7 @@ LogicalResult GetRemappedPaddings(
// Inserts padding maps for relevant arguments as argument attributes on the
// encapsulated function. The padding maps will be in the form of:
// %arg0 : type {xla_hlo.padding_map = {shape_indices = [...],
// %arg0 : type {mhlo.padding_map = {shape_indices = [...],
// padding_arg_indices = [...]}}
void AnnotateFunctionArgumentsWithPaddings(
FuncOp func,
@ -174,7 +174,7 @@ void AnnotateFunctionArgumentsWithPaddings(
"padding_arg_indices",
builder.getI32ArrayAttr(padding.getSecond().second));
func.setArgAttr(
padding.getFirst(), "xla_hlo.padding_map",
padding.getFirst(), "mhlo.padding_map",
builder.getDictionaryAttr({shape_indices, padding_arg_indices}));
}
}

View File

@ -39,7 +39,7 @@ namespace mlir {
namespace TFTPU {
namespace {
constexpr char kShardingAttr[] = "xla_hlo.sharding";
constexpr char kShardingAttr[] = "mhlo.sharding";
struct TPUShardingIdentificationPass
: public PassWrapper<TPUShardingIdentificationPass,

View File

@ -108,7 +108,7 @@ Status GetXlaInputShapes(
// Rewrite layout with sharding, if sharding is set.
auto sharding =
main_func.getArgAttrOfType<mlir::StringAttr>(i, "xla_hlo.sharding");
main_func.getArgAttrOfType<mlir::StringAttr>(i, "mhlo.sharding");
if (!sharding) continue;
absl::optional<xla::HloSharding> arg_sharding;
@ -253,7 +253,7 @@ static void RegisterDialects() {
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
mlir::registerDialect<mlir::shape::ShapeDialect>();
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
mlir::registerDialect<mlir::mhlo::XlaHloDialect>();
return true;
}();
(void)init_once;
@ -279,9 +279,9 @@ Status ConvertMLIRToXlaComputation(
// LegalizeTFControlFlow encapsulates arguments for control flow operations
// with a tuple argument which break the assumption of resource lifting
// inside PromoteResourcesToArgs.
tf2xla.addPass(mlir::xla_hlo::createLegalizeTFControlFlowPass());
tf2xla.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
tf2xla.addNestedPass<mlir::FuncOp>(mlir::xla_hlo::createLegalizeTFPass(true));
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(true));
for (auto& target_pass : custom_legalization_passes) {
tf2xla.addNestedPass<mlir::FuncOp>(std::move(target_pass));
}
@ -290,7 +290,7 @@ Status ConvertMLIRToXlaComputation(
// Leverage tf2xla kernels for ops that didn't get lowered in the previous
// legalization pass.
tf2xla.addPass(mlir::xla_hlo::createLegalizeTfWithTf2XlaPass(device_type));
tf2xla.addPass(mlir::mhlo::createLegalizeTfWithTf2XlaPass(device_type));
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
// Run shape inference pass to propagate shapes through tensor_cast operations
@ -303,12 +303,11 @@ Status ConvertMLIRToXlaComputation(
// expose more graph pruning and canonicalization opportunities that are
// necessary for the second LegalizeTFPass(allow_partial_conversion=false)
// invocation.
tf2xla.addNestedPass<mlir::FuncOp>(
mlir::xla_hlo::createLegalizeTFPass(false));
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(false));
// In order to export to XLA, we must sink constants to control flow regions,
// since XLA uses functional control flow.
tf2xla.addNestedPass<mlir::FuncOp>(
mlir::xla_hlo::createSinkConstantsToControlFlowPass());
mlir::mhlo::createSinkConstantsToControlFlowPass());
if (VLOG_IS_ON(1)) {
// Print the whole module after each pass which requires disabling

View File

@ -184,7 +184,7 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
// only be lowered when tf.Shape is folded into a constant.
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {xla_hlo.is_same_data_across_replicas}) -> tensor<10x19xf32> {
func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {mhlo.is_same_data_across_replicas}) -> tensor<10x19xf32> {
%0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64>
%1 = "tf.Reshape"(%arg1, %0) : (tensor<19x10xf32>, tensor<2xi64>) -> tensor<10x19xf32>
return %1 : tensor<10x19xf32>
@ -344,7 +344,7 @@ ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) {
TEST(CompileSerializedMlirToXlaHloTest, ArgumentSharding) {
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<128x10xf32> {xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {xla_hlo.sharding = ""}) {
func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {mhlo.sharding = ""}) {
return
}
}
@ -383,7 +383,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> () {
TEST(CompileSerializedMlirToXlaHloTest, BadArgumentSharding) {
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<128x10xf32> {xla_hlo.sharding = "bad_sharding"}) {
func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "bad_sharding"}) {
return
}
}
@ -403,7 +403,7 @@ module attributes {tf.versions = {producer = 179 : i32}} {
TEST(CompileSerializedMlirToXlaHloTest, ResultSharding) {
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 351 : i32}} {
func @main(%arg0: tensor<128x10xf32>, %arg1: tensor<10x1024xf32>, %arg2: tensor<128x1024xf32>) -> (tensor<128x10xf32> {xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, tensor<10x1024xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<128x1024xf32> {xla_hlo.sharding = ""}) {
func @main(%arg0: tensor<128x10xf32>, %arg1: tensor<10x1024xf32>, %arg2: tensor<128x1024xf32>) -> (tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<128x1024xf32> {mhlo.sharding = ""}) {
return %arg0, %arg1, %arg2 : tensor<128x10xf32>, tensor<10x1024xf32>, tensor<128x1024xf32>
}
}

View File

@ -87,14 +87,14 @@ struct MaterializeBroadcastsPass
mlir::ConversionTarget conversionTarget(getContext());
mlir::OwningRewritePatternList conversionPatterns;
// Consider the xla_hlo dialect legal for tests.
conversionTarget.addLegalDialect<mlir::xla_hlo::XlaHloDialect>();
// Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<mlir::mhlo::XlaHloDialect>();
// The conversion uses helpers from the Standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
mlir::xla_hlo::SetupMaterializeBroadcastsLegality(&getContext(),
mlir::mhlo::SetupMaterializeBroadcastsLegality(&getContext(),
&conversionTarget);
mlir::xla_hlo::PopulateMaterializeBroadcastsPatterns(&getContext(),
mlir::mhlo::PopulateMaterializeBroadcastsPatterns(&getContext(),
&conversionPatterns);
if (failed(applyPartialConversion(getFunction(), conversionTarget,
@ -108,7 +108,7 @@ struct UnfuseBatchNormPass
: public mlir::PassWrapper<UnfuseBatchNormPass, mlir::FunctionPass> {
void runOnFunction() override {
mlir::OwningRewritePatternList patterns;
mlir::xla_hlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
mlir::mhlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
mlir::applyPatternsAndFoldGreedily(getOperation(), patterns);
}
};
@ -122,11 +122,11 @@ Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) {
/*shouldPrintAfterPass=*/enable_if_vlog_is_on,
/*printModuleScope=*/false,
/*printAfterOnlyOnChange=*/false, llvm::dbgs());
pm.addNestedPass<mlir::FuncOp>(mlir::xla_hlo::createLegalizeTFPass(false));
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(false));
pm.addNestedPass<mlir::FuncOp>(
absl::make_unique<MaterializeBroadcastsPass>());
pm.addNestedPass<mlir::FuncOp>(absl::make_unique<UnfuseBatchNormPass>());
pm.addPass(mlir::xla_hlo::createLegalizeToLhloPass(
pm.addPass(mlir::mhlo::createLegalizeToLhloPass(
/*results_escape_functions=*/true));
pm.addNestedPass<mlir::FuncOp>(mlir::xla_lhlo::createLhloCopyRemovalPass());

View File

@ -115,9 +115,9 @@ cc_library(
)
cc_library(
name = "xla_hlo_to_lhlo_with_xla",
srcs = ["transforms/xla_hlo_to_lhlo_with_xla.cc"],
hdrs = ["transforms/xla_hlo_to_lhlo_with_xla.h"],
name = "mhlo_to_lhlo_with_xla",
srcs = ["transforms/mhlo_to_lhlo_with_xla.cc"],
hdrs = ["transforms/mhlo_to_lhlo_with_xla.h"],
deps = [
":hlo_module_importer",
":hlo_utils",
@ -363,7 +363,7 @@ cc_library(
"//tensorflow/compiler/mlir:__subpackages__",
],
deps = [
":xla_hlo_to_lhlo_with_xla",
":mhlo_to_lhlo_with_xla",
":xla_legalize_tf",
":xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/mlir/hlo",
@ -376,7 +376,7 @@ cc_library(
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine",
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu",
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_parallel_loops",
"//tensorflow/compiler/mlir/hlo:xla_hlo_fusion",
"//tensorflow/compiler/mlir/hlo:mhlo_fusion",
"//tensorflow/compiler/mlir/hlo:xla_legalize_control_flow",
"//tensorflow/compiler/mlir/hlo:xla_legalize_tanh_to_approximation",
"//tensorflow/compiler/mlir/hlo:xla_legalize_to_linalg",

View File

@ -42,7 +42,7 @@ mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config,
}
// Converts the gather dimensions to attributes.
mlir::xla_hlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
mlir::mhlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder) {
std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
dnums.offset_dims().end());
@ -50,14 +50,14 @@ mlir::xla_hlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
dnums.start_index_map().end());
return mlir::xla_hlo::GatherDimensionNumbers::get(
return mlir::mhlo::GatherDimensionNumbers::get(
Convert(offset_dims, builder), Convert(collapsed_slice_dims, builder),
Convert(start_index_map, builder),
builder->getI64IntegerAttr(dnums.index_vector_dim()),
builder->getContext());
}
mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
mlir::mhlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder) {
std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
dnums.update_window_dims().end());
@ -66,7 +66,7 @@ mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
std::vector<int64_t> scatter_dims_to_operand_dims(
dnums.scatter_dims_to_operand_dims().begin(),
dnums.scatter_dims_to_operand_dims().end());
return mlir::xla_hlo::ScatterDimensionNumbers::get(
return mlir::mhlo::ScatterDimensionNumbers::get(
Convert(update_window_dims, builder),
Convert(inserted_window_dims, builder),
Convert(scatter_dims_to_operand_dims, builder),
@ -74,7 +74,7 @@ mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
builder->getContext());
}
mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers(
mlir::mhlo::DotDimensionNumbers ConvertDotDimensionNumbers(
const DotDimensionNumbers& dnums, mlir::Builder* builder) {
std::vector<int64_t> rhs_contracting_dimensions(
dnums.rhs_contracting_dimensions().begin(),
@ -93,12 +93,12 @@ mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers(
auto lhs_contracting_dims_attr = Convert(lhs_contracting_dimensions, builder);
auto rhs_contracting_dims_attr = Convert(rhs_contracting_dimensions, builder);
return mlir::xla_hlo::DotDimensionNumbers::get(
return mlir::mhlo::DotDimensionNumbers::get(
lhs_batch_dims_attr, rhs_batch_dims_attr, lhs_contracting_dims_attr,
rhs_contracting_dims_attr, builder->getContext());
}
mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
mlir::mhlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder) {
llvm::SmallVector<int64_t, 4> input_spatial_dims(
dnums.input_spatial_dimensions().begin(),
@ -109,7 +109,7 @@ mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
llvm::SmallVector<int64_t, 4> output_spatial_dims(
dnums.output_spatial_dimensions().begin(),
dnums.output_spatial_dimensions().end());
return mlir::xla_hlo::ConvDimensionNumbers::get(
return mlir::mhlo::ConvDimensionNumbers::get(
builder->getI64IntegerAttr(dnums.input_batch_dimension()),
builder->getI64IntegerAttr(dnums.input_feature_dimension()),
Convert(input_spatial_dims, builder),

View File

@ -29,19 +29,19 @@ mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config,
mlir::Builder* builder);
// Converts the gather dimensions to attributes.
mlir::xla_hlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
mlir::mhlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder);
// Converts the scatter dimensions to attributes.
mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
mlir::mhlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder);
// Converts the dot dimensions to attributes.
mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers(
mlir::mhlo::DotDimensionNumbers ConvertDotDimensionNumbers(
const DotDimensionNumbers& dnums, mlir::Builder* builder);
// Converts the conv dimensions to attributes.
mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
mlir::mhlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder);
} // namespace xla

View File

@ -171,7 +171,7 @@ tensorflow::Status HloFunctionImporter::ImportInstructions(
if (llvm::isa<FuncOp>(block->getParentOp())) {
builder.create<mlir::ReturnOp>(loc, result);
} else {
builder.create<mlir::xla_hlo::ReturnOp>(loc, result);
builder.create<mlir::mhlo::ReturnOp>(loc, result);
}
return tensorflow::Status::OK();
}
@ -202,7 +202,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
}
case HloOpcode::kIota: {
return func_builder
->create<mlir::xla_hlo::IotaOp>(
->create<mlir::mhlo::IotaOp>(
loc, result_type,
func_builder->getI64IntegerAttr(
Cast<HloIotaInstruction>(instruction)->iota_dimension()))
@ -211,8 +211,8 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
#define MakeAndReturn(mlir_op) \
{ \
mlir::Operation* new_operation = \
func_builder->create<mlir::xla_hlo::mlir_op>(loc, result_type, \
operands, attributes); \
func_builder->create<mlir::mhlo::mlir_op>(loc, result_type, operands, \
attributes); \
return new_operation; \
}
case HloOpcode::kBroadcast: {
@ -314,14 +314,14 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
instruction->dynamic_slice_sizes().begin(),
instruction->dynamic_slice_sizes().end());
return func_builder
->create<mlir::xla_hlo::DynamicSliceOp>(
->create<mlir::mhlo::DynamicSliceOp>(
loc, result_type, operands[0],
makeArrayRef(operands).drop_front(), Convert(slice_sizes))
.getOperation();
}
case HloOpcode::kDynamicUpdateSlice: {
return func_builder
->create<mlir::xla_hlo::DynamicUpdateSliceOp>(
->create<mlir::mhlo::DynamicUpdateSliceOp>(
loc, result_type, operands[0], operands[1],
llvm::ArrayRef<Value>(operands.begin() + 2, operands.end()))
.getOperation();
@ -354,7 +354,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
}
return func_builder
->create<mlir::xla_hlo::PadOp>(loc, result_type, operands[0],
->create<mlir::mhlo::PadOp>(loc, result_type, operands[0],
operands[1], Convert(edge_padding_low),
Convert(edge_padding_high),
Convert(interior_padding))
@ -372,7 +372,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
attributes.push_back(builder_->getNamedAttr(
"unique_indices", builder_->getBoolAttr(scatter->unique_indices())));
auto scatter_op = func_builder->create<mlir::xla_hlo::ScatterOp>(
auto scatter_op = func_builder->create<mlir::mhlo::ScatterOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(),
&scatter_op.update_computation()));
@ -394,7 +394,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
Convert(window_dimensions)));
attributes.push_back(ConvertPadding(padding));
auto select_scatter_op =
func_builder->create<mlir::xla_hlo::SelectAndScatterOp>(
func_builder->create<mlir::mhlo::SelectAndScatterOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(),
&select_scatter_op.select()));
@ -410,7 +410,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
}
case HloOpcode::kSlice: {
return func_builder
->create<mlir::xla_hlo::SliceOp>(
->create<mlir::mhlo::SliceOp>(
loc, result_type, operands[0],
ConvertDimensions(instruction->slice_starts()),
ConvertDimensions(instruction->slice_limits()),
@ -419,7 +419,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
}
case HloOpcode::kSort: {
auto sort_instruction = Cast<HloSortInstruction>(instruction);
auto sort_op = func_builder->create<mlir::xla_hlo::SortOp>(
auto sort_op = func_builder->create<mlir::mhlo::SortOp>(
loc, result_type, operands,
builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
builder_->getBoolAttr(sort_instruction->is_stable()));
@ -437,7 +437,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
TF_RETURN_IF_ERROR(GetMlirTypes(
{instruction->true_computation()->root_instruction()}, &rets));
auto op = func_builder->create<mlir::xla_hlo::IfOp>(loc, rets, operands,
auto op = func_builder->create<mlir::mhlo::IfOp>(loc, rets, operands,
attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(),
&op.true_branch()));
@ -451,7 +451,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
{instruction->branch_computation(0)->root_instruction()}, &rets));
int num_branches = instruction->branch_count();
auto op = func_builder->create<mlir::xla_hlo::CaseOp>(
auto op = func_builder->create<mlir::mhlo::CaseOp>(
loc, rets, operands, attributes, num_branches);
for (auto index_and_computation :
llvm::enumerate(instruction->branch_computations())) {
@ -465,7 +465,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
// TODO(b/132057942): Support taking an uint64_t instead of an IntegerAttr
// for concatenate dimension.
return func_builder
->create<mlir::xla_hlo::ConcatenateOp>(
->create<mlir::mhlo::ConcatenateOp>(
loc, result_type, operands,
builder_->getI64IntegerAttr(instruction->concatenate_dimension()))
.getOperation();
@ -474,7 +474,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
attributes.push_back(ConvertReplicaGroups(all_reduce->replica_groups()));
attributes.push_back(ConvertChannelHandle(all_reduce->channel_id()));
auto all_reduce_op = func_builder->create<mlir::xla_hlo::AllReduceOp>(
auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
&all_reduce_op.computation()));
@ -484,7 +484,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
// Operands in the first half are reduction inputs and the remaining
// operands are corresponding initial values.
size_t num_inputs = operands.size() / 2;
auto reduce = func_builder->create<mlir::xla_hlo::ReduceOp>(
auto reduce = func_builder->create<mlir::mhlo::ReduceOp>(
loc, result_type, llvm::makeArrayRef(operands).take_front(num_inputs),
llvm::makeArrayRef(operands).drop_front(num_inputs),
ConvertDimensions(instruction->dimensions()));
@ -494,7 +494,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
}
case HloOpcode::kReverse: {
return func_builder
->create<mlir::xla_hlo::ReverseOp>(
->create<mlir::mhlo::ReverseOp>(
loc, result_type, operands[0],
ConvertDimensions(instruction->dimensions()))
.getOperation();
@ -505,14 +505,14 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
switch (instruction->random_distribution()) {
case xla::RNG_UNIFORM:
return func_builder
->create<mlir::xla_hlo::RngUniformOp>(
loc, result_type, operands[0], operands[1], shape)
->create<mlir::mhlo::RngUniformOp>(loc, result_type, operands[0],
operands[1], shape)
.getOperation();
case xla::RNG_NORMAL:
return func_builder
->create<mlir::xla_hlo::RngNormalOp>(
loc, result_type, operands[0], operands[1], shape)
->create<mlir::mhlo::RngNormalOp>(loc, result_type, operands[0],
operands[1], shape)
.getOperation();
default:
@ -522,7 +522,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
}
}
case HloOpcode::kWhile: {
auto op = func_builder->create<mlir::xla_hlo::WhileOp>(
auto op = func_builder->create<mlir::mhlo::WhileOp>(
loc, operands[0].getType(), operands[0]);
TF_RETURN_IF_ERROR(
ImportAsRegion(*instruction->while_condition(), &op.cond()));
@ -585,14 +585,14 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
attributes.push_back(builder_->getNamedAttr(
"window_dilations", ConvertDimensions(win_dilations)));
attributes.push_back(ConvertPadding(padding));
auto reduce = func_builder->create<mlir::xla_hlo::ReduceWindowOp>(
auto reduce = func_builder->create<mlir::mhlo::ReduceWindowOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(
ImportAsRegion(*instruction->to_apply(), &reduce.body()));
return reduce.getOperation();
}
case HloOpcode::kMap: {
auto op = func_builder->create<mlir::xla_hlo::MapOp>(
auto op = func_builder->create<mlir::mhlo::MapOp>(
loc, result_type, operands,
ConvertDimensions(instruction->dimensions()));
TF_RETURN_IF_ERROR(
@ -714,7 +714,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
// is not mentioned in xla client anywhere or in the hlo of our sample
// models.
default: {
mlir::OperationState result(loc, "xla_hlo.unknown");
mlir::OperationState result(loc, "mhlo.unknown");
result.addOperands(operands);
result.addTypes(result_type);
for (auto attr : attributes) {
@ -840,7 +840,7 @@ mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
const xla::ChannelHandle& channel) {
return builder_->getNamedAttr(
"channel_handle",
mlir::xla_hlo::ChannelHandle::get(
mlir::mhlo::ChannelHandle::get(
builder_->getI64IntegerAttr(channel.handle()),
builder_->getI64IntegerAttr(channel.type()), context_));
}

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
#include <unordered_map>
@ -143,4 +143,4 @@ class HloFunctionImporter {
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_
#endif // TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_MODULE_IMPORTER_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_MODULE_IMPORTER_H_
#include <unordered_map>
@ -59,4 +59,4 @@ class HloModuleImporter {
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_
#endif // TENSORFLOW_COMPILER_MLIR_XLA_MODULE_IMPORTER_H_

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TO_MLIR_HLO_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_TO_MLIR_HLO_H_
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/xla/status.h"
@ -36,4 +36,4 @@ Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModule* hlo_module);
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TO_MLIR_HLO_H_

View File

@ -197,7 +197,7 @@ StatusOr<mlir::Type> ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,
}
}
mlir::xla_hlo::GatherDimensionNumbers CreateGatherDimensionNumbers(
mlir::mhlo::GatherDimensionNumbers CreateGatherDimensionNumbers(
const GatherDimensionNumbers& input, mlir::Builder builder) {
auto offset_dims = CreateDenseIntElementsAttrFromVector(
llvm::SmallVector<int64, 4>{input.offset_dims().begin(),
@ -215,7 +215,7 @@ mlir::xla_hlo::GatherDimensionNumbers CreateGatherDimensionNumbers(
mlir::IntegerAttr index_vector_dim =
builder.getI64IntegerAttr(input.index_vector_dim());
return mlir::xla_hlo::GatherDimensionNumbers::get(
return mlir::mhlo::GatherDimensionNumbers::get(
offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim,
builder.getContext());
}

View File

@ -15,8 +15,8 @@ limitations under the License.
// This file defines helpers useful when creating or manipulating lhlo/hlo.
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_UTILS_H_
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_UTILS_H_
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
@ -39,7 +39,7 @@ mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
StatusOr<mlir::Type> ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,
mlir::Builder builder);
mlir::xla_hlo::GatherDimensionNumbers CreateGatherDimensionNumbers(
mlir::mhlo::GatherDimensionNumbers CreateGatherDimensionNumbers(
const GatherDimensionNumbers& input, mlir::Builder builder);
template <typename TypeT>
@ -77,11 +77,11 @@ static StatusOr<mlir::Type> ConvertShapeToType(const Shape& shape,
return builder.getTupleType(contents);
}
if (shape.IsToken()) {
return mlir::xla_hlo::TokenType::get(builder.getContext());
return mlir::mhlo::TokenType::get(builder.getContext());
}
return ConvertTensorShapeToType<TypeT>(shape, builder);
}
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_UTILS_H_
#endif // TENSORFLOW_COMPILER_MLIR_XLA_UTILS_H_

View File

@ -33,8 +33,7 @@ namespace xla {
static std::string GetMlirOpName(HloOpcode opcode) {
std::string op_name = HloOpcodeString(opcode);
absl::c_replace(op_name, '-', '_');
return mlir::xla_hlo::XlaHloDialect::getDialectNamespace().str() + "." +
op_name;
return mlir::mhlo::XlaHloDialect::getDialectNamespace().str() + "." + op_name;
}
static std::string ToString(mlir::Type ty) {
@ -90,7 +89,7 @@ XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(mlir::DenseElementsAttr attr,
CreateDenseElementsAttrFromLiteral(literal, builder_));
auto op = builder_.create<mlir::xla_hlo::ConstOp>(loc_, attr);
auto op = builder_.create<mlir::mhlo::ConstOp>(loc_, attr);
return MakeXlaOp(op);
});
}
@ -108,7 +107,7 @@ StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal(
mlir::ArrayAttr config_attr;
if (precision_config)
config_attr = ConvertPrecisionConfig(precision_config, &builder_);
auto op = builder_.create<mlir::xla_hlo::ConvOp>(
auto op = builder_.create<mlir::mhlo::ConvOp>(
loc_, ty, GetValue(lhs), GetValue(rhs),
GetI64ElementsAttr(window_strides, &builder_),
ConvertPadding(padding, &builder_),
@ -125,7 +124,7 @@ StatusOr<XlaOp> MlirHloBuilder::FftInternal(
absl::Span<const int64> fft_length) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::FftOp>(
auto op = builder_.create<mlir::mhlo::FftOp>(
loc_, ty, GetValue(operand),
builder_.getStringAttr(FftType_Name(fft_type)),
GetI64ElementsAttr(fft_length, &builder_));
@ -141,7 +140,7 @@ StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
"CustomCall doesn't support operands shapes with layout");
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::CustomCallOp>(
auto op = builder_.create<mlir::mhlo::CustomCallOp>(
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
/*has_side_effect=*/builder_.getBoolAttr(false),
builder_.getStringAttr(opaque));
@ -155,13 +154,13 @@ StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
// Reduce takes two set of variadic operands inputs and init_values.
// all_operands contains both of these so split operands into two parts.
int64_t num_args = all_operands.size() / 2;
auto op = builder_.create<mlir::xla_hlo::ReduceOp>(
auto op = builder_.create<mlir::mhlo::ReduceOp>(
loc_, GetValues(all_operands.first(num_args)),
GetValues(all_operands.subspan(num_args)),
GetI64ElementsAttr(dimensions_to_reduce, &builder_));
TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
if (op.getNumResults() == 1) return MakeXlaOp(op.getResult(0));
auto tuple = builder_.create<mlir::xla_hlo::TupleOp>(loc_, op.getResults());
auto tuple = builder_.create<mlir::mhlo::TupleOp>(loc_, op.getResults());
return MakeXlaOp(tuple);
}
@ -183,7 +182,7 @@ StatusOr<XlaOp> MlirHloBuilder::ReduceWindowInternal(
auto padding_ty =
mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
builder_.getIntegerType(64));
auto op = builder_.create<mlir::xla_hlo::ReduceWindowOp>(
auto op = builder_.create<mlir::mhlo::ReduceWindowOp>(
loc_, ty, GetValue(operand), GetValue(init_value),
GetI64ElementsAttr(sizes, &builder_),
GetI64ElementsAttr(strides, &builder_),
@ -199,7 +198,7 @@ XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) {
TF_ASSIGN_OR_RETURN(
mlir::Type ty,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
auto op = builder_.create<mlir::xla_hlo::IotaOp>(
auto op = builder_.create<mlir::mhlo::IotaOp>(
loc_, ty,
builder_.getIntegerAttr(builder_.getI64Type(), iota_dimension));
return MakeXlaOp(op);
@ -210,7 +209,7 @@ StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::TransposeOp>(
auto op = builder_.create<mlir::mhlo::TransposeOp>(
loc_, ty, GetValue(operand), GetI64ElementsAttr(permutation, &builder_));
return MakeXlaOp(op);
}
@ -219,7 +218,7 @@ StatusOr<XlaOp> MlirHloBuilder::RevInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64> dimensions) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::ReverseOp>(
auto op = builder_.create<mlir::mhlo::ReverseOp>(
loc_, ty, GetValue(operand), GetI64ElementsAttr(dimensions, &builder_));
return MakeXlaOp(op);
}
@ -230,7 +229,7 @@ StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::GatherOp>(
auto op = builder_.create<mlir::mhlo::GatherOp>(
loc_, ty, GetValue(input), GetValue(start_indices),
ConvertGatherDimensionNumbers(dimension_numbers, &builder_),
GetI64ElementsAttr(slice_sizes, &builder_));
@ -244,7 +243,7 @@ StatusOr<XlaOp> MlirHloBuilder::ScatterInternal(
bool unique_indices) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::ScatterOp>(
auto op = builder_.create<mlir::mhlo::ScatterOp>(
loc_, ty, GetValue(input), GetValue(scatter_indices), GetValue(updates),
ConvertScatterDimensionNumbers(dimension_numbers, &builder_),
builder_.getBoolAttr(indices_are_sorted),
@ -262,11 +261,11 @@ StatusOr<XlaOp> MlirHloBuilder::RngOpInternal(
// and RngNormal can be mapped to the new op.
std::string op_name;
if (distribution == xla::RandomDistribution::RNG_UNIFORM) {
op_name = "xla_hlo.rng_uniform";
op_name = "mhlo.rng_uniform";
} else {
TF_RET_CHECK(distribution == xla::RandomDistribution::RNG_NORMAL)
<< "Unexpected distribution: " << distribution;
op_name = "xla_hlo.rng_normal";
op_name = "mhlo.rng_normal";
}
if (shape.is_dynamic())
@ -288,7 +287,7 @@ StatusOr<XlaOp> MlirHloBuilder::ReshapeInternal(const Shape& shape,
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
mlir::Value value = GetValue(operand);
auto op = builder_.create<mlir::xla_hlo::ReshapeOp>(loc_, ty, value);
auto op = builder_.create<mlir::mhlo::ReshapeOp>(loc_, ty, value);
return MakeXlaOp(op.getResult());
}
@ -298,7 +297,7 @@ StatusOr<XlaOp> MlirHloBuilder::DotGeneralInternal(
const PrecisionConfig* precision_config) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::DotGeneralOp>(
auto op = builder_.create<mlir::mhlo::DotGeneralOp>(
loc_, ty, GetValue(lhs), GetValue(rhs),
ConvertDotDimensionNumbers(dimension_number, &builder_),
ConvertPrecisionConfig(precision_config, &builder_));
@ -312,7 +311,7 @@ StatusOr<XlaOp> MlirHloBuilder::InDimBroadcast(
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
mlir::Value value = GetValue(operand);
auto op = builder_.create<mlir::xla_hlo::BroadcastInDimOp>(
auto op = builder_.create<mlir::mhlo::BroadcastInDimOp>(
loc_, ty, value, GetI64ElementsAttr(broadcast_dimensions, &builder_));
return MakeXlaOp(op.getResult());
}
@ -322,7 +321,7 @@ StatusOr<XlaOp> MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs,
ComparisonDirection direction) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::CompareOp>(
auto op = builder_.create<mlir::mhlo::CompareOp>(
loc_, ty, GetValue(lhs), GetValue(rhs),
builder_.getStringAttr(ComparisonDirectionToString(direction)));
return MakeXlaOp(op.getResult());
@ -343,8 +342,8 @@ StatusOr<XlaOp> MlirHloBuilder::AddOpWithShape(
XlaOp MlirHloBuilder::CreateToken() {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
return MakeXlaOp(builder_.create<mlir::xla_hlo::CreateTokenOp>(
loc_, mlir::xla_hlo::TokenType::get(builder_.getContext())));
return MakeXlaOp(builder_.create<mlir::mhlo::CreateTokenOp>(
loc_, mlir::mhlo::TokenType::get(builder_.getContext())));
});
}
@ -353,16 +352,16 @@ StatusOr<XlaOp> MlirHloBuilder::InfeedWithTokenInternal(
TF_ASSIGN_OR_RETURN(mlir::Type result_type,
ConvertShapeToType<mlir::RankedTensorType>(
infeed_instruction_shape, builder_));
return MakeXlaOp(builder_.create<mlir::xla_hlo::InfeedOp>(
loc_, result_type, GetValue(token),
return MakeXlaOp(
builder_.create<mlir::mhlo::InfeedOp>(loc_, result_type, GetValue(token),
/*infeed_config=*/config));
}
StatusOr<XlaOp> MlirHloBuilder::OutfeedWithTokenInternal(
XlaOp operand, XlaOp token, const Shape& shape_with_layout,
const string& outfeed_config) {
auto token_type = mlir::xla_hlo::TokenType::get(builder_.getContext());
return MakeXlaOp(builder_.create<mlir::xla_hlo::OutfeedOp>(
auto token_type = mlir::mhlo::TokenType::get(builder_.getContext());
return MakeXlaOp(builder_.create<mlir::mhlo::OutfeedOp>(
loc_, token_type, GetValue(operand), GetValue(token), outfeed_config));
}
@ -372,7 +371,7 @@ StatusOr<XlaOp> MlirHloBuilder::ConcatInDimInternal(
mlir::Type result_type,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
auto mlir_operands = GetValues(operands);
return MakeXlaOp(builder_.create<mlir::xla_hlo::ConcatenateOp>(
return MakeXlaOp(builder_.create<mlir::mhlo::ConcatenateOp>(
loc_, result_type, mlir_operands, builder_.getI64IntegerAttr(dimension)));
}
@ -382,7 +381,7 @@ StatusOr<XlaOp> MlirHloBuilder::GetTupleElementInternal(const Shape& shape,
TF_ASSIGN_OR_RETURN(
mlir::Type result_type,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
return MakeXlaOp(builder_.create<mlir::xla_hlo::GetTupleElementOp>(
return MakeXlaOp(builder_.create<mlir::mhlo::GetTupleElementOp>(
loc_, result_type, GetValue(tuple_data),
builder_.getI32IntegerAttr(index)));
}
@ -390,7 +389,7 @@ StatusOr<XlaOp> MlirHloBuilder::GetTupleElementInternal(const Shape& shape,
StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
return MakeXlaOp(builder_.create<mlir::xla_hlo::SliceOp>(
return MakeXlaOp(builder_.create<mlir::mhlo::SliceOp>(
loc_, GetValue(operand), GetI64ElementsAttr(start_indices, &builder_),
GetI64ElementsAttr(limit_indices, &builder_),
GetI64ElementsAttr(strides, &builder_)));
@ -402,7 +401,7 @@ StatusOr<XlaOp> MlirHloBuilder::DynamicSliceInternal(
TF_ASSIGN_OR_RETURN(
mlir::Type result_ty,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
return MakeXlaOp(builder_.create<mlir::xla_hlo::DynamicSliceOp>(
return MakeXlaOp(builder_.create<mlir::mhlo::DynamicSliceOp>(
loc_, result_ty, GetValue(operand), GetValues(start_indices),
GetI64ElementsAttr(slice_sizes, &builder_)));
}
@ -413,7 +412,7 @@ StatusOr<XlaOp> MlirHloBuilder::DynamicUpdateSliceInternal(
TF_ASSIGN_OR_RETURN(
mlir::Type result_ty,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
return MakeXlaOp(builder_.create<mlir::xla_hlo::DynamicUpdateSliceOp>(
return MakeXlaOp(builder_.create<mlir::mhlo::DynamicUpdateSliceOp>(
loc_, result_ty, GetValue(operand), GetValue(update),
GetValues(start_indices)));
}
@ -432,7 +431,7 @@ StatusOr<XlaOp> MlirHloBuilder::PadInternal(
high.push_back(dimension.edge_padding_high());
internal.push_back(dimension.interior_padding());
}
return MakeXlaOp(builder_.create<mlir::xla_hlo::PadOp>(
return MakeXlaOp(builder_.create<mlir::mhlo::PadOp>(
loc_, result_type, GetValue(operand), GetValue(padding_value),
GetI64ElementsAttr(low, &builder_), GetI64ElementsAttr(high, &builder_),
GetI64ElementsAttr(internal, &builder_)));
@ -444,7 +443,7 @@ StatusOr<XlaOp> MlirHloBuilder::TupleInternal(
for (auto& element : elements) {
operands.push_back(GetValue(element));
}
return MakeXlaOp(builder_.create<mlir::xla_hlo::TupleOp>(loc_, operands));
return MakeXlaOp(builder_.create<mlir::mhlo::TupleOp>(loc_, operands));
}
StatusOr<XlaOp> MlirHloBuilder::CreateOp(

View File

@ -34,7 +34,7 @@ limitations under the License.
namespace xla {
// Provides a way to construct xla_hlo dialect ops in MLIR using XlaBuilder
// Provides a way to construct mhlo dialect ops in MLIR using XlaBuilder
// interface.
//
// Requires that all XlaOp arguments are either returned by any of the builder

View File

@ -69,12 +69,12 @@ using ::tensorflow::uint32;
using ::tensorflow::uint64;
using ::tensorflow::uint8;
constexpr char kPaddingMapAttr[] = "xla_hlo.padding_map";
constexpr char kPaddingMapAttr[] = "mhlo.padding_map";
constexpr char kShapeIndicesAttr[] = "shape_indices";
constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
constexpr char kShardingAttr[] = "xla_hlo.sharding";
constexpr char kFrontendAttributesAttr[] = "xla_hlo.frontend_attributes";
constexpr char kRepicationAttr[] = "xla_hlo.is_same_data_across_replicas";
constexpr char kShardingAttr[] = "mhlo.sharding";
constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
constexpr char kRepicationAttr[] = "mhlo.is_same_data_across_replicas";
// Passes through everything except for unique_ptr, on which it calls get().
// This exists to allow the generated code to call XLA functions that take a raw
@ -247,7 +247,7 @@ static std::unique_ptr<xla::PrecisionConfig> Convert_precision_config(
}
static xla::DotDimensionNumbers Convert_dot_dimension_numbers(
mlir::xla_hlo::DotDimensionNumbers dot_dimension_numbers_attr) {
mlir::mhlo::DotDimensionNumbers dot_dimension_numbers_attr) {
xla::DotDimensionNumbers dot_dimension_numbers;
auto rhs_contracting_dimensions =
@ -282,7 +282,7 @@ static xla::DotDimensionNumbers Convert_dot_dimension_numbers(
}
static xla::ConvolutionDimensionNumbers Convert_dimension_numbers(
mlir::xla_hlo::ConvDimensionNumbers input) {
mlir::mhlo::ConvDimensionNumbers input) {
xla::ConvolutionDimensionNumbers output;
output.set_input_batch_dimension(
@ -315,7 +315,7 @@ static xla::ConvolutionDimensionNumbers Convert_dimension_numbers(
return output;
}
xla::ChannelHandle Convert_channel_handle(mlir::xla_hlo::ChannelHandle attr) {
xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandle attr) {
xla::ChannelHandle channel_handle;
channel_handle.set_handle(ConvertAPInt(attr.handle().getValue()));
channel_handle.set_type(static_cast<xla::ChannelHandle::ChannelType>(
@ -333,7 +333,7 @@ static xla::ComparisonDirection Convert_comparison_direction(
}
static xla::GatherDimensionNumbers Convert_dimension_numbers(
mlir::xla_hlo::GatherDimensionNumbers input) {
mlir::mhlo::GatherDimensionNumbers input) {
xla::GatherDimensionNumbers output;
auto offset_dims = ConvertDenseIntAttr(input.offset_dims());
@ -357,7 +357,7 @@ static xla::GatherDimensionNumbers Convert_dimension_numbers(
}
static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
mlir::xla_hlo::ScatterDimensionNumbers input) {
mlir::mhlo::ScatterDimensionNumbers input) {
xla::ScatterDimensionNumbers output;
auto update_window_dims = ConvertDenseIntAttr(input.update_window_dims());
@ -574,7 +574,7 @@ llvm::SmallVector<xla::XlaOp, 4> GetTuple(mlir::Operation::operand_range values,
} // namespace
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) {
@ -829,7 +829,7 @@ LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) {
}
LogicalResult ExportXlaOp(ReturnOp op, OpLoweringContext ctx) {
// Failure on purpose because `xla_hlo::ReturnOp` will be handled by
// Failure on purpose because `mhlo::ReturnOp` will be handled by
// special purpose logic in `ConvertToHloModule::Lower`.
return failure();
}
@ -943,7 +943,7 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
}
} // namespace
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir
#include "tensorflow/compiler/mlir/xla/operator_writers.inc"
@ -1060,7 +1060,7 @@ LogicalResult ConvertToHloModule::Lower(
return success();
}
if (isa<xla_hlo::ReturnOp, mlir::ReturnOp>(inst)) {
if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
// Construct the return value for the function. If there are multiple
// values returned, then create a tuple, else return value directly.
xla::XlaOp return_value;
@ -1405,7 +1405,7 @@ void AddDynamicParameterBindingEntry(xla::DynamicParameterBindingProto* binding,
}
// Validates and populates dynamic parameter bindings from a module's entry
// function `xla_hlo.padding_map` argument attributes to a `xla::HloModuleProto`
// function `mhlo.padding_map` argument attributes to a `xla::HloModuleProto`
// `DynamicParameterBindingProto`.
LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
xla::HloModuleProto* hlo_module_proto,

View File

@ -73,8 +73,8 @@ static StringRef GetClientBuilder(const Operator& op) {
}
static void BuildOperator(const Operator& op, raw_ostream& os) {
os << "mlir::LogicalResult ExportXlaOp(mlir::xla_hlo::"
<< op.getCppClassName() << " op, OpLoweringContext ctx) {\n"
os << "mlir::LogicalResult ExportXlaOp(mlir::mhlo::" << op.getCppClassName()
<< " op, OpLoweringContext ctx) {\n"
<< " auto& value_map = *ctx.values;\n"
<< " auto result = op.getResult();\n";
@ -164,12 +164,12 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
Operator op(def);
// Cast to the current operation and build the exporter.
os << " if (auto xla_op = llvm::dyn_cast<mlir::xla_hlo::"
os << " if (auto xla_op = llvm::dyn_cast<mlir::mhlo::"
<< op.getCppClassName() << ">(op)) {\n";
os << " return ";
// The autogenerated converters aren't in the same namespace.
// TODO(jpienaar): Reconsider this.
if (def->getValueAsBit("hasCustomHLOConverter")) os << "mlir::xla_hlo::";
if (def->getValueAsBit("hasCustomHLOConverter")) os << "mlir::mhlo::";
os << "ExportXlaOp(xla_op, lowering_context);\n";
os << " }\n";
}

View File

@ -7,7 +7,7 @@ func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.abs
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%abs = "xla_hlo.abs"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%abs = "mhlo.abs"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %abs : tensor<2x2xf32>
}
@ -22,7 +22,7 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
// CHECK: lhlo.add
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "xla_hlo.add"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%res = "mhlo.add"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
@ -37,7 +37,7 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32
// CHECK: lhlo.and
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "xla_hlo.and"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
%res = "mhlo.and"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %res : tensor<2x2xi32>
}
@ -50,7 +50,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.ceil
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "xla_hlo.ceil"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%res = "mhlo.ceil"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
@ -65,7 +65,7 @@ func @main(%value0: tensor<1x2xf32>, %value1: tensor<1x2xf32>) -> tensor<1x2xcom
// CHECK: lhlo.complex
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "xla_hlo.complex"(%value0, %value1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
%res = "mhlo.complex"(%value0, %value1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
return %res : tensor<1x2xcomplex<f32>>
}
@ -79,7 +79,7 @@ func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xcomplex<f32>> {
// CHECK: lhlo.cosine
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
// CHECK-NEXT: return
%res = "xla_hlo.cosine"(%value0) : (tensor<1x2xcomplex<f32>>) -> tensor<1x2xcomplex<f32>>
%res = "mhlo.cosine"(%value0) : (tensor<1x2xcomplex<f32>>) -> tensor<1x2xcomplex<f32>>
return %res : tensor<1x2xcomplex<f32>>
}
@ -94,7 +94,7 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
// CHECK: lhlo.divide
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "xla_hlo.divide"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%res = "mhlo.divide"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
@ -107,7 +107,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.exponential
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "xla_hlo.exponential"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%res = "mhlo.exponential"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
@ -120,7 +120,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.log
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "xla_hlo.log"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%res = "mhlo.log"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
@ -135,7 +135,7 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
// CHECK: lhlo.maximum
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "xla_hlo.maximum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%res = "mhlo.maximum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
@ -150,7 +150,7 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
// CHECK: lhlo.minimum
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "xla_hlo.minimum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%res = "mhlo.minimum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
@ -165,7 +165,7 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
// CHECK: lhlo.multiply
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "xla_hlo.multiply"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%res = "mhlo.multiply"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
@ -178,7 +178,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.negate
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "xla_hlo.negate"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%res = "mhlo.negate"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
@ -191,7 +191,7 @@ func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32>
// CHECK: lhlo.real
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "xla_hlo.real"(%value0) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
%res = "mhlo.real"(%value0) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
return %res : tensor<1x2xf32>
}
@ -204,7 +204,7 @@ func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32>
// CHECK: lhlo.imag
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "xla_hlo.imag"(%value0) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
%res = "mhlo.imag"(%value0) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
return %res : tensor<1x2xf32>
}
@ -219,7 +219,7 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32
// CHECK: lhlo.remainder
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "xla_hlo.remainder"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
%res = "mhlo.remainder"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %res : tensor<2x2xi32>
}
@ -232,7 +232,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.rsqrt
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "xla_hlo.rsqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%res = "mhlo.rsqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
@ -248,7 +248,7 @@ func @main(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>)
// CHECK: lhlo.select
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[VIEW]]
// CHECK-NEXT: return
%0 = "xla_hlo.select"(%pred, %lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
%0 = "mhlo.select"(%pred, %lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
return %0 : tensor<2x2xf32>
}
@ -261,7 +261,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.sign
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "xla_hlo.sign"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%res = "mhlo.sign"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
@ -274,7 +274,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.sqrt
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "xla_hlo.sqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%res = "mhlo.sqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
@ -289,7 +289,7 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32
// CHECK: lhlo.subtract
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "xla_hlo.subtract"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
%res = "mhlo.subtract"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %res : tensor<2x2xi32>
}
@ -302,7 +302,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.tanh
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "xla_hlo.tanh"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
%res = "mhlo.tanh"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
@ -317,10 +317,10 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW1:.*]] = std.view %[[ARG3]]{{.*}} : memref<100xi8> to memref<5x5xf32>
// CHECK: "xla_lhlo.sort"(%[[ARG0]], %[[ARG1]], %[[VIEW0]], %[[VIEW1]])
func @main(%key: tensor<5x5xi32>, %value: tensor<5x5xf32>) -> tuple<tensor<5x5xi32>, tensor<5x5xf32>> {
%res = "xla_hlo.sort"(%key, %value) ({
%res = "mhlo.sort"(%key, %value) ({
^bb0(%a: tensor<i32>, %b: tensor<i32>, %c: tensor<f32>, %d: tensor<f32>):
%ret = "xla_hlo.compare"(%c, %d) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%ret) : (tensor<i1>) -> ()
%ret = "mhlo.compare"(%c, %d) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%ret) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true}: (tensor<5x5xi32>, tensor<5x5xf32>) -> tuple<tensor<5x5xi32>, tensor<5x5xf32>>
return %res : tuple<tensor<5x5xi32>, tensor<5x5xf32>>

View File

@ -15,11 +15,11 @@ func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) ->
// CHECK: [[BCASTHEAD:%.*]] = "shape.broadcast"([[LHSHEAD]], [[RHSHEAD]]) : (!shape.shape, !shape.shape) -> !shape.shape
// CHECK: [[LHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[LHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape
// CHECK: [[LHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[LHSBCASTSHAPE]] : tensor<3xindex>
// CHECK: [[LHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32>
// CHECK: [[LHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32>
// CHECK: [[RHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[RHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape
// CHECK: [[RHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[RHSBCASTSHAPE]] : tensor<3xindex>
// CHECK: [[RHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32>
// CHECK: [[RESULT:%.*]] = "xla_hlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
// CHECK: [[RHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32>
// CHECK: [[RESULT:%.*]] = "mhlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
// CHECK: return [[RESULT]] : tensor<3x4x4xf32>
// CHECK: }
@ -29,9 +29,9 @@ func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) ->
func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> {
// CHECK-LABEL: func @batchmatmulv2_lhs_batch
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
@ -42,9 +42,9 @@ func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>)
func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> {
// CHECK-LABEL: func @batchmatmulv2_rhs_batch
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
@ -55,7 +55,7 @@ func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>)
func @batchmatmulv2_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
// CHECK-LABEL: func @batchmatmulv2_dynamic
// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
@ -66,7 +66,7 @@ func @batchmatmulv2_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>)
func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> {
// CHECK-LABEL: func @batchmatmulv2_adj_real
// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
@ -78,14 +78,14 @@ func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) ->
func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex<f32>>, %arg1: tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
// CHECK-LABEL: func @batchmatmulv2_adj_complex(
// CHECK-SAME: [[LHS:%.*]]: tensor<5x2xcomplex<f32>>, [[RHS:%.*]]: tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
// CHECK: [[LHSRE:%.*]] = "xla_hlo.real"([[LHS]])
// CHECK: [[LHSIM:%.*]] = "xla_hlo.imag"([[LHS]])
// CHECK: [[LHSIMNEG:%.*]] = "xla_hlo.negate"([[LHSIM]])
// CHECK: [[LHSCONJ:%.*]] = "xla_hlo.complex"([[LHSRE]], [[LHSIMNEG]])
// CHECK: [[RHSRE:%.*]] = "xla_hlo.real"([[RHS]])
// CHECK: [[RHSIM:%.*]] = "xla_hlo.imag"([[RHS]])
// CHECK: [[RHSIMNEG:%.*]] = "xla_hlo.negate"([[RHSIM]])
// CHECK: [[RHSCONJ:%.*]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]])
// CHECK: [[LHSRE:%.*]] = "mhlo.real"([[LHS]])
// CHECK: [[LHSIM:%.*]] = "mhlo.imag"([[LHS]])
// CHECK: [[LHSIMNEG:%.*]] = "mhlo.negate"([[LHSIM]])
// CHECK: [[LHSCONJ:%.*]] = "mhlo.complex"([[LHSRE]], [[LHSIMNEG]])
// CHECK: [[RHSRE:%.*]] = "mhlo.real"([[RHS]])
// CHECK: [[RHSIM:%.*]] = "mhlo.imag"([[RHS]])
// CHECK: [[RHSIMNEG:%.*]] = "mhlo.negate"([[RHSIM]])
// CHECK: [[RHSCONJ:%.*]] = "mhlo.complex"([[RHSRE]], [[RHSIMNEG]])
// CHECK: shape.shape_of [[LHSCONJ]]
// CHECK: shape.shape_of [[RHSCONJ]]
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex<f32>>, tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>

View File

@ -11,8 +11,8 @@
// CHECK-LABEL: func @add
func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-NEXT: %[[SUM0:.*]] = xla_hlo.add %arg0, %arg0 : tensor<2xi32>
// CHECK-NEXT: %[[SUM1:.*]] = xla_hlo.add %[[SUM0]], %arg0 : tensor<2xi32>
// CHECK-NEXT: %[[SUM0:.*]] = mhlo.add %arg0, %arg0 : tensor<2xi32>
// CHECK-NEXT: %[[SUM1:.*]] = mhlo.add %[[SUM0]], %arg0 : tensor<2xi32>
// CHECK-NEXT: return %[[SUM1]] : tensor<2xi32>
%0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
%1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
@ -24,8 +24,8 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// patterns unambiguous and more interesting (once broadcastable trait is
// fixed upstream).
func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-NEXT: xla_hlo.add %[[LHS_BCAST]], %arg1
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0: tensor<1x2xi32>
}
@ -34,8 +34,8 @@ func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
// TODO(laurenzo): Change this to a (4x1x1 + 1x4x4x4) shaped add once upstream
// broadcastable bug is fixed (helps make the CHECK matching unambiguous)
func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}
// CHECK-NEXT: xla_hlo.add %[[LHS_BCAST]], %arg1
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}
// CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
return %0: tensor<4x4x4x4xi32>
}
@ -50,9 +50,9 @@ func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1
// CHECK-NEXT: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-NEXT: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK-NEXT: %[[RESULT:.+]] = xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] : tensor<?x?xi32>
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-NEXT: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] : tensor<?x?xi32>
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
%0 = "tf.Add"(%arg0, %arg1) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
return %0: tensor<?x?xi32>
@ -60,7 +60,7 @@ func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
// CHECK-LABEL: func @div
func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32>
// CHECK-NEXT: %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32>
// CHECK-NEXT: return %0 : tensor<2xi32>
%0 = "tf.Div"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0: tensor<2xi32>
@ -68,7 +68,7 @@ func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-LABEL: func @shift_left
func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK: xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
// CHECK: mhlo.shift_left %arg0, %arg1 : tensor<4xi32>
%0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
@ -82,21 +82,21 @@ func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi
// CHECK-LABEL: func @maximum
func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
// CHECK-NEXT: mhlo.maximum %arg0, %arg1 : tensor<4xf32>
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// CHECK-LABEL: func @minimum
func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
// CHECK-NEXT: mhlo.minimum %arg0, %arg1 : tensor<4xf32>
%0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// CHECK-LABEL: func @mul
func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32>
// CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32>
// CHECK-NEXT: return %0 : tensor<2xi32>
%0 = "tf.Mul"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0: tensor<2xi32>
@ -104,14 +104,14 @@ func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-LABEL: func @real_div
func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32>
// CHECK-NEXT: %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32>
%0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0: tensor<2xi32>
}
// CHECK-LABEL: func @sub
func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-NEXT: %0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32>
// CHECK-NEXT: %0 = mhlo.subtract %arg0, %arg0 : tensor<2xi32>
// CHECK-NEXT: return %0 : tensor<2xi32>
%0 = "tf.Sub"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0: tensor<2xi32>
@ -119,7 +119,7 @@ func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-LABEL: func @shift_right
func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
// CHECK: mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
@ -140,7 +140,7 @@ func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8
// CHECK-LABEL: func @and
func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> {
// CHECK-NEXT: xla_hlo.and
// CHECK-NEXT: mhlo.and
%0 = "tf.LogicalAnd"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1>
return %0: tensor<2xi1>
}
@ -154,28 +154,28 @@ func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> {
// CHECK-LABEL: func @or
func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> {
// CHECK-NEXT: xla_hlo.or
// CHECK-NEXT: mhlo.or
%0 = "tf.LogicalOr"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1>
return %0: tensor<2xi1>
}
// CHECK-LABEL: func @bitwise_or
func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-NEXT: xla_hlo.or
// CHECK-NEXT: mhlo.or
%0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0: tensor<4xi32>
}
// CHECK-LABEL: func @bitwise_and
func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-NEXT: xla_hlo.and
// CHECK-NEXT: mhlo.and
%0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0: tensor<4xi32>
}
// CHECK-LABEL: func @pow
func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-NEXT: xla_hlo.power
// CHECK-NEXT: mhlo.power
%0 = "tf.Pow"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %0: tensor<2xf32>
}
@ -188,7 +188,7 @@ func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-LABEL: func @equal
func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"}
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"}
%0 = "tf.Equal"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0: tensor<2xi1>
}
@ -202,9 +202,9 @@ func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1>
// CHECK-DAG: %[[LHS_SHAPE1:.+]] = shape.shape_of %arg0
// CHECK-NEXT: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE1]], %[[RHS_SHAPE]])
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK-NEXT: %[[RESULT:.+]] = "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"}
// CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK-NEXT: %[[RESULT:.+]] = "mhlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"}
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
return %0: tensor<?xi1>
@ -212,8 +212,8 @@ func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1>
// CHECK-LABEL: func @equal_broadcast
func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-NEXT: "xla_hlo.compare"(%[[LHS_BCAST]], %arg1) {comparison_direction = "EQ"}
// CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-NEXT: "mhlo.compare"(%[[LHS_BCAST]], %arg1) {comparison_direction = "EQ"}
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0: tensor<1x2xi1>
}
@ -255,7 +255,7 @@ func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1>
// CHECK-LABEL: func @notequal
func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"}
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"}
%0 = "tf.NotEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0: tensor<2xi1>
}
@ -268,15 +268,15 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-LABEL: func @greater
func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
// CHECK: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"}
// CHECK: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"}
%0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0: tensor<2xi1>
}
// CHECK-LABEL: func @broadcast_greater
func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-NEXT: "xla_hlo.compare"(%[[LHS_BCAST]], %arg1) {comparison_direction = "GT"}
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-NEXT: "mhlo.compare"(%[[LHS_BCAST]], %arg1) {comparison_direction = "GT"}
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0: tensor<1x2xi1>
}
@ -291,9 +291,9 @@ func @greater_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi1
// CHECK-DAG: %[[RHS_SHAPE1:.+]] = shape.shape_of %arg1
// CHECK-NEXT: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE1]], %[[RHS_SHAPE1]])
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK-NEXT: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"}
// CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK-NEXT: "mhlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"}
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi1>
return %0: tensor<?xi1>
}
@ -307,21 +307,21 @@ func @greater_uranked(%arg0: tensor<*xi32>) -> tensor<*xi1> {
// CHECK-LABEL: func @greater_equal
func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"}
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"}
%0 = "tf.GreaterEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0: tensor<2xi1>
}
// CHECK-LABEL: func @less
func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"}
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"}
%0 = "tf.Less"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0: tensor<2xi1>
}
// CHECK-LABEL: func @less_equal
func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"}
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"}
%0 = "tf.LessEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0: tensor<2xi1>
}

View File

@ -3,40 +3,40 @@
// CHECK-LABEL: @if
func @if(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>)
attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} {
// CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: [[VAL1:%.+]] = "xla_hlo.tuple"(%arg0, %arg1)
// CHECK: [[VAL2:%.+]] = "xla_hlo.if"([[VAL0]], [[VAL1]], [[VAL1]]) ( {
// CHECK: [[VAL0:%.+]] = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
%0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: [[VAL1:%.+]] = "mhlo.tuple"(%arg0, %arg1)
// CHECK: [[VAL2:%.+]] = "mhlo.if"([[VAL0]], [[VAL1]], [[VAL1]]) ( {
// CHECK: ^bb0(%arg2: tuple<tensor<f32>, tensor<f32>>):
// CHECK: [[VAL4:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 0 : i32}
// CHECK: [[VAL5:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 1 : i32}
// CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32}
// CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"(%arg2) {index = 1 : i32}
// CHECK: [[VAL6:%.+]] = call @cond_true([[VAL4]], [[VAL5]])
// CHECK: [[VAL7:%.+]] = "xla_hlo.tuple"([[VAL6]])
// CHECK: "xla_hlo.return"([[VAL7]]) : (tuple<tensor<f32>>) -> ()
// CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]])
// CHECK: "mhlo.return"([[VAL7]]) : (tuple<tensor<f32>>) -> ()
// CHECK: }, {
// CHECK: ^bb0(%arg2: tuple<tensor<f32>, tensor<f32>>)
// CHECK: [[VAL4:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 0 : i32}
// CHECK: [[VAL5:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 1 : i32}
// CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32}
// CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"(%arg2) {index = 1 : i32}
// CHECK: [[VAL6:%.+]] = call @cond_false([[VAL4]], [[VAL5]])
// CHECK: [[VAL7:%.+]] = "xla_hlo.tuple"([[VAL6]])
// CHECK: "xla_hlo.return"([[VAL7]]) : (tuple<tensor<f32>>) -> ()
// CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]])
// CHECK: "mhlo.return"([[VAL7]]) : (tuple<tensor<f32>>) -> ()
// CHECK: })
%1 = "tf.If"(%0, %arg0, %arg1) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _lower_using_switch_merge = true, _output_shapes = ["tfshape$"], device = "", else_branch = @cond_false, is_stateless = true, name = "cond", output_shapes = [#tf.shape<>], then_branch = @cond_true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: [[VAL3:%.+]] = "xla_hlo.get_tuple_element"([[VAL2]]) {index = 0 : i32}
// CHECK: [[VAL3:%.+]] = "mhlo.get_tuple_element"([[VAL2]]) {index = 0 : i32}
// CHECK: return [[VAL3]]
return %1 : tensor<f32>
}
func @cond_false(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32>
attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} {
%0 = "xla_hlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
%0 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
func @cond_true(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32>
attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} {
%0 = "xla_hlo.log"(%arg0) : (tensor<f32>) -> tensor<f32>
%0 = "mhlo.log"(%arg0) : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
@ -45,42 +45,42 @@ attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} {
// CHECK-SAME: %[[BRANCH_INDEX:.*]]: tensor<i32>, %[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>)
func @case(%index: tensor<i32>, %arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%0:2 = "tf.Case"(%index, %arg0, %arg1) {branches = [@exponential, @log, @floor]} : (tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// CHECK: %[[TUPLE_INPUT:.*]] = "xla_hlo.tuple"(%[[ARG0]], %[[ARG1]]) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
// CHECK: %[[CASE:.*]]:2 = "xla_hlo.case"(%[[BRANCH_INDEX]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]]) ( {
// CHECK: %[[TUPLE_INPUT:.*]] = "mhlo.tuple"(%[[ARG0]], %[[ARG1]]) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
// CHECK: %[[CASE:.*]]:2 = "mhlo.case"(%[[BRANCH_INDEX]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]]) ( {
// CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple<tensor<f32>, tensor<f32>>):
// CHECK: %[[TUPLE_ELEMENT_0:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
// CHECK: %[[TUPLE_ELEMENT_1:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
// CHECK: %[[TUPLE_ELEMENT_0:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
// CHECK: %[[TUPLE_ELEMENT_1:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
// CHECK: %[[CALL_EXP:.*]]:2 = call @exponential(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// CHECK: "xla_hlo.return"(%[[CALL_EXP]]#0, %[[CALL_EXP]]#1) : (tensor<f32>, tensor<f32>) -> ()
// CHECK: "mhlo.return"(%[[CALL_EXP]]#0, %[[CALL_EXP]]#1) : (tensor<f32>, tensor<f32>) -> ()
// CHECK: }, {
// CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple<tensor<f32>, tensor<f32>>):
// CHECK: %[[TUPLE_ELEMENT_0:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
// CHECK: %[[TUPLE_ELEMENT_1:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
// CHECK: %[[TUPLE_ELEMENT_0:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
// CHECK: %[[TUPLE_ELEMENT_1:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
// CHECK: %[[CALL_LOG:.*]]:2 = call @log(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// CHECK: "xla_hlo.return"(%[[CALL_LOG]]#0, %[[CALL_LOG]]#1) : (tensor<f32>, tensor<f32>) -> ()
// CHECK: "mhlo.return"(%[[CALL_LOG]]#0, %[[CALL_LOG]]#1) : (tensor<f32>, tensor<f32>) -> ()
// CHECK: }, {
// CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple<tensor<f32>, tensor<f32>>):
// CHECK: %[[TUPLE_ELEMENT_0:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
// CHECK: %[[TUPLE_ELEMENT_1:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
// CHECK: %[[TUPLE_ELEMENT_0:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
// CHECK: %[[TUPLE_ELEMENT_1:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
// CHECK: %[[CALL_FLOOR:.*]]:2 = call @floor(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
// CHECK: "xla_hlo.return"(%[[CALL_FLOOR]]#0, %[[CALL_FLOOR]]#1) : (tensor<f32>, tensor<f32>) -> ()
// CHECK: "mhlo.return"(%[[CALL_FLOOR]]#0, %[[CALL_FLOOR]]#1) : (tensor<f32>, tensor<f32>) -> ()
// CHECK: }) : (tensor<i32>, tuple<tensor<f32>, tensor<f32>>, tuple<tensor<f32>, tensor<f32>>, tuple<tensor<f32>, tensor<f32>>) -> (tensor<f32>, tensor<f32>)
return %0#0, %0#1 : tensor<f32>, tensor<f32>
// CHECK: return %[[CASE]]#0, %[[CASE]]#1 : tensor<f32>, tensor<f32>
}
func @exponential(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%0 = "xla_hlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
%0 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
return %0, %arg1 : tensor<f32>, tensor<f32>
}
func @log(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%0 = "xla_hlo.log"(%arg0) : (tensor<f32>) -> tensor<f32>
%0 = "mhlo.log"(%arg0) : (tensor<f32>) -> tensor<f32>
return %0, %arg1 : tensor<f32>, tensor<f32>
}
func @floor(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%0 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
%0 = "mhlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
return %0, %arg1 : tensor<f32>, tensor<f32>
}
@ -88,44 +88,44 @@ func @floor(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>
// CHECK-LABEL: func @while
func @while(%arg0: tensor<f32> {tf_saved_model.index_path = [0]}) -> (tensor<i32> {tf_saved_model.index_path = []})
attributes {tf._input_shapes = ["tfshape$"]} {
// CHECK: [[VAL0:%.+]] = xla_hlo.constant dense<0>
// CHECK: [[VAL1:%.+]] = xla_hlo.constant dense<-1>
%0 = xla_hlo.constant dense<0> : tensor<i32>
%1 = xla_hlo.constant dense<-1> : tensor<i32>
// CHECK: [[VAL2:%.+]] = "xla_hlo.tuple"([[VAL0]], [[VAL1]], [[VAL0]])
// CHECK: [[VAL3:%.+]] = "xla_hlo.while"([[VAL2]]) ( {
// CHECK: [[VAL0:%.+]] = mhlo.constant dense<0>
// CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1>
%0 = mhlo.constant dense<0> : tensor<i32>
%1 = mhlo.constant dense<-1> : tensor<i32>
// CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[VAL0]], [[VAL1]], [[VAL0]])
// CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( {
// CHECK: ^bb0(%arg1: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
// CHECK: [[VAL7:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32}
// CHECK: [[VAL8:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 1 : i32}
// CHECK: [[VAL9:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 2 : i32}
// CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32}
// CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 1 : i32}
// CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 2 : i32}
// CHECK: [[VAL10:%.+]] = call @while_cond([[VAL7]], [[VAL8]], [[VAL9]])
// CHECK: "xla_hlo.return"([[VAL10]])
// CHECK: "mhlo.return"([[VAL10]])
// CHECK: }, {
// CHECK: ^bb0(%arg1: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
// CHECK: [[VAL7:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32}
// CHECK: [[VAL8:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 1 : i32}
// CHECK: [[VAL9:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 2 : i32}
// CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32}
// CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 1 : i32}
// CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 2 : i32}
// CHECK: [[VAL10:%.+]]:3 = call @while_body([[VAL7]], [[VAL8]], [[VAL9]])
// CHECK: [[VAL11:%.+]] = "xla_hlo.tuple"([[VAL10]]#0, [[VAL10]]#1, [[VAL10]]#2)
// CHECK: "xla_hlo.return"([[VAL11]])
// CHECK: [[VAL11:%.+]] = "mhlo.tuple"([[VAL10]]#0, [[VAL10]]#1, [[VAL10]]#2)
// CHECK: "mhlo.return"([[VAL11]])
// CHECK: }) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
// CHECK: [[VAL4:%.+]] = "xla_hlo.get_tuple_element"([[VAL3]]) {index = 0 : i32}
// CHECK: [[VAL5:%.+]] = "xla_hlo.get_tuple_element"([[VAL3]]) {index = 1 : i32}
// CHECK: [[VAL6:%.+]] = "xla_hlo.get_tuple_element"([[VAL3]]) {index = 2 : i32}
// CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32}
// CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 1 : i32}
// CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 2 : i32}
// CHECK: return [[VAL6]]
%2:3 = "tf.While"(%0, %1, %0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_INT32", "tfdtype$DT_INT32"], _lower_using_switch_merge = true, _num_original_outputs = 3 : i64, _output_shapes = ["tfshape$", "tfshape$", "tfshape$"], body = @while_body, cond = @while_cond, device = "", is_stateless = true, name = "while", output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
return %2#2 : tensor<i32>
}
func @while_cond(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i1>
attributes {tf._input_shapes = ["tfshape$", "tfshape$", "tfshape$"]} {
%0 = xla_hlo.constant dense<10> : tensor<i32>
%1 = "xla_hlo.compare"(%arg2, %0) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
%0 = mhlo.constant dense<10> : tensor<i32>
%1 = "mhlo.compare"(%arg2, %0) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
return %1 : tensor<i1>
}
func @while_body(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
attributes {tf._input_shapes = ["tfshape$", "tfshape$", "tfshape$"]} {
%0 = xla_hlo.constant dense<1> : tensor<i32>
%1 = xla_hlo.add %arg2, %0 : tensor<i32>
%2 = xla_hlo.add %arg0, %0 : tensor<i32>
%0 = mhlo.constant dense<1> : tensor<i32>
%1 = mhlo.add %arg2, %0 : tensor<i32>
%2 = mhlo.add %arg0, %0 : tensor<i32>
return %2, %arg1, %1 : tensor<i32>, tensor<i32>, tensor<i32>
}

View File

@ -7,7 +7,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// CHECK-LABEL: abs
// expected-error@+1 {{unsupported device}}
func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: %[[RESULT:.*]] = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: %[[RESULT:.*]] = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
%0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
// return %[[RESULT]]
@ -54,7 +54,7 @@ func @dynamic_operand(%arg0: tensor<?xf32>) -> tensor<?xf32> {
func @tuple_type(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
// Verifies that the pass can handle operands of non-tensor type like tuple
// from non TensorFlow ops.
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
%0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
return %0 : tensor<f32>
}
@ -69,9 +69,9 @@ func @unsupported_dtype(%arg0: tensor<2x!tf.variant>) -> tensor<2x!tf.variant> {
// CHECK-LABEL: multiple_dialect_ops
func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: xla_hlo.negate
%0 = "xla_hlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: xla_hlo.abs
// CHECK: mhlo.negate
%0 = "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: mhlo.abs
%1 = "tf.Abs"(%0) : (tensor<2xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32>
@ -79,21 +79,21 @@ func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-LABEL: binary_op
func @binary_op(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: xla_hlo.atan2 %arg0, %arg1 : tensor<2xf32>
// CHECK: mhlo.atan2 %arg0, %arg1 : tensor<2xf32>
%0 = "tf.Atan2"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// CHECK-LABEL: binary_op_broadcast
func @binary_op_broadcast(%arg0: tensor<4x1xf32>, %arg1: tensor<4x1x4xf32>) -> tensor<4x4x4xf32> {
// CHECK: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x1xf32>) -> tensor<4x4x1xf32>
// CHECK: %[[RESHAPE0:.*]] = "xla_hlo.reshape"(%[[BROADCAST0]]) : (tensor<4x4x1xf32>) -> tensor<4x4xf32>
// CHECK: %[[UPDATED_ARG0:.*]] = "xla_hlo.broadcast_in_dim"(%[[RESHAPE0]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32>
// CHECK: %[[BROADCAST0:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x1xf32>) -> tensor<4x4x1xf32>
// CHECK: %[[RESHAPE0:.*]] = "mhlo.reshape"(%[[BROADCAST0]]) : (tensor<4x4x1xf32>) -> tensor<4x4xf32>
// CHECK: %[[UPDATED_ARG0:.*]] = "mhlo.broadcast_in_dim"(%[[RESHAPE0]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32>
// CHECK: %[[RESHAPE1:.*]] = "xla_hlo.reshape"(%arg1) : (tensor<4x1x4xf32>) -> tensor<4x4xf32>
// CHECK: %[[UPDATED_ARG1:.*]] = "xla_hlo.broadcast_in_dim"(%[[RESHAPE1]]) {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32>
// CHECK: %[[RESHAPE1:.*]] = "mhlo.reshape"(%arg1) : (tensor<4x1x4xf32>) -> tensor<4x4xf32>
// CHECK: %[[UPDATED_ARG1:.*]] = "mhlo.broadcast_in_dim"(%[[RESHAPE1]]) {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32>
// CHECK: %[[RESULT:.*]] = xla_hlo.atan2 %[[UPDATED_ARG0]], %[[UPDATED_ARG1]] : tensor<4x4x4xf32>
// CHECK: %[[RESULT:.*]] = mhlo.atan2 %[[UPDATED_ARG0]], %[[UPDATED_ARG1]] : tensor<4x4x4xf32>
// CHECK: return %[[RESULT]] : tensor<4x4x4xf32>
%0 = "tf.Atan2"(%arg0, %arg1) : (tensor<4x1xf32>, tensor<4x1x4xf32>) -> tensor<4x4x4xf32>
@ -102,23 +102,23 @@ func @binary_op_broadcast(%arg0: tensor<4x1xf32>, %arg1: tensor<4x1x4xf32>) -> t
// CHECK-LABEL: func @ternary_op
func @ternary_op(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
// CHECK: "xla_hlo.select"(%arg0, %arg1, %arg2)
// CHECK: "mhlo.select"(%arg0, %arg1, %arg2)
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0: tensor<2xi32>
}
// CHECK-LABEL: func @convert
func @convert(%arg0: tensor<2xi32>) -> tensor<2xf32> {
// CHECK: "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
// CHECK: "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
%0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<2xi32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// CHECK-LABEL: func @constant
func @constant(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: %[[SCALAR_ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[ONE:.*]] = "xla_hlo.broadcast_in_dim"(%[[SCALAR_ONE]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<2xf32>
// CHECK: %[[RESULT:.*]] = xla_hlo.divide %[[ONE]], %arg0 : tensor<2xf32>
// CHECK: %[[SCALAR_ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[ONE:.*]] = "mhlo.broadcast_in_dim"(%[[SCALAR_ONE]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<2xf32>
// CHECK: %[[RESULT:.*]] = mhlo.divide %[[ONE]], %arg0 : tensor<2xf32>
// CHECK: return %[[RESULT]]
%0 = "tf.Inv"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
@ -127,7 +127,7 @@ func @constant(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-LABEL: func @greater
func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"}
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"}
%0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0: tensor<2xi1>
}
@ -136,14 +136,14 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x2xf64>, %[[ARG1:.*]]: tensor<f64>,
func @const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor<f64>, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>, %arg4: tensor<2xi32>) -> tensor<6x5xf64> {
// CHECK: "xla_hlo.pad"(%[[ARG0]], %[[ARG1]])
// CHECK: "mhlo.pad"(%[[ARG0]], %[[ARG1]])
// CHECK-SAME-DAG: edge_padding_high = dense<[1, 2]> : tensor<2xi64>
// CHECK-SAME-DAG: edge_padding_low = dense<[2, 1]> : tensor<2xi64>
// CHECK-SAME-DAG: interior_padding = dense<[1, 0]> : tensor<2xi64>
%0 = xla_hlo.constant dense<[2, 1]> : tensor<2xi32>
%1 = xla_hlo.constant dense<[1, 2]> : tensor<2xi32>
%2 = xla_hlo.constant dense<[1, 0]> : tensor<2xi32>
%0 = mhlo.constant dense<[2, 1]> : tensor<2xi32>
%1 = mhlo.constant dense<[1, 2]> : tensor<2xi32>
%2 = mhlo.constant dense<[1, 0]> : tensor<2xi32>
%3 = "tf.XlaPad"(%arg0, %arg1, %0, %1, %2) : (tensor<2x2xf64>, tensor<f64>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<6x5xf64>
return %3 : tensor<6x5xf64>
}
@ -156,7 +156,7 @@ func @non_const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor<f64>, %arg2: tensor
// CHECK-LABEL: dynamic_result_type
func @dynamic_result_type(%arg0: tensor<2xf32>) -> tensor<*xf32> {
// CHECK: %[[RESULT:.*]] = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: %[[RESULT:.*]] = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: tensor_cast %0 : tensor<2xf32> to tensor<*xf32>
%0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<*xf32>
@ -166,7 +166,7 @@ func @dynamic_result_type(%arg0: tensor<2xf32>) -> tensor<*xf32> {
func @truncated_normal() -> tensor<2x2xf32> {
// CHECK-NOT: tf.TruncatedNormal
%0 = xla_hlo.constant dense<[2, 2]> : tensor<2xi32>
%0 = mhlo.constant dense<[2, 2]> : tensor<2xi32>
%1 = "tf.TruncatedNormal"(%0) {T = i32, device = "", dtype = f32, seed = 0 : i64, seed2 = 1950157571 : i64} : (tensor<2xi32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
}
@ -175,21 +175,21 @@ func @truncated_normal() -> tensor<2x2xf32> {
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xi32>, %[[ARG1:.*]]: tensor<2x2xi32>, %[[ARG2:.*]]: tensor<2xi32>
func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2: tensor<2xi32>) -> tensor<3x4xi32> {
// CHECK: %[[SLICE0:.*]] = "xla_hlo.slice"(%[[ARG2]])
// CHECK: %[[SLICE0:.*]] = "mhlo.slice"(%[[ARG2]])
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>
// CHECK-DAG-SAME: limit_indices = dense<1> : tensor<1xi64>
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>
// CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32>
// CHECK: %[[DIM0:.*]] = "xla_hlo.reshape"(%[[SLICE0]]) : (tensor<1xi32>) -> tensor<i32>
// CHECK: %[[DIM0:.*]] = "mhlo.reshape"(%[[SLICE0]]) : (tensor<1xi32>) -> tensor<i32>
// CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%[[ARG2]])
// CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%[[ARG2]])
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>
// CHECK-DAG-SAME: limit_indices = dense<2> : tensor<1xi64>
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>
// CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32>
// CHECK: %[[DIM1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<1xi32>) -> tensor<i32>
// CHECK: %[[DIM1:.*]] = "mhlo.reshape"(%[[SLICE1]]) : (tensor<1xi32>) -> tensor<i32>
// CHECK: "xla_hlo.dynamic-update-slice"(%[[ARG0]], %[[ARG1]], %[[DIM0]], %[[DIM1]])
// CHECK: "mhlo.dynamic-update-slice"(%[[ARG0]], %[[ARG1]], %[[DIM0]], %[[DIM1]])
%0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<3x4xi32>, tensor<2x2xi32>, tensor<2xi32>) -> tensor<3x4xi32>
return %0: tensor<3x4xi32>
@ -199,12 +199,12 @@ func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xi32>, %[[ARG1:.*]]: tensor<3xf32>, %[[ARG2:.*]]: tensor<f32>)
func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>) -> tensor<3x3xf32> {
// CHECK: %[[CST:.*]] = xla_hlo.constant dense<3> : tensor<2xi32>
// CHECK: %[[DEFAULT:.*]] = "xla_hlo.broadcast_in_dim"(%[[ARG2]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<3x3xf32>
// CHECK: %[[CST:.*]] = mhlo.constant dense<3> : tensor<2xi32>
// CHECK: %[[DEFAULT:.*]] = "mhlo.broadcast_in_dim"(%[[ARG2]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<3x3xf32>
// CHECK: %[[RESULT:.*]] = "xla_hlo.scatter"(%[[DEFAULT]], %[[ARG0]], %[[ARG1]]) ( {
// CHECK: %[[RESULT:.*]] = "mhlo.scatter"(%[[DEFAULT]], %[[ARG0]], %[[ARG1]]) ( {
// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): // no predecessors
// CHECK: "xla_hlo.return"(%[[ARG4]]) : (tensor<f32>) -> ()
// CHECK: "mhlo.return"(%[[ARG4]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: indices_are_sorted = false
// CHECK-SAME: scatter_dimension_numbers
@ -217,14 +217,14 @@ func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tenso
// return %[[RESULT]] : tensor<3x3xf32>
%cst = xla_hlo.constant dense<3> : tensor<2xi32>
%cst = mhlo.constant dense<3> : tensor<2xi32>
%0 = "tf.SparseToDense"(%arg0, %cst, %arg1, %arg2) {validate_indices = true}: (tensor<3x2xi32>, tensor<2xi32>, tensor<3xf32>, tensor<f32>) -> tensor<3x3xf32>
return %0 : tensor<3x3xf32>
}
// CHECK-LABEL: fft
func @fft(%arg0: tensor<3x5x8xcomplex<f32>>) -> tensor<3x5x8xcomplex<f32>> {
// CHECK: "xla_hlo.fft"(%arg0)
// CHECK: "mhlo.fft"(%arg0)
%0 = "tf.FFT"(%arg0) : (tensor<3x5x8xcomplex<f32>>) -> tensor<3x5x8xcomplex<f32>>
return %0 : tensor<3x5x8xcomplex<f32>>
}
@ -238,7 +238,7 @@ func @reverse_sequence(%arg0: tensor<4x2x3x1x1xi32>, %arg1: tensor<3xi32>) -> te
// CHECK-LABEL: mirror_pad
func @mirror_pad(%arg0: tensor<2x3xcomplex<f64>>) -> tensor<4x7xcomplex<f64>> {
%0 = xla_hlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>
%0 = mhlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>
// CHECK-NOT: tf.MirrorPad
%1 = "tf.MirrorPad"(%arg0, %0) {mode = "SYMMETRIC"} : (tensor<2x3xcomplex<f64>>, tensor<2x2xi32>) -> tensor<4x7xcomplex<f64>>
return %1 : tensor<4x7xcomplex<f64>>
@ -254,7 +254,7 @@ func @bucketize(%arg0: tensor<2x5xf32>) -> tensor<2x5xi32> {
// CHECK-LABEL: arg_min
func @arg_min(%arg0: tensor<6xf64>) -> tensor<i32> {
// CHECK-NOT: ArgMin
%0 = xla_hlo.constant dense<0> : tensor<i32>
%0 = mhlo.constant dense<0> : tensor<i32>
%1 = "tf.ArgMin"(%arg0, %0) : (tensor<6xf64>, tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}

File diff suppressed because it is too large Load Diff

View File

@ -48,7 +48,7 @@ class XlaBuilderTest : public ::testing::Test {
xla_builder_(name_, builder_, module_->getLoc()) {}
string SetupTest() {
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
mlir::registerDialect<mlir::mhlo::XlaHloDialect>();
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
}
@ -75,7 +75,7 @@ TEST_F(XlaBuilderTest, CreateToken) {
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(GetMlirOpString(token),
R"("xla_hlo.create_token"() : () -> !xla_hlo.token)");
R"("mhlo.create_token"() : () -> !mhlo.token)");
}
TEST_F(XlaBuilderTest, Infeed) {
@ -85,7 +85,7 @@ TEST_F(XlaBuilderTest, Infeed) {
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(infeed),
R"("xla_hlo.infeed"(%0) {infeed_config = ""} : (!xla_hlo.token) -> tuple<tensor<4x8xf32>, !xla_hlo.token>)");
R"("mhlo.infeed"(%0) {infeed_config = ""} : (!mhlo.token) -> tuple<tensor<4x8xf32>, !mhlo.token>)");
}
TEST_F(XlaBuilderTest, Outfeed) {
@ -99,7 +99,7 @@ TEST_F(XlaBuilderTest, Outfeed) {
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(outfeed),
R"("xla_hlo.outfeed"(%0, %1) {outfeed_config = ""} : (tensor<4x8xf32>, !xla_hlo.token) -> !xla_hlo.token)");
R"("mhlo.outfeed"(%0, %1) {outfeed_config = ""} : (tensor<4x8xf32>, !mhlo.token) -> !mhlo.token)");
}
TEST_F(XlaBuilderTest, ConcatInDim) {
@ -112,7 +112,7 @@ TEST_F(XlaBuilderTest, ConcatInDim) {
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(concat),
R"("xla_hlo.concatenate"(%0, %1) {dimension = 1 : i64} : (tensor<2x4x5xf32>, tensor<2x6x5xf32>) -> tensor<2x10x5xf32>)");
R"("mhlo.concatenate"(%0, %1) {dimension = 1 : i64} : (tensor<2x4x5xf32>, tensor<2x6x5xf32>) -> tensor<2x10x5xf32>)");
}
TEST_F(XlaBuilderTest, Tuple) {
@ -125,7 +125,7 @@ TEST_F(XlaBuilderTest, Tuple) {
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(tuple),
R"("xla_hlo.tuple"(%0, %1) : (tensor<3x7xf32>, tensor<f32>) -> tuple<tensor<3x7xf32>, tensor<f32>>)");
R"("mhlo.tuple"(%0, %1) : (tensor<3x7xf32>, tensor<f32>) -> tuple<tensor<3x7xf32>, tensor<f32>>)");
}
TEST_F(XlaBuilderTest, GetTupleElement) {
@ -139,7 +139,7 @@ TEST_F(XlaBuilderTest, GetTupleElement) {
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(gte),
R"("xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : (tuple<tensor<3x7xf32>, tensor<f32>>) -> tensor<f32>)");
R"("mhlo.get_tuple_element"(%2) {index = 1 : i32} : (tuple<tensor<3x7xf32>, tensor<f32>>) -> tensor<f32>)");
}
TEST_F(XlaBuilderTest, Slice) {
@ -150,7 +150,7 @@ TEST_F(XlaBuilderTest, Slice) {
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(slice),
R"("xla_hlo.slice"(%0) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x7xf32>) -> tensor<2x4xf32>)");
R"("mhlo.slice"(%0) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x7xf32>) -> tensor<2x4xf32>)");
}
TEST_F(XlaBuilderTest, Pad) {
@ -172,7 +172,7 @@ TEST_F(XlaBuilderTest, Pad) {
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(pad),
R"("xla_hlo.pad"(%0, %1) {edge_padding_high = dense<[2, 0]> : tensor<2xi64>, edge_padding_low = dense<[1, 3]> : tensor<2xi64>, interior_padding = dense<[0, 1]> : tensor<2xi64>} : (tensor<3x7xf32>, tensor<f32>) -> tensor<6x16xf32>)");
R"("mhlo.pad"(%0, %1) {edge_padding_high = dense<[2, 0]> : tensor<2xi64>, edge_padding_low = dense<[1, 3]> : tensor<2xi64>, interior_padding = dense<[0, 1]> : tensor<2xi64>} : (tensor<3x7xf32>, tensor<f32>) -> tensor<6x16xf32>)");
}
} // namespace

View File

@ -12,9 +12,9 @@ func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: ROOT %add.4 = f32[4] add(f32[4] %add.3, f32[4] %Arg_1.2)
%1 = "xla_hlo.add"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%1 = "mhlo.add"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %1 : tensor<4xf32>
}

View File

@ -5,18 +5,18 @@ func @main() -> tensor<f32> {
%cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor<f32>
%cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor<f32>
%cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor<f32>
%0 = "xla_hlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( {
%0 = "mhlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
%1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}, {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.copy"(%arg0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
%1 = "mhlo.copy"(%arg0) : (tensor<f32>) -> tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}, {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
%1 = "mhlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {name = "conditional"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
@ -52,18 +52,18 @@ func @main() -> (tensor<f32>, tensor<f32>) {
%cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor<f32>
%cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor<f32>
%cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor<f32>
%0:2 = "xla_hlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( {
%0:2 = "mhlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.negate"(%arg0) {name = "negate"} : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
%1 = "mhlo.negate"(%arg0) {name = "negate"} : (tensor<f32>) -> tensor<f32>
"mhlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
}, {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.copy"(%arg0) {name = "copy"} : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
%1 = "mhlo.copy"(%arg0) {name = "copy"} : (tensor<f32>) -> tensor<f32>
"mhlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
}, {
^bb0(%arg0: tensor<f32>):
%1 = "xla_hlo.floor"(%arg0) {name = "floor"} : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
%1 = "mhlo.floor"(%arg0) {name = "floor"} : (tensor<f32>) -> tensor<f32>
"mhlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
}) {name = "conditional"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
return %0#0, %0#1 : tensor<f32>, tensor<f32>
}

View File

@ -30,17 +30,17 @@ ENTRY %indexed_conditional () -> f32[] {
// CHECK: %[[OPERAND_1:.*]] = constant {name = "{{.*}}"} dense<5.600000e+01> : tensor<f32>
// CHECK: %[[OPERAND_2:.*]] = constant {name = "{{.*}}"} dense<1.200000e+01> : tensor<f32>
// CHECK: %[[OPERAND_3:.*]] = constant {name = "{{.*}}"} dense<1.300000e+01> : tensor<f32>
// CHECK: %[[RESULT:.*]] = "xla_hlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( {
// CHECK: %[[RESULT:.*]] = "mhlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( {
// CHECK: ^bb0(%[[ARG_1:.*]]: tensor<f32>):
// CHECK: %[[RES_1:.*]] = "xla_hlo.negate"(%[[ARG_1]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
// CHECK: "xla_hlo.return"(%[[RES_1]]) : (tensor<f32>) -> ()
// CHECK: %[[RES_1:.*]] = "mhlo.negate"(%[[ARG_1]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
// CHECK: "mhlo.return"(%[[RES_1]]) : (tensor<f32>) -> ()
// CHECK: }, {
// CHECK: ^bb0(%[[ARG_2:.*]]: tensor<f32>):
// CHECK: %[[RES_2:.*]] = "xla_hlo.copy"(%[[ARG_2]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
// CHECK: "xla_hlo.return"(%[[RES_2]]) : (tensor<f32>) -> ()
// CHECK: %[[RES_2:.*]] = "mhlo.copy"(%[[ARG_2]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
// CHECK: "mhlo.return"(%[[RES_2]]) : (tensor<f32>) -> ()
// CHECK: }, {
// CHECK: ^bb0(%[[ARG_3:.*]]: tensor<f32>):
// CHECK: %[[RES_3:.*]] = "xla_hlo.floor"(%[[ARG_3]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
// CHECK: "xla_hlo.return"(%[[RES_3]]) : (tensor<f32>) -> ()
// CHECK: %[[RES_3:.*]] = "mhlo.floor"(%[[ARG_3]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
// CHECK: "mhlo.return"(%[[RES_3]]) : (tensor<f32>) -> ()
// CHECK: }) {name = "{{.*}}"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: return %[[RESULT]] : tensor<f32>

View File

@ -19,7 +19,7 @@ func @main(%arg0: tensor<10xf32>, %arg1: tensor<i32>) {
// Test entry function with single dynamic parameter binding on an argument.
func @main(%arg0: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [0 : i32], padding_arg_indices = [1 : i32]}}, %arg1: tensor<i32>) {
func @main(%arg0: tensor<10xf32> {mhlo.padding_map = {shape_indices = [0 : i32], padding_arg_indices = [1 : i32]}}, %arg1: tensor<i32>) {
return
}
@ -42,7 +42,7 @@ func @main(%arg0: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [0 : i3
// Test entry function with multiple dynamic parameter bindings on an argument.
func @main(%arg0: tensor<8x10xf32> {xla_hlo.padding_map = {shape_indices = [0 : i32, 1 : i32], padding_arg_indices = [1 : i32, 2 : i32]}}, %arg1: tensor<i32>, %arg2: tensor<i32>) {
func @main(%arg0: tensor<8x10xf32> {mhlo.padding_map = {shape_indices = [0 : i32, 1 : i32], padding_arg_indices = [1 : i32, 2 : i32]}}, %arg1: tensor<i32>, %arg2: tensor<i32>) {
return
}
@ -75,7 +75,7 @@ func @main(%arg0: tensor<8x10xf32> {xla_hlo.padding_map = {shape_indices = [0 :
// Test entry function with multiple dynamic parameter bindings on multiple
// arguments.
func @main(%arg0: tensor<8x10xf32> {xla_hlo.padding_map = {shape_indices = [0 : i32, 1 : i32], padding_arg_indices = [1 : i32, 2 : i32]}}, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<10x8x6xi32> {xla_hlo.padding_map = {shape_indices = [2 : i32], padding_arg_indices = [4 : i32]}}, %arg4: tensor<i32>) {
func @main(%arg0: tensor<8x10xf32> {mhlo.padding_map = {shape_indices = [0 : i32, 1 : i32], padding_arg_indices = [1 : i32, 2 : i32]}}, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<10x8x6xi32> {mhlo.padding_map = {shape_indices = [2 : i32], padding_arg_indices = [4 : i32]}}, %arg4: tensor<i32>) {
return
}

View File

@ -1,148 +1,148 @@
// RUN: not tf-mlir-translate -split-input-file -mlir-hlo-to-hlo %s -o - 2>&1 | FileCheck %s
// Test bad `xla_hlo.padding_map` attribute type.
// Test bad `mhlo.padding_map` attribute type.
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = ""}) {
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = ""}) {
return
}
// CHECK: requires 'xla_hlo.padding_map' dict attribute at arg 1
// CHECK: requires 'mhlo.padding_map' dict attribute at arg 1
// -----
// Test missing `shape_indices` attribute in `xla_hlo.padding_map`.
// Test missing `shape_indices` attribute in `mhlo.padding_map`.
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {}}) {
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {}}) {
return
}
// CHECK: requires 'shape_indices' array attribute in 'xla_hlo.padding_map' dict at arg 1
// CHECK: requires 'shape_indices' array attribute in 'mhlo.padding_map' dict at arg 1
// -----
// Test bad `shape_indices` attribute type in `xla_hlo.padding_map`.
// Test bad `shape_indices` attribute type in `mhlo.padding_map`.
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = ""}}) {
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = ""}}) {
return
}
// CHECK: requires 'shape_indices' array attribute in 'xla_hlo.padding_map' dict at arg 1
// CHECK: requires 'shape_indices' array attribute in 'mhlo.padding_map' dict at arg 1
// -----
// Test missing `padding_arg_indices` attribute in `xla_hlo.padding_map`.
// Test missing `padding_arg_indices` attribute in `mhlo.padding_map`.
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = []}}) {
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = []}}) {
return
}
// CHECK: requires 'padding_arg_indices' array attribute in 'xla_hlo.padding_map' dict at arg 1
// CHECK: requires 'padding_arg_indices' array attribute in 'mhlo.padding_map' dict at arg 1
// -----
// Test bad `padding_arg_indices` attribute type in `xla_hlo.padding_map`.
// Test bad `padding_arg_indices` attribute type in `mhlo.padding_map`.
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [], padding_arg_indices = ""}}) {
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [], padding_arg_indices = ""}}) {
return
}
// CHECK: requires 'padding_arg_indices' array attribute in 'xla_hlo.padding_map' dict at arg 1
// CHECK: requires 'padding_arg_indices' array attribute in 'mhlo.padding_map' dict at arg 1
// -----
// Test mismatched `shape_indices` and `padding_arg_indices` lengths.
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32, 0 : i32 ]}}) {
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32, 0 : i32 ]}}) {
return
}
// CHECK: requires 'shape_indices' and 'padding_arg_indices' array attributes in 'xla_hlo.padding_map' dic at arg 1 to be of the same size, got sizes 1 and 2
// CHECK: requires 'shape_indices' and 'padding_arg_indices' array attributes in 'mhlo.padding_map' dic at arg 1 to be of the same size, got sizes 1 and 2
// -----
// Test non integer attribute in `shape_indices`.
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32, 0.0: f32 ], padding_arg_indices = [ 0: i32, 0: i32 ]}}) {
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32, 0.0: f32 ], padding_arg_indices = [ 0: i32, 0: i32 ]}}) {
return
}
// CHECK: requires element 1 in 'shape_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be an int attribute
// CHECK: requires element 1 in 'shape_indices' array of 'mhlo.padding_map' dict at arg 1 to be an int attribute
// -----
// Test non integer attribute in `padding_arg_indices`.
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32, 0: i32 ], padding_arg_indices = [ 0: i32, 0.0: f32 ]}}) {
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32, 0: i32 ], padding_arg_indices = [ 0: i32, 0.0: f32 ]}}) {
return
}
// CHECK: requires element 1 in 'padding_arg_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be an int attribute
// CHECK: requires element 1 in 'padding_arg_indices' array of 'mhlo.padding_map' dict at arg 1 to be an int attribute
// -----
// Test negative out of range shape index in `shape_indices`.
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ -1: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ -1: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
return
}
// CHECK: requires element 0 in 'shape_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be in range [0, 1), got -1
// CHECK: requires element 0 in 'shape_indices' array of 'mhlo.padding_map' dict at arg 1 to be in range [0, 1), got -1
// -----
// Test positive out of range shape index in `shape_indices`.
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 1: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 1: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
return
}
// CHECK: requires element 0 in 'shape_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be in range [0, 1), got 1
// CHECK: requires element 0 in 'shape_indices' array of 'mhlo.padding_map' dict at arg 1 to be in range [0, 1), got 1
// -----
// Test negative shape index in `shape_indices` for unranked argument.
func @main(%arg0: tensor<i32>, %arg1: tensor<*xf32> {xla_hlo.padding_map = {shape_indices = [ -1: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
func @main(%arg0: tensor<i32>, %arg1: tensor<*xf32> {mhlo.padding_map = {shape_indices = [ -1: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
return
}
// CHECK: requires element 0 in 'shape_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be non-negative, got -1
// CHECK: requires element 0 in 'shape_indices' array of 'mhlo.padding_map' dict at arg 1 to be non-negative, got -1
// -----
// Test duplicate shape indices in `shape_indices`.
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32, 0: i32 ], padding_arg_indices = [ 0: i32, 0: i32 ]}}) {
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32, 0: i32 ], padding_arg_indices = [ 0: i32, 0: i32 ]}}) {
return
}
// CHECK: requires elements in 'shape_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be unique, got duplicate element 0 at index 1
// CHECK: requires elements in 'shape_indices' array of 'mhlo.padding_map' dict at arg 1 to be unique, got duplicate element 0 at index 1
// -----
// Test negative out of range shape index in `padding_arg_indices`.
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ -1: i32 ]}}) {
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ -1: i32 ]}}) {
return
}
// CHECK: requires element 0 in 'padding_arg_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be in range [0, 2), got -1
// CHECK: requires element 0 in 'padding_arg_indices' array of 'mhlo.padding_map' dict at arg 1 to be in range [0, 2), got -1
// -----
// Test positive out of range shape index in `padding_arg_indices`.
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 2: i32 ]}}) {
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 2: i32 ]}}) {
return
}
// CHECK: requires element 0 in 'padding_arg_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be in range [0, 2), got 2
// CHECK: requires element 0 in 'padding_arg_indices' array of 'mhlo.padding_map' dict at arg 1 to be in range [0, 2), got 2
// -----
// Test non scalar padding argument.
func @main(%arg0: tensor<8xi32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
func @main(%arg0: tensor<8xi32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
return
}
@ -152,7 +152,7 @@ func @main(%arg0: tensor<8xi32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {s
// Test non integer type padding argument.
func @main(%arg0: tensor<f32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
func @main(%arg0: tensor<f32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
return
}

View File

@ -1,9 +1,9 @@
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
// CHECK: HloModule
func @main(%arg0: !xla_hlo.token, %arg1: !xla_hlo.token) -> !xla_hlo.token {
%0 = "xla_hlo.after_all"(%arg0, %arg1) : (!xla_hlo.token, !xla_hlo.token) -> !xla_hlo.token
return %0 : !xla_hlo.token
func @main(%arg0: !mhlo.token, %arg1: !mhlo.token) -> !mhlo.token {
%0 = "mhlo.after_all"(%arg0, %arg1) : (!mhlo.token, !mhlo.token) -> !mhlo.token
return %0 : !mhlo.token
}
// CHECK: ENTRY
@ -15,11 +15,11 @@ func @main(%arg0: !xla_hlo.token, %arg1: !xla_hlo.token) -> !xla_hlo.token {
// CHECK: HloModule
func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
%0 = "xla_hlo.all_reduce"(%arg0) ({
%0 = "mhlo.all_reduce"(%arg0) ({
// Perform max reduction inside the region
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
"mhlo.return"(%max) : (tensor<f32>) -> ()
})
{
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
@ -43,7 +43,7 @@ func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
// CHECK: HloModule
func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>> {
%0 = "xla_hlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
%0 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
return %0 : tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
}
@ -60,7 +60,7 @@ func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf
// CHECK: HloModule
func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>> {
%0 = "xla_hlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 3 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
%0 = "mhlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 3 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
return %0 : tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
}
@ -78,16 +78,16 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor
// CHECK: [[VAL_1:%.*]] = s32[4] parameter(0)
// CHECK: [[VAL_2:%.*]] = s32[4] parameter(1)
// CHECK: [[ATAN2:%.*]] = s32[4] atan2(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
%0 = xla_hlo.atan2 %arg0, %arg1 : tensor<4xi32>
%0 = mhlo.atan2 %arg0, %arg1 : tensor<4xi32>
// CHECK: [[SHL:%.*]] = s32[4] shift-left(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
%1 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
%1 = mhlo.shift_left %arg0, %arg1 : tensor<4xi32>
// CHECK: [[SHRA:%.*]] = s32[4] shift-right-arithmetic(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
%2 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
%2 = mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
// CHECK: [[SHRL:%.*]] = s32[4] shift-right-logical(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
%3 = xla_hlo.shift_right_logical %arg0, %arg1 : tensor<4xi32>
%3 = mhlo.shift_right_logical %arg0, %arg1 : tensor<4xi32>
// CHECK: ROOT
// CHECK-SAME: [[VAL_7:%.*]] = (s32[4], s32[4], s32[4], s32[4]) tuple(s32[4] [[ATAN2]], s32[4] [[SHL]], s32[4] [[SHRA]], s32[4] [[SHRL]])
@ -98,7 +98,7 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor
// CHECK: HloModule
func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
%0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
@ -112,7 +112,7 @@ func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> {
// CHECK: [[ARG:%.*]] = s32[4] parameter(0)
// CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] [[ARG]]), dimensions={3}
%0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
%0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
return %0 : tensor<1x2x3x4xi32>
}
@ -120,7 +120,7 @@ func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> {
// CHECK: HloModule
func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
%result = "xla_hlo.broadcast_in_dim"(%arg0) {
%result = "mhlo.broadcast_in_dim"(%arg0) {
broadcast_dimensions = dense<0> : tensor<1xi64>
} : (tensor<1xf32>) -> tensor<1x10xf32>
return %result : tensor<1x10xf32>
@ -133,9 +133,9 @@ func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
// -----
// CHECK: HloModule
func @main() -> !xla_hlo.token {
%0 = "xla_hlo.create_token"() : () -> !xla_hlo.token
return %0 : !xla_hlo.token
func @main() -> !mhlo.token {
%0 = "mhlo.create_token"() : () -> !mhlo.token
return %0 : !mhlo.token
}
// CHECK: ROOT [[TOKEN:%.*]] = token[] after-all()
@ -150,7 +150,7 @@ func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> {
}
func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
@ -181,8 +181,8 @@ func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
}
func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%1 = "xla_hlo.multiply"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
%1 = "mhlo.multiply"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
return %0, %1 : tensor<4xi32>, tensor<4xi32>
}
@ -202,7 +202,7 @@ func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tens
// CHECK: HloModule
func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> {
%0 = "xla_hlo.collective_permute"(%arg0) {
%0 = "mhlo.collective_permute"(%arg0) {
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
} : (tensor<128x32xf32>) -> tensor<128x32xf32>
return %0 : tensor<128x32xf32>
@ -217,7 +217,7 @@ func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> {
func @main(%arg0 : tensor<5x2xf32>,
%arg1 : tensor<5x5xf32>,
%arg2 : tensor<5x7xf32>) -> tensor<5x14xf32> {
%result = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) {
%result = "mhlo.concatenate"(%arg0, %arg1, %arg2) {
dimension = 1 : i64
} : (tensor<5x2xf32>, tensor<5x5xf32>, tensor<5x7xf32>) -> tensor<5x14xf32>
return %result : tensor<5x14xf32>
@ -279,7 +279,7 @@ func @main() {
// CHECK: HloModule
func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> {
%result = "xla_hlo.convolution"(%arg0, %arg1) {
%result = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
@ -312,7 +312,7 @@ func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> te
// CHECK: HloModule
func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
%0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
%0 = "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
@ -324,7 +324,7 @@ func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
// CHECK: HloModule
func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = "xla_hlo.copy"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
%0 = "mhlo.copy"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}
@ -336,8 +336,8 @@ func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
// CHECK: HloModule
func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
%0 = xla_hlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
%1 = "xla_hlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32>
%0 = mhlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
%1 = "mhlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32>
return %1 : tensor<10xf32>
}
@ -354,7 +354,7 @@ func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
// CHECK: HloModule
func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3xf32> {
%0 = "xla_hlo.custom_call"(%arg0, %arg1) {backend_config = "bar", call_target_name = "foo"} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
%0 = "mhlo.custom_call"(%arg0, %arg1) {backend_config = "bar", call_target_name = "foo"} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
return %0 : tensor<1x2x3xf32>
}
@ -369,7 +369,7 @@ func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3xf32>
// CHECK: HloModule
func @main(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> {
%0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x64xbf16>
%0 = "mhlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x64xbf16>
return %0 : tensor<16x64xbf16>
}
@ -388,7 +388,7 @@ func @main(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> {
// CHECK: HloModule
func @main(%arg: tensor<16x16xi32>) -> tensor<16x32xbf16> {
%0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false, is_16bits = true} : (tensor<16x16xi32>) -> tensor<16x32xbf16>
%0 = "mhlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false, is_16bits = true} : (tensor<16x16xi32>) -> tensor<16x32xbf16>
return %0 : tensor<16x32xbf16>
}
@ -408,7 +408,7 @@ func @main(%arg: tensor<16x16xi32>) -> tensor<16x32xbf16> {
func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> {
// Simple einsum is lowered to HLO dot op.
// CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0}
%0 = "xla_hlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32>
%0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32>
return %0 : tensor<3x5xi32>
}
@ -416,7 +416,7 @@ func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> {
// CHECK: HloModule
func @main(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex<f32>> {
%0 = "xla_hlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<3x9xf32>) -> tensor<3x5xcomplex<f32>>
%0 = "mhlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<3x9xf32>) -> tensor<3x5xcomplex<f32>>
return %0 : tensor<3x5xcomplex<f32>>
}
@ -437,7 +437,7 @@ func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10
// CHECK-SAME: index_vector_dim=1
// CHECK-SAME: slice_sizes={1,1,300}
// CHECK-SAME: indices_are_sorted=true
%0 = "xla_hlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = true, name = "gather", slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>} : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32>
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = true, name = "gather", slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>} : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32>
return %0 : tensor<10x300xf32>
}
@ -445,8 +445,8 @@ func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10
// CHECK: HloModule
func @main(%arg: tensor<4x2xf32>, %size: tensor<i32>) -> tensor<i32> {
%0 = "xla_hlo.set_dimension_size"(%arg, %size) {dimension = 1 : i32} : (tensor<4x2xf32>, tensor<i32>) -> tensor<4x2xf32>
%1 = "xla_hlo.get_dimension_size"(%0) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
%0 = "mhlo.set_dimension_size"(%arg, %size) {dimension = 1 : i32} : (tensor<4x2xf32>, tensor<i32>) -> tensor<4x2xf32>
%1 = "mhlo.get_dimension_size"(%0) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
return %1 : tensor<i32>
}
@ -461,7 +461,7 @@ func @main(%arg: tensor<4x2xf32>, %size: tensor<i32>) -> tensor<i32> {
// CHECK: HloModule
func @main(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
%0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
return %0 : tensor<f32>
}
@ -472,9 +472,9 @@ func @main(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
// -----
// CHECK: HloModule
func @main(%arg0: !xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token> {
%0 = "xla_hlo.infeed"(%arg0) {infeed_config = "foobar"} : (!xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token>
return %0 : tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token>
func @main(%arg0: !mhlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !mhlo.token> {
%0 = "mhlo.infeed"(%arg0) {infeed_config = "foobar"} : (!mhlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !mhlo.token>
return %0 : tuple<tuple<tensor<3xi32>, tensor<i1>>, !mhlo.token>
}
// CHECK: ENTRY
@ -485,7 +485,7 @@ func @main(%arg0: !xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !xl
// CHECK: HloModule
func @main() -> tensor<1x10xf32> {
%result = "xla_hlo.iota"() {
%result = "mhlo.iota"() {
iota_dimension = 1 : i64
} : () -> tensor<1x10xf32>
return %result : tensor<1x10xf32>
@ -498,10 +498,10 @@ func @main() -> tensor<1x10xf32> {
// CHECK: HloModule
func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%0 = "xla_hlo.map"(%arg0, %arg1) ( {
%0 = "mhlo.map"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
%1 = xla_hlo.add %arg2, %arg3 {name = "add"} : tensor<f32>
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
%1 = mhlo.add %arg2, %arg3 {name = "add"} : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -522,9 +522,9 @@ func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// -----
// CHECK: HloModule
func @main(%data: tensor<3xi32>, %token: !xla_hlo.token) -> !xla_hlo.token {
%0 = "xla_hlo.outfeed"(%data, %token) {outfeed_config = "foobar"} : (tensor<3xi32>, !xla_hlo.token) -> !xla_hlo.token
return %0 : !xla_hlo.token
func @main(%data: tensor<3xi32>, %token: !mhlo.token) -> !mhlo.token {
%0 = "mhlo.outfeed"(%data, %token) {outfeed_config = "foobar"} : (tensor<3xi32>, !mhlo.token) -> !mhlo.token
return %0 : !mhlo.token
}
// CHECK: ENTRY
@ -536,7 +536,7 @@ func @main(%data: tensor<3xi32>, %token: !xla_hlo.token) -> !xla_hlo.token {
// CHECK: HloModule
func @main(%arg: tensor<4x6xf32>, %pad: tensor<f32>) -> tensor<13x19xf32> {
%0 = "xla_hlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<13x19xf32>
%0 = "mhlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<13x19xf32>
return %0 : tensor<13x19xf32>
}
@ -549,15 +549,15 @@ func @main(%arg: tensor<4x6xf32>, %pad: tensor<f32>) -> tensor<13x19xf32> {
// -----
// CHECK: HloModule
func @main(%token: !xla_hlo.token) -> tuple<tensor<3x4xi32>, !xla_hlo.token> {
%0 = "xla_hlo.recv"(%token) {
func @main(%token: !mhlo.token) -> tuple<tensor<3x4xi32>, !mhlo.token> {
%0 = "mhlo.recv"(%token) {
channel_id = {
handle = 5 : i64,
type = 3 : i64 // Host to device channel
},
is_host_transfer = true
} : (!xla_hlo.token) -> tuple<tensor<3x4xi32>, !xla_hlo.token>
return %0 : tuple<tensor<3x4xi32>, !xla_hlo.token>
} : (!mhlo.token) -> tuple<tensor<3x4xi32>, !mhlo.token>
return %0 : tuple<tensor<3x4xi32>, !mhlo.token>
}
// CHECK: ENTRY
@ -569,15 +569,15 @@ func @main(%token: !xla_hlo.token) -> tuple<tensor<3x4xi32>, !xla_hlo.token> {
// -----
// CHECK: HloModule
func @main(%token: !xla_hlo.token) -> tuple<tensor<3x4xi32>, !xla_hlo.token> {
%0 = "xla_hlo.recv"(%token) {
func @main(%token: !mhlo.token) -> tuple<tensor<3x4xi32>, !mhlo.token> {
%0 = "mhlo.recv"(%token) {
channel_id = {
handle = 5 : i64,
type = 1 : i64 // Device to device channel
},
is_host_transfer = false
} : (!xla_hlo.token) -> tuple<tensor<3x4xi32>, !xla_hlo.token>
return %0 : tuple<tensor<3x4xi32>, !xla_hlo.token>
} : (!mhlo.token) -> tuple<tensor<3x4xi32>, !mhlo.token>
return %0 : tuple<tensor<3x4xi32>, !mhlo.token>
}
// CHECK: ENTRY
@ -591,11 +591,11 @@ func @main(%token: !xla_hlo.token) -> tuple<tensor<3x4xi32>, !xla_hlo.token> {
// CHECK: HloModule
func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f32>, %arg3 : tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) {
%result0, %result1 = "xla_hlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
%result0, %result1 = "mhlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
^bb0(%fa: tensor<f32>, %ia : tensor<i32>, %fb: tensor<f32>, %ib: tensor<i32>): // no predecessors
%fmax = "xla_hlo.maximum"(%fa, %fb) {} : (tensor<f32>, tensor<f32>) -> tensor<f32>
%imax = "xla_hlo.maximum"(%ia, %ib) {} : (tensor<i32>, tensor<i32>) -> tensor<i32>
"xla_hlo.return"(%fmax, %imax) : (tensor<f32>, tensor<i32>) -> ()
%fmax = "mhlo.maximum"(%fa, %fb) {} : (tensor<f32>, tensor<f32>) -> tensor<f32>
%imax = "mhlo.maximum"(%ia, %ib) {} : (tensor<i32>, tensor<i32>) -> tensor<i32>
"mhlo.return"(%fmax, %imax) : (tensor<f32>, tensor<i32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor<f32>, tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>)
return %result0, %result1 : tensor<1xf32>, tensor<1xi32>
}
@ -617,11 +617,11 @@ func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f3
// CHECK: HloModule
func @main(%arg0: tensor<2x17x31x7xi32>) -> tensor<2x3x5x7xi32> {
%0 = xla_hlo.constant dense<-2147483648> : tensor<i32>
%1 = "xla_hlo.reduce_window"(%arg0, %0) ( {
%0 = mhlo.constant dense<-2147483648> : tensor<i32>
%1 = "mhlo.reduce_window"(%arg0, %0) ( {
^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>): // no predecessors
%2 = xla_hlo.maximum %arg1, %arg2 : tensor<i32>
"xla_hlo.return"(%2) : (tensor<i32>) -> ()
%2 = mhlo.maximum %arg1, %arg2 : tensor<i32>
"mhlo.return"(%2) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>,
@ -646,7 +646,7 @@ func @main(%arg0: tensor<2x17x31x7xi32>) -> tensor<2x3x5x7xi32> {
// CHECK: HloModule
func @main(%arg0: tensor<2xf32>) -> tensor<1x2xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x2xf32>
%0 = "mhlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x2xf32>
return %0 : tensor<1x2xf32>
}
@ -658,7 +658,7 @@ func @main(%arg0: tensor<2xf32>) -> tensor<1x2xf32> {
// CHECK: HloModule
func @main(%arg0 : tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
%result = "xla_hlo.reverse"(%arg0) {
%result = "mhlo.reverse"(%arg0) {
dimensions = dense<[1,2]> : tensor<2xi64>
} : (tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32>
return %result : tensor<10x11x12x13xf32>
@ -672,8 +672,8 @@ func @main(%arg0 : tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
// CHECK: HloModule
func @main(%mu: tensor<f32>, %sigma: tensor<f32>) -> tensor<2x3x5xf32> {
%shape = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64>
%0 = "xla_hlo.rng_normal"(%mu, %sigma, %shape) : (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
%shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64>
%0 = "mhlo.rng_normal"(%mu, %sigma, %shape) : (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
return %0 : tensor<2x3x5xf32>
}
@ -686,10 +686,10 @@ func @main(%mu: tensor<f32>, %sigma: tensor<f32>) -> tensor<2x3x5xf32> {
// CHECK: HloModule
func @main() -> tensor<2x3x5xf32> {
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%1 = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
%2 = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64>
%3 = "xla_hlo.rng_uniform"(%0, %1, %2) : (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%1 = mhlo.constant dense<1.000000e+00> : tensor<f32>
%2 = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64>
%3 = "mhlo.rng_uniform"(%0, %1, %2) : (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
return %3 : tensor<2x3x5xf32>
}
@ -702,10 +702,10 @@ func @main() -> tensor<2x3x5xf32> {
// CHECK: HloModule
func @main(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xf32> {
%0 = "xla_hlo.scatter" (%input_tensor, %scatter_indices, %updates) ({
%0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
%add = xla_hlo.add %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%add) : (tensor<f32>) -> ()
%add = mhlo.add %lhs, %rhs : tensor<f32>
"mhlo.return"(%add) : (tensor<f32>) -> ()
}) {
scatter_dimension_numbers = {
update_window_dims = dense<[1]> : tensor<1xi64>,
@ -737,7 +737,7 @@ func @main(%arg0: tensor<i1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) ->
// CHECK: %[[ARG2:.*]] = s32[2,3] parameter(2)
// CHECK: ROOT %[[RES:.*]] = s32[2,3] select(pred[2,3] %[[COND]], s32[2,3] %[[ARG1]], s32[2,3] %[[ARG2]])
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<i1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%0 = "mhlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<i1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
return %0 : tensor<2x3xi32>
}
@ -745,15 +745,15 @@ func @main(%arg0: tensor<i1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) ->
// CHECK: HloModule
func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%1 = "xla_hlo.select_and_scatter"(%arg0, %arg1, %0) ( {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ( {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
%2 = "xla_hlo.compare"(%arg3, %arg4) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%2) : (tensor<i1>) -> ()
%2 = "mhlo.compare"(%arg3, %arg4) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
%2 = xla_hlo.add %arg3, %arg4 : tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
@ -780,15 +780,15 @@ func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> te
// -----
// CHECK: HloModule
func @main(%arg: tensor<3x4xi32>, %token: !xla_hlo.token) -> !xla_hlo.token {
%0 = "xla_hlo.send"(%arg, %token) {
func @main(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token {
%0 = "mhlo.send"(%arg, %token) {
channel_id = {
handle = 5 : i64,
type = 2 : i64 // Device to host channel
},
is_host_transfer = true
} : (tensor<3x4xi32>, !xla_hlo.token) -> !xla_hlo.token
return %0 : !xla_hlo.token
} : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.token
return %0 : !mhlo.token
}
// CHECK: ENTRY
@ -801,15 +801,15 @@ func @main(%arg: tensor<3x4xi32>, %token: !xla_hlo.token) -> !xla_hlo.token {
// -----
// CHECK: HloModule
func @main(%arg: tensor<3x4xi32>, %token: !xla_hlo.token) -> !xla_hlo.token {
%0 = "xla_hlo.send"(%arg, %token) {
func @main(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token {
%0 = "mhlo.send"(%arg, %token) {
channel_id = {
handle = 5 : i64,
type = 1 : i64 // Device to device channel
},
is_host_transfer = false
} : (tensor<3x4xi32>, !xla_hlo.token) -> !xla_hlo.token
return %0 : !xla_hlo.token
} : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.token
return %0 : !mhlo.token
}
// CHECK: ENTRY
@ -823,7 +823,7 @@ func @main(%arg: tensor<3x4xi32>, %token: !xla_hlo.token) -> !xla_hlo.token {
// CHECK: HloModule
func @main(%arg: tensor<4x4xf32>, %size: tensor<i32>) -> tensor<4x4xf32> {
%0 = "xla_hlo.set_dimension_size"(%arg, %size) {dimension = 1 : i32} : (tensor<4x4xf32>, tensor<i32>) -> tensor<4x4xf32>
%0 = "mhlo.set_dimension_size"(%arg, %size) {dimension = 1 : i32} : (tensor<4x4xf32>, tensor<i32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
@ -837,7 +837,7 @@ func @main(%arg: tensor<4x4xf32>, %size: tensor<i32>) -> tensor<4x4xf32> {
// CHECK: HloModule
func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> {
%0 = "xla_hlo.slice"(%arg) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
%0 = "mhlo.slice"(%arg) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
@ -850,7 +850,7 @@ func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> {
// CHECK: HloModule
func @main(%arg: tensor<3x4xi32>, %start1: tensor<i64>, %start2: tensor<i64>) -> tensor<1x4xi32> {
%0 = "xla_hlo.dynamic-slice"(%arg, %start1, %start2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
%0 = "mhlo.dynamic-slice"(%arg, %start1, %start2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32>
}
@ -865,7 +865,7 @@ func @main(%arg: tensor<3x4xi32>, %start1: tensor<i64>, %start2: tensor<i64>) ->
// CHECK: HloModule
func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
"xla_hlo.trace"(%arg0) {tag = "This is a random test"} : (tensor<2xi32>) -> ()
"mhlo.trace"(%arg0) {tag = "This is a random test"} : (tensor<2xi32>) -> ()
return %arg0: tensor<2xi32>
}
@ -880,7 +880,7 @@ func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
// CHECK: [[ARG:%.*]] = s32[1,2,3,4] parameter(0)
// CHECK-NEXT: ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] [[ARG]]), dimensions={1,0,3,2}
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
%0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
return %0 : tensor<2x1x4x3xi32>
}
@ -888,7 +888,7 @@ func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
// CHECK: HloModule
func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> {
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32>
%0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32>
return %0 : tensor<4x3xf32>
}
@ -901,7 +901,7 @@ func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> {
// CHECK: HloModule
func @main(%arg0: tensor<f32>, %arg1 : tensor<i32>) -> tuple<tensor<f32>, tensor<i32>> {
%result = "xla_hlo.tuple"(%arg0, %arg1) {} : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>>
%result = "mhlo.tuple"(%arg0, %arg1) {} : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>>
return %result : tuple<tensor<f32>, tensor<i32>>
}
@ -916,17 +916,17 @@ func @main(%arg0: tensor<f32>, %arg1 : tensor<i32>) -> tuple<tensor<f32>, tensor
func @main(%arg_f32: tensor<4xf32>, %arg_i32: tensor<4xi32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>) {
// CHECK: [[ARG_F32:%.*]] = f32[4] parameter(0)
// CHECK: [[EXPM1:%.*]] = f32[4] exponential-minus-one(f32[4] [[ARG_F32]])
%expm1 = "xla_hlo.exponential_minus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
%expm1 = "mhlo.exponential_minus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: [[LOG1P:%.*]] = f32[4] log-plus-one(f32[4] [[ARG_F32]])
%log1p = "xla_hlo.log_plus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
%log1p = "mhlo.log_plus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
// CHECK: [[ARG_I32:%.*]] = s32[4] parameter(1)
// CHECK: [[NOT:%.*]] = s32[4] not(s32[4] [[ARG_I32]])
%not = "xla_hlo.not"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
%not = "mhlo.not"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
// CHECK: [[POPCNT:%.*]] = s32[4] popcnt(s32[4] [[ARG_I32]])
%popcnt = "xla_hlo.popcnt"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
%popcnt = "mhlo.popcnt"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
return %expm1, %log1p, %not, %popcnt : tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>
}
@ -937,7 +937,7 @@ func @main(%arg_f32: tensor<4xf32>, %arg_i32: tensor<4xi32>) -> (tensor<4xf32>,
func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: [[VAL_1:%.*]] = pred[4] parameter(0)
// CHECK: [[VAL_2:%.*]] = pred[4] parameter(1)
%0 = xla_hlo.xor %arg0, %arg1 : tensor<4xi1>
%0 = mhlo.xor %arg0, %arg1 : tensor<4xi1>
// CHECK: ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]])
return %0 : tensor<4xi1>
}
@ -946,10 +946,10 @@ func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: HloModule
func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
%0 = "xla_hlo.sort"(%input0, %input1) ( {
%0 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
return
}
@ -975,7 +975,7 @@ func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// CHECK: HloModule
func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> {
%0 = "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"} : (tensor<16x16xf32>) -> tensor<16x16xf32>
%0 = "mhlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"} : (tensor<16x16xf32>) -> tensor<16x16xf32>
return %0 : tensor<16x16xf32>
}
@ -988,8 +988,8 @@ func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> {
// Tests that the exported HLO module keeps parameter replication annotation.
// CHECK: HloModule
func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {xla_hlo.is_same_data_across_replicas}) -> tensor<16x16xf32> {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32>
func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {mhlo.is_same_data_across_replicas}) -> tensor<16x16xf32> {
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32>
return %0 : tensor<16x16xf32>
}
@ -1003,8 +1003,8 @@ func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {xla_hlo.is_same_d
// CHECK: HloModule
func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f64>>) -> (tensor<2xf32>, tensor<2xf64>) {
%0 = "xla_hlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%1 = "xla_hlo.abs"(%arg1) : (tensor<2xcomplex<f64>>) -> (tensor<2xf64>)
%0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%1 = "mhlo.abs"(%arg1) : (tensor<2xcomplex<f64>>) -> (tensor<2xf64>)
return %0, %1 : tensor<2xf32>, tensor<2xf64>
}
@ -1019,7 +1019,7 @@ func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f64>>) -> (ten
// CHECK: HloModule
func @main(%arg0: tensor<4xui8>) -> tensor<4xui8> {
%0 = "xla_hlo.not"(%arg0) : (tensor<4xui8>) -> tensor<4xui8>
%0 = "mhlo.not"(%arg0) : (tensor<4xui8>) -> tensor<4xui8>
return %0 : tensor<4xui8>
}
@ -1031,7 +1031,7 @@ func @main(%arg0: tensor<4xui8>) -> tensor<4xui8> {
// CHECK: HloModule
func @main(%arg0: tensor<4xi32>) -> tensor<*xi32> {
%0 = "xla_hlo.not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
%0 = "mhlo.not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
%1 = tensor_cast %0 : tensor<4xi32> to tensor<*xi32>
return %1 : tensor<*xi32>
}
@ -1046,10 +1046,10 @@ func @main(%arg0: tensor<4xi32>) -> tensor<*xi32> {
// correctly in HloModule as frontend_attributes.
// CHECK: HloModule
func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> tuple<tensor<3x4xf32>, !xla_hlo.token> {
%0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_dtoh_0"}} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token
%1 = "xla_hlo.recv"(%0) {channel_id = {handle = 2 : i64, type = 3 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_htod_0"}} : (!xla_hlo.token) -> tuple<tensor<3x4xf32>, !xla_hlo.token>
return %1 : tuple<tensor<3x4xf32>, !xla_hlo.token>
func @main(%arg: tensor<3x4xf32>, %token: !mhlo.token) -> tuple<tensor<3x4xf32>, !mhlo.token> {
%0 = "mhlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_dtoh_0"}} : (tensor<3x4xf32>, !mhlo.token) -> !mhlo.token
%1 = "mhlo.recv"(%0) {channel_id = {handle = 2 : i64, type = 3 : i64}, is_host_transfer = true, mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_htod_0"}} : (!mhlo.token) -> tuple<tensor<3x4xf32>, !mhlo.token>
return %1 : tuple<tensor<3x4xf32>, !mhlo.token>
}
// CHECK: ENTRY
@ -1068,9 +1068,9 @@ func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> tuple<tensor<3x4xf3
// populated in HloModule.
// CHECK: HloModule
func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> !xla_hlo.token {
%0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {}} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token
return %0 : !xla_hlo.token
func @main(%arg: tensor<3x4xf32>, %token: !mhlo.token) -> !mhlo.token {
%0 = "mhlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, mhlo.frontend_attributes = {}} : (tensor<3x4xf32>, !mhlo.token) -> !mhlo.token
return %0 : !mhlo.token
}
// CHECK-NOT: frontend_attributes
@ -1081,9 +1081,9 @@ func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> !xla_hlo.token {
// populated in HloModule.
// CHECK: HloModule
func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> !xla_hlo.token {
%0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token
return %0 : !xla_hlo.token
func @main(%arg: tensor<3x4xf32>, %token: !mhlo.token) -> !mhlo.token {
%0 = "mhlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true} : (tensor<3x4xf32>, !mhlo.token) -> !mhlo.token
return %0 : !mhlo.token
}
// CHECK-NOT: frontend_attributes

View File

@ -9,95 +9,95 @@ ENTRY %tfcompile.48 {
%arg0.1 = f32[1,300] parameter(0)
%arg1.2 = f32[1,300,3,1] parameter(1)
// CHECK-NEXT: %0 = "xla_hlo.reshape"(%arg0) {name = "reshape.3"} : (tensor<1x300xf32>) -> tensor<1x300xf32>
// CHECK-NEXT: %0 = "mhlo.reshape"(%arg0) {name = "reshape.3"} : (tensor<1x300xf32>) -> tensor<1x300xf32>
%reshape.3 = f32[1,300] reshape(%arg0.1)
// CHECK-NEXT: %1 = "xla_hlo.transpose"(%0) {name = "transpose.27", permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32>
// CHECK-NEXT: %1 = "mhlo.transpose"(%0) {name = "transpose.27", permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32>
%transpose.27 = f32[300,1] transpose(%reshape.3), dimensions={1,0}
// CHECK-NEXT: %2 = "xla_hlo.reshape"(%1) {name = "reshape.28"} : (tensor<300x1xf32>) -> tensor<300x1x1xf32>
// CHECK-NEXT: %2 = "mhlo.reshape"(%1) {name = "reshape.28"} : (tensor<300x1xf32>) -> tensor<300x1x1xf32>
%reshape.28 = f32[300,1,1] reshape(%transpose.27)
// CHECK-NEXT: %3 = "xla_hlo.reshape"(%2) {name = "reshape.29"} : (tensor<300x1x1xf32>) -> tensor<300x1xf32>
// CHECK-NEXT: %3 = "mhlo.reshape"(%2) {name = "reshape.29"} : (tensor<300x1x1xf32>) -> tensor<300x1xf32>
%reshape.29 = f32[300,1] reshape(%reshape.28)
// CHECK-NEXT: %4 = "xla_hlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.30"} : (tensor<300x1xf32>) -> tensor<300x1x5xf32>
// CHECK-NEXT: %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.30"} : (tensor<300x1xf32>) -> tensor<300x1x5xf32>
%broadcast.30 = f32[300,1,5] broadcast(%reshape.29), dimensions={0,1}
// CHECK-NEXT: %cst = constant {name = "constant.8"} dense<1.000000e+00> : tensor<f32>
%constant.8 = f32[] constant(1)
// CHECK-NEXT: %5 = "xla_hlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.9"} : (tensor<f32>) -> tensor<300x1x5xf32>
// CHECK-NEXT: %5 = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.9"} : (tensor<f32>) -> tensor<300x1x5xf32>
%broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={}
// CHECK-NEXT: %6 = xla_hlo.multiply %4, %5 {name = "multiply.31"} : tensor<300x1x5xf32>
// CHECK-NEXT: %6 = mhlo.multiply %4, %5 {name = "multiply.31"} : tensor<300x1x5xf32>
%multiply.31 = f32[300,1,5] multiply(%broadcast.30, %broadcast.9)
// CHECK-NEXT: %cst_0 = constant {name = "constant.32"} dense<0.000000e+00> : tensor<f32>
%constant.32 = f32[] constant(0)
// CHECK-NEXT: %7 = "xla_hlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.33"} : (tensor<f32>) -> tensor<300x1x5xf32>
// CHECK-NEXT: %7 = "mhlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.33"} : (tensor<f32>) -> tensor<300x1x5xf32>
%broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={}
// CHECK-NEXT: %8 = "xla_hlo.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1>
// CHECK-NEXT: %8 = "mhlo.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1>
%compare.34 = pred[300,1,5] compare(%multiply.31, %broadcast.33), direction=GT
// CHECK-NEXT: %cst_1 = constant {name = "constant.10"} dense<0.000000e+00> : tensor<f32>
%constant.10 = f32[] constant(0)
// CHECK-NEXT: %9 = "xla_hlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.11"} : (tensor<f32>) -> tensor<300x1x5xf32>
// CHECK-NEXT: %9 = "mhlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.11"} : (tensor<f32>) -> tensor<300x1x5xf32>
%broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={}
// CHECK-NEXT: %cst_2 = constant {name = "constant.40"} dense<0.000000e+00> : tensor<f32>
%constant.40 = f32[] constant(0)
// CHECK-NEXT: %10 = "xla_hlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.41"} : (tensor<f32>) -> tensor<300x5xf32>
// CHECK-NEXT: %10 = "mhlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.41"} : (tensor<f32>) -> tensor<300x5xf32>
%broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={}
// CHECK-NEXT: %11 = "xla_hlo.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
// CHECK-NEXT: %11 = "mhlo.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
%copy.1 = f32[1,300,3,1] copy(%arg1.2)
// CHECK-NEXT: %12 = "xla_hlo.reshape"(%11) {name = "reshape.4"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
// CHECK-NEXT: %12 = "mhlo.reshape"(%11) {name = "reshape.4"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
%reshape.4 = f32[1,300,3,1] reshape(%copy.1)
// CHECK-NEXT: %13 = "xla_hlo.reshape"(%12) {name = "reshape.24"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32>
// CHECK-NEXT: %13 = "mhlo.reshape"(%12) {name = "reshape.24"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32>
%reshape.24 = f32[1,300,3] reshape(%reshape.4)
// CHECK-NEXT: %14 = "xla_hlo.transpose"(%13) {name = "transpose.25", permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32>
// CHECK-NEXT: %14 = "mhlo.transpose"(%13) {name = "transpose.25", permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32>
%transpose.25 = f32[300,1,3] transpose(%reshape.24), dimensions={1,0,2}
// CHECK-NEXT: %15 = "xla_hlo.reshape"(%14) {name = "reshape.26"} : (tensor<300x1x3xf32>) -> tensor<300x3xf32>
// CHECK-NEXT: %15 = "mhlo.reshape"(%14) {name = "reshape.26"} : (tensor<300x1x3xf32>) -> tensor<300x3xf32>
%reshape.26 = f32[300,3] reshape(%transpose.25)
// CHECK-NEXT: %cst_3 = constant {name = "constant.35"} dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
%constant.35 = f32[3,5] constant({ { -0.106023, 0.121505, 0.800239, -0.768885, 0.0966113 }, { 0.689014, -0.407056, -0.797853, 0.00378925, -0.208881 }, { -0.608529, 0.0276617, 0.268557, 0.577401, -0.428437 } })
// TODO(b/129709049) consider making this default precision config implied.
// CHECK-NEXT: %16 = "xla_hlo.dot"(%15, %cst_3) {name = "dot.36", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
// CHECK-NEXT: %16 = "mhlo.dot"(%15, %cst_3) {name = "dot.36", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
%dot.36 = f32[300,5] dot(%reshape.26, %constant.35), lhs_contracting_dims={1}, rhs_contracting_dims={0}
// CHECK-NEXT: %cst_4 = constant {name = "constant.37"} dense<0.000000e+00> : tensor<5xf32>
%constant.37 = f32[5]{0} constant({0, 0, 0, 0, 0})
// CHECK-NEXT: %17 = "xla_hlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.38"} : (tensor<5xf32>) -> tensor<300x5xf32>
// CHECK-NEXT: %17 = "mhlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.38"} : (tensor<5xf32>) -> tensor<300x5xf32>
%broadcast.38 = f32[300,5] broadcast(%constant.37), dimensions={1}
// CHECK-NEXT: %18 = xla_hlo.add %16, %17 {name = "add.39"} : tensor<300x5xf32>
// CHECK-NEXT: %18 = mhlo.add %16, %17 {name = "add.39"} : tensor<300x5xf32>
%add.39 = f32[300,5] add(%dot.36, %broadcast.38)
// CHECK-NEXT: %19 = xla_hlo.maximum %10, %18 {name = "maximum.42"} : tensor<300x5xf32>
// CHECK-NEXT: %19 = mhlo.maximum %10, %18 {name = "maximum.42"} : tensor<300x5xf32>
%maximum.42 = f32[300,5] maximum(%broadcast.41, %add.39)
// CHECK-NEXT: %20 = "xla_hlo.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32>
// CHECK-NEXT: %20 = "mhlo.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32>
%reshape.44 = f32[300,1,5] reshape(%maximum.42)
// CHECK-NEXT: %21 = "xla_hlo.select"(%8, %9, %20) {name = "select.45"} : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
// CHECK-NEXT: %21 = "mhlo.select"(%8, %9, %20) {name = "select.45"} : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
%select.45 = f32[300,1,5] select(%compare.34, %broadcast.11, %reshape.44)
// CHECK-NEXT: %22 = "xla_hlo.reshape"(%21) {name = "reshape.46"} : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
// CHECK-NEXT: %22 = "mhlo.reshape"(%21) {name = "reshape.46"} : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
%reshape.46 = f32[300,1,5] reshape(%select.45)
// CHECK-NEXT: %23 = "xla_hlo.tuple"(%22) {name = "tuple.47"} : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>>
// CHECK-NEXT: %23 = "mhlo.tuple"(%22) {name = "tuple.47"} : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>>
// CHECK-NEXT: return %23 : tuple<tensor<300x1x5xf32>>
ROOT %tuple.47 = (f32[300,1,5]) tuple(%reshape.46)
}

View File

@ -4,13 +4,13 @@
// CHECK: %[[A0]] = (f32[]) parameter(0)
func @then_branch(%arg0: tuple<tensor<f32>>) -> tuple<tensor<f32>> {
// CHECK: %[[VAL0:.+]] = f32[] get-tuple-element((f32[]) %[[A0]]), index=0
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
%0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
// CHECK: %[[VAL1:.+]] = f32[] log(f32[] %[[VAL0]])
%1 = "xla_hlo.log"(%0) : (tensor<f32>) -> tensor<f32>
%1 = "mhlo.log"(%0) : (tensor<f32>) -> tensor<f32>
// CHECK: ROOT %[[VAl2:.+]] = (f32[]) tuple(f32[] %[[VAL1]])
%2 = "xla_hlo.tuple"(%1) : (tensor<f32>) -> tuple<tensor<f32>>
%2 = "mhlo.tuple"(%1) : (tensor<f32>) -> tuple<tensor<f32>>
return %2 : tuple<tensor<f32>>
}
@ -18,13 +18,13 @@ func @then_branch(%arg0: tuple<tensor<f32>>) -> tuple<tensor<f32>> {
// CHECK: %[[A0]] = (f32[]) parameter(0)
func @else_branch(%arg0: tuple<tensor<f32>>) -> tuple<tensor<f32>> {
// CHECK: %[[VAL0:.+]] = f32[] get-tuple-element((f32[]) %[[A0]]), index=0
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
%0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
// CHECK: %[[VAL1:.+]] = f32[] exponential(f32[] %[[VAL0]])
%1 = "xla_hlo.exponential"(%0) : (tensor<f32>) -> tensor<f32>
%1 = "mhlo.exponential"(%0) : (tensor<f32>) -> tensor<f32>
// CHECK: ROOT %[[VAL2:.+]] = (f32[]) tuple(f32[] %[[VAL1]])
%2 = "xla_hlo.tuple"(%1) : (tensor<f32>) -> tuple<tensor<f32>>
%2 = "mhlo.tuple"(%1) : (tensor<f32>) -> tuple<tensor<f32>>
return %2 : tuple<tensor<f32>>
}
@ -35,30 +35,30 @@ func @main(%arg0: tensor<f32>) -> tuple<tensor<f32>> {
%cst = constant dense<1.000000e+01> : tensor<f32>
// CHECK: %[[VAL1:.+]] = pred[] compare(f32[] %[[A0]], f32[] %[[VAL0]]), direction=LT
%0 = "xla_hlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
%0 = "mhlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL2:.+]] = (f32[]) tuple(f32[] %[[A0]])
%1 = "xla_hlo.tuple"(%arg0) : (tensor<f32>) -> tuple<tensor<f32>>
%1 = "mhlo.tuple"(%arg0) : (tensor<f32>) -> tuple<tensor<f32>>
// CHECK: %[[VAL3:.+]] = (f32[]) conditional(pred[] %[[VAL1]], (f32[]) %[[VAL2]], (f32[]) %[[VAL2]]), true_computation=[[R0]], false_computation=[[R1]]
%2 = "xla_hlo.if"(%0, %1, %1) ( {
%2 = "mhlo.if"(%0, %1, %1) ( {
^bb0(%arg1: tuple<tensor<f32>>):
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
%7 = "xla_hlo.log"(%6) : (tensor<f32>) -> tensor<f32>
%8 = "xla_hlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
"xla_hlo.return"(%8) : (tuple<tensor<f32>>) -> ()
%6 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
%7 = "mhlo.log"(%6) : (tensor<f32>) -> tensor<f32>
%8 = "mhlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
"mhlo.return"(%8) : (tuple<tensor<f32>>) -> ()
}, {
^bb0(%arg1: tuple<tensor<f32>>):
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
%7 = "xla_hlo.exponential"(%6) : (tensor<f32>) -> tensor<f32>
%8 = "xla_hlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
"xla_hlo.return"(%8) : (tuple<tensor<f32>>) -> ()
%6 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
%7 = "mhlo.exponential"(%6) : (tensor<f32>) -> tensor<f32>
%8 = "mhlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
"mhlo.return"(%8) : (tuple<tensor<f32>>) -> ()
}) : (tensor<i1>, tuple<tensor<f32>>, tuple<tensor<f32>>) -> tuple<tensor<f32>>
// CHECK: %[[VAL4:.+]] = f32[] get-tuple-element((f32[]) %[[VAL3]]), index=0
%3 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
%3 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
// CHECK: ROOT %[[VAL5:.+]] = (f32[]) tuple(f32[] %[[VAL4]])
%4 = "xla_hlo.tuple"(%3) : (tensor<f32>) -> tuple<tensor<f32>>
%4 = "mhlo.tuple"(%3) : (tensor<f32>) -> tuple<tensor<f32>>
return %4 : tuple<tensor<f32>>
}

View File

@ -23,31 +23,31 @@ ENTRY %tfcompile.20 {
// CHECK: [[C0:%.+]] = constant
%constant.3 = f32[] constant(10), metadata={op_type="Less" op_name="Less"}
// CHECK: [[R1:%.+]] = "xla_hlo.compare"([[A0]], [[C0]])
// CHECK: [[R1:%.+]] = "mhlo.compare"([[A0]], [[C0]])
%compare.4 = pred[] compare(%arg0.1, %constant.3), direction=LT, metadata={op_type="Less" op_name="Less"}
// CHECK: [[R2:%.+]] = "xla_hlo.tuple"([[A0]])
// CHECK: [[R2:%.+]] = "mhlo.tuple"([[A0]])
%tuple.5 = (f32[]) tuple(%arg0.1), metadata={op_type="If" op_name="cond/Merge_if"}
// CHECK: [[R3:%.+]] = "xla_hlo.if"([[R1]], [[R2]], [[R2]]) ( {
// CHECK: [[R3:%.+]] = "mhlo.if"([[R1]], [[R2]], [[R2]]) ( {
// CHECK: ^bb0([[A1:%.+]]: tuple<tensor<f32>>):
// CHECK: [[R7:%.+]] = "xla_hlo.get_tuple_element"([[A1]])
// CHECK: [[R8:%.+]] = "xla_hlo.log"([[R7]])
// CHECK: [[R9:%.+]] = "xla_hlo.tuple"([[R8]])
// CHECK: "xla_hlo.return"([[R9]])
// CHECK: [[R7:%.+]] = "mhlo.get_tuple_element"([[A1]])
// CHECK: [[R8:%.+]] = "mhlo.log"([[R7]])
// CHECK: [[R9:%.+]] = "mhlo.tuple"([[R8]])
// CHECK: "mhlo.return"([[R9]])
// CHECK: }, {
// CHECK: ^bb0([[A1:%.+]]: tuple<tensor<f32>>):
// CHECK: [[R7:%.+]] = "xla_hlo.get_tuple_element"([[A1]])
// CHECK: [[R8:%.+]] = "xla_hlo.exponential"([[R7]])
// CHECK: [[R9:%.+]] = "xla_hlo.tuple"([[R8]])
// CHECK: "xla_hlo.return"([[R9]])
// CHECK: [[R7:%.+]] = "mhlo.get_tuple_element"([[A1]])
// CHECK: [[R8:%.+]] = "mhlo.exponential"([[R7]])
// CHECK: [[R9:%.+]] = "mhlo.tuple"([[R8]])
// CHECK: "mhlo.return"([[R9]])
// CHECK: })
%conditional.16 = (f32[]) conditional(%compare.4, %tuple.5, %tuple.5), true_computation=%then_branch, false_computation=%else_branch, metadata={op_type="If" op_name="cond/Merge_if"}
// CHECK: [[R4:%.+]] = "xla_hlo.get_tuple_element"([[R3]])
// CHECK: [[R4:%.+]] = "mhlo.get_tuple_element"([[R3]])
%get-tuple-element.17 = f32[] get-tuple-element(%conditional.16), index=0, metadata={op_type="If" op_name="cond/Merge_if"}
// CHECK: [[R5:%.+]] = "xla_hlo.tuple"([[R4]])
// CHECK: [[R5:%.+]] = "mhlo.tuple"([[R4]])
// CHECK: return [[R5]]
ROOT %tuple.19 = (f32[]) tuple(%get-tuple-element.17), metadata={op_name="XLA_Retvals"}
}

View File

@ -13,20 +13,20 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
%Arg_0.1 = f32[4]{0} parameter(0)
%Arg_1.2 = f32[4]{0} parameter(1)
// CHECK-NEXT: xla_hlo.add %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
// CHECK-NEXT: mhlo.add %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
%add.3 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2)
// TODO(b/129709049) consider making this default precision config inferred.
// CHECK-NEXT: "xla_hlo.dot"(%0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
// CHECK-NEXT: "mhlo.dot"(%0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}
}
// CHECK-LABEL: func @test_after_all
// CHECK-SAME: ([[VAL_0:%.*]]: !xla_hlo.token, [[VAL_1:%.*]]: !xla_hlo.token) -> !xla_hlo.token [[PRIVATE]]
// CHECK-SAME: ([[VAL_0:%.*]]: !mhlo.token, [[VAL_1:%.*]]: !mhlo.token) -> !mhlo.token [[PRIVATE]]
%test_after_all (token0: token[], token1: token[] ) -> token[] {
token0 = token[] parameter(0)
token1 = token[] parameter(1)
// CHECK-NEXT: "xla_hlo.after_all"([[VAL_0]], [[VAL_1]]) {name = "{{.*}}"} : (!xla_hlo.token, !xla_hlo.token) -> !xla_hlo.token
// CHECK-NEXT: "mhlo.after_all"([[VAL_0]], [[VAL_1]]) {name = "{{.*}}"} : (!mhlo.token, !mhlo.token) -> !mhlo.token
ROOT after-all = token[] after-all(token0, token1)
}
@ -41,10 +41,10 @@ add {
// CHECK-SAME: ([[INPUT:%.*]]: tensor<8xf32>)
%test_all_reduce {
input = f32[8] parameter(0)
// CHECK-NEXT: "xla_hlo.all_reduce"([[INPUT]])
// CHECK-NEXT: "mhlo.all_reduce"([[INPUT]])
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = xla_hlo.add [[ARG0]], [[ARG1]]
// CHECK: "xla_hlo.return"([[ADD]]) : (tensor<f32>) -> ()
// CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
// CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
// CHECK: }) {
// CHECK-SAME: channel_handle = {handle = 1 : i64, type = 0 : i64}
// CHECK-SAME: replica_groups = dense<{{\[\[}}0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>
@ -57,7 +57,7 @@ add {
%Arg_0.1 = pred[4] parameter(0)
%Arg_1.2 = pred[4] parameter(1)
// CHECK-NEXT: xla_hlo.and %arg0, %arg1
// CHECK-NEXT: mhlo.and %arg0, %arg1
ROOT %and.3 = pred[4] and(pred[4] %Arg_0.1, pred[4] %Arg_1.2)
}
@ -67,7 +67,7 @@ add {
%Arg_0.1 = s32[4] parameter(0)
%Arg_1.2 = s32[4] parameter(1)
// CHECK: xla_hlo.atan2 [[VAL_0]], [[VAL_1]]
// CHECK: mhlo.atan2 [[VAL_0]], [[VAL_1]]
ROOT %atan2 = s32[4] atan2(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
}
@ -75,10 +75,10 @@ add {
%test_broadcast_in_dim {
%Arg_0.1 = f32[1, 2] parameter(0)
// CHECK-NEXT: "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<1x2x3xf32>
// CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<1x2x3xf32>
%broadcast.2 = f32[1,2,3] broadcast(%Arg_0.1), dimensions={0,1}
// CHECK-NEXT: "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<3x1x2xf32>
// CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<3x1x2xf32>
ROOT broadcast.4 = f32[3,1,2] broadcast(%Arg_0.1), dimensions={1, 2}
}
@ -90,7 +90,7 @@ add {
%variance = f32[2] parameter(3)
%grad_output = f32[2,2,2,2] parameter(4)
// CHECK: "xla_hlo.batch_norm_grad"
// CHECK: "mhlo.batch_norm_grad"
// CHECK-SAME: epsilon = 1.000000e-03 : f32
// CHECK-SAME: feature_index = 1 : i64
ROOT %batch-norm-grad = (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(f32[2,2,2,2] %input, f32[2] %scale, f32[2] %mean, f32[2] %variance, f32[2,2,2,2] %grad_output), epsilon=0.001, feature_index=1
@ -113,7 +113,7 @@ add {
// CHECK-SAME: ([[ARG:%.*]]: tensor<1x291x291xf32>) -> tensor<1x291x291xf32>
%test_cholesky (a: f32[1,291,291]) -> f32[1,291,291] {
%a = f32[1,291,291] parameter(0)
// CHECK-NEXT: "xla_hlo.cholesky"([[ARG]]) {lower = true, name = {{.*}}} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32>
// CHECK-NEXT: "mhlo.cholesky"([[ARG]]) {lower = true, name = {{.*}}} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32>
ROOT %out = f32[1,291,291] cholesky(f32[1,291,291] %a), lower=true
}
@ -124,7 +124,7 @@ add {
%Arg_1.2 = f32[4] parameter(1)
%Arg_2.3 = f32[] parameter(2)
// CHECK-NEXT: "xla_hlo.clamp"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
// CHECK-NEXT: "mhlo.clamp"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
ROOT %clamp.3 = f32[4] clamp(f32[] %Arg_0.1, f32[4] %Arg_1.2, f32[] %Arg_2.3)
}
@ -132,7 +132,7 @@ add {
// CHECK-SAME: ([[ARG:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32>
%test_collective_permute (input: f32[128,32]) -> f32[128,32] {
%input = f32[128,32]{0,1} parameter(0)
// CHECK-NEXT: "xla_hlo.collective_permute"([[ARG]]) {name = {{.*}}, source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>} : (tensor<128x32xf32>) -> tensor<128x32xf32>
// CHECK-NEXT: "mhlo.collective_permute"([[ARG]]) {name = {{.*}}, source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>} : (tensor<128x32xf32>) -> tensor<128x32xf32>
ROOT root = f32[128,32]{0,1} collective-permute(%input), source_target_pairs={{0,1},{1,2},{2,3}}
}
@ -143,14 +143,14 @@ add {
%Arg_1.2 = f32[3] parameter(1)
%Arg_2.3 = f32[3] parameter(2)
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
%compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "LE", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
%compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE
// Requires broadcast of compatible tensors.
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT
}
@ -159,7 +159,7 @@ add {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: "xla_hlo.complex"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
// CHECK-NEXT: "mhlo.complex"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
ROOT %complex.3 = c64[4] complex(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
@ -168,7 +168,7 @@ add {
%Arg_0.1 = f32[4, 1] parameter(0)
%Arg_1.2 = f32[4, 2] parameter(1)
// CHECK-NEXT: "xla_hlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<4x1xf32>, tensor<4x2xf32>) -> tensor<4x3xf32>
// CHECK-NEXT: "mhlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<4x1xf32>, tensor<4x2xf32>) -> tensor<4x3xf32>
ROOT %concatenate.3 = f32[4, 3] concatenate(f32[4, 1] %Arg_0.1, f32[4, 2] %Arg_1.2), dimensions={1}
}
@ -206,10 +206,10 @@ add {
%test_conv {
%arg0.1 = f32[256,32,32,6]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
// CHECK-NEXT: %0 = "xla_hlo.copy"(%arg0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
// CHECK-NEXT: %0 = "mhlo.copy"(%arg0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
%copy.1 = f32[256,32,32,6]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"}
// CHECK-NEXT: %1 = "xla_hlo.reshape"(%0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
// CHECK-NEXT: %1 = "mhlo.reshape"(%0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
%reshape.2 = f32[256,32,32,6]{2,1,3,0} reshape(%copy.1)
// Note that double brackets "[[" have to be escaped as they denote variables
@ -217,7 +217,7 @@ add {
// CHECK-NEXT: %cst = constant {name = "{{.*}}"} dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32>
%constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
// CHECK-NEXT: %2 = "xla_hlo.convolution"(%1, %cst) {
// CHECK-NEXT: %2 = "mhlo.convolution"(%1, %cst) {
// CHECK-SAME: batch_group_count = 1 : i64
// CHECK-SAME: dimension_numbers = {
// CHECK-SAME: input_batch_dimension = 0 : i64
@ -241,10 +241,10 @@ add {
%convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
// CHECK-NEXT: %3 = "xla_hlo.reshape"(%2) {name = "{{.*}}"} : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32>
// CHECK-NEXT: %3 = "mhlo.reshape"(%2) {name = "{{.*}}"} : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32>
%reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"}
// CHECK-NEXT: "xla_hlo.tuple"(%3) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tuple<tensor<256x30x30x16xf32>>
// CHECK-NEXT: "mhlo.tuple"(%3) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tuple<tensor<256x30x30x16xf32>>
ROOT %tuple.6 = (f32[256,30,30,16]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"}
}
@ -253,7 +253,7 @@ add {
%test_convolve1D_padding (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,5,1] {
%input = f32[1,2,1] parameter(0)
%filter = f32[1,1,1] parameter(1)
// CHECK: "xla_hlo.convolution"
// CHECK: "mhlo.convolution"
// CHECK-SAME: padding = dense<{{\[\[}}1, 2]]> : tensor<1x2xi64>
ROOT %convolution = f32[1,5,1] convolution(f32[1,2,1] %input, f32[1,1,1] %filter), feature_group_count=1, dim_labels=b0f_0io->b0f, window={pad=1_2 size=1}
}
@ -263,13 +263,13 @@ add {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: %0 = "xla_hlo.convert"(%arg0) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64>
// CHECK-NEXT: %0 = "mhlo.convert"(%arg0) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64>
%convert.3 = f64[4] convert(f32[4] %Arg_0.1)
// CHECK-NEXT: %1 = "xla_hlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64>
// CHECK-NEXT: %1 = "mhlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64>
%convert.4 = f64[4] convert(f32[4] %Arg_1.2)
// CHECK-NEXT: xla_hlo.add %0, %1
// CHECK-NEXT: mhlo.add %0, %1
ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[4] %convert.4)
}
@ -277,7 +277,7 @@ add {
%test_cosine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
// CHECK-NEXT: "xla_hlo.cosine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
// CHECK-NEXT: "mhlo.cosine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
ROOT %cosine.3 = f32[1,16,16,3]{3,2,1,0} cosine(f32[1,16,16,3]{3,2,1,0} %arg0.1)
}
@ -286,7 +286,7 @@ add {
%test_custom_call (arg1: f32[2,3], arg2: f32[5,5]) -> f32[1,2,3] {
%arg1 = f32[2,3] parameter(0)
%arg2 = f32[5,5] parameter(1)
// CHECK: "xla_hlo.custom_call"([[ARG_0]], [[ARG_1]]) {backend_config = "bar", call_target_name = "foo", has_side_effect = true, name = {{.*}}} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
// CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {backend_config = "bar", call_target_name = "foo", has_side_effect = true, name = {{.*}}} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[2,3] %arg1, f32[5,5] %arg2), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true
}
@ -295,7 +295,7 @@ add {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: xla_hlo.divide %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
// CHECK-NEXT: mhlo.divide %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
ROOT %divide.3 = f32[4] divide(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
@ -304,17 +304,17 @@ add {
%Arg_0.1 = f32[1, 4] parameter(0)
%Arg_1.2 = f32[4, 1] parameter(1)
// CHECK-NEXT: %0 = "xla_hlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
// CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={high,highest}
// CHECK-NEXT: %1 = "xla_hlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
// CHECK-NEXT: %1 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,default}
// CHECK-NEXT: %2 = "xla_hlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
// CHECK-NEXT: %2 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
%dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default,default}
// TODO(b/129709049) consider making this default precision config inferred.
// CHECK-NEXT: "xla_hlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
// CHECK-NEXT: "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
@ -325,17 +325,17 @@ add {
%Arg_0.1 = f32[4, 1] parameter(0)
%Arg_1.2 = f32[1, 4] parameter(1)
// CHECK-NEXT: [[R0:%.+]] = "xla_hlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]}
// CHECK-NEXT: [[R0:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]}
dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={high,highest}
// CHECK-NEXT: [[R1:%.+]] = "xla_hlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]}
// CHECK-NEXT: [[R1:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]}
dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={highest,default}
// CHECK-NEXT: [[R2:%.+]] = "xla_hlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]}
// CHECK-NEXT: [[R2:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]}
%dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={default,default}
// TODO(b/129709049) consider making this default precision config inferred.
// CHECK-NEXT: "xla_hlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]}
// CHECK-NEXT: "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]}
ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}
}
@ -346,7 +346,7 @@ add {
%start_idx_1 = s32[] parameter(1)
%start_idx_2 = s32[] parameter(2)
%start_idx_3 = s32[] parameter(3)
// CHECK: "xla_hlo.dynamic-slice"([[OPERAND]], [[START_IDX_1]], [[START_IDX_2]], [[START_IDX_3]])
// CHECK: "mhlo.dynamic-slice"([[OPERAND]], [[START_IDX_1]], [[START_IDX_2]], [[START_IDX_3]])
// CHECK-SAME: slice_sizes = dense<[1, 1, 32]> : tensor<3xi64>
ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[] %start_idx_1, s32[] %start_idx_2, s32[] %start_idx_3), dynamic_slice_sizes={1,1,32}
}
@ -358,7 +358,7 @@ add {
%Arg_2.3 = s32[] parameter(2)
%Arg_3.4 = s32[] parameter(3)
// CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor<i32>, tensor<i32>) -> tensor<4x4xf32>
// CHECK-NEXT: "mhlo.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor<i32>, tensor<i32>) -> tensor<4x4xf32>
ROOT %dynamic-update-slice.5 = f32[4, 4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4)
}
@ -368,7 +368,7 @@ add {
%Arg_1.2 = f32[2] parameter(1)
%Arg_2.3 = s32[] parameter(2)
// CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<i32>) -> tensor<4xf32>
// CHECK-NEXT: "mhlo.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<i32>) -> tensor<4xf32>
ROOT %dynamic-update-slice.5 = f32[4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3)
}
@ -376,7 +376,7 @@ add {
%test_exponential (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
// CHECK-NEXT: "xla_hlo.exponential"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
// CHECK-NEXT: "mhlo.exponential"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
ROOT %exp.2 = f32[16] exponential(f32[16] %arg0.1)
}
@ -384,14 +384,14 @@ add {
%test_expm1 (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
// CHECK: "xla_hlo.exponential_minus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
// CHECK: "mhlo.exponential_minus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
ROOT %expm1.2 = f32[16] exponential-minus-one(f32[16] %arg0.1)
}
// CHECK-LABEL: func @test_fft(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex<f32>>
%test_fft {
%arg0.1 = f32[3,9]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
// CHECK: "xla_hlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"
// CHECK: "mhlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"
ROOT %fft.2 = c64[3,5]{1,0} fft(%arg0.1), fft_type=RFFT, fft_length={9}, metadata={op_type="RFFT" op_name="rfft"}
}
@ -400,7 +400,7 @@ add {
%test_floor (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
// CHECK-NEXT: "xla_hlo.floor"([[A0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
// CHECK-NEXT: "mhlo.floor"([[A0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
ROOT %floor.2 = f32[16] floor(f32[16] %arg0.1)
}
@ -409,7 +409,7 @@ add {
%test_gather (arg.0: f32[200,100,300], arg.1: s32[10,2]) -> f32[10,300] {
%arg.0 = f32[200,100,300] parameter(0)
%arg.1 = s32[10,2] parameter(1)
// CHECK: "xla_hlo.gather"([[ARG0]], [[ARG1]])
// CHECK: "mhlo.gather"([[ARG0]], [[ARG1]])
// CHECK-SAME: dimension_numbers
// CHECK-SAME: collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>
// CHECK-SAME: index_vector_dim = 1 : i64
@ -430,7 +430,7 @@ add {
// CHECK-SAME: ([[ARG:%.*]]: tensor<4x2xf32>)
%test_get_dimension_size (Arg_0.1: f32[4,2]) -> s32[] {
%Arg_0.1 = f32[4,2] parameter(0)
// CHECK-NEXT: "xla_hlo.get_dimension_size"([[ARG]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x2xf32>) -> tensor<i32>
// CHECK-NEXT: "mhlo.get_dimension_size"([[ARG]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x2xf32>) -> tensor<i32>
ROOT %get-dimension-size.2 = s32[] get-dimension-size(f32[4,2] %Arg_0.1), dimensions={1}
}
@ -438,15 +438,15 @@ add {
%test_imag (Arg_0.1: c64[4]) -> f32[4] {
%Arg_0.1 = c64[4] parameter(0)
// CHECK-NEXT: "xla_hlo.imag"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
// CHECK-NEXT: "mhlo.imag"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
ROOT %imag.3 = f32[4] imag(c64[4] %Arg_0.1)
}
// CHECK-LABEL: func @test_infeed
// CHECK-SAME: ([[TOKEN:%.*]]: !xla_hlo.token) -> tuple<tensor<3xi32>, !xla_hlo.token>
// CHECK-SAME: ([[TOKEN:%.*]]: !mhlo.token) -> tuple<tensor<3xi32>, !mhlo.token>
%test_infeed (token0: token[]) -> (s32[3], token[]) {
%token0 = token[] parameter(0)
// CHECK-NEXT: "xla_hlo.infeed"([[TOKEN]])
// CHECK-NEXT: "mhlo.infeed"([[TOKEN]])
// CHECK-SAME: infeed_config = "foobar"
ROOT %infeed = (s32[3], token[]) infeed(token[] %token0), infeed_config="foobar"
}
@ -454,13 +454,13 @@ add {
// CHECK-LABEL: func @test_iota_1() -> tensor<4xf32>
%test_iota_1 () -> f32[4] {
// CHECK-NEXT: "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
// CHECK-NEXT: "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
ROOT %iota.0 = f32[4] iota(), iota_dimension=0
}
// CHECK-LABEL: func @test_iota_2() -> tensor<4x5xf32>
%test_iota_2 () -> f32[4, 5] {
// CHECK-NEXT: "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4x5xf32>
// CHECK-NEXT: "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4x5xf32>
ROOT %iota.0 = f32[4, 5] iota(), iota_dimension=1
}
@ -468,7 +468,7 @@ add {
%test_log (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
// CHECK-NEXT: "xla_hlo.log"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
// CHECK-NEXT: "mhlo.log"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
ROOT %log.2 = f32[16] log(f32[16] %arg0.1)
}
@ -476,11 +476,11 @@ add {
%test_log1p (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
// CHECK: "xla_hlo.log_plus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
// CHECK: "mhlo.log_plus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
ROOT %log1p.2 = f32[16] log-plus-one(f32[16] %arg0.1)
}
// Test xla_hlo.map
// Test mhlo.map
%map_computation {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
@ -492,10 +492,10 @@ add {
%test_map {
param0 = f32[4]{0} parameter(0)
param1 = f32[4]{0} parameter(1)
// CHECK: "xla_hlo.map"([[ARG_0]], [[ARG_1]]) ( {
// CHECK: "mhlo.map"([[ARG_0]], [[ARG_1]]) ( {
// CHECK: ^bb0([[ARG_2:%.*]]: tensor<f32>, [[ARG_3:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = xla_hlo.add [[ARG_2]], [[ARG_3]]
// CHECK: "xla_hlo.return"([[ADD]]) : (tensor<f32>) -> ()
// CHECK: [[ADD:%.*]] = mhlo.add [[ARG_2]], [[ARG_3]]
// CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
// CHECK: }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=%map_computation
}
@ -507,7 +507,7 @@ add {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: xla_hlo.maximum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
// CHECK-NEXT: mhlo.maximum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
ROOT %maximum.3 = f32[4] maximum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
@ -516,7 +516,7 @@ add {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: xla_hlo.minimum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
// CHECK-NEXT: mhlo.minimum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
ROOT %minimum.3 = f32[4] minimum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
@ -525,7 +525,7 @@ add {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
// CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
ROOT %multiply.3 = f32[4] multiply(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
@ -533,7 +533,7 @@ add {
%test_negate (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
// CHECK-NEXT: "xla_hlo.negate"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
// CHECK-NEXT: "mhlo.negate"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
ROOT %negate.2 = f32[16] negate(f32[16] %arg0.1)
}
@ -541,7 +541,7 @@ add {
%test_not (arg0.1: pred[16]) -> pred[16] {
%arg0.1 = pred[16] parameter(0)
// CHECK: "xla_hlo.not"(%arg0) {name = "{{.*}}"} : (tensor<16xi1>) -> tensor<16xi1>
// CHECK: "mhlo.not"(%arg0) {name = "{{.*}}"} : (tensor<16xi1>) -> tensor<16xi1>
ROOT %not.2 = pred[16] not(pred[16] %arg0.1)
}
@ -550,16 +550,16 @@ add {
%Arg_0.1 = pred[4] parameter(0)
%Arg_1.2 = pred[4] parameter(1)
// CHECK-NEXT: xla_hlo.or %arg0, %arg1
// CHECK-NEXT: mhlo.or %arg0, %arg1
ROOT %or.3 = pred[4] or(pred[4] %Arg_0.1, pred[4] %Arg_1.2)
}
// CHECK-LABEL: func @test_outfeed
// CHECK-SAME: ([[DATA:%.*]]: tensor<3xi32>, [[TOKEN:%.*]]: !xla_hlo.token) -> !xla_hlo.token
// CHECK-SAME: ([[DATA:%.*]]: tensor<3xi32>, [[TOKEN:%.*]]: !mhlo.token) -> !mhlo.token
%test_outfeed (Arg_0.1: s32[3], Arg_1.2: token[]) -> token[] {
%Arg_0.1 = s32[3] parameter(0)
%Arg_1.2 = token[] parameter(1)
// CHECK-NEXT: "xla_hlo.outfeed"([[DATA]], [[TOKEN]])
// CHECK-NEXT: "mhlo.outfeed"([[DATA]], [[TOKEN]])
// CHECK-SAME: outfeed_config = "foobar"
ROOT %outfeed.3 = token[] outfeed(s32[3] %Arg_0.1, token[] %Arg_1.2), outfeed_config="foobar"
}
@ -569,7 +569,7 @@ add {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[] parameter(1)
// CHECK-NEXT: "xla_hlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
// CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
ROOT %pad.3 = f32[4] pad(%Arg_0.1, %Arg_1.2), padding=0_0_0
}
@ -578,7 +578,7 @@ add {
%Arg_0.1 = f32[4, 4, 4] parameter(0)
%Arg_1.2 = f32[] parameter(1)
// CHECK-NEXT: "xla_hlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 4, 6]> : tensor<3xi64>, edge_padding_low = dense<[1, 3, 5]> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<4x4x4xf32>, tensor<f32>) -> tensor<7x11x15xf32>
// CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 4, 6]> : tensor<3xi64>, edge_padding_low = dense<[1, 3, 5]> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<4x4x4xf32>, tensor<f32>) -> tensor<7x11x15xf32>
ROOT %pad.3 = f32[7, 11, 15] pad(%Arg_0.1, %Arg_1.2), padding=1_2x3_4x5_6
}
@ -587,7 +587,7 @@ add {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[] parameter(1)
// CHECK-NEXT: "xla_hlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<2> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<10xf32>
// CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<2> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<10xf32>
ROOT %pad.3 = f32[10] pad(%Arg_0.1, %Arg_1.2), padding=0_0_2
}
@ -595,7 +595,7 @@ add {
%test_popcnt (arg0.1: s32[16]) -> s32[16] {
%arg0.1 = s32[16] parameter(0)
// CHECK: "xla_hlo.popcnt"(%arg0) {name = "{{.*}}"} : (tensor<16xi32>) -> tensor<16xi32>
// CHECK: "mhlo.popcnt"(%arg0) {name = "{{.*}}"} : (tensor<16xi32>) -> tensor<16xi32>
ROOT %popcnt.2 = s32[16] popcnt(s32[16] %arg0.1)
}
@ -604,7 +604,7 @@ add {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: xla_hlo.power %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
// CHECK-NEXT: mhlo.power %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
ROOT %power.3 = f32[4] power(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
@ -614,7 +614,7 @@ add {
%Arg_0.1 = f32[] parameter(0)
%Arg_1.2 = f32[] parameter(1)
// CHECK: [[CST:%.*]] = constant dense<[2, 3, 5]> : tensor<3xi64>
// CHECK: "xla_hlo.rng_normal"([[ARG0]], [[ARG1]], [[CST]])
// CHECK: "mhlo.rng_normal"([[ARG0]], [[ARG1]], [[CST]])
ROOT %rng.4 = f32[2,3,5] rng(f32[] %Arg_0.1, f32[] %Arg_1.2), distribution=rng_normal
}
@ -624,7 +624,7 @@ add {
%Arg_0.1 = f32[] parameter(0)
%Arg_1.2 = f32[] parameter(1)
// CHECK: [[CST:%.*]] = constant dense<[2, 3, 5]> : tensor<3xi64>
// CHECK: "xla_hlo.rng_uniform"([[ARG0]], [[ARG1]], [[CST]])
// CHECK: "mhlo.rng_uniform"([[ARG0]], [[ARG1]], [[CST]])
ROOT %rng.4 = f32[2,3,5] rng(f32[] %Arg_0.1, f32[] %Arg_1.2), distribution=rng_uniform
}
@ -632,7 +632,7 @@ add {
%test_real (Arg_0.1: c64[4]) -> f32[4] {
%Arg_0.1 = c64[4] parameter(0)
// CHECK-NEXT: "xla_hlo.real"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
// CHECK-NEXT: "mhlo.real"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
ROOT %real.3 = f32[4] real(c64[4] %Arg_0.1)
}
@ -666,28 +666,28 @@ add {
%Arg_1.2 = f32[4] parameter(1)
%Arg_2.3 = f32[] parameter(2)
// CHECK: "xla_hlo.reduce"([[ARG0]], [[ARG0]], [[ARG2]], [[ARG2]])
// CHECK: xla_hlo.add{{.*}} : tensor<f32>
// CHECK: xla_hlo.add{{.*}} : tensor<f32>
// CHECK: "mhlo.reduce"([[ARG0]], [[ARG0]], [[ARG2]], [[ARG2]])
// CHECK: mhlo.add{{.*}} : tensor<f32>
// CHECK: mhlo.add{{.*}} : tensor<f32>
// CHECK: {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
%reduce.1 = (f32[], f32[]) reduce(%Arg_0.1, %Arg_0.1, %Arg_2.3, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.1
// CHECK: [[VAL2:%.*]] = "xla_hlo.reduce"([[ARG0]], [[ARG2]])
// CHECK: xla_hlo.add{{.*}} : tensor<f32>
// CHECK: [[VAL2:%.*]] = "mhlo.reduce"([[ARG0]], [[ARG2]])
// CHECK: mhlo.add{{.*}} : tensor<f32>
// CHECK: {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<f32>) -> tensor<f32>
%reduce.3 = f32[] reduce(%Arg_0.1, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.3
// CHECK: [[VAL3:%.*]] = "xla_hlo.reduce"([[ARG0]], [[ARG1]])
// CHECK: xla_hlo.add{{.*}} : tensor<4xf32>
// CHECK: [[VAL3:%.*]] = "mhlo.reduce"([[ARG0]], [[ARG1]])
// CHECK: mhlo.add{{.*}} : tensor<4xf32>
// CHECK: {dimensions = dense<0> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32>
%reduce.2 = f32[4] reduce(%Arg_0.1, %Arg_1.2), dimensions={0}, to_apply=%reduce_helper.2
// CHECK: [[VAL4:%.*]] = "xla_hlo.reduce"([[VAL3]], [[ARG2]])
// CHECK: xla_hlo.add{{.*}} : tensor<f32>
// CHECK: [[VAL4:%.*]] = "mhlo.reduce"([[VAL3]], [[ARG2]])
// CHECK: mhlo.add{{.*}} : tensor<f32>
// CHECK: {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<f32>
%reduce.4 = f32[] reduce(%reduce.2, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.3
// CHECK: %4 = xla_hlo.subtract [[VAL2]], [[VAL4]] {name = "{{.*}}"} : tensor<f32>
// CHECK: %4 = mhlo.subtract [[VAL2]], [[VAL4]] {name = "{{.*}}"} : tensor<f32>
%sub.5 = f32[] subtract(%reduce.3, %reduce.4)
ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.1, %sub.5)
@ -699,8 +699,8 @@ add {
%Arg_0.1 = f32[2,17,31,7] parameter(0)
%Arg_1.2 = f32[] parameter(1)
// CHECK: "xla_hlo.reduce_window"([[ARG0]], [[ARG1]]) ( {
// CHECK: xla_hlo.add {{.*}} : tensor<f32>
// CHECK: "mhlo.reduce_window"([[ARG0]], [[ARG1]]) ( {
// CHECK: mhlo.add {{.*}} : tensor<f32>
// CHECK: }) {
// CHECK-SAME: base_dilations = dense<1> : tensor<4xi64>
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64>
@ -716,7 +716,7 @@ add {
%test_remainder (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK: xla_hlo.remainder [[VAL_0]], [[VAL_1]]
// CHECK: mhlo.remainder [[VAL_0]], [[VAL_1]]
ROOT %remainder.3 = f32[4] remainder(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
@ -724,7 +724,7 @@ add {
%test_reverse_1d (Arg_0.1: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
// CHECK-NEXT: "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
ROOT reverse.2 = f32[4] reverse(%Arg_0.1), dimensions={0}
}
@ -732,7 +732,7 @@ add {
%test_reverse_2d (Arg_0.1: f32[4, 4]) -> f32[4, 4] {
%Arg_0.1 = f32[4, 4] parameter(0)
// CHECK-NEXT: "xla_hlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4xf32>
// CHECK-NEXT: "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4xf32>
ROOT reverse.2 = f32[4, 4] reverse(%Arg_0.1), dimensions={0, 1}
}
@ -741,7 +741,7 @@ add {
%test_rsqrt (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
// CHECK: "xla_hlo.rsqrt"([[ARG0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
// CHECK: "mhlo.rsqrt"([[ARG0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
ROOT %rsqrt.2 = f32[16] rsqrt(f32[16] %arg0.1)
}
@ -767,10 +767,10 @@ add {
// CHECK-LABEL: func @test_scatter
// CHECK-SAME: [[ARG_0:%.*]]: tensor<200x100x300xf32>, [[ARG_1:%.*]]: tensor<10x2xi64>, [[ARG_2:%.*]]: tensor<10x300xf32>) -> tensor<200x100x300xf32>
// CHECK: "xla_hlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) ( {
// CHECK: "mhlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) ( {
// CHECK: ^bb0([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = xla_hlo.add [[LHS]], [[RHS]]
// CHECK: "xla_hlo.return"([[ADD]]) : (tensor<f32>) -> ()
// CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]]
// CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: indices_are_sorted = false
// CHECK-SAME: scatter_dimension_numbers = {
@ -788,7 +788,7 @@ add {
%Arg_1.2 = s32[2,3] parameter(1)
%Arg_2.3 = s32[2,3] parameter(2)
// CHECK: "xla_hlo.select"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK: "mhlo.select"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3)
}
@ -814,14 +814,14 @@ add {
ROOT %select-and-scatter = f32[4,5] select-and-scatter(f32[4,5] %input, f32[2,2] %source, f32[] %init_value), window={size=2x3 stride=2x3 pad=0_0x0_1}, select=%ge_select, scatter=%add_gather
}
// CHECK: [[RESULT:%.*]] = "xla_hlo.select_and_scatter"([[INPUT]], [[SOURCE]], [[INIT_VAL]]) ( {
// CHECK: [[RESULT:%.*]] = "mhlo.select_and_scatter"([[INPUT]], [[SOURCE]], [[INIT_VAL]]) ( {
// CHECK: ^bb0([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
// CHECK: [[CMP:%.*]] = "xla_hlo.compare"([[LHS]], [[RHS]])
// CHECK: "xla_hlo.return"([[CMP]]) : (tensor<i1>) -> ()
// CHECK: [[CMP:%.*]] = "mhlo.compare"([[LHS]], [[RHS]])
// CHECK: "mhlo.return"([[CMP]]) : (tensor<i1>) -> ()
// CHECK: }, {
// CHECK: ^bb0([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = xla_hlo.add [[LHS]], [[RHS]]
// CHECK: "xla_hlo.return"([[ADD]]) : (tensor<f32>) -> ()
// CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]]
// CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
// CHECK: }) {
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 1]]> : tensor<2x2xi64>
// CHECK-SAME: window_dimensions = dense<[2, 3]> : tensor<2xi64>
@ -835,7 +835,7 @@ add {
%test_set_dimension_size (Arg_0.1: f32[4,4], Arg_1.2: s32[]) -> f32[4,<=4] {
%Arg_0.1 = f32[4,4] parameter(0)
%Arg_1.2 = s32[] parameter(1)
// CHECK-NEXT: "xla_hlo.set_dimension_size"([[ARG]], [[SIZE]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x4xf32>, tensor<i32>) -> tensor<4x4xf32>
// CHECK-NEXT: "mhlo.set_dimension_size"([[ARG]], [[SIZE]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x4xf32>, tensor<i32>) -> tensor<4x4xf32>
ROOT %set-dimension-size.2 = f32[4,<=4] set-dimension-size(f32[4,4] %Arg_0.1, s32[] %Arg_1.2), dimensions={1}
}
@ -843,7 +843,7 @@ add {
%test_sine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
// CHECK-NEXT: "xla_hlo.sine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
// CHECK-NEXT: "mhlo.sine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
ROOT %sine.3 = f32[1,16,16,3]{3,2,1,0} sine(f32[1,16,16,3]{3,2,1,0} %arg0.1)
}
@ -860,10 +860,10 @@ add {
}
// CHECK-LABEL: func @test_sort
// CHECK-SAME: [[ARG:%.*]]: tensor<1024xf32>) -> tensor<1024xf32>
// CHECK: "xla_hlo.sort"([[ARG]]) ( {
// CHECK: "mhlo.sort"([[ARG]]) ( {
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
// CHECK: [[CMP:%.*]] = "xla_hlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "LT", name = "lt"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: "xla_hlo.return"([[CMP]]) : (tensor<i1>) -> ()
// CHECK: [[CMP:%.*]] = "mhlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "LT", name = "lt"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: "mhlo.return"([[CMP]]) : (tensor<i1>) -> ()
// CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<1024xf32>) -> tensor<1024xf32>
// CHECK-LABEL: func @test_subtract
@ -871,7 +871,7 @@ add {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: xla_hlo.subtract %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
// CHECK-NEXT: mhlo.subtract %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
ROOT %subtract.3 = f32[4] subtract(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
@ -879,7 +879,7 @@ add {
%test_tanh (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
// CHECK-NEXT: "xla_hlo.tanh"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
// CHECK-NEXT: "mhlo.tanh"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
ROOT %tanh.3 = f32[1,16,16,3]{3,2,1,0} tanh(f32[1,16,16,3]{3,2,1,0} %arg0.1), metadata={op_type="Tanh" op_name="embedded_inference/tanh_model/Tanh"}
}
@ -887,7 +887,7 @@ add {
%test_transpose {
%Arg_0.1 = s32[1,2,3,4] parameter(0)
// CHECK: "xla_hlo.transpose"(%arg0) {name = "{{.*}}", permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
// CHECK: "mhlo.transpose"(%arg0) {name = "{{.*}}", permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2}
}
@ -896,7 +896,7 @@ add {
%test_triangular_solve (Arg_0.1: f32[4,4], Arg_1.2: f32[4,3]) -> f32[4,3] {
%Arg_0.1 = f32[4,4] parameter(0)
%Arg_1.2 = f32[4,3] parameter(1)
// CHECK-NEXT: "xla_hlo.triangular_solve"([[ARG_A]], [[ARG_B]])
// CHECK-NEXT: "mhlo.triangular_solve"([[ARG_A]], [[ARG_B]])
// CHECK-SAME: left_side = true
// CHECK-SAME: lower = true
// CHECK-SAME: transpose_a = "NO_TRANSPOSE"
@ -909,10 +909,10 @@ add {
%Arg_0.1 = s32[1] parameter(0)
%Arg_1.2 = f32[1, 2] parameter(1)
// CHECK-NEXT: %0 = "xla_hlo.tuple"(%arg0) {name = "{{.*}}"} : (tensor<1xi32>) -> tuple<tensor<1xi32>>
// CHECK-NEXT: %0 = "mhlo.tuple"(%arg0) {name = "{{.*}}"} : (tensor<1xi32>) -> tuple<tensor<1xi32>>
%tuple.3 = (s32[1]) tuple(%Arg_0.1)
// CHECK: "xla_hlo.tuple"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>
// CHECK: "mhlo.tuple"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>
ROOT %tuple.4 = (s32[1], f32[1,2]) tuple(%Arg_0.1, %Arg_1.2)
}
@ -932,14 +932,14 @@ add {
// CHECK-LABEL: func @test_while(%arg0: tensor<i64>) -> tensor<i64>
%test_while (arg0.1: s64[]) -> s64[] {
%arg0.1 = s64[] parameter(0), metadata={op_name="HLO_Args"}
// CHECK-NEXT: "xla_hlo.while"(%arg0) ( {
// CHECK-NEXT: "mhlo.while"(%arg0) ( {
// CHECK-NEXT: ^bb0(%arg1: tensor<i64>): // no predecessors
// CHECK-NEXT: [[CMP:%.*]] = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "{{.*}}"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK-NEXT: "xla_hlo.return"([[CMP]]) : (tensor<i1>) -> ()
// CHECK-NEXT: [[CMP:%.*]] = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "{{.*}}"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK-NEXT: "mhlo.return"([[CMP]]) : (tensor<i1>) -> ()
// CHECK-NEXT: }, {
// CHECK-NEXT: ^bb0(%arg1: tensor<i64>): // no predecessors
// CHECK-NEXT: [[ADD:%.*]] = xla_hlo.add %arg1, %arg1 {name = "{{.*}}"} : tensor<i64>
// CHECK-NEXT: "xla_hlo.return"([[ADD]]) : (tensor<i64>) -> ()
// CHECK-NEXT: [[ADD:%.*]] = mhlo.add %arg1, %arg1 {name = "{{.*}}"} : tensor<i64>
// CHECK-NEXT: "mhlo.return"([[ADD]]) : (tensor<i64>) -> ()
// CHECK-NEXT: }) : (tensor<i64>) -> tensor<i64>
ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond
}
@ -950,7 +950,7 @@ add {
%Arg_0.1 = pred[4] parameter(0)
%Arg_1.2 = pred[4] parameter(1)
// CHECK: xla_hlo.xor [[VAL_0]], [[VAL_1]]
// CHECK: mhlo.xor [[VAL_0]], [[VAL_1]]
ROOT %xor.3 = pred[4] xor(pred[4] %Arg_0.1, pred[4] %Arg_1.2)
}
@ -960,7 +960,7 @@ add {
%Arg_0.1 = s32[4] parameter(0)
%Arg_1.2 = s32[4] parameter(1)
// CHECK: xla_hlo.shift_left [[VAL_0]], [[VAL_1]]
// CHECK: mhlo.shift_left [[VAL_0]], [[VAL_1]]
ROOT %shiftleft = s32[4] shift-left(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
}
@ -970,7 +970,7 @@ add {
%Arg_0.1 = s32[4] parameter(0)
%Arg_1.2 = s32[4] parameter(1)
// CHECK: xla_hlo.shift_right_arithmetic [[VAL_0]], [[VAL_1]]
// CHECK: mhlo.shift_right_arithmetic [[VAL_0]], [[VAL_1]]
ROOT %shiftright.arithmetic = s32[4] shift-right-arithmetic(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
}
@ -980,7 +980,7 @@ add {
%Arg_0.1 = s32[4] parameter(0)
%Arg_1.2 = s32[4] parameter(1)
// CHECK: xla_hlo.shift_right_logical [[VAL_0]], [[VAL_1]]
// CHECK: mhlo.shift_right_logical [[VAL_0]], [[VAL_1]]
ROOT %shiftright.logical = s32[4] shift-right-logical(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
}
@ -992,8 +992,8 @@ add {
%Arg_1.2 = c128[2] parameter(1)
%abs.4 = f64[2] abs(c128[2] %Arg_1.2)
// CHECK: "xla_hlo.abs"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// CHECK: "xla_hlo.abs"(%[[ARG1]]) {name = "{{.*}}"} : (tensor<2xcomplex<f64>>) -> tensor<2xf64>
// CHECK: "mhlo.abs"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// CHECK: "mhlo.abs"(%[[ARG1]]) {name = "{{.*}}"} : (tensor<2xcomplex<f64>>) -> tensor<2xf64>
ROOT %tuple.5 = (f32[2], f64[2]) tuple(f32[2] %abs.3, f64[2] %abs.4)
}
@ -1002,6 +1002,6 @@ add {
%unsigned_int(Arg_0.1: u16[4]) -> u16[4] {
%Arg_0.1 = u16[4] parameter(0)
// CHECK: "xla_hlo.not"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<4xui16>) -> tensor<4xui16>
// CHECK: "mhlo.not"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<4xui16>) -> tensor<4xui16>
ROOT %not.2 = u16[4] not(u16[4] %Arg_0.1)
}

View File

@ -6,7 +6,7 @@
// TUPLE-ARG-LABEL: ENTRY %main
// TUPLE-ARG: // OutputIndex {0} aliases with input 0 at {0}
func @main(%arg0: tensor<1xf32> {tf.aliasing_output = 0 : i64}) -> (tensor<1xf32>) {
%0 = xla_hlo.constant dense<4.200000e+01> : tensor<1xf32>
%1 = xla_hlo.add %arg0, %0 : tensor<1xf32>
%0 = mhlo.constant dense<4.200000e+01> : tensor<1xf32>
%1 = mhlo.add %arg0, %0 : tensor<1xf32>
return %1 : tensor<1xf32>
}

View File

@ -9,6 +9,6 @@
func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<1x2x3x4xi32>) {
// CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0)
// CHECK-NEXT: %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3}
%0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
%0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
return %arg0, %0 : tensor<4xi32>, tensor<1x2x3x4xi32>
}

View File

@ -139,8 +139,8 @@ dynamic_parameter_binding {
}
# CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<f32> {
# CHECK-NEXT: %0 = xla_hlo.add %arg0, %arg1 {name = "add.3"} : tensor<4xf32>
# CHECK-NEXT: %0 = mhlo.add %arg0, %arg1 {name = "add.3"} : tensor<4xf32>
# TODO(b/129709049) consider making this default precision config inferred.
# CHECK-NEXT: %1 = "xla_hlo.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
# CHECK-NEXT: %1 = "mhlo.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
# CHECK-NEXT: return %1 : tensor<f32>
# CHECK-NEXT: }

View File

@ -2,8 +2,8 @@
func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> {
^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%1 = "xla_hlo.dot"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%1 = "mhlo.dot"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %1 : tensor<4xf32>
}

View File

@ -16,14 +16,14 @@ HloModule foo
ENTRY %foo (arg0.1: s64[]) -> s64[] {
%arg0.1 = s64[] parameter(0), metadata={op_name="HLO_Args"}
// CHECK: "xla_hlo.while"(%arg0) ( {
// CHECK: "mhlo.while"(%arg0) ( {
// CHECK: ^bb0
// CHECK: "xla_hlo.compare"
// CHECK: "xla_hlo.return"
// CHECK: "mhlo.compare"
// CHECK: "mhlo.return"
// CHECK: }, {
// CHECK: ^bb0
// CHECK: xla_hlo.add
// CHECK: "xla_hlo.return"
// CHECK: mhlo.add
// CHECK: "mhlo.return"
// CHECK: }) : (tensor<i64>) -> tensor<i64>
ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond
}

View File

@ -2,7 +2,7 @@
module {
func @main(%arg0: tensor<i64>) -> tensor<i64> {
%0 = "xla_hlo.while"(%arg0) ( {
%0 = "mhlo.while"(%arg0) ( {
// CHECK: [[R0:%.+]] ([[A0:.+]]: s64[]) -> s64[] {
// CHECK: %[[A0]] = s64[] parameter(0)
// CHECK: ROOT %add.4 = s64[] add(s64[] %[[A0]], s64[] %[[A0]])
@ -10,12 +10,12 @@ module {
// CHECK: %[[A0]] = s64[] parameter(0)
// CHECK: ROOT %compare.7 = pred[] compare(s64[] %[[A0]], s64[] %[[A0]]), direction=LT
^bb0(%arg1: tensor<i64>):
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"mhlo.return"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i64>):
%1 = xla_hlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64>
"xla_hlo.return"(%1) : (tensor<i64>) -> ()
%1 = mhlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64>
"mhlo.return"(%1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
// CHECK: ENTRY %main.9 ([[A0:.+]]: s64[]) -> s64[] {

View File

@ -62,10 +62,10 @@ limitations under the License.
#include "tensorflow/core/util/tensor_format.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
constexpr char kShardingAttr[] = "xla_hlo.sharding";
constexpr char kShardingAttr[] = "mhlo.sharding";
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
public:
@ -289,7 +289,7 @@ static ConstOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value,
return builder->create<ConstOp>(loc, xla::GetScalarOfType(ty, raw_value));
}
// Creates an xla_hlo::SliceOp where the major dimensions have full size, and
// Creates an mhlo::SliceOp where the major dimensions have full size, and
// the minor dimensions have the provided offsets and sizes.
static Value SliceInMinorDims(Location loc, Value v,
ArrayRef<int64_t> minor_starts,
@ -326,7 +326,7 @@ static llvm::SmallVector<Value, 4> CreateFullIndexVectorFromMinorIndices(
return indices;
}
// Creates an xla_hlo::DynamicSliceOp where the major dimensions have full size,
// Creates an mhlo::DynamicSliceOp where the major dimensions have full size,
// and the minor dimensions have the provided offsets and sizes.
static Value DynamicSliceInMinorDims(Location loc, Value v,
ArrayRef<Value> minor_starts,
@ -341,12 +341,12 @@ static Value DynamicSliceInMinorDims(Location loc, Value v,
std::copy(minor_sizes.begin(), minor_sizes.end(),
slice_sizes.begin() + major_dims);
auto slice_type = RankedTensorType::get(slice_sizes, type.getElementType());
return builder->create<xla_hlo::DynamicSliceOp>(
return builder->create<mhlo::DynamicSliceOp>(
loc, slice_type, v, slice_starts,
GetI64ElementsAttr(slice_sizes, builder));
}
// Creates an xla_hlo::DynamicUpdateSliceOp where the major dimensions have zero
// Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero
// offsets, and the minor dimensions have the provided offsets.
static Value DynamicUpdateSliceInMinorDims(Location loc, Value v, Value update,
ArrayRef<Value> minor_starts,
@ -359,7 +359,7 @@ static Value DynamicUpdateSliceInMinorDims(Location loc, Value v, Value update,
llvm::makeArrayRef(dus_starts));
}
// Creates an xla_hlo::DynamicUpdateSliceOp where the major dimensions have zero
// Creates an mhlo::DynamicUpdateSliceOp where the major dimensions have zero
// offsets, and the minor dimensions have the provided static offsets.
static Value UpdateSliceInMinorDims(Location loc, Value v, Value update,
ArrayRef<int64_t> minor_starts,
@ -540,7 +540,7 @@ static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to,
loc, to_type, input, result_extents, broadcast_dims);
}
// Creates a batch dot using xla_hlo::DotGeneralOp.
// Creates a batch dot using mhlo::DotGeneralOp.
Value BatchDot(Location loc, Value lhs, bool transpose_lhs, Value rhs,
bool transpose_rhs, int64_t num_batch_dims,
ArrayAttr precision_config, OpBuilder *builder) {
@ -605,23 +605,22 @@ static Value ApplyReduction(Location loc, Value input,
builder->getBoolAttr(false));
}
// Creates a xla_hlo.rng_uniform op with `builder` to generate `num_elements`
// Creates a mhlo.rng_uniform op with `builder` to generate `num_elements`
// 32-bit integer numbers in the range of [`lower_limit`, `upper_limit`).
static xla_hlo::RngUniformOp CreateRngUniform32(Location loc, int num_elements,
int lower_limit,
int upper_limit,
static mhlo::RngUniformOp CreateRngUniform32(Location loc, int num_elements,
int lower_limit, int upper_limit,
OpBuilder *builder) {
auto i32_type = builder->getIntegerType(32);
auto key_type = RankedTensorType::get({num_elements}, i32_type);
auto shape_tensor = builder->create<xla_hlo::ConstOp>(
auto shape_tensor = builder->create<mhlo::ConstOp>(
loc, GetI64ElementsAttr({num_elements}, builder));
auto lower = builder->create<xla_hlo::ConstOp>(
auto lower = builder->create<mhlo::ConstOp>(
loc, builder->getI32IntegerAttr(lower_limit));
auto upper = builder->create<xla_hlo::ConstOp>(
auto upper = builder->create<mhlo::ConstOp>(
loc, builder->getI32IntegerAttr(upper_limit));
return builder->create<xla_hlo::RngUniformOp>(loc, key_type, lower, upper,
return builder->create<mhlo::RngUniformOp>(loc, key_type, lower, upper,
shape_tensor);
}
@ -629,7 +628,7 @@ using WhileBodyFnType = llvm::function_ref<void(
Location loc, Value iteration, ArrayRef<Value> old_values,
SmallVectorImpl<Value> *new_values, OpBuilder *builder)>;
// Creates a xla_hlo.while op with `builder` to loop `num_interations` times,
// Creates a mhlo.while op with `builder` to loop `num_interations` times,
// each time calling the given `body_fn` on a set of values to generate a new
// set of values. Returns the final set of values via `final_values`. The
// initial set of values is passed in via `init_values`.
@ -659,16 +658,16 @@ static void CreateWhile32(Location loc, int num_iterations,
init_values_with_loop_iv.reserve(value_count);
// The initial value for the loop induction variable is 0.
init_values_with_loop_iv.push_back(
builder->create<xla_hlo::ConstOp>(loc, builder->getI32IntegerAttr(0)));
builder->create<mhlo::ConstOp>(loc, builder->getI32IntegerAttr(0)));
init_values_with_loop_iv.append(init_values.begin(), init_values.end());
// Prepare the initial tuple for the while op.
auto init_tuple =
builder->create<xla_hlo::TupleOp>(loc, init_values_with_loop_iv);
builder->create<mhlo::TupleOp>(loc, init_values_with_loop_iv);
auto tuple_type = init_tuple.getType();
// Create the while op.
auto while_op = builder->create<xla_hlo::WhileOp>(loc, init_tuple);
auto while_op = builder->create<mhlo::WhileOp>(loc, init_tuple);
{
OpBuilder::InsertionGuard guard(*builder);
@ -681,13 +680,13 @@ static void CreateWhile32(Location loc, int num_iterations,
// Get the loop induction variable and compare it against the upper limit.
auto loop_iv = builder->create<GetTupleElementOp>(loc, arg, 0);
auto upper_limit = builder->create<xla_hlo::ConstOp>(
auto upper_limit = builder->create<mhlo::ConstOp>(
loc, builder->getI32IntegerAttr(num_iterations));
StringAttr compare_direction = StringAttr::get("LT", builder->getContext());
Value compare = builder->create<xla_hlo::CompareOp>(
loc, loop_iv, upper_limit, compare_direction);
Value compare = builder->create<mhlo::CompareOp>(loc, loop_iv, upper_limit,
compare_direction);
builder->create<xla_hlo::ReturnOp>(loc, compare);
builder->create<mhlo::ReturnOp>(loc, compare);
}
{
@ -714,16 +713,16 @@ static void CreateWhile32(Location loc, int num_iterations,
// Increment the loop induction variable by one.
auto one =
builder->create<xla_hlo::ConstOp>(loc, builder->getI32IntegerAttr(1));
builder->create<mhlo::ConstOp>(loc, builder->getI32IntegerAttr(1));
auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder);
auto plus_one = builder->create<xla_chlo::BroadcastAddOp>(
loc, old_values[0], one, scalar_broadcast_dims);
// Prepend with the updated loop induction variable.
new_values.insert(new_values.begin(), plus_one);
Value updated_tuple = builder->create<xla_hlo::TupleOp>(loc, new_values);
Value updated_tuple = builder->create<mhlo::TupleOp>(loc, new_values);
builder->create<xla_hlo::ReturnOp>(loc, updated_tuple);
builder->create<mhlo::ReturnOp>(loc, updated_tuple);
}
final_values->reserve(init_values.size());
@ -786,7 +785,7 @@ static Value CreateConvertOp(OpBuilder *builder, Location loc, Value input,
Value elem_type_tensor) {
auto element_type =
elem_type_tensor.getType().cast<TensorType>().getElementType();
return builder->create<xla_hlo::ConvertOp>(loc, input, element_type);
return builder->create<mhlo::ConvertOp>(loc, input, element_type);
}
//===----------------------------------------------------------------------===//
@ -1023,9 +1022,9 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes(
// Sort op utilities.
//===----------------------------------------------------------------------===//
// Builds the region `body` for xla_hlo.sort's comparator: for each type in
// Builds the region `body` for mhlo.sort's comparator: for each type in
// `element_types`, create two block arguments, one for lhs and one for rhs, and
// generates xla_hlo.compare op to compare them with the given `direction`.
// generates mhlo.compare op to compare them with the given `direction`.
//
// Note that this right now only does comparision on the first pair of block
// arguments.
@ -1044,10 +1043,10 @@ static void BuildSortComparisonBody(llvm::ArrayRef<Type> element_types,
Location loc = body->getLoc();
StringAttr compare_direction =
StringAttr::get(direction, builder->getContext());
Value compare = builder->create<xla_hlo::CompareOp>(
Value compare = builder->create<mhlo::CompareOp>(
loc, block->getArgument(0), block->getArgument(1), compare_direction);
builder->create<xla_hlo::ReturnOp>(loc, compare);
builder->create<mhlo::ReturnOp>(loc, compare);
}
//===----------------------------------------------------------------------===//
@ -1110,7 +1109,7 @@ class ConvertBiasAddOp : public OpRewritePattern<TF::BiasAddOp> {
//
// Sample result for Conv2D:
//
// %conv = "xla_hlo.convolution"(%input, %filter) {
// %conv = "mhlo.convolution"(%input, %filter) {
// strides = [1, 2],
// paddings = [[1, 0], [1, 1]],
// ...
@ -1235,7 +1234,7 @@ class ConvertConvOp : public OpRewritePattern<OpTy> {
new_shape.push_back(1);
new_shape.push_back(filter_shape[num_spatial_dims] *
filter_shape[num_spatial_dims + 1]);
operands[1] = rewriter.create<xla_hlo::ReshapeOp>(
operands[1] = rewriter.create<mhlo::ReshapeOp>(
op.getLoc(),
RankedTensorType::get(new_shape, filter_ty.getElementType()),
operands[1]);
@ -1319,16 +1318,16 @@ class ConvertBroadcastToOp : public OpRewritePattern<TF::BroadcastToOp> {
// Converts TensorFlow DiagPartOp to HLO ops using reduction on masked matrix.
// For a Rank-2 input, it creates the following ops:
// %1 = "xla_hlo.iota"() {iota_dimension = 0 : i64}
// %2 = "xla_hlo.iota"() {iota_dimension = 1 : i64}
// %3 = "xla_hlo.compare"(%1, %2) {comparison_direction = "EQ"}
// %4 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
// %5 = "xla_hlo.broadcast"(%4)
// %6 = "xla_hlo.select"(%3, %input, %5)
// %7 = "xla_hlo.reduce"(%6, %4) ( {
// %1 = "mhlo.iota"() {iota_dimension = 0 : i64}
// %2 = "mhlo.iota"() {iota_dimension = 1 : i64}
// %3 = "mhlo.compare"(%1, %2) {comparison_direction = "EQ"}
// %4 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// %5 = "mhlo.broadcast"(%4)
// %6 = "mhlo.select"(%3, %input, %5)
// %7 = "mhlo.reduce"(%6, %4) ( {
// ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
// %9 = xla_hlo.add %arg1, %arg2 : tensor<f32>
// "xla_hlo.return"(%9) : (tensor<f32>) -> ()
// %9 = mhlo.add %arg1, %arg2 : tensor<f32>
// "mhlo.return"(%9) : (tensor<f32>) -> ()
// }) {dimensions = dense<0> : tensor<1xi64>}
//
// If the input's rank N is greater than 2, we will reshape it to R2 first and
@ -1353,7 +1352,7 @@ class ConvertDiagPartOp : public OpRewritePattern<TF::DiagPartOp> {
new_size *= input_type.getDimSize(i);
new_dims.push_back(input_type.getDimSize(i));
}
Value reshaped_input = rewriter.create<xla_hlo::ReshapeOp>(
Value reshaped_input = rewriter.create<mhlo::ReshapeOp>(
op.getLoc(),
RankedTensorType::get({new_size, new_size},
input_type.getElementType()),
@ -1490,23 +1489,23 @@ class ConvertFusedBatchNormGradBase
Value scratch1 = rewriter.create<RsqrtOp>(loc, add_op);
// scratch2 = sum(y_backprop * (x - mean))
auto sub_op = rewriter.create<xla_hlo::SubOp>(
auto sub_op = rewriter.create<mhlo::SubOp>(
loc, act,
Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter));
auto weighted_grad = rewriter.create<xla_hlo::MulOp>(loc, grad, sub_op);
auto weighted_grad = rewriter.create<mhlo::MulOp>(loc, grad, sub_op);
Value scratch2 =
ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter);
// x_backprop = y_backprop * (scale * scratch1)
auto scaled_grad =
rewriter.create<xla_hlo::MulOp>(loc, op.scale(), scratch1);
x_backprop = rewriter.create<xla_hlo::MulOp>(
rewriter.create<mhlo::MulOp>(loc, op.scale(), scratch1);
x_backprop = rewriter.create<mhlo::MulOp>(
loc, grad,
Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim,
rewriter));
// scale_backprop = scratch2 * scratch1
scale_backprop = rewriter.create<xla_hlo::MulOp>(loc, scratch1, scratch2);
scale_backprop = rewriter.create<mhlo::MulOp>(loc, scratch1, scratch2);
// offset_backprop = sum(y_backprop)
offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter);
@ -1559,8 +1558,8 @@ class ConvertFusedBatchNormV3Op
// TODO(b/69928690): Support mixed precision in the XLA batch
// normalization operators. As a workaround, create a new x with the same
// element type as scale (which may be more precise than the input type).
Value bn_train_input = rewriter.create<xla_hlo::ConvertOp>(
op.getLoc(), op.x(), scale_element_type);
Value bn_train_input = rewriter.create<mhlo::ConvertOp>(op.getLoc(), op.x(),
scale_element_type);
TensorType bn_train_input_type_tensor =
bn_train_input.getType().cast<TensorType>();
@ -1579,17 +1578,17 @@ class ConvertFusedBatchNormV3Op
mean_var_type, mean_var_type};
Type result_type = TupleType::get(operand_types, rewriter.getContext());
auto bn_train_op = rewriter.create<xla_hlo::BatchNormTrainingOp>(
auto bn_train_op = rewriter.create<mhlo::BatchNormTrainingOp>(
op.getLoc(), result_type, bn_train_input, op.scale(), op.offset(),
op.epsilon(), feature_dim.getValue());
// HLO op outputs a tuple of tensors. Extract those results.
auto bn_train_op_result = bn_train_op.getResult();
Value y_out = rewriter.create<xla_hlo::GetTupleElementOp>(
Value y_out = rewriter.create<mhlo::GetTupleElementOp>(
op.getLoc(), bn_train_op_result, 0);
Value batch_mean = rewriter.create<xla_hlo::GetTupleElementOp>(
Value batch_mean = rewriter.create<mhlo::GetTupleElementOp>(
op.getLoc(), bn_train_op_result, 1);
Value reserve_space_1 = batch_mean;
Value batch_variance = rewriter.create<xla_hlo::GetTupleElementOp>(
Value batch_variance = rewriter.create<mhlo::GetTupleElementOp>(
op.getLoc(), bn_train_op_result, 2);
// Apply Bessel's correction on the variance.
@ -1599,7 +1598,7 @@ class ConvertFusedBatchNormV3Op
int sample_size_minus_one = std::max(1, sample_size - 1);
double factor = static_cast<double>(sample_size) /
static_cast<double>(sample_size_minus_one);
auto factor_const_op = rewriter.create<xla_hlo::ConstOp>(
auto factor_const_op = rewriter.create<mhlo::ConstOp>(
op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor));
Value corrected_variance = rewriter.create<xla_chlo::BroadcastMulOp>(
@ -1608,16 +1607,16 @@ class ConvertFusedBatchNormV3Op
// Convert back to input type to stay aligned with expected output type
// for TF op.
y_out = rewriter.create<xla_hlo::ConvertOp>(op.getLoc(), y_out,
y_out = rewriter.create<mhlo::ConvertOp>(op.getLoc(), y_out,
input_element_type);
float exponential_avg_factor =
op.exponential_avg_factor().convertToFloat();
if (exponential_avg_factor != 1.0f) {
auto alpha = rewriter.create<xla_hlo::ConstOp>(
auto alpha = rewriter.create<mhlo::ConstOp>(
op.getLoc(), rewriter.getFloatAttr(mean_element_type,
1.0f - exponential_avg_factor));
auto beta = rewriter.create<xla_hlo::ConstOp>(
auto beta = rewriter.create<mhlo::ConstOp>(
op.getLoc(),
rewriter.getFloatAttr(mean_element_type, exponential_avg_factor));
@ -1666,7 +1665,7 @@ class ConvertFusedBatchNormV3Op
// Convert back to input type to stay aligned with expected output type
// for TF op.
auto y_out = rewriter.create<xla_hlo::ConvertOp>(op.getLoc(), bn_train_op,
auto y_out = rewriter.create<mhlo::ConvertOp>(op.getLoc(), bn_train_op,
input_element_type);
// The mean, variance, and reserved space outputs of the batch norm op are
@ -1947,8 +1946,8 @@ class ConvertAvgPoolGradOp : public OpRewritePattern<OpTy> {
BuildReduceBody<AddOp>(element_type, &window_counts.body(), &rewriter);
// Divide `out_grad` by window counts.
out_grad_divided = rewriter.create<xla_hlo::DivOp>(
loc, out_grad_type, out_grad, window_counts);
out_grad_divided = rewriter.create<mhlo::DivOp>(loc, out_grad_type,
out_grad, window_counts);
}
// Get same padding as for original input.
@ -2053,7 +2052,7 @@ using ConvertAvgPool3DGradOp =
// Sample result for VALID padding mode:
//
// %init = constant dense<...> : tensor<i32>
// %max_pool = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.maximum"]
// %max_pool = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"]
// {window_dimensions = ..., window_strides = ... }
//
template <typename OpTy, int num_dims>
@ -2098,13 +2097,13 @@ using ConvertMaxPool3DOp = ConvertMaxPoolOp<TF::MaxPool3DOp, /*num_dims=*/5>;
//
// will be converted into:
//
// %pred = "xla_hlo.broadcast_in_dim"(%cond)
// %pred = "mhlo.broadcast_in_dim"(%cond)
// {broadcast_dimensions = dense<[0]> : tensor<1xi64>} :
// (tensor<1xi1>) -> tensor<2xi1>
// %on_false = "xla_hlo.broadcast_in_dim"(%e)
// %on_false = "mhlo.broadcast_in_dim"(%e)
// {broadcast_dimensions = dense<[0]> : tensor<1xi64>} :
// (tensor<1xi32>) -> tensor<2xi32>
// %select = "xla_hlo.select"(%pred, %t, %on_false) :
// %select = "mhlo.select"(%pred, %t, %on_false) :
// (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
class ConvertSelectV2Op : public OpRewritePattern<TF::SelectV2Op> {
public:
@ -2173,18 +2172,18 @@ class ConvertSelectV2Op : public OpRewritePattern<TF::SelectV2Op> {
// Sample result with 2-d f16 inputs with B batches of with N elements each.
//
// // Create an array of 0.5 the shape of the input array.
// %half = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
// %half_array = "xla_hlo.broadcast"(half)
// %half = mhlo.constant dense<5.000000e-01> : tensor<f32>
// %half_array = "mhlo.broadcast"(half)
// {broadcast_sizes = dense<2> : tensor<1xi64>}
// : (tensor<f32>) -> tensor<2xf32>
//
// // Compute Tanh of half the logits of the values.
// %halved_logits = xla_hlo.multiply %logits, %half_array : tensor<2xf32>
// %tanh = "xla_hlo.tanh"(%halved_logits) : (tensor<2xf32>) -> tensor<2xf32>
// %halved_logits = mhlo.multiply %logits, %half_array : tensor<2xf32>
// %tanh = "mhlo.tanh"(%halved_logits) : (tensor<2xf32>) -> tensor<2xf32>
//
// // Have the result of Tanh and add 0.5.
// %halved_tanh = xla_hlo.multiply %tanh, %half : tensor<2xf32>
// %sigmoid = xla_hlo.add %halved_tanh, %half : tensor<2xf32>
// %halved_tanh = mhlo.multiply %tanh, %half : tensor<2xf32>
// %sigmoid = mhlo.add %halved_tanh, %half : tensor<2xf32>
//
class ConvertSigmoidOp : public OpRewritePattern<TF::SigmoidOp> {
public:
@ -2227,15 +2226,15 @@ class ConvertSigmoidOp : public OpRewritePattern<TF::SigmoidOp> {
// // stability.
// %max = "tf.Max"(%input, %reduce_dim)
// : (tensor<BxNxf16>, tensor<1xi64>) -> tensor<Bxf16>
// %sub = "xla_hlo.subtract"(%inp, %max) {broadcast_dimensions = 0}
// %sub = "mhlo.subtract"(%inp, %max) {broadcast_dimensions = 0}
// : (tensor<BxNxf16>, tensor<Bxf16>) -> tensor<BxNxf16>
//
// %exp = "xla_hlo.exponential"(%sub) : (tensor<BxNxf16>) -> tensor<BxNxf16>
// %exp = "mhlo.exponential"(%sub) : (tensor<BxNxf16>) -> tensor<BxNxf16>
// %sum = "tf.Sum"(%exp, %reduce_dim)
// : (tensor<BxNxf32>, tensor<1xi64>) -> tensor<Bxf32>
//
// // Softmax computation:
// %softmax = "xla_hlo.divide"(%exp, %sum_f16) {broadcast_dimensions = 0}
// %softmax = "mhlo.divide"(%exp, %sum_f16) {broadcast_dimensions = 0}
// : (tensor<BxNxf16>, tensor<Bxf16>) -> tensor<BxNxf16>
template <typename OpTy, bool use_log = true>
class ConvertSoftmaxOp : public OpRewritePattern<OpTy> {
@ -2270,8 +2269,8 @@ class ConvertSoftmaxOp : public OpRewritePattern<OpTy> {
/*keep_dims=*/rewriter.getBoolAttr(false));
auto max_logits_broadcast =
CommonPrefixBroadcast(loc, logits, max_logits, rewriter);
auto shifted_logits = rewriter.create<xla_hlo::SubOp>(loc, type, logits,
max_logits_broadcast);
auto shifted_logits =
rewriter.create<mhlo::SubOp>(loc, type, logits, max_logits_broadcast);
// Exponentiate the inputs.
Value exp = rewriter.create<ExpOp>(loc, type, shifted_logits);
@ -2285,11 +2284,11 @@ class ConvertSoftmaxOp : public OpRewritePattern<OpTy> {
if (use_log) {
Value log = rewriter.create<LogOp>(loc, sum);
auto log_broadcast = CommonPrefixBroadcast(loc, logits, log, rewriter);
rewriter.replaceOpWithNewOp<xla_hlo::SubOp>(op, shifted_logits,
rewriter.replaceOpWithNewOp<mhlo::SubOp>(op, shifted_logits,
log_broadcast);
} else {
auto sum_broadcast = CommonPrefixBroadcast(loc, logits, sum, rewriter);
rewriter.replaceOpWithNewOp<xla_hlo::DivOp>(op, exp, sum_broadcast);
rewriter.replaceOpWithNewOp<mhlo::DivOp>(op, exp, sum_broadcast);
}
return success();
}
@ -2307,16 +2306,16 @@ class ConvertSoftmaxOp : public OpRewritePattern<OpTy> {
//
// will be converted into:
//
// %const = xla_hlo.constant dense<1> : tensor<i32>
// %dim_0 = "xla_hlo.get_dimension_size"(%input) {dimension = 0 : i32} :
// %const = mhlo.constant dense<1> : tensor<i32>
// %dim_0 = "mhlo.get_dimension_size"(%input) {dimension = 0 : i32} :
// (tensor<2x?x8xf32>) -> tensor<i32>
// %prod_0 = xla_hlo.multiply %const, %dim_0 : tensor<i32>
// %dim_1 = "xla_hlo.get_dimension_size"(%input) {dimension = 1 : i32} :
// %prod_0 = mhlo.multiply %const, %dim_0 : tensor<i32>
// %dim_1 = "mhlo.get_dimension_size"(%input) {dimension = 1 : i32} :
// (tensor<2x?x8xf32>) -> tensor<i32>
// %prod_1 = xla_hlo.multiply %prod_0, %dim_1 : tensor<i32>
// %dim_2 = "xla_hlo.get_dimension_size"(%input) {dimension = 2 : i32} :
// %prod_1 = mhlo.multiply %prod_0, %dim_1 : tensor<i32>
// %dim_2 = "mhlo.get_dimension_size"(%input) {dimension = 2 : i32} :
// (tensor<2x?x8xf32>) -> tensor<i32>
// %size = xla_hlo.multiply %prod_1, %dim_2 : tensor<i32>
// %size = mhlo.multiply %prod_1, %dim_2 : tensor<i32>
class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> {
public:
using OpRewritePattern::OpRewritePattern;
@ -2470,17 +2469,17 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
//
// will be converted into:
//
// %0 = "xla_hlo.slice"(%input) {
// %0 = "mhlo.slice"(%input) {
// limit_indices = dense<[4, 2]> : tensor<2xi64>,
// start_indices = dense<0> : tensor<2xi64>,
// strides = dense<1> : tensor<2xi64>} :
// (tensor<4x6xf32>) -> tensor<4x2xf32>
// %1 = "xla_hlo.slice"(%input) {
// %1 = "mhlo.slice"(%input) {
// limit_indices = dense<4> : tensor<2xi64>,
// start_indices = dense<[0, 2]> : tensor<2xi64>,
// strides = dense<1> : tensor<2xi64>} :
// (tensor<4x6xf32>) -> tensor<4x2xf32>
// %2 = "xla_hlo.slice"(%input) {
// %2 = "mhlo.slice"(%input) {
// limit_indices = dense<[4, 6]> : tensor<2xi64>,
// start_indices = dense<[0, 4]> : tensor<2xi64>,
// strides = dense<1> : tensor<2xi64>} :
@ -2563,17 +2562,17 @@ class ConvertSplitOp : public OpRewritePattern<TF::SplitOp> {
// (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>)
//
// We will generate slices following slices:
// %0 = "xla_hlo.slice"(%input) {
// %0 = "mhlo.slice"(%input) {
// limit_indices = dense<[4, 1]> : tensor<2xi64>,
// start_indices = dense<0> : tensor<2xi64>,
// strides = dense<1> : tensor<2xi64>} :
// (tensor<4x6xf32>) -> tensor<4x1xf32>
// %1 = "xla_hlo.slice"(%input) {
// %1 = "mhlo.slice"(%input) {
// limit_indices = dense<[4, 3]> : tensor<2xi64>,
// start_indices = dense<[0, 1]> : tensor<2xi64>,
// strides = dense<1> : tensor<2xi64>} :
// (tensor<4x6xf32>) -> tensor<4x2xf32>
// %2 = "xla_hlo.slice"(%input) {
// %2 = "mhlo.slice"(%input) {
// limit_indices = dense<[4, 6]> : tensor<2xi64>,
// start_indices = dense<[0, 3]> : tensor<2xi64>,
// strides = dense<1> : tensor<2xi64>} :
@ -2645,7 +2644,7 @@ class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
for (int i = 0; i < op.getNumResults(); ++i) {
end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i];
slices.push_back(rewriter.create<xla_hlo::SliceOp>(
slices.push_back(rewriter.create<mhlo::SliceOp>(
op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
GetI64ElementsAttr(end_indices, &rewriter),
GetI64ElementsAttr(strides, &rewriter)));
@ -2663,7 +2662,7 @@ class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
// strides operands are converted to attributes with non-negative indexing.
//
// If the begin input is not a compile time constant, the begin input needs to
// be sliced and the slice needs to be lowered to xla_hlo.DynamicSlice. In this
// be sliced and the slice needs to be lowered to mhlo.DynamicSlice. In this
// case, strides must have a known value of 1 (otherwise we have insufficient
// information to conform to XLA's op semantics).
//
@ -2672,10 +2671,10 @@ class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
// : tensor<AxBxf32> -> tensor<Pxf32>
//
// If the %begin input is constant, output would be:
// %reversed = "xla_hlo.Reverse" (%input) {dimensions = ...}
// %sliced = "xla_hlo.Slice" (%input)
// %reversed = "mhlo.Reverse" (%input) {dimensions = ...}
// %sliced = "mhlo.Slice" (%input)
// {start_indices = ..., limit_indices = ..., strides = ...}
// %output = "xla_hlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor<Pxf32>
// %output = "mhlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor<Pxf32>
//
class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> {
public:
@ -2940,7 +2939,7 @@ class ConvertStridedSliceGradOp
Type element_type = grad.getType().cast<ShapedType>().getElementType();
// Perform reshape to undo any new/shrink axes done by strided slice.
grad = rewriter.create<xla_hlo::ReshapeOp>(
grad = rewriter.create<mhlo::ReshapeOp>(
op.getLoc(), RankedTensorType::get(shape, element_type), grad);
SmallVector<int64_t, 4> padding_low, padding_high, padding_interm;
@ -2976,13 +2975,13 @@ class ConvertStridedSliceGradOp
}
if (!dims_to_reverse.empty()) {
grad = rewriter.create<xla_hlo::ReverseOp>(
grad = rewriter.create<mhlo::ReverseOp>(
op.getLoc(), grad.getType(), grad,
GetI64ElementsAttr(dims_to_reverse, &rewriter));
}
auto zero = GetScalarConstOfType(element_type, op.getLoc(), 0, &rewriter);
rewriter.replaceOpWithNewOp<xla_hlo::PadOp>(
rewriter.replaceOpWithNewOp<mhlo::PadOp>(
op, op.getType(), grad, zero,
GetI64ElementsAttr(padding_low, &rewriter),
GetI64ElementsAttr(padding_high, &rewriter),
@ -2991,7 +2990,7 @@ class ConvertStridedSliceGradOp
}
};
/// Converts the RangeOp tensorflow op to a xla_hlo.iota op with a scaling and
/// Converts the RangeOp tensorflow op to a mhlo.iota op with a scaling and
/// offset applied to generate the range values. The output tensor needs to
/// have a static shape.
///
@ -3000,11 +2999,11 @@ class ConvertStridedSliceGradOp
/// : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32>
///
/// Output would be:
/// %iota = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32>
/// %scaled = "xla_hlo.multiply"(%iota, %delta)
/// %iota = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32>
/// %scaled = "mhlo.multiply"(%iota, %delta)
/// {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
/// (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
/// %result = "xla_hlo.add"(%scaled, %offset)
/// %result = "mhlo.add"(%scaled, %offset)
/// {broadcast_dimensions = dense<[]> : tensor<0xi64>} :
/// (tensor<5xf32>, tensor<f32>) -> tensor<5xf32>
///
@ -3071,23 +3070,23 @@ class ConvertDynamicRangeOp : public OpRewritePattern<TF::RangeOp> {
// some conversion to float for the operations.
//
// %size = ceil(abs((%limit - %start) / %delta))
auto range = rewriter.create<xla_hlo::SubOp>(op.getLoc(), limit, start);
auto abs = rewriter.create<xla_hlo::AbsOp>(op.getLoc(), range);
auto range = rewriter.create<mhlo::SubOp>(op.getLoc(), limit, start);
auto abs = rewriter.create<mhlo::AbsOp>(op.getLoc(), range);
// Delta is not necessarily the same type as start and limit.
auto abs_cast =
rewriter.create<xla_hlo::ConvertOp>(op.getLoc(), compute_type, abs);
rewriter.create<mhlo::ConvertOp>(op.getLoc(), compute_type, abs);
auto delta_cast =
rewriter.create<xla_hlo::ConvertOp>(op.getLoc(), compute_type, delta);
rewriter.create<mhlo::ConvertOp>(op.getLoc(), compute_type, delta);
// Compute the total number of integer steps and convert to the HLO
// dimension tensor.
auto normalized =
rewriter.create<xla_hlo::DivOp>(op.getLoc(), abs_cast, delta_cast);
auto ceil = rewriter.create<xla_hlo::CeilOp>(op.getLoc(), normalized);
auto steps = rewriter.create<xla_hlo::ConvertOp>(
rewriter.create<mhlo::DivOp>(op.getLoc(), abs_cast, delta_cast);
auto ceil = rewriter.create<mhlo::CeilOp>(op.getLoc(), normalized);
auto steps = rewriter.create<mhlo::ConvertOp>(
op.getLoc(), RankedTensorType::get({}, rewriter.getI64Type()), ceil);
auto reshape = rewriter.create<xla_hlo::ReshapeOp>(
auto reshape = rewriter.create<mhlo::ReshapeOp>(
op.getLoc(), RankedTensorType::get({1}, rewriter.getI64Type()), steps);
// Using the resulting length compute the correct range value:
@ -3095,10 +3094,10 @@ class ConvertDynamicRangeOp : public OpRewritePattern<TF::RangeOp> {
// %range = %start + %delta * iota(%size)
auto out_scalar_type =
RankedTensorType::get({}, getElementTypeOrSelf(result_type));
auto start_out_cast = rewriter.create<xla_hlo::ConvertOp>(
op.getLoc(), out_scalar_type, start);
auto delta_out_cast = rewriter.create<xla_hlo::ConvertOp>(
op.getLoc(), out_scalar_type, delta);
auto start_out_cast =
rewriter.create<mhlo::ConvertOp>(op.getLoc(), out_scalar_type, start);
auto delta_out_cast =
rewriter.create<mhlo::ConvertOp>(op.getLoc(), out_scalar_type, delta);
auto iota = rewriter.create<DynamicIotaOp>(
op.getLoc(), result_type, reshape, rewriter.getI64IntegerAttr(0));
@ -3127,7 +3126,7 @@ ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) {
return builder->getI64TensorAttr(axis);
}
/// Converts the LinSpace tensorflow op to a xla_hlo.iota op with a scaling
/// Converts the LinSpace tensorflow op to a mhlo.iota op with a scaling
/// and offset applied to generate the linspace values. The output tensor needs
/// to have a static shape. The implementation is defined in C++ because there
/// is no type inference for the iota op.
@ -3183,7 +3182,7 @@ class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
}
};
/// Converts a generic OpTy tensorflow op to a xla_hlo.reduce op over
/// Converts a generic OpTy tensorflow op to a mhlo.reduce op over
/// ReductionOp.
/// `is_accumulation` controls whether it uses higher precision for the actual
/// reduction. This is set to false for ops like max where there is no precision
@ -3272,10 +3271,10 @@ class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
// Converts Mean op to HLO Reduce op.
//
// %init = constant dense<...> : tensor<T>
// %sum = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.add"]
// %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"]
// {dimensions = ...}
// %divisor = constant dense<...> : tensor<T>
// %mean = "xla_hlo.divide"(%sum, %divisor)
// %mean = "mhlo.divide"(%sum, %divisor)
class ConvertMeanOp
: public GenericConvertReductionOp<ConvertMeanOp, TF::MeanOp, AddOp> {
public:
@ -3289,7 +3288,7 @@ class ConvertMeanOp
// Converts Sum op to HLO Reduce op.
//
// %init = constant dense<...> : tensor<T>
// %sum = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.add"]
// %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"]
// {dimensions = ...}
class ConvertSumOp
: public GenericConvertReductionOp<ConvertSumOp, TF::SumOp, AddOp> {
@ -3305,7 +3304,7 @@ class ConvertSumOp
// Converts Max op to HLO Reduce op.
//
// %init = constant dense<...> : tensor<T>
// %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.maximum"]
// %max = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"]
// {dimensions = ...}
class ConvertMaxOp
: public GenericConvertReductionOp<ConvertMaxOp, TF::MaxOp, MaxOp,
@ -3322,7 +3321,7 @@ class ConvertMaxOp
// Converts Min op to HLO Reduce op.
//
// %init = constant dense<...> : tensor<T>
// %min = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.minimum"]
// %min = "mhlo.reduce"(%inp, %init) ["mhlo.minimum"]
// {dimensions = ...}
class ConvertMinOp
: public GenericConvertReductionOp<ConvertMinOp, TF::MinOp, MinOp,
@ -3339,7 +3338,7 @@ class ConvertMinOp
// Converts Prod op to HLO Reduce op.
//
// %init = constant dense<...> : tensor<T>
// %prod = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.multiply"]
// %prod = "mhlo.reduce"(%inp, %init) ["mhlo.multiply"]
// {dimensions = ...}
class ConvertProdOp
: public GenericConvertReductionOp<ConvertProdOp, TF::ProdOp, MulOp> {
@ -3355,7 +3354,7 @@ class ConvertProdOp
// Converts All op to HLO Reduce op.
//
// %init = constant dense<...> : tensor<T>
// %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.and"]
// %max = "mhlo.reduce"(%inp, %init) ["mhlo.and"]
// {dimensions = ...}
class ConvertAllOp
: public GenericConvertReductionOp<ConvertAllOp, TF::AllOp, AndOp> {
@ -3370,7 +3369,7 @@ class ConvertAllOp
// Converts Any op to HLO Reduce op.
//
// %init = constant dense<...> : tensor<T>
// %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.or"]
// %max = "mhlo.reduce"(%inp, %init) ["mhlo.or"]
// {dimensions = ...}
class ConvertAnyOp
: public GenericConvertReductionOp<ConvertAnyOp, TF::AnyOp, OrOp> {
@ -3382,7 +3381,7 @@ class ConvertAnyOp
}
};
// Converts tensorflow ArgMin or ArgMax op to xla_hlo operations that perform
// Converts tensorflow ArgMin or ArgMax op to mhlo operations that perform
// a reduction on the original input and the corresponding index. The reduction
// sub-computation selects the max (or min) value and the index for the value.
// Derived: is the resulting derived class of this class.
@ -3454,13 +3453,13 @@ class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> {
}
};
// Converts tensorflow ArgMax op to xla_hlo operations. The actual
// Converts tensorflow ArgMax op to mhlo operations. The actual
// implementation is in class ConvertArgMinMaxOp:
//
// %init_index = constant dense<...> : tensor<T>
// %init = constant dense<...> : tensor<T>
// %reduce = "xla_hlo.reduce"(%selected_input, %select_index, %init,
// %init_index) ["xla_hlo.arg_max"]
// %reduce = "mhlo.reduce"(%selected_input, %select_index, %init,
// %init_index) ["mhlo.arg_max"]
class ConvertArgMaxOp
: public ConvertArgMinMaxOp<ConvertArgMaxOp, TF::ArgMaxOp> {
public:
@ -3476,7 +3475,7 @@ class ConvertArgMaxOp
// Converts TF TensorScatterUpdate op into Scatter Op with assignment:
//
// %result = "xla_hlo.scatter"(%tensor, %indices, %updates)
// %result = "mhlo.scatter"(%tensor, %indices, %updates)
// { dimensions = ... }
//
class ConvertTensorScatterUpdateOp
@ -3534,10 +3533,10 @@ class ConvertTensorScatterUpdateOp
// For shape [S1, S2] and multiples [M1, M2],
// MS1 = M1 * S1; MS2 = M2 * S2
//
// %broadcast = xla_hlo.broadcast_in_dim(%input) {
// %broadcast = mhlo.broadcast_in_dim(%input) {
// broadcast_dimensions = [0, 2]
// }
// %result = "xla_hlo.reshape"(%broadcast) : (tensor<S1xM1xS2xM2xf32>)
// %result = "mhlo.reshape"(%broadcast) : (tensor<S1xM1xS2xM2xf32>)
// -> tensor<MS1xMS2xf32>
class ConvertTileOp : public OpRewritePattern<TF::TileOp> {
public:
@ -3657,8 +3656,8 @@ using ConvertMaxPool3DGradOp =
ConvertMaxPoolGradOp<TF::MaxPool3DGradOp, /*num_dims=*/5>;
// Converts tf.Conv?DBackpropInputOp into:
// %rev_filter = "xla_hlo.reverse"(%filter)
// %result = "xla_hlo.convolution"(%out_backprop, %rev_filter)
// %rev_filter = "mhlo.reverse"(%filter)
// %result = "mhlo.convolution"(%out_backprop, %rev_filter)
template <typename OpTy, int num_spatial_dims>
class ConvertConvBackpropInputOp : public OpRewritePattern<OpTy> {
public:
@ -3821,7 +3820,7 @@ using ConvertConv3DBackpropInputOp =
/*num_spatial_dims=*/3>;
// Converts tf.Conv?DBackpropFilterOp into:
// %result = "xla_hlo.convolution"(%input, %out_backprop)
// %result = "mhlo.convolution"(%input, %out_backprop)
template <typename OpTy, int num_spatial_dims>
class ConvertConvBackpropFilterOp : public OpRewritePattern<OpTy> {
public:
@ -4078,7 +4077,7 @@ class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
loc, index_type, op.indices(),
GetI64ElementsAttr(broadcast_dims, &rewriter));
Value compare = rewriter.create<xla_hlo::CompareOp>(
Value compare = rewriter.create<mhlo::CompareOp>(
loc, broadcast_indices, iota,
StringAttr::get("EQ", rewriter.getContext()));
Value on_value = rewriter.create<BroadcastOp>(
@ -4111,13 +4110,13 @@ class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
//
// would be lowered to
//
// %token = "xla_hlo.create_token"() : () -> !xla_hlo.token
// %data_and_token = "xla_hlo.infeed"(%token) {infeed_config = ""} :
// (!xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<4xf32>>,
// !xla_hlo.token>
// %data = "xla_hlo.get_tuple_element"(%data_and_token) {index = 0}
// %0#0 = "xla_hlo.get_tuple_element"(%data) {index = 0}
// %0#1 = "xla_hlo.get_tuple_element"(%data) {index = 1}
// %token = "mhlo.create_token"() : () -> !mhlo.token
// %data_and_token = "mhlo.infeed"(%token) {infeed_config = ""} :
// (!mhlo.token) -> tuple<tuple<tensor<3xi32>, tensor<4xf32>>,
// !mhlo.token>
// %data = "mhlo.get_tuple_element"(%data_and_token) {index = 0}
// %0#0 = "mhlo.get_tuple_element"(%data) {index = 0}
// %0#1 = "mhlo.get_tuple_element"(%data) {index = 1}
//
class ConvertInfeedDequeueTupleOp
: public OpRewritePattern<TF::InfeedDequeueTupleOp> {
@ -4133,7 +4132,7 @@ class ConvertInfeedDequeueTupleOp
// Infeed takes a single token operand. Generate the token using
// create_token op to pass to the infeed op.
auto token = rewriter.create<CreateTokenOp>(
op.getLoc(), xla_hlo::TokenType::get(rewriter.getContext()));
op.getLoc(), mhlo::TokenType::get(rewriter.getContext()));
// Emit infeed op.
// The result type of infeed is a tuple(tuple(result types), token type).
@ -4196,11 +4195,11 @@ class ConvertInfeedDequeueTupleOp
//
// would be lowered to
//
// %tuple = "xla_hlo.tuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
// %tuple = "mhlo.tuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
// tuple<tensor<3xi32>, tensor<4xf32>>
// %token = "xla_hlo.create_token"() : () -> !xla_hlo.token
// %outfeed_token = "xla_hlo.outfeed"(%tuple, %token) {outfeed_config = ""} :
// (tuple<tensor<3xi32>, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token
// %token = "mhlo.create_token"() : () -> !mhlo.token
// %outfeed_token = "mhlo.outfeed"(%tuple, %token) {outfeed_config = ""} :
// (tuple<tensor<3xi32>, tensor<4xf32>>, !mhlo.token) -> !mhlo.token
//
class ConvertOutfeedEnqueueTupleOp
: public OpRewritePattern<TF::OutfeedEnqueueTupleOp> {
@ -4209,7 +4208,7 @@ class ConvertOutfeedEnqueueTupleOp
LogicalResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op,
PatternRewriter &rewriter) const override {
auto token_type = xla_hlo::TokenType::get(rewriter.getContext());
auto token_type = mhlo::TokenType::get(rewriter.getContext());
auto tuple = rewriter.create<TupleOp>(op.getLoc(), op.inputs());
auto token = rewriter.create<CreateTokenOp>(op.getLoc(), token_type);
rewriter.create<OutfeedOp>(op.getLoc(), token_type, tuple, token,
@ -4235,20 +4234,20 @@ class ConvertOutfeedEnqueueTupleOp
//
// We will get:
//
// %1 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32>
// %2 = "xla_hlo.sort"(%input, %1) ( {
// %1 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<16x16xi32>
// %2 = "mhlo.sort"(%input, %1) ( {
// ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>,
// %arg3: tensor<i32>, %arg4: tensor<i32>):
// %7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ...
// "xla_hlo.return"(%7) : (tensor<i1>) -> ()
// %7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"}: ...
// "mhlo.return"(%7) : (tensor<i1>) -> ()
// }) {dimension = 1 : i64, is_stable = true} : ...
// %3 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : ...
// %4 = "xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : ...
// %5 = "xla_hlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>,
// %3 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : ...
// %4 = "mhlo.get_tuple_element"(%2) {index = 1 : i32} : ...
// %5 = "mhlo.slice"(%3) {limit_indices = dense<[16, 8]> : tensor<2xi64>,
// start_indices dense<0> : tensor<2xi64>,
// strides = dense<1> : tensor<2xi64>} :
// (tensor<16x16xf32>) -> tensor<16x8xf32>
// %6 = "xla_hlo.slice"(%4) ...
// %6 = "mhlo.slice"(%4) ...
class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
public:
using OpRewritePattern::OpRewritePattern;
@ -4271,12 +4270,12 @@ class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
// Create an Itoa op for indices.
auto i32_type = rewriter.getIntegerType(32);
Type iota_type = RankedTensorType::get(input_type.getShape(), i32_type);
Value iota_op = rewriter.create<xla_hlo::IotaOp>(
Value iota_op = rewriter.create<mhlo::IotaOp>(
op.getLoc(), iota_type, rewriter.getI64IntegerAttr(last_dim_index));
// Create the sort op. It takes two inputs, one for the original input, the
// other for the indices.
auto sort_op = rewriter.create<xla_hlo::SortOp>(
auto sort_op = rewriter.create<mhlo::SortOp>(
op.getLoc(), llvm::ArrayRef<Value>{op.input(), iota_op}, last_dim_index,
/*is_stable=*/true);
BuildSortComparisonBody({input_type.getElementType(), i32_type},
@ -4285,9 +4284,9 @@ class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
// Get the sorted input and index tuple element.
auto tuple_first_element =
rewriter.create<xla_hlo::GetTupleElementOp>(op.getLoc(), sort_op, 0);
rewriter.create<mhlo::GetTupleElementOp>(op.getLoc(), sort_op, 0);
auto tuple_second_element =
rewriter.create<xla_hlo::GetTupleElementOp>(op.getLoc(), sort_op, 1);
rewriter.create<mhlo::GetTupleElementOp>(op.getLoc(), sort_op, 1);
SmallVector<int64_t, 4> begin_indices(input_rank, 0);
auto end_indices = llvm::to_vector<4>(input_type.getShape());
@ -4297,13 +4296,13 @@ class ConvertTopKV2Op : public OpRewritePattern<TF::TopKV2Op> {
// Get the slice for the top K elements.
Value values = rewriter.create<xla_hlo::SliceOp>(
Value values = rewriter.create<mhlo::SliceOp>(
op.getLoc(), tuple_first_element,
GetI64ElementsAttr(begin_indices, &rewriter),
GetI64ElementsAttr(end_indices, &rewriter),
GetI64ElementsAttr(strides, &rewriter));
Value indices = rewriter.create<xla_hlo::SliceOp>(
Value indices = rewriter.create<mhlo::SliceOp>(
op.getLoc(), tuple_second_element,
GetI64ElementsAttr(begin_indices, &rewriter),
GetI64ElementsAttr(end_indices, &rewriter),
@ -4346,12 +4345,12 @@ class ConvertUnpackOp : public OpRewritePattern<TF::UnpackOp> {
begin_indices[axis] = i;
end_indices[axis] = i + 1;
auto slice_op = rewriter.create<xla_hlo::SliceOp>(
auto slice_op = rewriter.create<mhlo::SliceOp>(
op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter),
GetI64ElementsAttr(end_indices, &rewriter),
GetI64ElementsAttr(strides, &rewriter));
// Reshape to drop the axis dimension.
auto reshape_op = rewriter.create<xla_hlo::ReshapeOp>(
auto reshape_op = rewriter.create<mhlo::ReshapeOp>(
op.getLoc(), op.getType(i), slice_op);
results.push_back(reshape_op);
}
@ -4410,7 +4409,7 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern<OpTy> {
// 'operand' parameter to scatter to for the final scatter op.
Value init = ConcreteClass::GetInitialValue(data_type.getElementType(),
op.getLoc(), &rewriter);
auto broadcasted_init = rewriter.create<xla_hlo::BroadcastOp>(
auto broadcasted_init = rewriter.create<mhlo::BroadcastOp>(
op.getLoc(), output_type, init,
GetI64ElementsAttr(output_shape, &rewriter));
@ -4565,7 +4564,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
auto keys =
CreateRngUniform32(op.getLoc(), num_elements, /*lower_limit=*/0,
/*upper_limit=*/u32_max, &rewriter);
auto sorted = rewriter.create<xla_hlo::SortOp>(
auto sorted = rewriter.create<mhlo::SortOp>(
op.getLoc(), llvm::ArrayRef<Value>{keys, current});
auto i32_type = rewriter.getIntegerType(32);
BuildSortComparisonBody({i32_type, input_type.getElementType()},
@ -4583,7 +4582,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
// Generate range(n) as the initial value for the indices to be swapped.
auto indices_type =
RankedTensorType::get({first_dim_size}, rewriter.getIntegerType(32));
Value indices = rewriter.create<xla_hlo::IotaOp>(
Value indices = rewriter.create<mhlo::IotaOp>(
op.getLoc(), indices_type, rewriter.getI64IntegerAttr(0));
// Generate random numbers to be used as swaps for the indices.
@ -4609,21 +4608,21 @@ class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
// We need to swap the indices[i] with indices[swaps[i]]. First get
// these index values.
Value source_index = builder->create<xla_hlo::DynamicSliceOp>(
Value source_index = builder->create<mhlo::DynamicSliceOp>(
loc, vec1_i32_type, indices, i, scalar_one);
Value swap_index = builder->create<xla_hlo::ReshapeOp>(
Value swap_index = builder->create<mhlo::ReshapeOp>(
loc, scalar_i32_type,
builder->create<xla_hlo::DynamicSliceOp>(loc, vec1_i32_type, swaps, i,
builder->create<mhlo::DynamicSliceOp>(loc, vec1_i32_type, swaps, i,
scalar_one));
Value target_index = builder->create<xla_hlo::DynamicSliceOp>(
Value target_index = builder->create<mhlo::DynamicSliceOp>(
loc, vec1_i32_type, indices, swap_index, scalar_one);
// Then perform the swap.
// indices[i] <- indices[swaps[i]]
indices = builder->create<xla_hlo::DynamicUpdateSliceOp>(
indices = builder->create<mhlo::DynamicUpdateSliceOp>(
loc, indices.getType(), indices, target_index, llvm::makeArrayRef(i));
// indices[swaps[i]] <- indices[i]
indices = builder->create<xla_hlo::DynamicUpdateSliceOp>(
indices = builder->create<mhlo::DynamicUpdateSliceOp>(
loc, indices.getType(), indices, source_index,
llvm::makeArrayRef(swap_index));
@ -4647,7 +4646,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
/*start_index_map=*/GetI64ElementsAttr({0}, &rewriter),
/*index_vector_dim=*/rewriter.getI64IntegerAttr(1),
rewriter.getContext());
rewriter.replaceOpWithNewOp<xla_hlo::GatherOp>(
rewriter.replaceOpWithNewOp<mhlo::GatherOp>(
op, op.getType(), op.value(), swaped_indices, dims_attr,
GetI64ElementsAttr(slice_sizes, &rewriter));
@ -4666,7 +4665,7 @@ class ConvertXlaShardingOp : public OpRewritePattern<TF::XlaShardingOp> {
// using a string.
if (!op._XlaSharding().hasValue()) return failure();
auto custom_call = rewriter.create<xla_hlo::CustomCallOp>(
auto custom_call = rewriter.create<mhlo::CustomCallOp>(
op.getLoc(), op.getType(), op.input(),
/*call_target_name=*/rewriter.getStringAttr("Sharding"),
/*has_side_effect=*/rewriter.getBoolAttr(false),
@ -4716,7 +4715,7 @@ class ConvertInplaceUpdateOp : public OpRewritePattern<TF::InplaceUpdateOp> {
updates_type.getElementType()));
auto cst =
rewriter.create<xla_hlo::ConstOp>(op.getLoc(), zero_attr).getResult();
rewriter.create<mhlo::ConstOp>(op.getLoc(), zero_attr).getResult();
auto split_updates = rewriter.create<TF::SplitOp>(
op.getLoc(), split_updates_type, cst, updates);
@ -4731,7 +4730,7 @@ class ConvertInplaceUpdateOp : public OpRewritePattern<TF::InplaceUpdateOp> {
for (auto pair :
llvm::zip(unpacked_indices.output(), split_updates.output())) {
input_indices.front() = std::get<0>(pair);
input = rewriter.create<xla_hlo::DynamicUpdateSliceOp>(
input = rewriter.create<mhlo::DynamicUpdateSliceOp>(
op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices);
}
@ -4759,7 +4758,7 @@ class ConvertXlaDynamicUpdateSliceOp
auto unpacked_indices = rewriter.create<TF::UnpackOp>(
op.getLoc(), unpacked_indices_type, op.indices(),
IntegerAttr::get(rewriter.getIntegerType(64), 0));
rewriter.replaceOpWithNewOp<xla_hlo::DynamicUpdateSliceOp>(
rewriter.replaceOpWithNewOp<mhlo::DynamicUpdateSliceOp>(
op, op.getType(), op.input(), op.update(), unpacked_indices.output());
return success();
}
@ -5143,7 +5142,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
precision, builder);
vva = BatchDot(loc, v_broadcast, true, vva, false, num_batch_dims,
precision, builder);
auto tau_x_vva = StaticBinaryBroadcast<xla_hlo::MulOp>(
auto tau_x_vva = StaticBinaryBroadcast<mhlo::MulOp>(
loc, tau, vva, GetI64ElementsAttr(batch_dim_indices, builder),
*builder);
a = builder->create<SubOp>(loc, a, tau_x_vva);
@ -5476,7 +5475,7 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
target.addLegalOp<TensorCastOp>();
if (!allow_partial_conversion) {
// Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp.
// Fully qualify ReturnOp here as mhlo dialect also defines a ReturnOp.
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ::mlir::ReturnOp>();
DenseSet<Operation *> nonlegalized_ops;
LogicalResult result =
@ -5498,5 +5497,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
return std::make_unique<LegalizeTF>(allow_partial_conversion, legalize_chlo);
}
} // end namespace xla_hlo
} // end namespace mhlo
} // end namespace mlir

View File

@ -48,7 +48,7 @@ limitations under the License.
using mlir::PassRegistration;
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
class LegalizeTFControlFlow
: public PassWrapper<LegalizeTFControlFlow, OperationPass<ModuleOp>> {
@ -67,7 +67,7 @@ namespace {
void Detuple(Value tuple, Operation::result_range replace, OpBuilder* builder) {
// De-tuple the results of the xla hlo if result.
for (auto result_it : llvm::enumerate(replace)) {
auto get_tuple_value = builder->create<xla_hlo::GetTupleElementOp>(
auto get_tuple_value = builder->create<mhlo::GetTupleElementOp>(
result_it.value().getLoc(), tuple, result_it.index());
result_it.value().replaceAllUsesWith(get_tuple_value);
}
@ -95,10 +95,10 @@ void ImportXlaRegion(mlir::FuncOp func, Region* dest_region, Location loc,
auto result = builder.create<CallOp>(loc, func, detupled_args).getResults();
if (!tuple_return) {
builder.create<xla_hlo::ReturnOp>(loc, result);
builder.create<mhlo::ReturnOp>(loc, result);
} else {
auto tuple_op = builder.create<TupleOp>(loc, result);
builder.create<xla_hlo::ReturnOp>(loc, tuple_op.getResult());
builder.create<mhlo::ReturnOp>(loc, tuple_op.getResult());
}
}
@ -109,11 +109,11 @@ void LowerIf(TF::IfOp op, ModuleOp module) {
// XLA prefers tuple arguments for control flow due to XLA not supporting
// multiple return values.
SmallVector<Value, 3> inputs(op.input());
auto tuple_input = builder.create<xla_hlo::TupleOp>(loc, inputs);
auto tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
// Create the new if op with tuple inputs.
auto result_type = builder.getTupleType(op.getResultTypes());
auto if_op = builder.create<xla_hlo::IfOp>(loc, result_type, op.cond(),
auto if_op = builder.create<mhlo::IfOp>(loc, result_type, op.cond(),
tuple_input, tuple_input);
// Import the regions for both the true and false cases. These regions
@ -136,15 +136,15 @@ void LowerCase(TF::CaseOp op, ModuleOp module) {
// XLA requires one argument per branch so we create a tuple of inputs to pass
// to each branch.
SmallVector<Value, 4> inputs(op.input());
auto tuple_input = builder.create<xla_hlo::TupleOp>(loc, inputs);
auto tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
// Create replica of input tuple for each branch
SmallVector<Value, 4> n_tuple_inputs(op.branches().size(), tuple_input);
// Create the new case op with tuple inputs.
auto case_op = builder.create<xla_hlo::CaseOp>(
loc, op.getResultTypes(), op.branch_index(), n_tuple_inputs,
op.branches().size());
auto case_op =
builder.create<mhlo::CaseOp>(loc, op.getResultTypes(), op.branch_index(),
n_tuple_inputs, op.branches().size());
// Import the regions for all branches.
for (unsigned i = 0; i < op.branches().size(); ++i) {
@ -166,10 +166,10 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) {
// multiple return values.
SmallVector<Value, 3> inputs(op.input());
builder.setInsertionPoint(op);
Value tuple_input = builder.create<xla_hlo::TupleOp>(loc, inputs);
Value tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
// Create the new while op with tuple inputs.
auto while_op = builder.create<xla_hlo::WhileOp>(
auto while_op = builder.create<mhlo::WhileOp>(
loc, builder.getTupleType(op.getResultTypes()), tuple_input);
// Import the regions for both the cond and body. These regions must be
@ -204,9 +204,9 @@ void LegalizeTFControlFlow::runOnOperation() {
}
});
}
} // namespace xla_hlo
} // namespace mhlo
} // namespace mlir
static PassRegistration<mlir::xla_hlo::LegalizeTFControlFlow> cfpass(
static PassRegistration<mlir::mhlo::LegalizeTFControlFlow> cfpass(
"xla-legalize-tf-control-flow",
"Legalize TensorFlow control flow to the XLA dialect");

View File

@ -366,7 +366,7 @@ class GetDimensionSizeFromEnd<string dimFromEnd>: NativeCodeCall<
// For now, this op needs to be created in C++ because the expected output type
// cannot be inferred.
class createIotaOp<string dim>: NativeCodeCall<
"$_builder.create<xla_hlo::IotaOp>($0.getOwner()->getLoc(), "
"$_builder.create<mhlo::IotaOp>($0.getOwner()->getLoc(), "
"Get2DTensorType($1), $_builder.getI64IntegerAttr(" # dim # "))">;
// This op needs to be created in C++ because the generated Convert Op has no

View File

@ -69,7 +69,7 @@ limitations under the License.
#include "tensorflow/stream_executor/stream_executor.h"
namespace mlir {
namespace xla_hlo {
namespace mhlo {
namespace {
template <typename T, size_t N>
@ -544,5 +544,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
return std::make_unique<LegalizeTF>(device_type);
}
} // end namespace xla_hlo
} // end namespace mhlo
} // end namespace mlir

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.h"
#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
#include <memory>
#include <tuple>

Some files were not shown because too many files have changed in this diff Show More