Rename xla_hlo
dialect to mhlo
This is part of the current refactoring of the HLO related dialect. `xla_hlo` will be reintroduced in a new form later. PiperOrigin-RevId: 319916753 Change-Id: I2c1b426b8a293927af5569bd35990a54b6b0743e
This commit is contained in:
parent
f6ab4daebc
commit
bafd347479
@ -482,8 +482,8 @@ tf_cc_test(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_hlo_fusion",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/xla_hlo_fusion.cc"],
|
||||
name = "mhlo_fusion",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/mhlo_fusion.cc"],
|
||||
deps = [
|
||||
":cycle_detector",
|
||||
":hlo",
|
||||
@ -696,7 +696,7 @@ cc_library(
|
||||
":lhlo_legalize_to_affine",
|
||||
":lhlo_legalize_to_gpu",
|
||||
":lhlo_legalize_to_parallel_loops",
|
||||
":xla_hlo_fusion",
|
||||
":mhlo_fusion",
|
||||
":xla_legalize_control_flow",
|
||||
":xla_legalize_tanh_to_approximation",
|
||||
":xla_legalize_to_linalg",
|
||||
|
@ -17,12 +17,12 @@ limitations under the License.
|
||||
// These ops are not necessarily orthogonal or optimized for transformation but
|
||||
// for ease of expression in certain cases deemed important for client
|
||||
// libraries (i.e. implicit broadcasting, helper ops, etc).
|
||||
// This dialect is considered to exist in addition to augment the xla_hlo
|
||||
// This dialect is considered to exist in addition to augment the mhlo
|
||||
// dialect for ergonomic needs, not duplicate/replace it.
|
||||
//
|
||||
// The typical use of this dialect is for client libraries to be able to emit
|
||||
// less constrained ops and rely on the conversion framework to lower any
|
||||
// xla_chlo ops to canonical xla_hlo ops.
|
||||
// xla_chlo ops to canonical mhlo ops.
|
||||
//
|
||||
// See: https://www.tensorflow.org/xla/operation_semantics
|
||||
|
||||
@ -44,7 +44,7 @@ def HLOClient_Dialect : Dialect {
|
||||
let description = [{
|
||||
This dialect contains ops that align closely with the API surface area
|
||||
of the XlaBuilder C++ API, where such ops have semantics that go beyond
|
||||
what exists in the lower level dialects (such as `xla_hlo`). Essentially,
|
||||
what exists in the lower level dialects (such as `mhlo`). Essentially,
|
||||
whenever the client library uses syntactic sugar or composition
|
||||
of multiple ops for an API call, this dialect tries to model the API call
|
||||
and provide conversion patterns to fully materialize into lower level
|
||||
@ -65,7 +65,7 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
|
||||
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
|
||||
// shape broadcasting.
|
||||
//
|
||||
// These correspond to operations in the xla_hlo dialect without the
|
||||
// These correspond to operations in the mhlo dialect without the
|
||||
// "broadcast_" prefix, except that those ops require same-shaped operands and
|
||||
// results.
|
||||
//
|
||||
|
@ -37,12 +37,12 @@ class OpBuilder;
|
||||
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc"
|
||||
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
class XlaHloDialect : public Dialect {
|
||||
public:
|
||||
explicit XlaHloDialect(MLIRContext *context);
|
||||
static StringRef getDialectNamespace() { return "xla_hlo"; }
|
||||
static StringRef getDialectNamespace() { return "mhlo"; }
|
||||
|
||||
// Registered hook to materialize a constant operation from a given attribute
|
||||
// value with the desired resultant type.
|
||||
@ -82,7 +82,7 @@ class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
|
||||
// %1 = index_cast %0 : index to i64
|
||||
// %2 = dim %arg0, 1 : memref<?x?xf32>
|
||||
// %3 = index_cast %2 : index to i64
|
||||
// %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3)
|
||||
// %4 = "mhlo.scalars_to_dimension_tensor"(%1, %3)
|
||||
// : (i64, i64) -> tensor<2xi64>
|
||||
//
|
||||
// and returns %4 as the shape value.
|
||||
@ -93,7 +93,7 @@ LogicalResult deriveShapeFromFirstOperand(
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
|
||||
|
||||
} // end namespace xla_hlo
|
||||
} // end namespace mhlo
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
|
||||
|
@ -29,8 +29,8 @@ include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
|
||||
|
||||
def HLO_Dialect : Dialect {
|
||||
let name = "xla_hlo";
|
||||
let cppNamespace = "xla_hlo";
|
||||
let name = "mhlo";
|
||||
let cppNamespace = "mhlo";
|
||||
}
|
||||
|
||||
class HLO_Op<string mnemonic, list<OpTrait> traits> :
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
template <typename HloOpTy>
|
||||
struct HloToLhloOpImpl {
|
||||
@ -31,10 +31,10 @@ struct HloToLhloOpImpl {
|
||||
template <typename HloOpTy>
|
||||
using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
|
||||
|
||||
#define MAP_HLO_TO_LHLO(OpName) \
|
||||
template <> \
|
||||
struct HloToLhloOpImpl<xla_hlo::OpName> { \
|
||||
using Type = xla_lhlo::OpName; \
|
||||
#define MAP_HLO_TO_LHLO(OpName) \
|
||||
template <> \
|
||||
struct HloToLhloOpImpl<mhlo::OpName> { \
|
||||
using Type = xla_lhlo::OpName; \
|
||||
}
|
||||
|
||||
MAP_HLO_TO_LHLO(AbsOp);
|
||||
@ -74,7 +74,7 @@ MAP_HLO_TO_LHLO(TanhOp);
|
||||
|
||||
#undef MAP_HLO_TO_LHLO
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
|
||||
|
@ -464,7 +464,7 @@ struct XlaOpToStdScalarOp {
|
||||
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
|
||||
typename = std::enable_if_t<
|
||||
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
|
||||
std::is_same<typename xla_hlo::HloToLhloOp<LhloOpTy>,
|
||||
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
|
||||
std::false_type>::value>>
|
||||
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
|
||||
@ -472,8 +472,8 @@ struct XlaOpToStdScalarOp {
|
||||
args, b);
|
||||
}
|
||||
|
||||
// Implementation for HLO ops except xla_hlo::CompareOp.
|
||||
template <typename XlaOpTy, typename LhloOpTy = xla_hlo::HloToLhloOp<XlaOpTy>,
|
||||
// Implementation for HLO ops except mhlo::CompareOp.
|
||||
template <typename XlaOpTy, typename LhloOpTy = mhlo::HloToLhloOp<XlaOpTy>,
|
||||
typename = std::enable_if_t<
|
||||
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
|
||||
!std::is_same<LhloOpTy, std::false_type>::value>>
|
||||
@ -493,10 +493,11 @@ struct XlaOpToStdScalarOp {
|
||||
op.getLoc(), comparison_direction, result_types, args, b);
|
||||
}
|
||||
|
||||
// Implementation for xla_hlo::CompareOp.
|
||||
template <typename HloOpTy, typename = std::enable_if_t<std::is_same<
|
||||
HloOpTy, xla_hlo::CompareOp>::value>>
|
||||
static Value map(xla_hlo::CompareOp op, ArrayRef<Type> result_types,
|
||||
// Implementation for mhlo::CompareOp.
|
||||
template <typename HloOpTy,
|
||||
typename =
|
||||
std::enable_if_t<std::is_same<HloOpTy, mhlo::CompareOp>::value>>
|
||||
static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
auto comparison_direction = op.comparison_direction();
|
||||
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||
|
@ -29,7 +29,7 @@ template <typename T>
|
||||
class OperationPass;
|
||||
class Pass;
|
||||
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
/// Lowers HLO control flow ops to the Standard dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
|
||||
@ -55,10 +55,10 @@ std::unique_ptr<OperationPass<FuncOp>> createTransformUnrankedHloPass();
|
||||
// necessary to export to XLA.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
|
||||
|
||||
// fuse xla_hlo ops to kLoop/kInput fusion patterns
|
||||
// fuse mhlo ops to kLoop/kInput fusion patterns
|
||||
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
|
||||
namespace xla_lhlo {
|
||||
|
||||
|
@ -27,7 +27,7 @@ class LLVMTypeConverter;
|
||||
class LowerToLLVMOptions;
|
||||
class OwningRewritePatternList;
|
||||
class BufferAssignmentPlacer;
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
// Collection of rewrite patterns for lowering a general dot product.
|
||||
void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
|
||||
@ -73,7 +73,7 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||
void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
|
||||
namespace xla_lhlo {
|
||||
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
|
||||
// Static initialization for XLA dialect registration.
|
||||
static mlir::DialectRegistration<mlir::xla_hlo::XlaHloDialect> xla_hlo_ops;
|
||||
static mlir::DialectRegistration<mlir::mhlo::XlaHloDialect> mhlo_ops;
|
||||
static mlir::DialectRegistration<mlir::xla_chlo::XlaHloClientDialect>
|
||||
xla_chlo_ops;
|
||||
static mlir::DialectRegistration<mlir::xla_lhlo::XlaLhloDialect> xla_lhlo_ops;
|
||||
|
@ -60,7 +60,7 @@ limitations under the License.
|
||||
|
||||
namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc"
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
|
||||
Attribute value, Type type,
|
||||
@ -68,8 +68,7 @@ Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
|
||||
// HLO dialect constants only support ElementsAttr unlike standard dialect
|
||||
// constant which supports all attributes.
|
||||
if (value.isa<ElementsAttr>())
|
||||
return builder.create<xla_hlo::ConstOp>(loc, type,
|
||||
value.cast<ElementsAttr>());
|
||||
return builder.create<mhlo::ConstOp>(loc, type, value.cast<ElementsAttr>());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -167,7 +166,7 @@ void ConstOp::build(OpBuilder& builder, OperationState& result,
|
||||
}
|
||||
|
||||
// TODO: support other XLA specific types.
|
||||
assert(type && "unsupported attribute type for building xla_hlo.constant");
|
||||
assert(type && "unsupported attribute type for building mhlo.constant");
|
||||
result.types.push_back(type);
|
||||
result.addAttribute("value", value);
|
||||
}
|
||||
@ -387,7 +386,7 @@ static LogicalResult Verify(GetTupleElementOp op) {
|
||||
|
||||
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto tupleOp =
|
||||
dyn_cast_or_null<xla_hlo::TupleOp>(getOperand().getDefiningOp())) {
|
||||
dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) {
|
||||
return tupleOp.getOperand(index().getLimitedValue());
|
||||
}
|
||||
|
||||
@ -693,10 +692,8 @@ void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs,
|
||||
}
|
||||
|
||||
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto real_op =
|
||||
dyn_cast_or_null<xla_hlo::RealOp>(getOperand(0).getDefiningOp());
|
||||
auto imag_op =
|
||||
dyn_cast_or_null<xla_hlo::ImagOp>(getOperand(1).getDefiningOp());
|
||||
auto real_op = dyn_cast_or_null<mhlo::RealOp>(getOperand(0).getDefiningOp());
|
||||
auto imag_op = dyn_cast_or_null<mhlo::ImagOp>(getOperand(1).getDefiningOp());
|
||||
if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
|
||||
return real_op.getOperand();
|
||||
}
|
||||
@ -727,7 +724,7 @@ void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) {
|
||||
|
||||
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto complex_op =
|
||||
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand().getDefiningOp())) {
|
||||
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
|
||||
return complex_op.getOperand(1);
|
||||
}
|
||||
|
||||
@ -740,7 +737,7 @@ void RealOp::build(OpBuilder& builder, OperationState& state, Value val) {
|
||||
|
||||
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto complex_op =
|
||||
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand().getDefiningOp())) {
|
||||
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
|
||||
return complex_op.getOperand(0);
|
||||
}
|
||||
|
||||
@ -1148,7 +1145,7 @@ static LogicalResult Verify(MapOp op) {
|
||||
// RecvOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Checks that the result type is of the form `tuple<any_type, xla_hlo::token>`
|
||||
// Checks that the result type is of the form `tuple<any_type, mhlo::token>`
|
||||
static LogicalResult Verify(RecvOp op) {
|
||||
auto result_ty = op.getResult().getType().cast<TupleType>();
|
||||
auto subtypes = result_ty.getTypes();
|
||||
@ -2020,7 +2017,7 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// xla_hlo Dialect Interfaces
|
||||
// mhlo Dialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
@ -2032,7 +2029,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface {
|
||||
BlockAndValueMapping& valueMapping) const final {
|
||||
return true;
|
||||
}
|
||||
// Operations in xla_hlo dialect are always legal to inline since they are
|
||||
// Operations in mhlo dialect are always legal to inline since they are
|
||||
// pure.
|
||||
bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final {
|
||||
return true;
|
||||
@ -2041,7 +2038,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface {
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// xla_hlo Dialect Constructor
|
||||
// mhlo Dialect Constructor
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
XlaHloDialect::XlaHloDialect(MLIRContext* context)
|
||||
@ -2061,8 +2058,7 @@ Type XlaHloDialect::parseType(DialectAsmParser& parser) const {
|
||||
if (parser.parseKeyword(&data_type)) return Type();
|
||||
|
||||
if (data_type == "token") return TokenType::get(getContext());
|
||||
parser.emitError(parser.getNameLoc())
|
||||
<< "unknown xla_hlo type: " << data_type;
|
||||
parser.emitError(parser.getNameLoc()) << "unknown mhlo type: " << data_type;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -2071,7 +2067,7 @@ void XlaHloDialect::printType(Type type, DialectAsmPrinter& os) const {
|
||||
os << "token";
|
||||
return;
|
||||
}
|
||||
os << "<unknown xla_hlo type>";
|
||||
os << "<unknown mhlo type>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2106,5 +2102,5 @@ LogicalResult deriveShapeFromFirstOperand(
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
@ -30,7 +30,7 @@ namespace xla_chlo {
|
||||
namespace {
|
||||
|
||||
// Converts binary ops that statically are determined to not broadcast directly
|
||||
// to the corresponding xla_hlo non-broadcasting op.
|
||||
// to the corresponding mhlo non-broadcasting op.
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
||||
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
||||
@ -63,7 +63,7 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
||||
};
|
||||
|
||||
// Converts a binary op with ranked broadcasting operands to explicitly
|
||||
// broadcast and invoke the corresponding xla_hlo non-broadcasting op.
|
||||
// broadcast and invoke the corresponding mhlo non-broadcasting op.
|
||||
// Note that dynamic broadcasting supported by this pattern is only valid for
|
||||
// "numpy" broadcasting semantics as defined here:
|
||||
// https://docs.scipy.org/doc/numpy/reference/ufuncs.html
|
||||
@ -136,7 +136,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
||||
// properly.
|
||||
auto lhs_broadcast_dimensions = llvm::to_vector<4>(
|
||||
llvm::seq<int64_t>(result_rank - lhs_type.getRank(), result_rank));
|
||||
Value broadcasted_lhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>(
|
||||
Value broadcasted_lhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||
loc,
|
||||
RankedTensorType::get(result_type.getShape(),
|
||||
lhs_type.getElementType()),
|
||||
@ -144,7 +144,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
||||
rewriter.getI64TensorAttr(lhs_broadcast_dimensions));
|
||||
auto rhs_broadcast_dimensions = llvm::to_vector<4>(
|
||||
llvm::seq<int64_t>(result_rank - rhs_type.getRank(), result_rank));
|
||||
Value broadcasted_rhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>(
|
||||
Value broadcasted_rhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||
loc,
|
||||
RankedTensorType::get(result_type.getShape(),
|
||||
rhs_type.getElementType()),
|
||||
@ -182,23 +182,21 @@ struct HloBinaryElementwiseAdaptor {
|
||||
};
|
||||
|
||||
struct HloComplexAdaptor {
|
||||
static xla_hlo::ComplexOp CreateOp(BroadcastComplexOp from_op,
|
||||
Type result_type, Value broadcasted_lhs,
|
||||
Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<xla_hlo::ComplexOp>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs);
|
||||
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs);
|
||||
}
|
||||
};
|
||||
|
||||
struct HloCompareAdaptor {
|
||||
static xla_hlo::CompareOp CreateOp(BroadcastCompareOp from_op,
|
||||
Type result_type, Value broadcasted_lhs,
|
||||
Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<xla_hlo::CompareOp>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs,
|
||||
from_op.comparison_direction());
|
||||
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<mhlo::CompareOp>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs,
|
||||
from_op.comparison_direction());
|
||||
}
|
||||
};
|
||||
|
||||
@ -214,28 +212,27 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||
HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \
|
||||
patterns);
|
||||
|
||||
POPULATE_BCAST(BroadcastAddOp, xla_hlo::AddOp);
|
||||
POPULATE_BCAST(BroadcastAndOp, xla_hlo::AndOp);
|
||||
POPULATE_BCAST(BroadcastAtan2Op, xla_hlo::Atan2Op);
|
||||
POPULATE_BCAST(BroadcastDivOp, xla_hlo::DivOp);
|
||||
POPULATE_BCAST(BroadcastMaxOp, xla_hlo::MaxOp);
|
||||
POPULATE_BCAST(BroadcastMinOp, xla_hlo::MinOp);
|
||||
POPULATE_BCAST(BroadcastMulOp, xla_hlo::MulOp);
|
||||
POPULATE_BCAST(BroadcastOrOp, xla_hlo::OrOp);
|
||||
POPULATE_BCAST(BroadcastPowOp, xla_hlo::PowOp);
|
||||
POPULATE_BCAST(BroadcastRemOp, xla_hlo::RemOp);
|
||||
POPULATE_BCAST(BroadcastShiftLeftOp, xla_hlo::ShiftLeftOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightArithmeticOp,
|
||||
xla_hlo::ShiftRightArithmeticOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightLogicalOp, xla_hlo::ShiftRightLogicalOp);
|
||||
POPULATE_BCAST(BroadcastSubOp, xla_hlo::SubOp);
|
||||
POPULATE_BCAST(BroadcastXorOp, xla_hlo::XorOp);
|
||||
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
|
||||
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
|
||||
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
|
||||
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
|
||||
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
|
||||
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
|
||||
POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
|
||||
POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
|
||||
POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
|
||||
POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
|
||||
POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
|
||||
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
|
||||
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
|
||||
|
||||
// Broadcasting ops requiring special construction.
|
||||
PopulateForBinaryOp<BroadcastComplexOp, xla_hlo::ComplexOp,
|
||||
HloComplexAdaptor>(context, patterns);
|
||||
PopulateForBinaryOp<BroadcastCompareOp, xla_hlo::CompareOp,
|
||||
HloCompareAdaptor>(context, patterns);
|
||||
PopulateForBinaryOp<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>(
|
||||
context, patterns);
|
||||
PopulateForBinaryOp<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>(
|
||||
context, patterns);
|
||||
}
|
||||
|
||||
} // namespace xla_chlo
|
||||
|
@ -32,8 +32,8 @@ struct TestChloLegalizeToHloPass
|
||||
OwningRewritePatternList conversionPatterns;
|
||||
|
||||
conversionTarget.addIllegalDialect<XlaHloClientDialect>();
|
||||
// Consider the xla_hlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<xla_hlo::XlaHloDialect>();
|
||||
// Consider the mhlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<mhlo::XlaHloDialect>();
|
||||
// The conversion uses helpers from the Standard dialect.
|
||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
|
||||
|
@ -37,7 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
@ -128,20 +128,20 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
||||
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
||||
}
|
||||
}
|
||||
rewriter.create<xla_hlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
||||
buffer_args, op->getAttrs());
|
||||
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
||||
buffer_args, op->getAttrs());
|
||||
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
: public BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp> {
|
||||
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
|
||||
public:
|
||||
using BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp>::BaseOpConversion;
|
||||
using BaseOpConversion<mhlo::DynamicBroadcastInDimOp>::BaseOpConversion;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_hlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
|
||||
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = op.getLoc();
|
||||
Value resultBuffer = InsertDynamicAllocAndDealloc(
|
||||
@ -162,7 +162,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
// and size of the target dimension if size-1 dimension expansion is
|
||||
// necessary.
|
||||
xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
|
||||
xla_hlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
|
||||
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
|
||||
auto loc = op.getLoc();
|
||||
auto operand_type = operand.getType().cast<MemRefType>();
|
||||
auto operand_shape = operand_type.getShape();
|
||||
@ -220,12 +220,12 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {
|
||||
struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
||||
public:
|
||||
using BaseOpConversion<xla_hlo::ReduceOp>::BaseOpConversion;
|
||||
using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_hlo::ReduceOp op, ArrayRef<Value> operands,
|
||||
mhlo::ReduceOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = op.getLoc();
|
||||
// TODO(b/137624192) Implement variadic reduce.
|
||||
@ -314,10 +314,10 @@ class HloToLhloTensorStoreOpConverter
|
||||
// "xla_lhlo.fusion"() ({
|
||||
// %0 = tensor_load %arg1 : memref<2x2xf32>
|
||||
// %1 = tensor_load %arg2 : memref<2x2xf32>
|
||||
// %2 = "xla_hlo.add"(%0, %1) :
|
||||
// %2 = "mhlo.add"(%0, %1) :
|
||||
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// %3 = tensor_load %arg0 : memref<2x2xf32>
|
||||
// %4 = "xla_hlo.multiply"(%2, %3) :
|
||||
// %4 = "mhlo.multiply"(%2, %3) :
|
||||
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// tensor_store %4, %arg3 : memref<2x2xf32>
|
||||
// "xla_lhlo.terminator"() : () -> ()
|
||||
@ -344,8 +344,8 @@ class HloToLhloTensorStoreOpConverter
|
||||
// FuncOp signature conversion example:
|
||||
//
|
||||
// func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// %0 = "xla_hlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
|
||||
// tensor<4xf32> %1 = "xla_hlo.add"(%arg0, %0) : (tensor<4xf32>,
|
||||
// %0 = "mhlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
|
||||
// tensor<4xf32> %1 = "mhlo.add"(%arg0, %0) : (tensor<4xf32>,
|
||||
// tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32>
|
||||
// }
|
||||
//
|
||||
@ -388,7 +388,7 @@ struct HloLegalizeToLhlo
|
||||
target.addIllegalOp<mlir::TensorStoreOp>();
|
||||
target.addLegalOp<ModuleTerminatorOp>();
|
||||
target.addLegalOp<TensorFromElementsOp>();
|
||||
target.addIllegalDialect<xla_hlo::XlaHloDialect>();
|
||||
target.addIllegalDialect<mhlo::XlaHloDialect>();
|
||||
|
||||
BufferAssignmentTypeConverter converter;
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||
@ -442,38 +442,38 @@ void populateHLOToLHLOConversionPattern(
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
HloToLhloDynamicBroadcastInDimOpConverter,
|
||||
HloToLhloOpConverter<xla_hlo::AbsOp>,
|
||||
HloToLhloOpConverter<xla_hlo::AddOp>,
|
||||
HloToLhloOpConverter<xla_hlo::AndOp>,
|
||||
HloToLhloOpConverter<xla_hlo::BroadcastInDimOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CeilOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CompareOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ComplexOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConstOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConvOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConvertOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CopyOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CosOp>,
|
||||
HloToLhloOpConverter<xla_hlo::DivOp>,
|
||||
HloToLhloOpConverter<xla_hlo::DotOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ExpOp>,
|
||||
HloToLhloOpConverter<xla_hlo::GatherOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ImagOp>,
|
||||
HloToLhloOpConverter<xla_hlo::IotaOp>,
|
||||
HloToLhloOpConverter<xla_hlo::LogOp>,
|
||||
HloToLhloOpConverter<xla_hlo::MaxOp>,
|
||||
HloToLhloOpConverter<xla_hlo::MinOp>,
|
||||
HloToLhloOpConverter<xla_hlo::MulOp>,
|
||||
HloToLhloOpConverter<xla_hlo::NegOp>,
|
||||
HloToLhloOpConverter<xla_hlo::RealOp>,
|
||||
HloToLhloOpConverter<xla_hlo::RemOp>,
|
||||
HloToLhloOpConverter<xla_hlo::RsqrtOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ReshapeOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SelectOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SignOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SqrtOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SubOp>,
|
||||
HloToLhloOpConverter<xla_hlo::TanhOp>,
|
||||
HloToLhloOpConverter<mhlo::AbsOp>,
|
||||
HloToLhloOpConverter<mhlo::AddOp>,
|
||||
HloToLhloOpConverter<mhlo::AndOp>,
|
||||
HloToLhloOpConverter<mhlo::BroadcastInDimOp>,
|
||||
HloToLhloOpConverter<mhlo::CeilOp>,
|
||||
HloToLhloOpConverter<mhlo::CompareOp>,
|
||||
HloToLhloOpConverter<mhlo::ComplexOp>,
|
||||
HloToLhloOpConverter<mhlo::ConstOp>,
|
||||
HloToLhloOpConverter<mhlo::ConvOp>,
|
||||
HloToLhloOpConverter<mhlo::ConvertOp>,
|
||||
HloToLhloOpConverter<mhlo::CopyOp>,
|
||||
HloToLhloOpConverter<mhlo::CosOp>,
|
||||
HloToLhloOpConverter<mhlo::DivOp>,
|
||||
HloToLhloOpConverter<mhlo::DotOp>,
|
||||
HloToLhloOpConverter<mhlo::ExpOp>,
|
||||
HloToLhloOpConverter<mhlo::GatherOp>,
|
||||
HloToLhloOpConverter<mhlo::ImagOp>,
|
||||
HloToLhloOpConverter<mhlo::IotaOp>,
|
||||
HloToLhloOpConverter<mhlo::LogOp>,
|
||||
HloToLhloOpConverter<mhlo::MaxOp>,
|
||||
HloToLhloOpConverter<mhlo::MinOp>,
|
||||
HloToLhloOpConverter<mhlo::MulOp>,
|
||||
HloToLhloOpConverter<mhlo::NegOp>,
|
||||
HloToLhloOpConverter<mhlo::RealOp>,
|
||||
HloToLhloOpConverter<mhlo::RemOp>,
|
||||
HloToLhloOpConverter<mhlo::RsqrtOp>,
|
||||
HloToLhloOpConverter<mhlo::ReshapeOp>,
|
||||
HloToLhloOpConverter<mhlo::SelectOp>,
|
||||
HloToLhloOpConverter<mhlo::SignOp>,
|
||||
HloToLhloOpConverter<mhlo::SqrtOp>,
|
||||
HloToLhloOpConverter<mhlo::SubOp>,
|
||||
HloToLhloOpConverter<mhlo::TanhOp>,
|
||||
HloToLhloReduceOpConverter,
|
||||
HloToLhloTensorLoadOpConverter,
|
||||
HloToLhloTensorStoreOpConverter
|
||||
@ -489,5 +489,5 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
|
||||
static PassRegistration<HloLegalizeToLhlo> legalize_pass(
|
||||
"hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
@ -35,7 +35,7 @@ limitations under the License.
|
||||
using mlir::PassRegistration;
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
struct LegalizeControlFlow
|
||||
: public mlir::PassWrapper<LegalizeControlFlow, FunctionPass> {
|
||||
@ -51,7 +51,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block,
|
||||
OpBuilder* builder) {
|
||||
for (auto& old_block : region->getBlocks()) {
|
||||
Block* block = mapper.lookup(&old_block);
|
||||
auto return_op = dyn_cast<xla_hlo::ReturnOp>(block->getTerminator());
|
||||
auto return_op = dyn_cast<mhlo::ReturnOp>(block->getTerminator());
|
||||
if (!return_op) continue;
|
||||
builder->setInsertionPointToEnd(block);
|
||||
builder->create<mlir::BranchOp>(loc, target_block, return_op.getOperands());
|
||||
@ -61,7 +61,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) {
|
||||
LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) {
|
||||
Operation* op_inst = if_op.getOperation();
|
||||
mlir::OpBuilder builder(if_op);
|
||||
auto orig_block = op_inst->getBlock();
|
||||
@ -106,13 +106,13 @@ LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
||||
LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
|
||||
// Converts an XLA while loop into control flow. This generates a set of MLIR
|
||||
// blocks and branches, along with inlining the regions provided by the XLA
|
||||
// while loop. The structure should be similar to below:
|
||||
//
|
||||
// <prior operations>
|
||||
// %0 = "xla_hlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
|
||||
// %0 = "mhlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
|
||||
// <post operations>
|
||||
auto* op_inst = while_op.getOperation();
|
||||
mlir::OpBuilder builder(while_op);
|
||||
@ -147,7 +147,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
||||
// extract_element and conditional branch. This changes the block below:
|
||||
// ^cond(%0):
|
||||
// <inlined conditional region>
|
||||
// "xla_hlo".return(%1)
|
||||
// "mhlo".return(%1)
|
||||
//
|
||||
// Into:
|
||||
// ^cond(%0):
|
||||
@ -156,14 +156,14 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
||||
// cond_br %2, ^body(%0), ^tail(%0) // Branch.
|
||||
builder.setInsertionPointToStart(cond_block);
|
||||
|
||||
// Replace the xla_hlo::ReturnOp with a branch back to the condition block.
|
||||
// This is required as the xla_hlo::ReturnOp is used to mark the end of a
|
||||
// Replace the mhlo::ReturnOp with a branch back to the condition block.
|
||||
// This is required as the mhlo::ReturnOp is used to mark the end of a
|
||||
// block for regions nested inside of a operations (MLIR ReturnOp cannot be
|
||||
// nested within an non-function region).
|
||||
for (auto& block : while_op.cond()) {
|
||||
auto new_block = mapper.lookup(&block);
|
||||
|
||||
auto return_op = dyn_cast<xla_hlo::ReturnOp>(new_block->getTerminator());
|
||||
auto return_op = dyn_cast<mhlo::ReturnOp>(new_block->getTerminator());
|
||||
if (!return_op) continue;
|
||||
builder.setInsertionPointToEnd(new_block);
|
||||
|
||||
@ -183,7 +183,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
||||
// conditional block. This changes the block below:
|
||||
// ^body(%0):
|
||||
// <inlined body block>
|
||||
// "xla_hlo".return(%1)
|
||||
// "mhlo".return(%1)
|
||||
//
|
||||
// Into:
|
||||
// ^body(%0):
|
||||
@ -191,8 +191,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
||||
// br ^cond(%0) // Branch.
|
||||
for (auto& block : while_op.body()) {
|
||||
auto new_block = mapper.lookup(&block);
|
||||
auto return_op =
|
||||
dyn_cast<mlir::xla_hlo::ReturnOp>(new_block->getTerminator());
|
||||
auto return_op = dyn_cast<mlir::mhlo::ReturnOp>(new_block->getTerminator());
|
||||
if (!return_op) continue;
|
||||
builder.setInsertionPointToEnd(new_block);
|
||||
builder.create<mlir::BranchOp>(loc, cond_block, return_op.getOperands());
|
||||
@ -224,14 +223,14 @@ void LegalizeControlFlow::runOnFunction() {
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
|
||||
mlir::xla_hlo::createLegalizeControlFlowPass() {
|
||||
mlir::mhlo::createLegalizeControlFlowPass() {
|
||||
return std::make_unique<LegalizeControlFlow>();
|
||||
}
|
||||
|
||||
static PassRegistration<mlir::xla_hlo::LegalizeControlFlow> legalize_cf_pass(
|
||||
static PassRegistration<mlir::mhlo::LegalizeControlFlow> legalize_cf_pass(
|
||||
"xla-legalize-control-flow",
|
||||
"Legalize from XLA control flow to MLIR control flow");
|
||||
|
@ -28,14 +28,14 @@ namespace mlir {
|
||||
namespace {
|
||||
#include "tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc"
|
||||
} // end anonymous namespace
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> {
|
||||
class CompareIConvert : public OpRewritePattern<mhlo::CompareOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(xla_hlo::CompareOp op,
|
||||
LogicalResult matchAndRewrite(mhlo::CompareOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto lhs = op.lhs();
|
||||
auto rhs = op.rhs();
|
||||
@ -68,11 +68,11 @@ class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> {
|
||||
}
|
||||
};
|
||||
|
||||
class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> {
|
||||
class CompareFConvert : public OpRewritePattern<mhlo::CompareOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(xla_hlo::CompareOp op,
|
||||
LogicalResult matchAndRewrite(mhlo::CompareOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto lhs = op.lhs();
|
||||
auto rhs = op.rhs();
|
||||
@ -109,11 +109,11 @@ class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> {
|
||||
// convert the integer constant to iota result type. For complex types, the real
|
||||
// part is replaced with the generated constant and the imaginary part is
|
||||
// replaced with zero tensor.
|
||||
class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
|
||||
class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(xla_hlo::IotaOp op,
|
||||
LogicalResult matchAndRewrite(mhlo::IotaOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto output_type = op.getType().cast<ShapedType>();
|
||||
auto output_size = output_type.getNumElements();
|
||||
@ -168,8 +168,7 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
|
||||
loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0)));
|
||||
auto imag_zeroes =
|
||||
rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, zeroes);
|
||||
rewriter.replaceOpWithNewOp<xla_hlo::ComplexOp>(op, iota_const,
|
||||
imag_zeroes);
|
||||
rewriter.replaceOpWithNewOp<mhlo::ComplexOp>(op, iota_const, imag_zeroes);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -197,12 +196,12 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
|
||||
/// Perform the lowering to standard dialect.
|
||||
void LegalizeToStandard::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::xla_hlo::PopulateXlaToStdPatterns(&patterns, &getContext());
|
||||
mlir::mhlo::PopulateXlaToStdPatterns(&patterns, &getContext());
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeToStandard> legalize_pass(
|
||||
"xla-legalize-to-std", "Legalize from XLA dialect to standard dialect");
|
||||
|
||||
} // end namespace xla_hlo
|
||||
} // end namespace mhlo
|
||||
} // end namespace mlir
|
||||
|
@ -84,14 +84,14 @@ Value TransposeReshape(Value arg, mlir::Location loc,
|
||||
transposed_shape.push_back(arg_shape[val]);
|
||||
}
|
||||
auto transpose_type = RankedTensorType::get(transposed_shape, element_type);
|
||||
auto transpose_result = rewriter->create<mlir::xla_hlo::TransposeOp>(
|
||||
auto transpose_result = rewriter->create<mlir::mhlo::TransposeOp>(
|
||||
loc, transpose_type, arg, transpose_permutation_attr);
|
||||
|
||||
// Return the final result.
|
||||
auto reshaped_type =
|
||||
RankedTensorType::get({left_size, right_size}, element_type);
|
||||
return rewriter->create<mlir::xla_hlo::ReshapeOp>(loc, reshaped_type,
|
||||
transpose_result);
|
||||
return rewriter->create<mlir::mhlo::ReshapeOp>(loc, reshaped_type,
|
||||
transpose_result);
|
||||
}
|
||||
|
||||
Value ProcessDotArg(Value arg, mlir::Location loc,
|
||||
@ -125,8 +125,7 @@ Value ProcessDotArg(Value arg, mlir::Location loc,
|
||||
return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter);
|
||||
}
|
||||
|
||||
struct GeneralDotConvert
|
||||
: public OpRewritePattern<mlir::xla_hlo::DotGeneralOp> {
|
||||
struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
|
||||
// Attempts to lower a General Dot operator to a standard Dot operator.
|
||||
// General dots include batching dimensions and can have collapsing
|
||||
// dimensions along any axis. Inserting correctly arrange transpose and
|
||||
@ -138,7 +137,7 @@ struct GeneralDotConvert
|
||||
explicit GeneralDotConvert(MLIRContext *context)
|
||||
: OpRewritePattern(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::xla_hlo::DotGeneralOp op,
|
||||
LogicalResult matchAndRewrite(mlir::mhlo::DotGeneralOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto dot_element_type = mlir::getElementTypeOrSelf(op);
|
||||
|
||||
@ -162,11 +161,11 @@ struct GeneralDotConvert
|
||||
auto new_dot_type =
|
||||
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
|
||||
|
||||
auto new_dot_op = rewriter.create<mlir::xla_hlo::DotOp>(
|
||||
auto new_dot_op = rewriter.create<mlir::mhlo::DotOp>(
|
||||
op.getLoc(), new_dot_type, lhs, rhs, *(op.precision_config()));
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::xla_hlo::ReshapeOp>(op, op.getType(),
|
||||
new_dot_op);
|
||||
rewriter.replaceOpWithNewOp<mlir::mhlo::ReshapeOp>(op, op.getType(),
|
||||
new_dot_op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -176,15 +175,14 @@ struct LegalizeGeneralDot
|
||||
/// Lower all general dots that can be represented as a non-batched matmul.
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns,
|
||||
&getContext());
|
||||
mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext());
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(
|
||||
void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(
|
||||
OwningRewritePatternList *patterns, MLIRContext *ctx) {
|
||||
patterns->insert<GeneralDotConvert>(ctx);
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -86,5 +86,5 @@ void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
|
||||
patterns->insert<ClampWithBroadcastConvert>(context);
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -33,7 +33,7 @@ struct TestMaterializeBroadcastsPass
|
||||
ConversionTarget conversionTarget(getContext());
|
||||
OwningRewritePatternList conversionPatterns;
|
||||
|
||||
// Consider the xla_hlo dialect legal for tests.
|
||||
// Consider the mhlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<XlaHloDialect>();
|
||||
// The conversion uses helpers from the Standard dialect.
|
||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||
@ -50,9 +50,9 @@ struct TestMaterializeBroadcastsPass
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::xla_hlo::TestMaterializeBroadcastsPass>
|
||||
pass("test-xla-materialize-broadcasts",
|
||||
"Test pass for materializing 'broadcast_dimensions' attributes");
|
||||
static mlir::PassRegistration<mlir::mhlo::TestMaterializeBroadcastsPass> pass(
|
||||
"test-xla-materialize-broadcasts",
|
||||
"Test pass for materializing 'broadcast_dimensions' attributes");
|
||||
|
@ -60,7 +60,7 @@ limitations under the License.
|
||||
// shape dialect once it is ready.
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
using llvm::EquivalenceClasses;
|
||||
@ -544,7 +544,7 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
|
||||
}
|
||||
|
||||
FusionOp fusion =
|
||||
b.create<xla_hlo::FusionOp>(fused_loc, output_types, inputs);
|
||||
b.create<mhlo::FusionOp>(fused_loc, output_types, inputs);
|
||||
Region& region = fusion.fused_computation();
|
||||
region.push_back(new Block);
|
||||
Block& block = region.front();
|
||||
@ -552,7 +552,7 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
|
||||
op->moveBefore(&block, block.end());
|
||||
}
|
||||
b.setInsertionPoint(&block, block.end());
|
||||
b.create<xla_hlo::ReturnOp>(fused_loc, outputs);
|
||||
b.create<mhlo::ReturnOp>(fused_loc, outputs);
|
||||
|
||||
for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) {
|
||||
Value output = std::get<0>(output_and_result);
|
||||
@ -572,8 +572,8 @@ std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusion() {
|
||||
return std::make_unique<XlaHloFusion>();
|
||||
}
|
||||
|
||||
static PassRegistration<XlaHloFusion> xla_hlo_fusion_pass(
|
||||
"xla-hlo-fusion", "fuse xla_hlo ops to kLoop/kInput fusion patterns.");
|
||||
static PassRegistration<XlaHloFusion> mhlo_fusion_pass(
|
||||
"xla-hlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns.");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -81,5 +81,5 @@ std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
|
||||
return std::make_unique<SinkConstantsToControlFlow>();
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -40,12 +40,12 @@ Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type,
|
||||
auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64));
|
||||
auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim});
|
||||
if (shape_value) {
|
||||
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
|
||||
return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
|
||||
loc, result_type, value_1d, shape_value, dims);
|
||||
}
|
||||
assert(result_type.hasStaticShape());
|
||||
return rewriter.create<xla_hlo::BroadcastInDimOp>(loc, result_type, value_1d,
|
||||
dims);
|
||||
return rewriter.create<mhlo::BroadcastInDimOp>(loc, result_type, value_1d,
|
||||
dims);
|
||||
}
|
||||
|
||||
// Calculate the shape value of operand, assuming it is a dynamic shape with
|
||||
@ -89,25 +89,25 @@ Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
|
||||
auto epsilon_tensor_attr =
|
||||
DenseElementsAttr::get(scalar_type, {epsilon_attr.cast<Attribute>()});
|
||||
Value epsilon =
|
||||
rewriter.create<xla_hlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
|
||||
rewriter.create<mhlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
|
||||
auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64));
|
||||
auto dims = DenseIntElementsAttr::get(dims_type, SmallVector<int64_t, 1>{});
|
||||
if (broadcast_to_type.hasStaticShape()) {
|
||||
return rewriter.create<xla_hlo::BroadcastInDimOp>(
|
||||
return rewriter.create<mhlo::BroadcastInDimOp>(
|
||||
op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims);
|
||||
}
|
||||
Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter);
|
||||
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
|
||||
return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
|
||||
op->getLoc(), broadcast_to_type, epsilon, shape_value,
|
||||
/*broadcast_dims=*/dims);
|
||||
}
|
||||
|
||||
class UnfuseBatchNormInferencePattern
|
||||
: public OpRewritePattern<xla_hlo::BatchNormInferenceOp> {
|
||||
: public OpRewritePattern<mhlo::BatchNormInferenceOp> {
|
||||
public:
|
||||
using OpRewritePattern<xla_hlo::BatchNormInferenceOp>::OpRewritePattern;
|
||||
using OpRewritePattern<mhlo::BatchNormInferenceOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(xla_hlo::BatchNormInferenceOp bn_op,
|
||||
LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bn_op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
// Enforce type invariants.
|
||||
// Note that we deduce the actual element type from the variance,
|
||||
@ -132,9 +132,9 @@ class UnfuseBatchNormInferencePattern
|
||||
if (!epsilon) {
|
||||
return failure();
|
||||
}
|
||||
Value stddev = rewriter.create<xla_hlo::AddOp>(bn_op.getLoc(),
|
||||
bn_op.variance(), epsilon);
|
||||
stddev = rewriter.create<xla_hlo::SqrtOp>(bn_op.getLoc(), stddev);
|
||||
Value stddev =
|
||||
rewriter.create<mhlo::AddOp>(bn_op.getLoc(), bn_op.variance(), epsilon);
|
||||
stddev = rewriter.create<mhlo::SqrtOp>(bn_op.getLoc(), stddev);
|
||||
|
||||
// Broadcast all terms.
|
||||
Value shape_value;
|
||||
@ -156,14 +156,13 @@ class UnfuseBatchNormInferencePattern
|
||||
|
||||
// Compute:
|
||||
// scale * (input - mean) / stddev + offset
|
||||
Value result = rewriter.create<xla_hlo::SubOp>(
|
||||
bn_op.getLoc(), bn_op.operand(), broadcast_mean);
|
||||
result = rewriter.create<xla_hlo::MulOp>(bn_op.getLoc(), result,
|
||||
broadcast_scale);
|
||||
result = rewriter.create<xla_hlo::DivOp>(bn_op.getLoc(), result,
|
||||
broadcast_stddev);
|
||||
rewriter.replaceOpWithNewOp<xla_hlo::AddOp>(bn_op, result,
|
||||
broadcast_offset);
|
||||
Value result = rewriter.create<mhlo::SubOp>(bn_op.getLoc(), bn_op.operand(),
|
||||
broadcast_mean);
|
||||
result =
|
||||
rewriter.create<mhlo::MulOp>(bn_op.getLoc(), result, broadcast_scale);
|
||||
result =
|
||||
rewriter.create<mhlo::DivOp>(bn_op.getLoc(), result, broadcast_stddev);
|
||||
rewriter.replaceOpWithNewOp<mhlo::AddOp>(bn_op, result, broadcast_offset);
|
||||
|
||||
return success();
|
||||
}
|
||||
@ -180,5 +179,5 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext* context,
|
||||
patterns->insert<UnfuseBatchNormInferencePattern>(context);
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -38,9 +38,9 @@ struct TestUnfuseBatchNormPass
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::xla_hlo::TestUnfuseBatchNormPass> pass(
|
||||
static mlir::PassRegistration<mlir::mhlo::TestUnfuseBatchNormPass> pass(
|
||||
"test-xla-unfuse-batch-norm",
|
||||
"Test pass for materializing 'broadcast_dimensions' attributes");
|
||||
|
@ -182,7 +182,7 @@ struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
|
||||
using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern;
|
||||
|
||||
// This code has been adapted from IREE's
|
||||
// (https://github.com/google/iree/) xla_hlo -> linalg conversion.
|
||||
// (https://github.com/google/iree/) mhlo -> linalg conversion.
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::ConvOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
@ -348,14 +348,14 @@ class BroadcastConverter
|
||||
|
||||
class HloBroadcastInDimConverter
|
||||
: public DataMovementOpConverter<HloBroadcastInDimConverter,
|
||||
xla_hlo::BroadcastInDimOp, false> {
|
||||
mhlo::BroadcastInDimOp, false> {
|
||||
public:
|
||||
using DataMovementOpConverter<HloBroadcastInDimConverter,
|
||||
xla_hlo::BroadcastInDimOp,
|
||||
mhlo::BroadcastInDimOp,
|
||||
false>::DataMovementOpConverter;
|
||||
|
||||
static SmallVector<AffineMap, 2> getIndexingMaps(
|
||||
xla_hlo::BroadcastInDimOp broadcastOp, Builder* b) {
|
||||
mhlo::BroadcastInDimOp broadcastOp, Builder* b) {
|
||||
auto resultType = getXLAOpResultType<false>(broadcastOp);
|
||||
auto operandType =
|
||||
broadcastOp.operand().getType().template cast<ShapedType>();
|
||||
@ -845,7 +845,7 @@ struct HloLegalizeToLinalg
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
|
||||
|
||||
auto func = getFunction();
|
||||
xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
@ -863,40 +863,40 @@ static PassRegistration<LhloLegalizeToLinalg> legalize_lhlo_pass(
|
||||
"lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect");
|
||||
} // namespace xla_lhlo
|
||||
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
patterns->insert<BroadcastConverter<xla_hlo::BroadcastOp, false>,
|
||||
patterns->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
|
||||
HloBroadcastInDimConverter,
|
||||
PointwiseToLinalgConverter<xla_hlo::AbsOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::AddOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::AndOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CeilOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CompareOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::ComplexOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::ConvertOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CopyOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CosOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::DivOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::ImagOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::LogOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::MaxOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::MinOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::MulOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::NegOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::RealOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::RsqrtOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SelectOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SinOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
|
||||
ReshapeOpConverter<xla_hlo::ReshapeOp, false>,
|
||||
ReverseConverter<xla_hlo::ReverseOp, false>,
|
||||
TransposeConverter<xla_hlo::TransposeOp, false>>(context);
|
||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CosOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::DivOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||
ReverseConverter<mhlo::ReverseOp, false>,
|
||||
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
||||
@ -905,5 +905,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
||||
|
||||
static PassRegistration<HloLegalizeToLinalg> legalize_hlo_pass(
|
||||
"hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect");
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
// TODO(frgossen): Make it variadic.
|
||||
@ -69,7 +69,7 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||
rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
|
||||
auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
operandTy.getElementType());
|
||||
Value flatOperand = rewriter.create<xla_hlo::DynamicReshapeOp>(
|
||||
Value flatOperand = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, flatTensorTy, operand, flatShapeAsDimTensor);
|
||||
|
||||
// Generate IR for the actual operation.
|
||||
@ -80,7 +80,7 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||
rewriter.getIndexType());
|
||||
Value shapeAsExtentTensor =
|
||||
rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
|
||||
Value result = rewriter.create<xla_hlo::DynamicReshapeOp>(
|
||||
Value result = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, operandTy, flatResult, shapeAsExtentTensor);
|
||||
rewriter.replaceOp(op, result);
|
||||
|
||||
@ -184,5 +184,5 @@ static PassRegistration<TransformUnrankedHloPass> transform_unranked_hlo_pass(
|
||||
"transform-unranked-hlo",
|
||||
"Realize element-wise operations on ranked tensors where possible");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
@ -2,107 +2,107 @@
|
||||
|
||||
// CHECK-LABEL: add_fold
|
||||
func @add_fold() -> tensor<4xi64> {
|
||||
%0 = xla_hlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
|
||||
%1 = xla_hlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<[6, 8, 10, 12]>
|
||||
%2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
%0 = mhlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<[6, 8, 10, 12]>
|
||||
%2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: add_scalar_fold
|
||||
func @add_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = xla_hlo.constant dense<1> : tensor<4xi64>
|
||||
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<6>
|
||||
%2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
%0 = mhlo.constant dense<1> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<5> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<6>
|
||||
%2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: add_fold_float
|
||||
func @add_fold_float() -> tensor<4xf64> {
|
||||
%0 = xla_hlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64>
|
||||
%1 = xla_hlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64>
|
||||
// CHECK: xla_hlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
|
||||
%2 = "xla_hlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
%0 = mhlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64>
|
||||
%1 = mhlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64>
|
||||
// CHECK: mhlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
|
||||
%2 = "mhlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
return %2 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: sub_scalar_fold
|
||||
func @sub_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = xla_hlo.constant dense<5> : tensor<4xi64>
|
||||
%1 = xla_hlo.constant dense<1> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<4>
|
||||
%2 = "xla_hlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
%0 = mhlo.constant dense<5> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<1> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<4>
|
||||
%2 = "mhlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: multiply_scalar_fold
|
||||
func @multiply_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = xla_hlo.constant dense<5> : tensor<4xi64>
|
||||
%1 = xla_hlo.constant dense<3> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<15>
|
||||
%2 = "xla_hlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
%0 = mhlo.constant dense<5> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<3> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<15>
|
||||
%2 = "mhlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: divide_scalar_fold
|
||||
func @divide_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<1>
|
||||
%2 = "xla_hlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<5> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<1>
|
||||
%2 = "mhlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: divide_fold_float
|
||||
func @divide_fold_float() -> tensor<4xf64> {
|
||||
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: xla_hlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]>
|
||||
%2 = "xla_hlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: mhlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]>
|
||||
%2 = "mhlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
return %2 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: max_scalar_fold
|
||||
func @max_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<7>
|
||||
%2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<5> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<7>
|
||||
%2 = "mhlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: max_fold_float
|
||||
func @max_fold_float() -> tensor<4xf64> {
|
||||
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: xla_hlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]>
|
||||
%2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: mhlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]>
|
||||
%2 = "mhlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
return %2 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: min_scalar_fold
|
||||
func @min_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = xla_hlo.constant dense<-5> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<-5>
|
||||
%2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<-5> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<-5>
|
||||
%2 = "mhlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: min_fold_float
|
||||
func @min_fold_float() -> tensor<4xf64> {
|
||||
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: xla_hlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]>
|
||||
%2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: mhlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]>
|
||||
%2 = "mhlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
return %2 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_noop
|
||||
func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK-SAME: [[ARG:%.+]]: tensor<4xi32>
|
||||
%0 = "xla_hlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32>
|
||||
%0 = "mhlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK: return [[ARG]]
|
||||
return %0 : tensor<4xi32>
|
||||
@ -112,7 +112,7 @@ func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
|
||||
func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> {
|
||||
// CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32>
|
||||
// CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32>
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32>
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK: return [[ARG0]]
|
||||
return %0 : tensor<4xi32>
|
||||
@ -120,34 +120,34 @@ func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) ->
|
||||
|
||||
// CHECK-LABEL: concatenate_empty_bool
|
||||
func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> {
|
||||
// CHECK: xla_hlo.constant
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
|
||||
// CHECK: mhlo.constant
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
|
||||
|
||||
return %0 : tensor<0xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_empty_int
|
||||
func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> {
|
||||
// CHECK: xla_hlo.constant
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32>
|
||||
// CHECK: mhlo.constant
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32>
|
||||
|
||||
return %0 : tensor<0xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_empty_float
|
||||
func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
|
||||
// CHECK: xla_hlo.constant
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
|
||||
// CHECK: mhlo.constant
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
|
||||
|
||||
return %0 : tensor<0xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_const_1D
|
||||
func @concatenate_const_1D() -> tensor<4xi32> {
|
||||
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[0, 1, 2, 3]>
|
||||
%0 = xla_hlo.constant dense<[0, 1]> : tensor<2xi32>
|
||||
%1 = xla_hlo.constant dense<[2, 3]> : tensor<2xi32>
|
||||
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32>
|
||||
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[0, 1, 2, 3]>
|
||||
%0 = mhlo.constant dense<[0, 1]> : tensor<2xi32>
|
||||
%1 = mhlo.constant dense<[2, 3]> : tensor<2xi32>
|
||||
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK: return [[VAL]]
|
||||
return %2 : tensor<4xi32>
|
||||
@ -155,11 +155,11 @@ func @concatenate_const_1D() -> tensor<4xi32> {
|
||||
|
||||
// CHECK-LABEL: concatenate_const_1D_float
|
||||
func @concatenate_const_1D_float() -> tensor<4xf32> {
|
||||
// CHECK: [[VAL:%.+]] = xla_hlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
|
||||
// CHECK: [[VAL:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
|
||||
|
||||
%0 = xla_hlo.constant dense<[0.0, 1.0]> : tensor<2xf32>
|
||||
%1 = xla_hlo.constant dense<[2.0, 3.0]> : tensor<2xf32>
|
||||
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32>
|
||||
%0 = mhlo.constant dense<[0.0, 1.0]> : tensor<2xf32>
|
||||
%1 = mhlo.constant dense<[2.0, 3.0]> : tensor<2xf32>
|
||||
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK: return [[VAL]]
|
||||
return %2 : tensor<4xf32>
|
||||
@ -167,12 +167,12 @@ func @concatenate_const_1D_float() -> tensor<4xf32> {
|
||||
|
||||
// CHECK-LABEL: concatenate_const_2D_vertical
|
||||
func @concatenate_const_2D_vertical() -> tensor<2x2xi32> {
|
||||
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[
|
||||
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[
|
||||
// CHECK-SAME: [0, 1], [2, 3]
|
||||
// CHECK-SAME: ]>
|
||||
%0 = xla_hlo.constant dense<[[0, 1]]> : tensor<1x2xi32>
|
||||
%1 = xla_hlo.constant dense<[[2, 3]]> : tensor<1x2xi32>
|
||||
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
|
||||
%0 = mhlo.constant dense<[[0, 1]]> : tensor<1x2xi32>
|
||||
%1 = mhlo.constant dense<[[2, 3]]> : tensor<1x2xi32>
|
||||
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
|
||||
|
||||
// CHECK: return [[VAL]]
|
||||
return %2 : tensor<2x2xi32>
|
||||
@ -180,12 +180,12 @@ func @concatenate_const_2D_vertical() -> tensor<2x2xi32> {
|
||||
|
||||
// CHECK-LABEL: concatenate_const_2D_horizontal
|
||||
func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
|
||||
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[
|
||||
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[
|
||||
// CHECK-SAME: [0, 2], [1, 3]
|
||||
// CHECK-SAME: ]>
|
||||
%0 = xla_hlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
|
||||
%1 = xla_hlo.constant dense<[[2], [3]]> : tensor<2x1xi32>
|
||||
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
||||
%0 = mhlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
|
||||
%1 = mhlo.constant dense<[[2], [3]]> : tensor<2x1xi32>
|
||||
%2 = "mhlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
||||
|
||||
// CHECK: return [[VAL]]
|
||||
return %2 : tensor<2x2xi32>
|
||||
@ -193,40 +193,40 @@ func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
|
||||
|
||||
// CHECK-LABEL: dynamic_slice_variable_start
|
||||
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
|
||||
// CHECK: "xla_hlo.dynamic-slice"
|
||||
%1 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
|
||||
// CHECK: "mhlo.dynamic-slice"
|
||||
%1 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
|
||||
return %1 : tensor<1x4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: dynamic_slice_constant_start
|
||||
func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0)
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0)
|
||||
// CHECK-DAG-SAME: limit_indices = dense<3> : tensor<1xi64>
|
||||
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>
|
||||
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>}
|
||||
// CHECK: return %[[RESULT]] : tensor<2xi32>
|
||||
%0 = xla_hlo.constant dense<1> : tensor<i64>
|
||||
%1 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
|
||||
%0 = mhlo.constant dense<1> : tensor<i64>
|
||||
%1 = "mhlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
|
||||
return %1 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape
|
||||
func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<?x4xi32> {
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0)
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0)
|
||||
// CHECK-DAG-SAME: limit_indices = dense<[2, 4]> : tensor<2xi64>
|
||||
// CHECK-DAG-SAME: start_indices = dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK-DAG-SAME: strides = dense<1> : tensor<2xi64>
|
||||
// CHECK: return %[[RESULT]] : tensor<?x4xi32>
|
||||
%0 = xla_hlo.constant dense<1> : tensor<i64>
|
||||
%1 = xla_hlo.constant dense<0> : tensor<i64>
|
||||
%2 = "xla_hlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<?x4xi32>
|
||||
%0 = mhlo.constant dense<1> : tensor<i64>
|
||||
%1 = mhlo.constant dense<0> : tensor<i64>
|
||||
%2 = "mhlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<?x4xi32>
|
||||
return %2 : tensor<?x4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_2D_noop
|
||||
// CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64>
|
||||
func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> {
|
||||
%0 = "xla_hlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>)
|
||||
%0 = "mhlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>)
|
||||
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
return %0 : tensor<2x2xi64>
|
||||
@ -234,80 +234,80 @@ func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> {
|
||||
|
||||
// CHECK-LABEL: slice_1D_fold
|
||||
func @slice_1D_fold() -> tensor<2xi64> {
|
||||
%0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<[7, 9]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
|
||||
%0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<[7, 9]>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
|
||||
return %1 : tensor<2xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_1D_fp
|
||||
func @slice_1D_fp() -> tensor<2xf32> {
|
||||
%0 = xla_hlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32>
|
||||
// CHECK: xla_hlo.constant dense<[7.000000e+00, 9.000000e+00]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>)
|
||||
%0 = mhlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32>
|
||||
// CHECK: mhlo.constant dense<[7.000000e+00, 9.000000e+00]>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>)
|
||||
return %1 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_1D_strided_fold
|
||||
func @slice_1D_strided_fold() -> tensor<2xi64> {
|
||||
%0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
|
||||
// CHECK: xla_hlo.constant dense<[7, 10]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
|
||||
%0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<[7, 10]>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
|
||||
return %1 : tensor<2xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_2D_fold
|
||||
func @slice_2D_fold() -> tensor<2x2xi64> {
|
||||
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: xla_hlo.constant dense<[
|
||||
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: mhlo.constant dense<[
|
||||
// CHECK-SAME: [6, 7],
|
||||
// CHECK-SAME: [10, 11]
|
||||
// CHECK-SAME: ]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>)
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>)
|
||||
return %1 : tensor<2x2xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_2D_fold_horizontal
|
||||
func @slice_2D_fold_horizontal() -> tensor<1x4xi64> {
|
||||
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: xla_hlo.constant dense<[
|
||||
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: mhlo.constant dense<[
|
||||
// CHECK-SAME: [0, 1, 2, 3]
|
||||
// CHECK-SAME: ]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>)
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>)
|
||||
return %1 : tensor<1x4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_2D_fold_vertical
|
||||
func @slice_2D_fold_vertical() -> tensor<4x1xi64> {
|
||||
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: xla_hlo.constant dense<[
|
||||
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: mhlo.constant dense<[
|
||||
// CHECK-SAME: [2], [6], [10], [14]
|
||||
// CHECK-SAME: ]>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>)
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>)
|
||||
return %1 : tensor<4x1xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_concat_fold_first
|
||||
func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> {
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
|
||||
// CHECK: return %arg0
|
||||
return %1 : tensor<1x5xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_concat_fold_second
|
||||
func @slice_concat_fold_second(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> {
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
|
||||
// CHECK: return %arg1
|
||||
return %1 : tensor<1x5xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_concat_fold_second_with_slice
|
||||
func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x4xf32> {
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
|
||||
// CHECK: [[SLICE:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32>
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>)
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
|
||||
// CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>)
|
||||
|
||||
// CHECK: return [[SLICE]]
|
||||
return %1 : tensor<1x4xf32>
|
||||
@ -315,9 +315,9 @@ func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<
|
||||
|
||||
// CHECK-LABEL: slice_concat_fold_middle
|
||||
func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> {
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
|
||||
// CHECK: [[SLICE:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>)
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
|
||||
// CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>)
|
||||
|
||||
// CHECK: return [[SLICE]]
|
||||
return %1 : tensor<1x5xf32>
|
||||
@ -325,11 +325,11 @@ func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %
|
||||
|
||||
// CHECK-LABEL: slice_concat_fold_two
|
||||
func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<2x5xf32> {
|
||||
// CHECK: [[CONCAT:%.+]] = "xla_hlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64}
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
|
||||
// CHECK: [[CONCAT:%.+]] = "mhlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64}
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
|
||||
|
||||
// CHECK: [[SLICE:%.+]] = "xla_hlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>)
|
||||
// CHECK: [[SLICE:%.+]] = "mhlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>)
|
||||
|
||||
// CHECK: return [[SLICE]]
|
||||
return %1 : tensor<2x5xf32>
|
||||
@ -338,72 +338,72 @@ func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg
|
||||
// CHECK-LABEL: func @broadcast_in_dim_identity
|
||||
func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
||||
// CHECK: return %arg0
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
|
||||
return %0 : tensor<2x3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @broadcast_in_dim_not_identity_because_it_actually_broadcasts
|
||||
func @broadcast_in_dim_not_identity_because_it_actually_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: xla_hlo.broadcast_in_dim
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: mhlo.broadcast_in_dim
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @broadcast_in_dim_not_identity_permutation
|
||||
func @broadcast_in_dim_not_identity_permutation(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: xla_hlo.broadcast_in_dim
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: mhlo.broadcast_in_dim
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic
|
||||
func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> {
|
||||
// CHECK: %[[RESULT:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32>
|
||||
%0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32>
|
||||
// CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32>
|
||||
%0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<5x4xf32>
|
||||
return %0 : tensor<5x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @complex_expand_fold
|
||||
func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
|
||||
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>)
|
||||
%1 = "xla_hlo.real"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
%2 = "xla_hlo.imag"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>)
|
||||
%1 = "mhlo.real"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
%2 = "mhlo.imag"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
// CHECK: return %arg0, %arg1
|
||||
return %1, %2 : tensor<4xf32>, tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @complex_collapse_fold
|
||||
func @complex_collapse_fold(%arg0: tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>> {
|
||||
%0 = "xla_hlo.real"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
%1 = "xla_hlo.imag"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
%2 = "xla_hlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
%0 = "mhlo.real"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
%1 = "mhlo.imag"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
%2 = "mhlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
// CHECK: return %arg0
|
||||
return %2 : tensor<4xcomplex<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @dynamic_iota_is_static
|
||||
func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> {
|
||||
// CHECK: [[RESULT:%.*]] = "xla_hlo.iota"
|
||||
// CHECK: [[RESULT:%.*]] = "mhlo.iota"
|
||||
// CHECK: return [[RESULT]]
|
||||
%0 = "xla_hlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32>
|
||||
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @iota_not_lowered_to_constant
|
||||
func @iota_not_lowered_to_constant() -> tensor<4xi32> {
|
||||
// CHECK: [[RESULT:%.*]] = "xla_hlo.iota"
|
||||
// CHECK: [[RESULT:%.*]] = "mhlo.iota"
|
||||
// CHECK: return [[RESULT]]
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @unary_einsum
|
||||
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: "xla_hlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"}
|
||||
%0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: "mhlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"}
|
||||
%0 = "mhlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -411,30 +411,30 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK: return [[ARG]]
|
||||
%0 = "xla_hlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
%0 = "mhlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic
|
||||
func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> {
|
||||
// CHECK: xla_hlo.reshape
|
||||
%0 = "xla_hlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32>
|
||||
// CHECK: mhlo.reshape
|
||||
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32>
|
||||
return %0 : tensor<4x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: do_not_dce_while_with_outfeed
|
||||
func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK: xla_hlo.while
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
// CHECK: mhlo.while
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
|
||||
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = "xla_hlo.create_token"() : () -> !xla_hlo.token
|
||||
%1 = "mhlo.create_token"() : () -> !mhlo.token
|
||||
// Side-effecting op outfeed present inside while.
|
||||
%2 = "xla_hlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor<i64>, !xla_hlo.token) -> !xla_hlo.token
|
||||
"xla_hlo.return"(%arg1) : (tensor<i64>) -> ()
|
||||
%2 = "mhlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor<i64>, !mhlo.token) -> !mhlo.token
|
||||
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
return %arg0 : tensor<i64>
|
||||
@ -442,15 +442,15 @@ func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
|
||||
// CHECK-LABEL: dce_while_without_side_effect
|
||||
func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK-NOT: xla_hlo.while
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
// CHECK-NOT: mhlo.while
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
|
||||
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = "xla_hlo.create_token"() : () -> !xla_hlo.token
|
||||
"xla_hlo.return"(%arg1) : (tensor<i64>) -> ()
|
||||
%1 = "mhlo.create_token"() : () -> !mhlo.token
|
||||
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
return %arg0 : tensor<i64>
|
||||
|
@ -4,7 +4,7 @@
|
||||
// representative op for detailed broadcast semantics.
|
||||
// CHECK-LABEL: @addWithoutBroadcast
|
||||
func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.add %arg0, %arg1
|
||||
// CHECK: mhlo.add %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -20,9 +20,9 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
|
||||
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]]
|
||||
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xf32>
|
||||
@ -41,9 +41,9 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK-NEXT: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = "mhlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xcomplex<f32>>
|
||||
@ -62,9 +62,9 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||
// CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.+]] = "mhlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: shape.assuming_yield %[[RESULT]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: return %[[FINAL_RESULT]] : tensor<?x?xi1>
|
||||
@ -76,7 +76,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||
// Verifies that broadcast_dimensions validity checks are valid.
|
||||
// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions
|
||||
func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK: xla_hlo.add
|
||||
// CHECK: mhlo.add
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
@ -85,7 +85,7 @@ func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<
|
||||
// Verifies that broadcast_dimensions validity checks are valid.
|
||||
// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions
|
||||
func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<f32>) -> tensor<1x4xf32> {
|
||||
// CHECK: xla_hlo.add
|
||||
// CHECK: mhlo.add
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
@ -113,7 +113,7 @@ func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1:
|
||||
// expansions. Tests below merely verify that the op has an expansion.
|
||||
// CHECK-LABEL: @andWithoutBroadcast
|
||||
func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: xla_hlo.and %arg0, %arg1
|
||||
// CHECK: mhlo.and %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
@ -121,7 +121,7 @@ func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
|
||||
// -----
|
||||
// CHECK-LABEL: @atan2WithoutBroadcast
|
||||
func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.atan2 %arg0, %arg1
|
||||
// CHECK: mhlo.atan2 %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -129,7 +129,7 @@ func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
|
||||
// -----
|
||||
// CHECK-LABEL: @compareWithoutBroadcast
|
||||
func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> {
|
||||
// CHECK: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
@ -137,7 +137,7 @@ func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
||||
// -----
|
||||
// CHECK-LABEL: @complexWithoutBroadcast
|
||||
func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex<f32>> {
|
||||
// CHECK: "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
// CHECK: "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
return %0 : tensor<4xcomplex<f32>>
|
||||
}
|
||||
@ -145,7 +145,7 @@ func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
||||
// -----
|
||||
// CHECK-LABEL: @divideWithoutBroadcast
|
||||
func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.divide %arg0, %arg1
|
||||
// CHECK: mhlo.divide %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -153,7 +153,7 @@ func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tens
|
||||
// -----
|
||||
// CHECK-LABEL: @maximumWithoutBroadcast
|
||||
func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.maximum %arg0, %arg1
|
||||
// CHECK: mhlo.maximum %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -161,7 +161,7 @@ func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
||||
// -----
|
||||
// CHECK-LABEL: @minimumWithoutBroadcast
|
||||
func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.minimum %arg0, %arg1
|
||||
// CHECK: mhlo.minimum %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -169,7 +169,7 @@ func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
||||
// -----
|
||||
// CHECK-LABEL: @multiplyWithoutBroadcast
|
||||
func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.multiply %arg0, %arg1
|
||||
// CHECK: mhlo.multiply %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -177,7 +177,7 @@ func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te
|
||||
// -----
|
||||
// CHECK-LABEL: @orWithoutBroadcast
|
||||
func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: xla_hlo.or %arg0, %arg1
|
||||
// CHECK: mhlo.or %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
@ -185,7 +185,7 @@ func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi
|
||||
// -----
|
||||
// CHECK-LABEL: @powerWithoutBroadcast
|
||||
func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.power %arg0, %arg1
|
||||
// CHECK: mhlo.power %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -193,7 +193,7 @@ func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
|
||||
// -----
|
||||
// CHECK-LABEL: @remainderWithoutBroadcast
|
||||
func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.remainder %arg0, %arg1
|
||||
// CHECK: mhlo.remainder %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -201,7 +201,7 @@ func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> t
|
||||
// -----
|
||||
// CHECK-LABEL: @shift_leftWithoutBroadcast
|
||||
func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.shift_left %arg0, %arg1
|
||||
// CHECK: mhlo.shift_left %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -209,7 +209,7 @@ func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) ->
|
||||
// -----
|
||||
// CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast
|
||||
func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1
|
||||
// CHECK: mhlo.shift_right_arithmetic %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -217,7 +217,7 @@ func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor
|
||||
// -----
|
||||
// CHECK-LABEL: @shift_right_logicalWithoutBroadcast
|
||||
func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.shift_right_logical %arg0, %arg1
|
||||
// CHECK: mhlo.shift_right_logical %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -225,7 +225,7 @@ func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4x
|
||||
// -----
|
||||
// CHECK-LABEL: @subWithoutBroadcast
|
||||
func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.subtract %arg0, %arg1
|
||||
// CHECK: mhlo.subtract %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -233,7 +233,7 @@ func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<
|
||||
// -----
|
||||
// CHECK-LABEL: @xorWithoutBroadcast
|
||||
func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: xla_hlo.xor %arg0, %arg1
|
||||
// CHECK: mhlo.xor %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
|
@ -3,7 +3,7 @@
|
||||
// CHECK-LABEL: func @single_operand
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @single_operand(%arg: tensor<1x2xf32>) -> tensor<1x2xf32> {
|
||||
%0 = "xla_hlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
%0 = "mhlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
return %0 : tensor<1x2xf32>
|
||||
}
|
@ -5,7 +5,7 @@
|
||||
// CHECK-LABEL: func @same_type
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @same_type(%arg: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "xla_hlo.convert"(%arg) : (tensor<f32>) -> tensor<f32>
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
@ -15,8 +15,8 @@ func @same_type(%arg: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK-LABEL: func @int_widening
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @int_widening(%arg: tensor<i32>) -> tensor<i64> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i64>
|
||||
%0 = "xla_hlo.convert"(%arg) : (tensor<i32>) -> tensor<i64>
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i64>
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<i64>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
@ -26,8 +26,8 @@ func @int_widening(%arg: tensor<i32>) -> tensor<i64> {
|
||||
// CHECK-LABEL: func @int_narrowing
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @int_narrowing(%arg: tensor<i32>) -> tensor<i16> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i16>
|
||||
%0 = "xla_hlo.convert"(%arg) : (tensor<i32>) -> tensor<i16>
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i16>
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<i16>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %0 : tensor<i16>
|
||||
}
|
||||
@ -37,8 +37,8 @@ func @int_narrowing(%arg: tensor<i32>) -> tensor<i16> {
|
||||
// CHECK-LABEL: func @float_int
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @float_int(%arg: tensor<f32>) -> tensor<i32> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<f32>) -> tensor<i32>
|
||||
%0 = "xla_hlo.convert"(%arg) : (tensor<f32>) -> tensor<i32>
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<f32>) -> tensor<i32>
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<f32>) -> tensor<i32>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
@ -48,8 +48,8 @@ func @float_int(%arg: tensor<f32>) -> tensor<i32> {
|
||||
// CHECK-LABEL: func @int_float
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @int_float(%arg: tensor<i32>) -> tensor<f32> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<f32>
|
||||
%0 = "xla_hlo.convert"(%arg) : (tensor<i32>) -> tensor<f32>
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<f32>
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<f32>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
@ -59,8 +59,8 @@ func @int_float(%arg: tensor<i32>) -> tensor<f32> {
|
||||
// CHECK-LABEL: func @high_rank_tensor
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32>
|
||||
%0 = "xla_hlo.convert"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32>
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32>
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %0 : tensor<2x3xf32>
|
||||
}
|
||||
@ -70,9 +70,9 @@ func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> {
|
||||
|
||||
// CHECK-LABEL: func @const_same_type
|
||||
func @const_same_type() -> tensor<i32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
|
||||
%cst = xla_hlo.constant dense<42> : tensor<i32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
|
||||
%cst = mhlo.constant dense<42> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
@ -81,9 +81,9 @@ func @const_same_type() -> tensor<i32> {
|
||||
|
||||
// CHECK-LABEL: func @const_float_int
|
||||
func @const_float_int() -> tensor<i32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
|
||||
%cst = xla_hlo.constant dense<42.0> : tensor<f32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<f32>) -> tensor<i32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
|
||||
%cst = mhlo.constant dense<42.0> : tensor<f32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<f32>) -> tensor<i32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
@ -92,9 +92,9 @@ func @const_float_int() -> tensor<i32> {
|
||||
|
||||
// CHECK-LABEL: func @const_int_float
|
||||
func @const_int_float() -> tensor<f32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.{{0*}}e+00> : tensor<f32>
|
||||
%cst = xla_hlo.constant dense<4> : tensor<i32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor<f32>
|
||||
%cst = mhlo.constant dense<4> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
@ -103,9 +103,9 @@ func @const_int_float() -> tensor<f32> {
|
||||
|
||||
// CHECK-LABEL: func @const_negative_int_float
|
||||
func @const_negative_int_float() -> tensor<f32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<-4.{{0*}}e+00> : tensor<f32>
|
||||
%cst = xla_hlo.constant dense<-4> : tensor<i32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-4.{{0*}}e+00> : tensor<f32>
|
||||
%cst = mhlo.constant dense<-4> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
@ -114,9 +114,9 @@ func @const_negative_int_float() -> tensor<f32> {
|
||||
|
||||
// CHECK-LABEL: func @const_int_bf16
|
||||
func @const_int_bf16() -> tensor<bf16> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.{{0*}}e+00> : tensor<bf16>
|
||||
%cst = xla_hlo.constant dense<4> : tensor<i32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<bf16>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor<bf16>
|
||||
%cst = mhlo.constant dense<4> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<bf16>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<bf16>
|
||||
}
|
||||
@ -125,9 +125,9 @@ func @const_int_bf16() -> tensor<bf16> {
|
||||
|
||||
// CHECK-LABEL: func @const_bf16_int
|
||||
func @const_bf16_int() -> tensor<i16> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i16>
|
||||
%cst = xla_hlo.constant dense<42.0> : tensor<bf16>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<i16>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16>
|
||||
%cst = mhlo.constant dense<42.0> : tensor<bf16>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<i16>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i16>
|
||||
}
|
||||
@ -136,9 +136,9 @@ func @const_bf16_int() -> tensor<i16> {
|
||||
|
||||
// CHECK-LABEL: func @const_int_narrowing
|
||||
func @const_int_narrowing() -> tensor<i32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
|
||||
%cst = xla_hlo.constant dense<42> : tensor<i64>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<i64>) -> tensor<i32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
|
||||
%cst = mhlo.constant dense<42> : tensor<i64>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i64>) -> tensor<i32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
@ -147,9 +147,9 @@ func @const_int_narrowing() -> tensor<i32> {
|
||||
|
||||
// CHECK-LABEL: func @const_int_widening
|
||||
func @const_int_widening() -> tensor<i64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i64>
|
||||
%cst = xla_hlo.constant dense<42> : tensor<i32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
|
||||
%cst = mhlo.constant dense<42> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
@ -158,9 +158,9 @@ func @const_int_widening() -> tensor<i64> {
|
||||
|
||||
// CHECK-LABEL: func @const_negative_int_widening
|
||||
func @const_negative_int_widening() -> tensor<i64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<-42> : tensor<i64>
|
||||
%cst = xla_hlo.constant dense<-42> : tensor<i32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64>
|
||||
%cst = mhlo.constant dense<-42> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
@ -169,9 +169,9 @@ func @const_negative_int_widening() -> tensor<i64> {
|
||||
|
||||
// CHECK-LABEL: func @const_float_narrowing
|
||||
func @const_float_narrowing() -> tensor<f32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<f32>
|
||||
%cst = xla_hlo.constant dense<4.2> : tensor<f64>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<f64>) -> tensor<f32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor<f32>
|
||||
%cst = mhlo.constant dense<4.2> : tensor<f64>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<f64>) -> tensor<f32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
@ -180,9 +180,9 @@ func @const_float_narrowing() -> tensor<f32> {
|
||||
|
||||
// CHECK-LABEL: func @const_f32_bf16
|
||||
func @const_f32_bf16() -> tensor<bf16> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+01> : tensor<bf16>
|
||||
%cst = xla_hlo.constant dense<42.0> : tensor<f32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<f32>) -> tensor<bf16>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+01> : tensor<bf16>
|
||||
%cst = mhlo.constant dense<42.0> : tensor<f32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<f32>) -> tensor<bf16>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<bf16>
|
||||
}
|
||||
@ -191,9 +191,9 @@ func @const_f32_bf16() -> tensor<bf16> {
|
||||
|
||||
// CHECK-LABEL: func @const_bf16_f64
|
||||
func @const_bf16_f64() -> tensor<f64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.187500e+00> : tensor<f64>
|
||||
%cst = xla_hlo.constant dense<4.2> : tensor<bf16>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.187500e+00> : tensor<f64>
|
||||
%cst = mhlo.constant dense<4.2> : tensor<bf16>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<f64>
|
||||
}
|
||||
@ -202,9 +202,9 @@ func @const_bf16_f64() -> tensor<f64> {
|
||||
|
||||
// CHECK-LABEL: func @const_bf16_int
|
||||
func @const_bf16_int() -> tensor<i64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i64>
|
||||
%cst = xla_hlo.constant dense<42.0> : tensor<bf16>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<i64>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
|
||||
%cst = mhlo.constant dense<42.0> : tensor<bf16>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<i64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
@ -214,11 +214,11 @@ func @const_bf16_int() -> tensor<i64> {
|
||||
|
||||
// CHECK-LABEL: func @const_high_rank_tensor
|
||||
func @const_high_rank_tensor() -> tensor<2x3xi32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[
|
||||
// CHECK-SAME: [1, 2, 3], [4, 5, 6]
|
||||
// CHECK-SAME: ]> : tensor<2x3xi32>
|
||||
%cst = xla_hlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
|
||||
%0 = "xla_hlo.convert"(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32>
|
||||
%cst = mhlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<2x3xi32>
|
||||
}
|
||||
|
@ -4,7 +4,7 @@
|
||||
// BOTH-LABEL: func @attrs
|
||||
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.exponential"(%tensor_operand)
|
||||
%tensor_result = "mhlo.exponential"(%tensor_operand)
|
||||
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
@ -28,11 +28,11 @@ func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
|
||||
// BOTH-LABEL: func @func_op_long
|
||||
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
%2 = xla_hlo.add %arg0, %1 : tensor<4xf32>
|
||||
%3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
|
||||
%4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32>
|
||||
%5 = xla_hlo.multiply %2, %4 : tensor<4xf32>
|
||||
%1 = mhlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
%2 = mhlo.add %arg0, %1 : tensor<4xf32>
|
||||
%3 = mhlo.minimum %arg0, %arg1 : tensor<4xf32>
|
||||
%4 = mhlo.subtract %arg1, %3 : tensor<4xf32>
|
||||
%5 = mhlo.multiply %2, %4 : tensor<4xf32>
|
||||
return %5 : tensor<4xf32>
|
||||
}
|
||||
// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
|
||||
@ -65,12 +65,12 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
||||
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
%tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
|
||||
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
|
||||
%sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2)
|
||||
%sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
|
||||
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier)
|
||||
%tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
|
||||
@ -86,7 +86,7 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
||||
// BOTH-LABEL: func @copy
|
||||
func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.copy"(%tensor_operand)
|
||||
%tensor_result = "mhlo.copy"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -98,7 +98,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @exp
|
||||
func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.exponential"(%tensor_operand)
|
||||
%tensor_result = "mhlo.exponential"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -110,7 +110,7 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @log
|
||||
func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.log"(%tensor_operand)
|
||||
%tensor_result = "mhlo.log"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -125,7 +125,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||
%tensor_pred = tensor_load %pred : memref<2x2xi1>
|
||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
|
||||
%tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
|
||||
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -138,7 +138,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||
func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs)
|
||||
%tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs)
|
||||
{comparison_direction = "EQ"}
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
// BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
|
||||
@ -151,7 +151,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x
|
||||
// BOTH-LABEL: func @broadcast
|
||||
func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<5xf32>
|
||||
%tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand)
|
||||
%tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand)
|
||||
{broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
: (tensor<5xf32>) -> tensor<10x5xf32>
|
||||
// BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
@ -170,7 +170,7 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
||||
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
|
||||
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
||||
%shape = call @external_func() : () -> tensor<3xi64>
|
||||
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
|
||||
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
|
||||
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// BOTH: %[[SHAPE:.*]] = call @external_func()
|
||||
@ -226,7 +226,7 @@ func @complex(%real: memref<2x2xf32>,
|
||||
%result: memref<2x2xcomplex<f32>>) {
|
||||
%tensor_real = tensor_load %real : memref<2x2xf32>
|
||||
%tensor_imag = tensor_load %imag : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.complex"(%tensor_real, %tensor_imag)
|
||||
%tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
|
||||
// BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
|
||||
@ -238,7 +238,7 @@ func @complex(%real: memref<2x2xf32>,
|
||||
// BOTH-LABEL: func @real
|
||||
func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
||||
%tensor_result = "xla_hlo.real"(%tensor_operand)
|
||||
%tensor_result = "mhlo.real"(%tensor_operand)
|
||||
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -250,7 +250,7 @@ func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @imag
|
||||
func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
||||
%tensor_result = "xla_hlo.imag"(%tensor_operand)
|
||||
%tensor_result = "mhlo.imag"(%tensor_operand)
|
||||
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -261,7 +261,7 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||
|
||||
// BOTH-LABEL: func @iota
|
||||
func @iota(%result: memref<10xi32>) {
|
||||
%tensor_result = "xla_hlo.iota"()
|
||||
%tensor_result = "mhlo.iota"()
|
||||
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
|
||||
// BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
|
||||
tensor_store %tensor_result, %result : memref<10xi32>
|
||||
@ -273,7 +273,7 @@ func @iota(%result: memref<10xi32>) {
|
||||
// BOTH-LABEL: func @abs
|
||||
func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.abs"(%tensor_operand)
|
||||
%tensor_result = "mhlo.abs"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -285,7 +285,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @ceil
|
||||
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.ceil"(%tensor_operand)
|
||||
%tensor_result = "mhlo.ceil"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -297,7 +297,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @convert
|
||||
func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.convert"(%tensor_operand)
|
||||
%tensor_result = "mhlo.convert"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
|
||||
// BOTH-NOT: tensor_store
|
||||
@ -310,7 +310,7 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @cos
|
||||
func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.cosine"(%tensor_operand)
|
||||
%tensor_result = "mhlo.cosine"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -322,7 +322,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @neg
|
||||
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.negate"(%tensor_operand)
|
||||
%tensor_result = "mhlo.negate"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -334,7 +334,7 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @rsqrt
|
||||
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.rsqrt"(%tensor_operand)
|
||||
%tensor_result = "mhlo.rsqrt"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -346,7 +346,7 @@ func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @sign
|
||||
func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.sign"(%tensor_operand)
|
||||
%tensor_result = "mhlo.sign"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -358,7 +358,7 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @sqrt
|
||||
func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.sqrt"(%tensor_operand)
|
||||
%tensor_result = "mhlo.sqrt"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -370,7 +370,7 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @tanh
|
||||
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.tanh"(%tensor_operand)
|
||||
%tensor_result = "mhlo.tanh"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -383,7 +383,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs)
|
||||
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -395,7 +395,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
|
||||
// Dynamic shape binary element-wise operation.
|
||||
// BOTH-LABEL: func @add_dyn
|
||||
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
||||
%result = "xla_hlo.add"(%lhs, %rhs)
|
||||
%result = "mhlo.add"(%lhs, %rhs)
|
||||
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// BOTH: %[[C0:.*]] = constant 0 : index
|
||||
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
@ -420,7 +420,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
||||
// Dynamic shape unary element-wise operation.
|
||||
// BOTH-LABEL: func @tanh_dyn
|
||||
func @tanh_dyn(%arg0: tensor<?x?xf32>) {
|
||||
%result = "xla_hlo.tanh"(%arg0)
|
||||
%result = "mhlo.tanh"(%arg0)
|
||||
: (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// BOTH: %[[C0:.*]] = constant 0 : index
|
||||
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
@ -448,7 +448,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
||||
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
||||
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
|
||||
// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
|
||||
%dot = "xla_hlo.dot"(%arg0, %arg0)
|
||||
%dot = "mhlo.dot"(%arg0, %arg0)
|
||||
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
|
||||
// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]])
|
||||
// ESC: return %[[ALLOC]]
|
||||
@ -466,7 +466,7 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor
|
||||
// BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
// BOTH-SAME: rhs_dilation = dense<[1, 2]>
|
||||
// BOTH-SAME: window_strides = dense<[2, 1]>
|
||||
%out = "xla_hlo.convolution"(%filter, %input) {
|
||||
%out = "mhlo.convolution"(%filter, %input) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
input_batch_dimension = 0 : i64,
|
||||
|
@ -10,7 +10,7 @@ func @float_add(%lhs: tensor<2x2xf32>,
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32
|
||||
// CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = addf %[[ARG0]], %[[ARG1]]
|
||||
// CHECK: linalg.yield %[[RESULT]]
|
||||
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
@ -22,7 +22,7 @@ func @integer_add(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: addi
|
||||
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
@ -34,7 +34,7 @@ func @float_mul(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: mulf
|
||||
%0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
%0 = "mhlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
@ -46,7 +46,7 @@ func @integer_mul(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: muli
|
||||
%0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
%0 = "mhlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
@ -58,7 +58,7 @@ func @float_remainder(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: remf
|
||||
%0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
%0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
@ -70,7 +70,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: remi_signed
|
||||
%0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
%0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
@ -79,7 +79,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>,
|
||||
|
||||
// CHECK-LABEL: func @float_rsqrt
|
||||
func @float_rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%tensor_result = "xla_hlo.rsqrt"(%operand)
|
||||
%tensor_result = "mhlo.rsqrt"(%operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: rsqrt
|
||||
@ -93,7 +93,7 @@ func @float_sub(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: subf
|
||||
%0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
@ -105,7 +105,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: subi
|
||||
%0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
@ -116,7 +116,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
|
||||
func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: absf
|
||||
%0 = "xla_hlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -126,7 +126,7 @@ func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: exp
|
||||
%0 = "xla_hlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -136,7 +136,7 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: log
|
||||
%0 = "xla_hlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -146,7 +146,7 @@ func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: ceilf
|
||||
%0 = "xla_hlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -156,7 +156,7 @@ func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: negf
|
||||
%0 = "xla_hlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -166,7 +166,7 @@ func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: tanh
|
||||
%0 = "xla_hlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -177,7 +177,7 @@ func @integer_and(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: and
|
||||
%0 = "xla_hlo.and"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
%0 = "mhlo.and"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
@ -187,7 +187,7 @@ func @integer_and(%lhs: tensor<2x2xi32>,
|
||||
// CHECK-LABEL: func @float_cmp
|
||||
func @float_cmp(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
|
||||
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"}
|
||||
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"}
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
return %0 : tensor<2x2xi1>
|
||||
}
|
||||
@ -201,7 +201,7 @@ func @float_cmp(%lhs: tensor<2x2xf32>,
|
||||
// CHECK-LABEL: func @int_cmp
|
||||
func @int_cmp(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "LT"}
|
||||
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "LT"}
|
||||
: (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>)
|
||||
return %0 : tensor<2x2xi1>
|
||||
}
|
||||
@ -216,7 +216,7 @@ func @int_cmp(%lhs: tensor<2x2xi32>,
|
||||
func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: cos
|
||||
%0 = "xla_hlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -226,7 +226,7 @@ func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: sin
|
||||
%0 = "xla_hlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -235,7 +235,7 @@ func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK-LABEL: func @copy
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
|
||||
%0 = "xla_hlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>)
|
||||
%0 = "mhlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>)
|
||||
return %0 : tensor<2x4x8xf32>
|
||||
}
|
||||
// CHECK: return [[ARG]] : tensor<2x4x8xf32>
|
||||
@ -245,7 +245,7 @@ func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
|
||||
// CHECK-LABEL: func @select
|
||||
func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%0 = "xla_hlo.select"(%pred, %lhs, %rhs)
|
||||
%0 = "mhlo.select"(%pred, %lhs, %rhs)
|
||||
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
@ -260,7 +260,7 @@ func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
|
||||
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-LABEL: func @broadcast_scalar
|
||||
func @broadcast_scalar(%arg: tensor<f32>) -> tensor<4x2x1xf32> {
|
||||
%0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<f32>) -> tensor<4x2x1xf32>
|
||||
%0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<f32>) -> tensor<4x2x1xf32>
|
||||
return %0: tensor<4x2x1xf32>
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
@ -273,7 +273,7 @@ func @broadcast_scalar(%arg: tensor<f32>) -> tensor<4x2x1xf32> {
|
||||
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
|
||||
// CHECK-LABEL: func @broadcast
|
||||
func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
|
||||
%0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32>
|
||||
%0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32>
|
||||
return %0: tensor<4x2x1x4x?x16xf32>
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
@ -286,7 +286,7 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
|
||||
// CHECK-LABEL: func @broadcast_in_dim
|
||||
func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> {
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%operand)
|
||||
%0 = "mhlo.broadcast_in_dim"(%operand)
|
||||
{broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>}
|
||||
: (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32>
|
||||
return %0 : tensor<7x10x6x4x5xf32>
|
||||
@ -302,7 +302,7 @@ func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> {
|
||||
// CHECK-LABEL: func @broadcast_in_dim_with_one_to_one
|
||||
func @broadcast_in_dim_with_one_to_one(
|
||||
%operand: tensor<1xf32>) -> tensor<1x5xf32> {
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%operand)
|
||||
%0 = "mhlo.broadcast_in_dim"(%operand)
|
||||
{broadcast_dimensions = dense<[0]> : tensor<1xi64>}
|
||||
: (tensor<1xf32>) -> tensor<1x5xf32>
|
||||
return %0 : tensor<1x5xf32>
|
||||
@ -317,7 +317,7 @@ func @broadcast_in_dim_with_one_to_one(
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-LABEL: func @broadcast_scalar
|
||||
func @broadcast_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> {
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%operand)
|
||||
%0 = "mhlo.broadcast_in_dim"(%operand)
|
||||
{broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
: (tensor<f32>) -> tensor<7x10x6xf32>
|
||||
return %0 : tensor<7x10x6xf32>
|
||||
@ -332,7 +332,7 @@ func @broadcast_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> {
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-LABEL: func @transpose
|
||||
func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
|
||||
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}
|
||||
%0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}
|
||||
: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
|
||||
return %0 : tensor<3x2x5x9xi32>
|
||||
}
|
||||
@ -344,7 +344,7 @@ func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
|
||||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK-LABEL: func @reshape_3D_2D
|
||||
func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
|
||||
return %0 : tensor<12x42xi32>
|
||||
}
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
|
||||
@ -355,7 +355,7 @@ func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
|
||||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_4D_2D
|
||||
func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32>
|
||||
return %0 : tensor<12x42xi32>
|
||||
}
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
|
||||
@ -366,7 +366,7 @@ func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> {
|
||||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_2D_4D
|
||||
func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32>
|
||||
return %0 : tensor<12x1x42x1xi32>
|
||||
}
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
|
||||
@ -375,7 +375,7 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
|
||||
|
||||
// CHECK-LABEL: func @minf
|
||||
func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%0 = "xla_hlo.minimum"(%lhs, %rhs)
|
||||
%0 = "mhlo.minimum"(%lhs, %rhs)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
@ -389,7 +389,7 @@ func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @maxi
|
||||
func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
%0 = "xla_hlo.maximum"(%lhs, %rhs)
|
||||
%0 = "mhlo.maximum"(%lhs, %rhs)
|
||||
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
@ -404,7 +404,7 @@ func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK-DAG: #[[MAP:.*]] = affine_map<() -> ()>
|
||||
// CHECK-LABEL: func @add_scalar
|
||||
func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -417,7 +417,7 @@ func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
|
||||
|
||||
func @reshape_collapse_single_dim
|
||||
(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x784xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>
|
||||
return %0 : tensor<1x784xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
|
||||
@ -428,7 +428,7 @@ func @reshape_collapse_single_dim
|
||||
// -----
|
||||
|
||||
func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32>
|
||||
return %0 : tensor<2x4x3xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
|
||||
@ -440,7 +440,7 @@ func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> {
|
||||
// -----
|
||||
|
||||
func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32>
|
||||
return %0 : tensor<2x4x2xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0)>
|
||||
@ -451,7 +451,7 @@ func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> {
|
||||
// -----
|
||||
|
||||
func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32>
|
||||
return %0 : tensor<1x4x2xf32>
|
||||
}
|
||||
// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
@ -462,7 +462,7 @@ func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> {
|
||||
|
||||
func @reshape_multiple_collapse
|
||||
(%arg0 : tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32>
|
||||
return %0 : tensor<1x4x5x6xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)>
|
||||
@ -476,7 +476,7 @@ func @reshape_multiple_collapse
|
||||
|
||||
// CHECK-LABEL: func @convert_i32_to_f32
|
||||
func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
|
||||
return %result : tensor<2x2xf32>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -488,7 +488,7 @@ func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @convert_i16_to_i32
|
||||
func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32>
|
||||
return %result : tensor<2x2xi32>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -500,7 +500,7 @@ func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> {
|
||||
|
||||
// CHECK-LABEL: func @convert_i32_to_i16
|
||||
func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16>
|
||||
return %result : tensor<2x2xi16>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -512,7 +512,7 @@ func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> {
|
||||
|
||||
// CHECK-LABEL: func @convert_f32_to_f64
|
||||
func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64>
|
||||
return %result : tensor<2x2xf64>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -524,7 +524,7 @@ func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> {
|
||||
|
||||
// CHECK-LABEL: func @convert_f64_to_f32
|
||||
func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32>
|
||||
return %result : tensor<2x2xf32>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -536,7 +536,7 @@ func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @convert_f32_to_i32
|
||||
func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
|
||||
return %result : tensor<2x2xi32>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -550,7 +550,7 @@ func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @reverse
|
||||
func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%result = "xla_hlo.reverse"(%input) {
|
||||
%result = "mhlo.reverse"(%input) {
|
||||
dimensions = dense<1> : tensor<1xi64>
|
||||
} : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %result : tensor<2x3xf32>
|
||||
|
@ -1,28 +1,28 @@
|
||||
// RUN: mlir-hlo-opt %s -inline | FileCheck %s
|
||||
|
||||
// Test case: Basic test of inlining into xla_hlo.while.
|
||||
// Test case: Basic test of inlining into mhlo.while.
|
||||
|
||||
// CHECK-LABEL: func @caller
|
||||
// CHECK: "xla_hlo.while"{{.*}}( {
|
||||
// CHECK: "mhlo.while"{{.*}}( {
|
||||
// CHECK: }, {
|
||||
// CHECK: "xla_hlo.exponential"
|
||||
// CHECK: "mhlo.exponential"
|
||||
// CHECK: })
|
||||
// CHECK-LABEL: func @callee
|
||||
|
||||
func @caller(%arg0: tensor<f32>, %pred: tensor<i1>) -> tensor<f32> {
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^entry(%unused: tensor<f32>):
|
||||
"xla_hlo.return"(%pred) : (tensor<i1>) -> ()
|
||||
"mhlo.return"(%pred) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^entry(%0: tensor<f32>):
|
||||
%1 = call @callee(%0) : (tensor<f32>) -> (tensor<f32>)
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||
} ) : (tensor<f32>) -> (tensor<f32>)
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
|
||||
func @callee(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "xla_hlo.exponential"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
%0 = "mhlo.exponential"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
@ -4,21 +4,21 @@
|
||||
func @while(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
//CHECK: br ^bb1(%arg0 : tensor<i64>)
|
||||
//CHECK: ^bb1([[VAL0:%.+]]: tensor<i64>):
|
||||
//CHECK: [[VAL1:%.+]] = "xla_hlo.compare"([[VAL0]], [[VAL0]])
|
||||
//CHECK: [[VAL1:%.+]] = "mhlo.compare"([[VAL0]], [[VAL0]])
|
||||
//CHECK: [[VAL2:%.+]] = extract_element [[VAL1]][] : tensor<i1>
|
||||
//CHECK: cond_br [[VAL2]], ^bb2([[VAL0]] : tensor<i64>), ^bb3([[VAL0]] : tensor<i64>)
|
||||
//CHECK: ^bb2([[VAL3:%.+]]: tensor<i64>):
|
||||
//CHECK: [[VAL4:%.+]] = xla_hlo.add [[VAL3]], [[VAL3]]
|
||||
//CHECK: [[VAL4:%.+]] = mhlo.add [[VAL3]], [[VAL3]]
|
||||
//CHECK: br ^bb1([[VAL4]] : tensor<i64>)
|
||||
//CHECK: ^bb3([[VAL5:%.+]]: tensor<i64>):
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
|
||||
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = xla_hlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64>
|
||||
"xla_hlo.return"(%1) : (tensor<i64>) -> ()
|
||||
%1 = mhlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64>
|
||||
"mhlo.return"(%1) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
// CHECK-NEXT: return [[VAL5]]
|
||||
@ -30,27 +30,27 @@ func @conditional(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: [[C0:%.+]] = constant dense<1.000000e+01> : tensor<f32>
|
||||
%cst = constant dense<1.000000e+01> : tensor<f32>
|
||||
|
||||
// CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
%0 = "xla_hlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: [[VAL0:%.+]] = "mhlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
%0 = "mhlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
|
||||
// CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor<i1>
|
||||
// CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor<f32>), ^bb2(%arg0 : tensor<f32>)
|
||||
%1 = "xla_hlo.if"(%0, %arg0, %arg0) ( {
|
||||
%1 = "mhlo.if"(%0, %arg0, %arg0) ( {
|
||||
|
||||
^bb0(%arg1: tensor<f32>):
|
||||
// CHECK: ^bb1([[VAL2:%.+]]: tensor<f32>):
|
||||
// CHECK: [[VAL3:%.+]] = "xla_hlo.log"([[VAL2]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: [[VAL3:%.+]] = "mhlo.log"([[VAL2]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: br ^bb3([[VAL3]] : tensor<f32>)
|
||||
%2 = "xla_hlo.log"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = "mhlo.log"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
|
||||
^bb0(%arg1: tensor<f32>):
|
||||
// CHECK: ^bb2([[VAL4:%.+]]: tensor<f32>):
|
||||
// CHECK: [[VAL5:%.+]] = "xla_hlo.exponential"([[VAL4]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: [[VAL5:%.+]] = "mhlo.exponential"([[VAL4]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: br ^bb3([[VAL5]] : tensor<f32>)
|
||||
%2 = "xla_hlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
|
||||
// CHECK: ^bb3([[VAL6:%.+]]: tensor<f32>):
|
||||
@ -62,27 +62,27 @@ func @conditional(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
func @while_with_multiple_blocks_in_body(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor<i64>)
|
||||
// CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
|
||||
// CHECK: %1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK: %1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK: %2 = extract_element %1[] : tensor<i1>
|
||||
// CHECK: cond_br %2, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
|
||||
// CHECK: ^[[BODY_ENTRY]](%3: tensor<i64>):
|
||||
// CHECK: br ^[[BODY_SUCC:.+]](%3 : tensor<i64>)
|
||||
// CHECK: ^[[BODY_SUCC]](%4: tensor<i64>):
|
||||
// CHECK: %5 = xla_hlo.add %4, %4 : tensor<i64>
|
||||
// CHECK: %5 = mhlo.add %4, %4 : tensor<i64>
|
||||
// CHECK: br ^[[COND_ENTRY]](%5 : tensor<i64>)
|
||||
// CHECK: ^[[EXIT]](%6: tensor<i64>):
|
||||
// CHECK: return %6 : tensor<i64>
|
||||
// CHECK: }
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^cond_entry(%arg1: tensor<i64>):
|
||||
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
|
||||
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^body_entry(%arg1: tensor<i64>):
|
||||
br ^body_succ(%arg1: tensor<i64>)
|
||||
^body_succ(%0: tensor<i64>):
|
||||
%1 = xla_hlo.add %0, %0 : tensor<i64>
|
||||
"xla_hlo.return"(%1) : (tensor<i64>) -> ()
|
||||
%1 = mhlo.add %0, %0 : tensor<i64>
|
||||
"mhlo.return"(%1) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
return %0 : tensor<i64>
|
||||
@ -94,7 +94,7 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
|
||||
// CHECK: br ^[[COND_SUCC:.+]](%0 : tensor<i64>)
|
||||
// CHECK: ^[[COND_SUCC]](%1: tensor<i64>):
|
||||
// CHECK: %2 = "xla_hlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK: %2 = "mhlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK: %3 = extract_element %2[] : tensor<i1>
|
||||
// CHECK: cond_br %3, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
|
||||
// CHECK: ^[[BODY_ENTRY]](%4: tensor<i64>):
|
||||
@ -102,15 +102,15 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK: ^[[EXIT]](%5: tensor<i64>):
|
||||
// CHECK: return %5 : tensor<i64>
|
||||
// CHECK: }
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^cond_entry(%arg1: tensor<i64>):
|
||||
br ^cond_succ(%arg1: tensor<i64>)
|
||||
^cond_succ(%0: tensor<i64>):
|
||||
%1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
|
||||
%1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^body_entry(%arg1: tensor<i64>):
|
||||
"xla_hlo.return"(%arg1) : (tensor<i64>) -> ()
|
||||
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
return %0 : tensor<i64>
|
||||
@ -123,24 +123,24 @@ func @conditional_with_multiple_blocks(%arg0: tensor<f32>, %arg1: tensor<f32>, %
|
||||
// CHECK: ^[[THEN_ENTRY]](%1: tensor<f32>):
|
||||
// CHECK: br ^[[THEN_SUCC:.+]](%1 : tensor<f32>)
|
||||
// CHECK: ^[[THEN_SUCC]](%2: tensor<f32>):
|
||||
// CHECK: %3 = "xla_hlo.log"(%2) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %3 = "mhlo.log"(%2) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: br ^[[EXIT:.+]](%3 : tensor<f32>)
|
||||
// CHECK: ^[[ELSE_ENTRY]](%4: tensor<f32>):
|
||||
// CHECK: %5 = "xla_hlo.exponential"(%4) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %5 = "mhlo.exponential"(%4) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: br ^[[EXIT]](%5 : tensor<f32>)
|
||||
// CHECK: ^[[EXIT]](%6: tensor<f32>):
|
||||
// CHECK: return %6 : tensor<f32>
|
||||
// CHECK: }
|
||||
%1 = "xla_hlo.if"(%pred, %arg0, %arg1) ( {
|
||||
%1 = "mhlo.if"(%pred, %arg0, %arg1) ( {
|
||||
^then_entry(%arg2: tensor<f32>):
|
||||
br ^then_succ(%arg2: tensor<f32>)
|
||||
^then_succ(%0: tensor<f32>):
|
||||
%2 = "xla_hlo.log"(%0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = "mhlo.log"(%0) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
^else_entry(%arg2: tensor<f32>):
|
||||
%2 = "xla_hlo.exponential"(%arg2) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = "mhlo.exponential"(%arg2) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
@ -3,19 +3,19 @@
|
||||
// CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: %0 = addf %arg0, %arg1 : tensor<4xf32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = "mhlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: %1 = mulf %0, %arg1 : tensor<4xf32>
|
||||
%1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%1 = "mhlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: %2 = subf %1, %arg1 : tensor<4xf32>
|
||||
%2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%2 = "mhlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: %3 = divf %2, %arg1 : tensor<4xf32>
|
||||
%3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%3 = "mhlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: %4 = remf %3, %arg1 : tensor<4xf32>
|
||||
%4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%4 = "mhlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: return %4 : tensor<4xf32>
|
||||
return %4 : tensor<4xf32>
|
||||
@ -24,19 +24,19 @@ func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf
|
||||
// CHECK-LABEL: func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK-NEXT: %0 = addi %arg0, %arg1 : tensor<4xi32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%0 = "mhlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK-NEXT: %1 = muli %0, %arg1 : tensor<4xi32>
|
||||
%1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%1 = "mhlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK-NEXT: %2 = subi %1, %arg1 : tensor<4xi32>
|
||||
%2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%2 = "mhlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK-NEXT: %3 = divi_signed %2, %arg1 : tensor<4xi32>
|
||||
%3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%3 = "mhlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK-NEXT: %4 = remi_signed %3, %arg1 : tensor<4xi32>
|
||||
%4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%4 = "mhlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK-NEXT: return %4 : tensor<4xi32>
|
||||
return %4 : tensor<4xi32>
|
||||
@ -45,17 +45,17 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32
|
||||
// CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
|
||||
func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
|
||||
// CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32>
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg0 : tensor<4xi32>
|
||||
%1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg0 : tensor<4xi32>
|
||||
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg0 : tensor<4xi32>
|
||||
%3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg0 : tensor<4xi32>
|
||||
%4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg0 : tensor<4xi32>
|
||||
%5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
||||
return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
||||
}
|
||||
@ -63,28 +63,28 @@ func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi
|
||||
// CHECK-LABEL: func @compare_float
|
||||
func @compare_float(%arg0: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
|
||||
// CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg0 : tensor<4xf32>
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %1 = cmpf "une", %arg0, %arg0 : tensor<4xf32>
|
||||
%1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg0 : tensor<4xf32>
|
||||
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg0 : tensor<4xf32>
|
||||
%3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg0 : tensor<4xf32>
|
||||
%4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg0 : tensor<4xf32>
|
||||
%5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @int_constant
|
||||
func @int_constant() -> (tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>) {
|
||||
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<i32>
|
||||
%0 = "xla_hlo.constant"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>)
|
||||
%0 = "mhlo.constant"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>)
|
||||
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xi32>
|
||||
%1 = "xla_hlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
|
||||
%1 = "mhlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
|
||||
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xi32>
|
||||
%2 = "xla_hlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
|
||||
%2 = "mhlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
|
||||
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
|
||||
return %0, %1, %2: tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
|
||||
}
|
||||
@ -92,11 +92,11 @@ func @int_constant() -> (tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>) {
|
||||
// CHECK-LABEL: func @float_constant
|
||||
func @float_constant() -> (tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>) {
|
||||
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<f32>
|
||||
%0 = "xla_hlo.constant"() {value = dense<0.0> : tensor<f32>} : () -> (tensor<f32>)
|
||||
%0 = "mhlo.constant"() {value = dense<0.0> : tensor<f32>} : () -> (tensor<f32>)
|
||||
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xf32>
|
||||
%1 = "xla_hlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
|
||||
%1 = "mhlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
|
||||
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xf32>
|
||||
%2 = "xla_hlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
|
||||
%2 = "mhlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
|
||||
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
|
||||
return %0, %1, %2: tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
|
||||
}
|
||||
@ -105,7 +105,7 @@ func @float_constant() -> (tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>) {
|
||||
// CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> {
|
||||
func @iota.const.1() -> tensor<4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
@ -113,7 +113,7 @@ func @iota.const.1() -> tensor<4xi32> {
|
||||
// CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> {
|
||||
func @iota.const.2() -> tensor<2x4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
|
||||
return %0 : tensor<2x4xi32>
|
||||
}
|
||||
@ -121,7 +121,7 @@ func @iota.const.2() -> tensor<2x4xi32> {
|
||||
// CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> {
|
||||
func @iota.const.3() -> tensor<2x4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
|
||||
return %0 : tensor<2x4xi32>
|
||||
}
|
||||
@ -129,7 +129,7 @@ func @iota.const.3() -> tensor<2x4xi32> {
|
||||
// CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> {
|
||||
func @iota.const.4() -> tensor<2x3x4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
|
||||
return %0 : tensor<2x3x4xi32>
|
||||
}
|
||||
@ -137,7 +137,7 @@ func @iota.const.4() -> tensor<2x3x4xi32> {
|
||||
// CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> {
|
||||
func @iota.const.5() -> tensor<2x3x4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
|
||||
return %0 : tensor<2x3x4xi32>
|
||||
}
|
||||
@ -145,7 +145,7 @@ func @iota.const.5() -> tensor<2x3x4xi32> {
|
||||
// CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> {
|
||||
func @iota.const.6() -> tensor<2x3x4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
|
||||
return %0 : tensor<2x3x4xi32>
|
||||
}
|
||||
@ -153,7 +153,7 @@ func @iota.const.6() -> tensor<2x3x4xi32> {
|
||||
// CHECK-LABEL: func @iota.const.f32
|
||||
func @iota.const.f32() -> tensor<4xf32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -161,7 +161,7 @@ func @iota.const.f32() -> tensor<4xf32> {
|
||||
// CHECK-LABEL: func @iota.const.f64
|
||||
func @iota.const.f64() -> tensor<4xf64> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<4xf64>
|
||||
return %0 : tensor<4xf64>
|
||||
}
|
||||
@ -169,7 +169,7 @@ func @iota.const.f64() -> tensor<4xf64> {
|
||||
// CHECK-LABEL: func @iota.const.bf16
|
||||
func @iota.const.bf16() -> tensor<4xbf16> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xbf16>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<4xbf16>
|
||||
return %0 : tensor<4xbf16>
|
||||
}
|
||||
@ -178,8 +178,8 @@ func @iota.const.bf16() -> tensor<4xbf16> {
|
||||
func @iota.const.complex.f32() -> tensor<4xcomplex<f32>> {
|
||||
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
|
||||
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf32>
|
||||
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f32>>
|
||||
// CHECK-NEXT: [[COMPLEX:%.*]] = "mhlo.complex"([[REAL]], [[IMAG]])
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f32>>
|
||||
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f32>>
|
||||
return %0 : tensor<4xcomplex<f32>>
|
||||
}
|
||||
@ -188,8 +188,8 @@ func @iota.const.complex.f32() -> tensor<4xcomplex<f32>> {
|
||||
func @iota.const.complex.f64() -> tensor<4xcomplex<f64>> {
|
||||
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
|
||||
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf64>
|
||||
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f64>>
|
||||
// CHECK-NEXT: [[COMPLEX:%.*]] = "mhlo.complex"([[REAL]], [[IMAG]])
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f64>>
|
||||
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f64>>
|
||||
return %0 : tensor<4xcomplex<f64>>
|
||||
}
|
||||
|
@ -396,9 +396,9 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m
|
||||
"xla_lhlo.fusion"() ( {
|
||||
%0 = tensor_load %input1 : memref<10xf32>
|
||||
%1 = tensor_load %input2 : memref<10xf32>
|
||||
%2 = "xla_hlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
%2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
%3 = tensor_load %input3 : memref<10xf32>
|
||||
%4 = "xla_hlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
%4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
tensor_store %4, %out : memref<10xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
} ) : () -> ()
|
||||
@ -803,15 +803,15 @@ func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %a
|
||||
func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
|
||||
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"mhlo.return"(%max) : (tensor<f32>) -> ()
|
||||
})
|
||||
{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> ()
|
||||
|
||||
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
|
||||
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"mhlo.return"(%max) : (tensor<f32>) -> ()
|
||||
})
|
||||
{
|
||||
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
|
||||
@ -958,8 +958,8 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32
|
||||
%updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () {
|
||||
"xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
|
||||
%add = xla_hlo.add %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%add) : (tensor<f32>) -> ()
|
||||
%add = mhlo.add %lhs, %rhs : tensor<f32>
|
||||
"mhlo.return"(%add) : (tensor<f32>) -> ()
|
||||
}) {
|
||||
scatter_dimension_numbers = {
|
||||
update_window_dims = dense<[1]> : tensor<1xi64>,
|
||||
@ -979,8 +979,8 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32
|
||||
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
|
||||
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>):
|
||||
%c = xla_hlo.add %a, %b : tensor<f32>
|
||||
"xla_hlo.return"(%c) : (tensor<f32>) -> ()
|
||||
%c = mhlo.add %a, %b : tensor<f32>
|
||||
"mhlo.return"(%c) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -991,8 +991,8 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref
|
||||
// expected-error@+1{{requires the same shape for all operands}}
|
||||
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>):
|
||||
%c = xla_hlo.add %a, %b : tensor<f32>
|
||||
"xla_hlo.return"(%c) : (tensor<f32>) -> ()
|
||||
%c = mhlo.add %a, %b : tensor<f32>
|
||||
"mhlo.return"(%c) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -1012,8 +1012,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
|
||||
return
|
||||
}
|
||||
@ -1025,8 +1025,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
|
||||
return
|
||||
}
|
||||
@ -1038,8 +1038,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -2,14 +2,14 @@
|
||||
|
||||
// CHECK-LABEL: @add
|
||||
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3
|
||||
%4 = "xla_hlo.add"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3
|
||||
%4 = "mhlo.add"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
@ -17,14 +17,14 @@ func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
|
||||
|
||||
// CHECK-LABEL: @add_unranked
|
||||
func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3
|
||||
%4 = "xla_hlo.add"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3
|
||||
%4 = "mhlo.add"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<*xf32>, tensor<*xf32>
|
||||
@ -32,14 +32,14 @@ func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
|
||||
|
||||
// CHECK-LABEL: @sub
|
||||
func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3
|
||||
%4 = "xla_hlo.subtract"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3
|
||||
%4 = "mhlo.subtract"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
@ -47,14 +47,14 @@ func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
|
||||
|
||||
// CHECK-LABEL: @sub_unranked
|
||||
func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3
|
||||
%4 = "xla_hlo.subtract"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3
|
||||
%4 = "mhlo.subtract"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<*xf32>, tensor<*xf32>
|
||||
@ -62,18 +62,18 @@ func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
|
||||
|
||||
// CHECK-LABEL: @mul
|
||||
func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3
|
||||
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
|
||||
%4 = "xla_hlo.multiply"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg3
|
||||
// CHECK-DAG: [[VAL2:%.+]] = mhlo.subtract [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply %arg0, %arg3
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = mhlo.add [[VAL3]], [[VAL4]]
|
||||
%4 = "mhlo.multiply"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return %2, %5 : tensor<2xf32>, tensor<2xf32>
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
@ -81,18 +81,18 @@ func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
|
||||
|
||||
// CHECK-LABEL: @mul_unranked
|
||||
func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3
|
||||
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
|
||||
%4 = "xla_hlo.multiply"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg3
|
||||
// CHECK-DAG: [[VAL2:%.+]] = mhlo.subtract [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply %arg0, %arg3
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = mhlo.add [[VAL3]], [[VAL4]]
|
||||
%4 = "mhlo.multiply"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return %2, %5 : tensor<*xf32>, tensor<*xf32>
|
||||
return %5, %6 : tensor<*xf32>, tensor<*xf32>
|
||||
@ -100,36 +100,36 @@ func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
|
||||
|
||||
// CHECK-LABEL: @div
|
||||
func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.negate"(%arg3)
|
||||
|
||||
// Compute the numerator's real component:
|
||||
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]]
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL2:%.+]] = mhlo.multiply %arg1, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.subtract [[VAL1]], [[VAL2]]
|
||||
|
||||
// Compute the real valued denominator as rhs * con(rhs):
|
||||
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]]
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]]
|
||||
|
||||
// Compute the numerator's imaginary component:
|
||||
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
|
||||
// CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
|
||||
// CHECK-DAG: [[VAL7:%.+]] = mhlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL8:%.+]] = mhlo.multiply %arg0, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL9:%.+]] = mhlo.add [[VAL8]], [[VAL7]]
|
||||
|
||||
// Divide the numerator by the real valued denominator.
|
||||
// CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]]
|
||||
// CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]]
|
||||
%4 = "xla_hlo.divide"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
// CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]]
|
||||
// CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]]
|
||||
%4 = "mhlo.divide"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL10]], [[VAL11]]
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
@ -139,36 +139,36 @@ func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %
|
||||
|
||||
// CHECK-LABEL: @div_unranked
|
||||
func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.negate"(%arg3)
|
||||
|
||||
// Compute the numerator's real component:
|
||||
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]]
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL2:%.+]] = mhlo.multiply %arg1, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.subtract [[VAL1]], [[VAL2]]
|
||||
|
||||
// Compute the real valued denominator as rhs * con(rhs):
|
||||
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]]
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]]
|
||||
|
||||
// Compute the numerator's imaginary component:
|
||||
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
|
||||
// CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
|
||||
// CHECK-DAG: [[VAL7:%.+]] = mhlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL8:%.+]] = mhlo.multiply %arg0, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL9:%.+]] = mhlo.add [[VAL8]], [[VAL7]]
|
||||
|
||||
// Divide the numerator by the real valued denominator.
|
||||
// CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]]
|
||||
// CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]]
|
||||
%4 = "xla_hlo.divide"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
// CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]]
|
||||
// CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]]
|
||||
%4 = "mhlo.divide"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return [[VAL10]], [[VAL11]]
|
||||
return %5, %6 : tensor<*xf32>, tensor<*xf32>
|
||||
@ -176,14 +176,14 @@ func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<
|
||||
|
||||
// CHECK-LABEL: @abs
|
||||
func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
|
||||
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg0
|
||||
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg1
|
||||
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.add [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.sqrt"([[VAL2]])
|
||||
%1 = "xla_hlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "xla_hlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg0
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg1
|
||||
// CHECK-DAG: [[VAL2:%.+]] = mhlo.add [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = "mhlo.sqrt"([[VAL2]])
|
||||
%1 = "mhlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL3]]
|
||||
return %2 : tensor<2xf32>
|
||||
@ -191,16 +191,16 @@ func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
|
||||
|
||||
// CHECK-LABEL: @exp
|
||||
func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0)
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1)
|
||||
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1)
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]]
|
||||
%1 = "xla_hlo.exponential"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "xla_hlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%3 = "xla_hlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0)
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1)
|
||||
// CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1)
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]]
|
||||
%1 = "mhlo.exponential"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%3 = "mhlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL3]], [[VAL4]]
|
||||
return %2, %3 : tensor<2xf32>, tensor<2xf32>
|
||||
@ -208,16 +208,16 @@ func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tenso
|
||||
|
||||
// CHECK-LABEL: @exp_unranked
|
||||
func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0)
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1)
|
||||
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1)
|
||||
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]]
|
||||
%1 = "xla_hlo.exponential"(%0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%2 = "xla_hlo.real"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%3 = "xla_hlo.imag"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0)
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1)
|
||||
// CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1)
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]]
|
||||
%1 = "mhlo.exponential"(%0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%2 = "mhlo.real"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%3 = "mhlo.imag"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return [[VAL3]], [[VAL4]]
|
||||
return %2, %3 : tensor<*xf32>, tensor<*xf32>
|
||||
|
@ -2,10 +2,10 @@
|
||||
|
||||
// CHECK-LABEL: @testDebatch1
|
||||
func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> {
|
||||
// CHECK-DAG: [[R0:%.+]] = "xla_hlo.reshape"(%arg0) : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
|
||||
// CHECK-DAG: [[R1:%.+]] = "xla_hlo.dot"([[R0]], %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
|
||||
// CHECK: [[R2:%.+]] = "xla_hlo.reshape"([[R1]]) : (tensor<1x3xf32>) -> tensor<1x1x3xf32>
|
||||
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32>
|
||||
// CHECK-DAG: [[R0:%.+]] = "mhlo.reshape"(%arg0) : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
|
||||
// CHECK-DAG: [[R1:%.+]] = "mhlo.dot"([[R0]], %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
|
||||
// CHECK: [[R2:%.+]] = "mhlo.reshape"([[R1]]) : (tensor<1x3xf32>) -> tensor<1x1x3xf32>
|
||||
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32>
|
||||
|
||||
return %0 : tensor<1x1x3xf32>
|
||||
}
|
||||
@ -14,13 +14,13 @@ func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1
|
||||
|
||||
// CHECK-LABEL: @testDebatch2
|
||||
func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3x1x1xf32> {
|
||||
// CHECK-DAG: [[R0:%.+]] = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
|
||||
// CHECK-DAG: [[R1:%.+]] = "xla_hlo.transpose"(%arg1) {permutation = dense<[2, 0, 1]> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32>
|
||||
// CHECK-DAG: [[R2:%.+]] = "xla_hlo.reshape"([[R1]]) : (tensor<2x1x1xf32>) -> tensor<2x1xf32>
|
||||
// CHECK-DAG: [[R3:%.+]] = "xla_hlo.dot"([[R0]], [[R2]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32>
|
||||
// CHECK: [[R4:%.+]] = "xla_hlo.reshape"([[R3]]) : (tensor<3x1xf32>) -> tensor<3x1x1xf32>
|
||||
// CHECK-DAG: [[R0:%.+]] = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
|
||||
// CHECK-DAG: [[R1:%.+]] = "mhlo.transpose"(%arg1) {permutation = dense<[2, 0, 1]> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32>
|
||||
// CHECK-DAG: [[R2:%.+]] = "mhlo.reshape"([[R1]]) : (tensor<2x1x1xf32>) -> tensor<2x1xf32>
|
||||
// CHECK-DAG: [[R3:%.+]] = "mhlo.dot"([[R0]], [[R2]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32>
|
||||
// CHECK: [[R4:%.+]] = "mhlo.reshape"([[R3]]) : (tensor<3x1xf32>) -> tensor<3x1x1xf32>
|
||||
|
||||
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32>
|
||||
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32>
|
||||
return %0 : tensor<3x1x1xf32>
|
||||
}
|
||||
|
||||
@ -28,8 +28,8 @@ func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3
|
||||
|
||||
// CHECK-LABEL: @testBatchPassthrough
|
||||
func @testBatchPassthrough(%arg0: tensor<2x2x3xf32>, %arg1: tensor<2x1x2xf32>) -> tensor<3x2x1xf32> {
|
||||
// CHECK-NEXT: "xla_hlo.dot_general"(%arg0, %arg1)
|
||||
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0]> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<[0]> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<3x2x1xf32>
|
||||
// CHECK-NEXT: "mhlo.dot_general"(%arg0, %arg1)
|
||||
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0]> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<[0]> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<3x2x1xf32>
|
||||
return %0 : tensor<3x2x1xf32>
|
||||
}
|
||||
|
||||
|
@ -3,9 +3,9 @@
|
||||
// CHECK-LABEL: @clampBroadcast
|
||||
// CHECK-SAME: (%[[MIN:.+]]: tensor<f32>, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor<f32>)
|
||||
func @clampBroadcast(%min: tensor<f32>, %value: tensor<4xf32>, %max: tensor<f32>) -> tensor<4xf32> {
|
||||
// CHECK-DAG: %[[MIN_BC:.+]] = "xla_hlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[MAX_BC:.+]] = "xla_hlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK: "xla_hlo.clamp"(%[[MIN_BC]], %[[VAL]], %[[MAX_BC]]) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = "xla_hlo.clamp"(%min, %value, %max) : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[MIN_BC:.+]] = "mhlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[MAX_BC:.+]] = "mhlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK: "mhlo.clamp"(%[[MIN_BC]], %[[VAL]], %[[MAX_BC]]) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = "mhlo.clamp"(%min, %value, %max) : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -4,11 +4,11 @@
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>)
|
||||
// CHECK: return %[[ARG0]]
|
||||
func @noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
|
||||
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%2 = "xla_hlo.reduce"(%arg0, %0) ( {
|
||||
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%2 = "mhlo.reduce"(%arg0, %0) ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%4 = xla_hlo.add %arg1, %arg2 : tensor<f32>
|
||||
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
|
||||
%4 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||
"mhlo.return"(%4) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
|
||||
return %2 : tensor<4x8xf32>
|
||||
}
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
// CHECK-LABEL: func @const_fold_collapse_to_scalar
|
||||
func @const_fold_collapse_to_scalar() -> tensor<i32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
|
||||
%cst = xla_hlo.constant dense<42> : tensor<1x1xi32>
|
||||
%0 = "xla_hlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor<i32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
|
||||
%cst = mhlo.constant dense<42> : tensor<1x1xi32>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor<i32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
@ -13,9 +13,9 @@ func @const_fold_collapse_to_scalar() -> tensor<i32> {
|
||||
|
||||
// CHECK-LABEL: func @const_fold_collapse_to_tensor
|
||||
func @const_fold_collapse_to_tensor() -> tensor<2xi32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<2xi32>
|
||||
%cst = xla_hlo.constant dense<42> : tensor<1x2xi32>
|
||||
%0 = "xla_hlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<2xi32>
|
||||
%cst = mhlo.constant dense<42> : tensor<1x2xi32>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
@ -24,9 +24,9 @@ func @const_fold_collapse_to_tensor() -> tensor<2xi32> {
|
||||
|
||||
// CHECK-LABEL: func @const_fold_expand
|
||||
func @const_fold_expand() -> tensor<1xi32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<1xi32>
|
||||
%cst = xla_hlo.constant dense<42> : tensor<i32>
|
||||
%0 = "xla_hlo.reshape"(%cst) : (tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<1xi32>
|
||||
%cst = mhlo.constant dense<42> : tensor<i32>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<1xi32>
|
||||
}
|
||||
@ -35,9 +35,9 @@ func @const_fold_expand() -> tensor<1xi32> {
|
||||
|
||||
// CHECK-LABEL: func @const_fold_nontrivial
|
||||
func @const_fold_nontrivial() -> tensor<16xi64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<16xi64>
|
||||
%cst = xla_hlo.constant dense<42> : tensor<4x4xi64>
|
||||
%0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64>
|
||||
%cst = mhlo.constant dense<42> : tensor<4x4xi64>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<16xi64>
|
||||
}
|
||||
@ -46,9 +46,9 @@ func @const_fold_nontrivial() -> tensor<16xi64> {
|
||||
|
||||
// CHECK-LABEL: func @const_fold_flatten
|
||||
func @const_fold_flatten() -> tensor<16xi64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<16xi64>
|
||||
%cst = xla_hlo.constant dense<42> : tensor<4x4xi64>
|
||||
%0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64>
|
||||
%cst = mhlo.constant dense<42> : tensor<4x4xi64>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<16xi64>
|
||||
}
|
||||
@ -57,9 +57,9 @@ func @const_fold_flatten() -> tensor<16xi64> {
|
||||
|
||||
// CHECK-LABEL: func @const_fold_6
|
||||
func @const_fold_6() -> tensor<6xi32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
|
||||
%cst = xla_hlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
|
||||
%0 = "xla_hlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
|
||||
%cst = mhlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<6xi32>
|
||||
}
|
||||
@ -68,11 +68,11 @@ func @const_fold_6() -> tensor<6xi32> {
|
||||
|
||||
// CHECK-LABEL: func @const_fold_same_shape
|
||||
func @const_fold_same_shape() -> tensor<2x3xi32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[
|
||||
// CHECK-SAME: [1, 2, 3], [4, 5, 6]
|
||||
// CHECK-SAME: ]> : tensor<2x3xi32>
|
||||
%cst = xla_hlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
|
||||
%0 = "xla_hlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32>
|
||||
%cst = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<2x3xi32>
|
||||
}
|
||||
@ -81,9 +81,9 @@ func @const_fold_same_shape() -> tensor<2x3xi32> {
|
||||
|
||||
// CHECK-LABEL: func @const_fold_float
|
||||
func @const_fold_float() -> tensor<16xf64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64>
|
||||
%cst = xla_hlo.constant dense<4.2> : tensor<4x4xf64>
|
||||
%0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64>
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64>
|
||||
%cst = mhlo.constant dense<4.2> : tensor<4x4xf64>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<16xf64>
|
||||
}
|
||||
@ -94,7 +94,7 @@ func @const_fold_float() -> tensor<16xf64> {
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
return %0 : tensor<2x3xi32>
|
||||
}
|
||||
|
||||
@ -103,10 +103,10 @@ func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
// CHECK-LABEL: func @non_const_chained_reshape
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, tensor<6xi32>) {
|
||||
// CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
// CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
|
||||
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
%1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
|
||||
// CHECK-NEXT: "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
// CHECK-NEXT: "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
|
||||
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
|
||||
return %0, %1 : tensor<3x2xi32>, tensor<6xi32> // return both so nothing is removed
|
||||
}
|
||||
|
||||
@ -115,9 +115,9 @@ func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, ten
|
||||
// CHECK-LABEL: func @non_const_chained_reshape_unused_parent
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<6xi32> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
|
||||
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
%1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
|
||||
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %1 : tensor<6xi32>
|
||||
}
|
||||
@ -127,8 +127,8 @@ func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<
|
||||
// CHECK-LABEL: func @non_const_chained_reshape_becomes_noop
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
%1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32>
|
||||
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32>
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
return %1 : tensor<2x3xi32>
|
||||
}
|
||||
@ -138,12 +138,12 @@ func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2
|
||||
// CHECK-LABEL: func @non_const_many_chained_reshapes
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @non_const_many_chained_reshapes(%arg : tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32>
|
||||
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32>
|
||||
%1 = "xla_hlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32>
|
||||
%2 = "xla_hlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32>
|
||||
%3 = "xla_hlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32>
|
||||
%4 = "xla_hlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32>
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.reshape"([[ARG]]) : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32>
|
||||
%0 = "mhlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32>
|
||||
%2 = "mhlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32>
|
||||
%3 = "mhlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32>
|
||||
%4 = "mhlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %4 : tensor<1x2x4x3xi32>
|
||||
}
|
||||
|
@ -3,7 +3,7 @@
|
||||
// CHECK-LABEL: func @noop
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>)
|
||||
func @noop(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> {
|
||||
%0 = "xla_hlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
%0 = "mhlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
// CHECK: return %[[ARG0]]
|
||||
return %0 : tensor<1x2xf32>
|
||||
}
|
||||
|
@ -4,27 +4,27 @@
|
||||
|
||||
// CHECK-LABEL: func @sink_const_to_while
|
||||
func @sink_const_to_while(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK-NEXT: xla_hlo.while
|
||||
%c0 = xla_hlo.constant dense<1> : tensor<i64>
|
||||
%c1 = xla_hlo.constant dense<2> : tensor<i64>
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
// CHECK-NEXT: mhlo.while
|
||||
%c0 = mhlo.constant dense<1> : tensor<i64>
|
||||
%c1 = mhlo.constant dense<2> : tensor<i64>
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
// CHECK: %[[ARG1A:.+]]: tensor<i64>
|
||||
// CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor<i64>
|
||||
// CHECK: "xla_hlo.compare"(%[[C0]], %[[ARG1A]])
|
||||
%1 = "xla_hlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
|
||||
// CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor<i64>
|
||||
// CHECK: "mhlo.compare"(%[[C0]], %[[ARG1A]])
|
||||
%1 = "mhlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
// CHECK: %[[ARG1B:.+]]: tensor<i64>
|
||||
// CHECK-DAG: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor<i64>
|
||||
// CHECK-DAG: %[[ADD0:.+]] = xla_hlo.add %[[ARG1B]], %[[ARG1B]]
|
||||
%2 = xla_hlo.add %arg1, %arg1 : tensor<i64>
|
||||
// CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], %[[ADD0]]
|
||||
%3 = xla_hlo.add %c1, %2 : tensor<i64>
|
||||
// CHECK: %[[ADD2:.+]] = xla_hlo.add %[[C1]], %[[ADD1]]
|
||||
%4 = xla_hlo.add %c1, %3 : tensor<i64>
|
||||
"xla_hlo.return"(%4) : (tensor<i64>) -> ()
|
||||
// CHECK-DAG: %[[C1:.+]] = mhlo.constant dense<2> : tensor<i64>
|
||||
// CHECK-DAG: %[[ADD0:.+]] = mhlo.add %[[ARG1B]], %[[ARG1B]]
|
||||
%2 = mhlo.add %arg1, %arg1 : tensor<i64>
|
||||
// CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]], %[[ADD0]]
|
||||
%3 = mhlo.add %c1, %2 : tensor<i64>
|
||||
// CHECK: %[[ADD2:.+]] = mhlo.add %[[C1]], %[[ADD1]]
|
||||
%4 = mhlo.add %c1, %3 : tensor<i64>
|
||||
"mhlo.return"(%4) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
@ -33,28 +33,28 @@ func @sink_const_to_while(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
|
||||
// CHECK-LABEL: func @sink_const_to_conditional
|
||||
func @sink_const_to_conditional(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
%c0 = xla_hlo.constant dense<1> : tensor<i64>
|
||||
%c1 = xla_hlo.constant dense<2> : tensor<i64>
|
||||
%0 = "xla_hlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
%1 = "xla_hlo.tuple"(%arg0) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
// CHECK: xla_hlo.if
|
||||
%2 = "xla_hlo.if"(%0, %1, %1) ( {
|
||||
%c0 = mhlo.constant dense<1> : tensor<i64>
|
||||
%c1 = mhlo.constant dense<2> : tensor<i64>
|
||||
%0 = "mhlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
%1 = "mhlo.tuple"(%arg0) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
// CHECK: mhlo.if
|
||||
%2 = "mhlo.if"(%0, %1, %1) ( {
|
||||
^bb0(%arg1: tuple<tensor<i64>>):
|
||||
// CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor<i64>
|
||||
%3 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
// CHECK: %[[ADD0:.+]] = xla_hlo.add %[[C0]],
|
||||
%4 = xla_hlo.add %c0, %3 : tensor<i64>
|
||||
%5 = "xla_hlo.tuple"(%4) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
"xla_hlo.return"(%5) : (tuple<tensor<i64>>) -> ()
|
||||
// CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor<i64>
|
||||
%3 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
// CHECK: %[[ADD0:.+]] = mhlo.add %[[C0]],
|
||||
%4 = mhlo.add %c0, %3 : tensor<i64>
|
||||
%5 = "mhlo.tuple"(%4) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
"mhlo.return"(%5) : (tuple<tensor<i64>>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tuple<tensor<i64>>):
|
||||
// CHECK: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor<i64>
|
||||
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
// CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]],
|
||||
%7 = xla_hlo.add %c1, %6 : tensor<i64>
|
||||
%8 = "xla_hlo.tuple"(%7) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
"xla_hlo.return"(%8) : (tuple<tensor<i64>>) -> ()
|
||||
// CHECK: %[[C1:.+]] = mhlo.constant dense<2> : tensor<i64>
|
||||
%6 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
// CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]],
|
||||
%7 = mhlo.add %c1, %6 : tensor<i64>
|
||||
%8 = "mhlo.tuple"(%7) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
"mhlo.return"(%8) : (tuple<tensor<i64>>) -> ()
|
||||
}) : (tensor<i1>, tuple<tensor<i64>>, tuple<tensor<i64>>) -> tuple<tensor<i64>>
|
||||
%9 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
%9 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
return %9 : tensor<i64>
|
||||
}
|
||||
|
@ -3,7 +3,7 @@
|
||||
// CHECK-LABEL: func @remove_noop
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> {
|
||||
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32>
|
||||
%0 = "mhlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32>
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
return %0 : tensor<2x3x9x5xi32>
|
||||
}
|
||||
@ -13,8 +13,8 @@ func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> {
|
||||
// CHECK-LABEL: func @keep_real_transpose
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
|
||||
// CHECK-NEXT: "xla_hlo.transpose"([[ARG]])
|
||||
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
|
||||
// CHECK-NEXT: "mhlo.transpose"([[ARG]])
|
||||
%0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
|
||||
return %0 : tensor<3x2x5x9xi32>
|
||||
}
|
||||
|
||||
@ -23,7 +23,7 @@ func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
|
||||
// CHECK-LABEL: func @keep_same_shape_real_transpose
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @keep_same_shape_real_transpose(%arg : tensor<4x4xi32>) -> tensor<4x4xi32> {
|
||||
// CHECK-NEXT: "xla_hlo.transpose"([[ARG]])
|
||||
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32>
|
||||
// CHECK-NEXT: "mhlo.transpose"([[ARG]])
|
||||
%0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32>
|
||||
return %0 : tensor<4x4xi32>
|
||||
}
|
||||
|
@ -4,7 +4,7 @@
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @fold_access(%arg : tensor<i32>) -> tensor<i32> {
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
%tuple = "xla_hlo.tuple"(%arg) : (tensor<i32>) -> tuple<tensor<i32>>
|
||||
%element = "xla_hlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
|
||||
%tuple = "mhlo.tuple"(%arg) : (tensor<i32>) -> tuple<tensor<i32>>
|
||||
%element = "mhlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
|
||||
return %element : tensor<i32>
|
||||
}
|
||||
|
@ -10,19 +10,19 @@ func @batchNormInference_2D_inner_features(
|
||||
%x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
|
||||
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
|
||||
-> (tensor<4x256xf32>) {
|
||||
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.001000e-05> : tensor<f32>
|
||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32>
|
||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
|
||||
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
|
||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32>
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.001000e-05> : tensor<f32>
|
||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32>
|
||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
|
||||
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
|
||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32>
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} :
|
||||
(tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>,
|
||||
tensor<256xf32>) -> tensor<4x256xf32>
|
||||
@ -36,12 +36,12 @@ func @batchNormInference_2D_inner_features(
|
||||
// the verifier to enforce the rest.
|
||||
// CHECK-SAME: %[[X:[^:]+]]
|
||||
// CHECK-SAME: %[[SCALE:[^:]+]]
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32>
|
||||
func @batchNormInference_4D_middle_features(
|
||||
%x: tensor<3x4x256x6xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
|
||||
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
|
||||
-> (tensor<3x4x256x6xf32>) {
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 1.001000e-05 : f32, feature_index = 2 : i64} :
|
||||
(tensor<3x4x256x6xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>,
|
||||
tensor<256xf32>) -> tensor<3x4x256x6xf32>
|
||||
@ -51,12 +51,12 @@ func @batchNormInference_4D_middle_features(
|
||||
// -----
|
||||
// CHECK-LABEL: @batchNormInference_f64
|
||||
// Validate that epsilon is properly promoted to f64
|
||||
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<f64>
|
||||
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f64>
|
||||
func @batchNormInference_f64(
|
||||
%x: tensor<4x256xf64>, %scale: tensor<256xf64>, %offset: tensor<256xf64>,
|
||||
%mean: tensor<256xf64>, %variance: tensor<256xf64>)
|
||||
-> (tensor<4x256xf64>) {
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 1.0 : f32, feature_index = 1 : i64} :
|
||||
(tensor<4x256xf64>, tensor<256xf64>, tensor<256xf64>, tensor<256xf64>,
|
||||
tensor<256xf64>) -> tensor<4x256xf64>
|
||||
@ -66,12 +66,12 @@ func @batchNormInference_f64(
|
||||
// -----
|
||||
// CHECK-LABEL: @batchNormInference_f16
|
||||
// Validate that epsilon is properly promoted to f64
|
||||
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<f16>
|
||||
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f16>
|
||||
func @batchNormInference_f16(
|
||||
%x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>,
|
||||
%mean: tensor<256xf16>, %variance: tensor<256xf16>)
|
||||
-> (tensor<4x256xf16>) {
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 1.0 : f32, feature_index = 1 : i64} :
|
||||
(tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>,
|
||||
tensor<256xf16>) -> tensor<4x256xf16>
|
||||
@ -85,7 +85,7 @@ func @batchNormInference_f16_overflow(
|
||||
%mean: tensor<256xf16>, %variance: tensor<256xf16>)
|
||||
-> (tensor<4x256xf16>) {
|
||||
// expected-warning @+1 {{Could not convert batch_norm epsilon to target fp type: opStatus = 24}}
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 0.00000001 : f32, feature_index = 1 : i64} :
|
||||
(tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>,
|
||||
tensor<256xf16>) -> tensor<4x256xf16>
|
||||
@ -108,26 +108,26 @@ func @batchNormInference_dynamic_shape(
|
||||
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
|
||||
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32>
|
||||
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
|
||||
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
|
||||
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex>
|
||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
|
||||
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
|
||||
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex>
|
||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<?x?x?x?xf32>
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<?x?x?x?xf32>
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 0.001 : f32, feature_index = 1 : i64} :
|
||||
(tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>,
|
||||
tensor<?xf32>) -> tensor<?x?x?x?xf32>
|
||||
|
@ -2,14 +2,14 @@
|
||||
|
||||
// CHECK-LABEL: func @multi_outputs_same
|
||||
func @multi_outputs_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%2 = "xla_hlo.add"(%1, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET:.*]]:2 = "xla_hlo.fusion"
|
||||
// CHECK-NEXT: xla_hlo.add
|
||||
// CHECK-NEXT: xla_hlo.subtract
|
||||
// CHECK-NEXT: xla_hlo.add
|
||||
// CHECK-NEXT: xla_hlo.return
|
||||
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%2 = "mhlo.add"(%1, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET:.*]]:2 = "mhlo.fusion"
|
||||
// CHECK-NEXT: mhlo.add
|
||||
// CHECK-NEXT: mhlo.subtract
|
||||
// CHECK-NEXT: mhlo.add
|
||||
// CHECK-NEXT: mhlo.return
|
||||
return %1, %2 : tensor<?x?xf32>, tensor<?x?xf32>
|
||||
}
|
||||
|
||||
@ -17,18 +17,18 @@ func @multi_outputs_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (ten
|
||||
|
||||
// CHECK-LABEL: func @multi_outputs_same_2
|
||||
func @multi_outputs_same_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
%0 = "xla_hlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "xla_hlo.abs"(%arg1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%2 = "xla_hlo.add"(%0, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%3 = "xla_hlo.abs"(%0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%4 = "xla_hlo.abs"(%1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET:.*]]:3 = "xla_hlo.fusion"
|
||||
// CHECK-NEXT: xla_hlo.abs
|
||||
// CHECK-NEXT: xla_hlo.abs
|
||||
// CHECK-NEXT: xla_hlo.add
|
||||
// CHECK-NEXT: xla_hlo.abs
|
||||
// CHECK-NEXT: xla_hlo.abs
|
||||
// CHECK-NEXT: xla_hlo.return
|
||||
%0 = "mhlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "mhlo.abs"(%arg1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%2 = "mhlo.add"(%0, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%3 = "mhlo.abs"(%0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%4 = "mhlo.abs"(%1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET:.*]]:3 = "mhlo.fusion"
|
||||
// CHECK-NEXT: mhlo.abs
|
||||
// CHECK-NEXT: mhlo.abs
|
||||
// CHECK-NEXT: mhlo.add
|
||||
// CHECK-NEXT: mhlo.abs
|
||||
// CHECK-NEXT: mhlo.abs
|
||||
// CHECK-NEXT: mhlo.return
|
||||
return %2, %3, %4 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
|
||||
}
|
||||
|
||||
@ -36,9 +36,9 @@ func @multi_outputs_same_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (t
|
||||
|
||||
// CHECK-LABEL: func @multi_outputs_not_sure_same
|
||||
func @multi_outputs_not_sure_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK-NOT: xla_hlo.fusion
|
||||
%1 = "xla_hlo.subtract"(%arg1, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%0 = "mhlo.add"(%arg0, %arg0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK-NOT: mhlo.fusion
|
||||
%1 = "mhlo.subtract"(%arg1, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
|
||||
}
|
||||
|
||||
@ -46,25 +46,25 @@ func @multi_outputs_not_sure_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>
|
||||
|
||||
// CHECK-LABEL: func @reduce
|
||||
func @reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET0:.*]] = "xla_hlo.fusion"
|
||||
// CHECK-NEXT: xla_hlo.add
|
||||
// CHECK-NEXT: xla_hlo.subtract
|
||||
// CHECK-NEXT: xla_hlo.return
|
||||
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET0:.*]] = "mhlo.fusion"
|
||||
// CHECK-NEXT: mhlo.add
|
||||
// CHECK-NEXT: mhlo.subtract
|
||||
// CHECK-NEXT: mhlo.return
|
||||
// Currently we do not support fuse arguments and ops without direct producer-consumer
|
||||
// relationship. Thus Reduce Op should not be fused with above two ops.
|
||||
|
||||
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%3 = "xla_hlo.reduce"(%arg0, %2) ( {
|
||||
%2 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%3 = "mhlo.reduce"(%arg0, %2) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||
%4 = "xla_hlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
|
||||
%4 = "mhlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%4) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
%4 = "xla_hlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%4 = "mhlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// Above two ops should not be fused since reduce op can not be
|
||||
// fused with its consumer.
|
||||
// CHECK-NOT: xla_hlo.fusion
|
||||
// CHECK-NOT: mhlo.fusion
|
||||
|
||||
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
|
||||
}
|
||||
@ -73,25 +73,25 @@ func @reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>
|
||||
|
||||
// CHECK-LABEL: func @reduce_2
|
||||
func @reduce_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
|
||||
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%3 = "xla_hlo.reduce"(%1, %2) ( {
|
||||
%2 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%3 = "mhlo.reduce"(%1, %2) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||
%4 = "xla_hlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
|
||||
%4 = "mhlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%4) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
// CHECK: %[[RET0:.*]]:2 = "xla_hlo.fusion"
|
||||
// CHECK-NEXT: xla_hlo.add
|
||||
// CHECK-NEXT: xla_hlo.subtract
|
||||
// CHECK-NEXT: xla_hlo.constant
|
||||
// CHECK-NEXT: xla_hlo.reduce
|
||||
// CHECK: xla_hlo.return
|
||||
// CHECK: %[[RET0:.*]]:2 = "mhlo.fusion"
|
||||
// CHECK-NEXT: mhlo.add
|
||||
// CHECK-NEXT: mhlo.subtract
|
||||
// CHECK-NEXT: mhlo.constant
|
||||
// CHECK-NEXT: mhlo.reduce
|
||||
// CHECK: mhlo.return
|
||||
|
||||
// Following op should not be fused with the above ops since reduce op can not be
|
||||
// fused with its consumer.
|
||||
// CHECK-NOT: xla_hlo.fusion
|
||||
%4 = "xla_hlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NOT: mhlo.fusion
|
||||
%4 = "mhlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
@ -9,15 +9,15 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%num_elements = shape.num_elements %shape
|
||||
%num_elements_as_index = shape.size_to_index %num_elements
|
||||
%flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex>
|
||||
%flat_a = "xla_hlo.dynamic_reshape"(%a, %flat_shape)
|
||||
%flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape)
|
||||
: (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
|
||||
// Apply operation.
|
||||
%flat_b = "xla_hlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%flat_b = "mhlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
|
||||
// Restore original shape.
|
||||
%shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor<?xindex>
|
||||
%b = "xla_hlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
|
||||
%b = "mhlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
|
||||
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
|
||||
return %b : tensor<*xf32>
|
||||
@ -33,12 +33,12 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
||||
// CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
|
||||
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[FLAT_B:.*]] = "xla_hlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
|
||||
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return %[[B]] : tensor<*xf32>
|
||||
%b = "xla_hlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %b : tensor<*xf32>
|
||||
}
|
||||
|
||||
@ -48,9 +48,9 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-LABEL: @sqrt_ranked
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<3x?xf32>)
|
||||
func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> {
|
||||
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||
// CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||
// CHECK-NEXT: return %[[B]] : tensor<3x?xf32>
|
||||
%b = "xla_hlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||
%b = "mhlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||
return %b : tensor<3x?xf32>
|
||||
}
|
||||
|
||||
@ -60,9 +60,9 @@ func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> {
|
||||
// CHECK-LABEL: @sqrt_static
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<2x3xf32>)
|
||||
func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
// CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
// CHECK-NEXT: return %[[B]] : tensor<2x3xf32>
|
||||
%b = "xla_hlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%b = "mhlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %b : tensor<2x3xf32>
|
||||
}
|
||||
|
||||
@ -77,12 +77,12 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
||||
// CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
|
||||
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
|
||||
// CHECK: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_B:.*]] = "xla_hlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_RESULT:.*]] = xla_hlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
|
||||
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
|
||||
// CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<*xf32>
|
||||
%result = xla_hlo.add %a, %b : tensor<*xf32>
|
||||
%result = mhlo.add %a, %b : tensor<*xf32>
|
||||
return %result : tensor<*xf32>
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ module attributes {tf.versions = {producer = 888 : i32}} {
|
||||
|
||||
// CHECK-LABEL: func @_func
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xi32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: tensor<?xi32> {xla_hlo.is_same_data_across_replicas}
|
||||
// CHECK-SAME: %[[ARG1:.*]]: tensor<?xi32> {mhlo.is_same_data_across_replicas}
|
||||
// CHECK-SAME: %[[ARG2:.*]]: tensor<?xi32>)
|
||||
func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>, %arg2: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
@ -54,9 +54,9 @@ module attributes {tf.versions = {producer = 888 : i32}} {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @_func
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xi32> {xla_hlo.is_same_data_across_replicas},
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xi32> {mhlo.is_same_data_across_replicas},
|
||||
// CHECK-SAME: %[[ARG1:.*]]: tensor<?xi32>,
|
||||
// CHECK-SAME: %[[ARG2:.*]]: tensor<!tf.resource<tensor<?xi32>>> {xla_hlo.is_same_data_across_replicas}
|
||||
// CHECK-SAME: %[[ARG2:.*]]: tensor<!tf.resource<tensor<?xi32>>> {mhlo.is_same_data_across_replicas}
|
||||
func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>, %arg2: tensor<!tf.resource<tensor<?xi32>>>) -> tensor<?xi32> {
|
||||
%0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
@ -78,7 +78,7 @@ module attributes {tf.versions = {producer = 888 : i32}} {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @_func
|
||||
// CHECK-NOT: xla_hlo.is_same_data_across_replicas
|
||||
// CHECK-NOT: mhlo.is_same_data_across_replicas
|
||||
func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
|
@ -17,8 +17,8 @@ func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tenso
|
||||
}
|
||||
|
||||
func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = xla_hlo.add %arg0, %arg0 : tensor<2xi32>
|
||||
%1 = xla_hlo.add %0, %arg0 : tensor<2xi32>
|
||||
%0 = mhlo.add %arg0, %arg0 : tensor<2xi32>
|
||||
%1 = mhlo.add %0, %arg0 : tensor<2xi32>
|
||||
return %1 : tensor<2xi32>
|
||||
}
|
||||
|
||||
@ -33,7 +33,7 @@ func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi3
|
||||
}
|
||||
|
||||
func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32>
|
||||
%0 = mhlo.divide %arg0, %arg0 : tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
@ -43,7 +43,7 @@ func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
|
||||
}
|
||||
|
||||
func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
%0 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
|
||||
%0 = mhlo.shift_left %arg0, %arg1 : tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
@ -53,17 +53,17 @@ func @div_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
|
||||
}
|
||||
|
||||
func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%0 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
%0 = mhlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%0 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
|
||||
%0 = mhlo.minimum %arg0, %arg1 : tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32>
|
||||
%0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
@ -73,7 +73,7 @@ func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
|
||||
}
|
||||
|
||||
func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32>
|
||||
%0 = mhlo.divide %arg0, %arg0 : tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
@ -83,7 +83,7 @@ func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor
|
||||
}
|
||||
|
||||
func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32>
|
||||
%0 = mhlo.subtract %arg0, %arg0 : tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
@ -93,7 +93,7 @@ func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
|
||||
}
|
||||
|
||||
func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
%0 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
|
||||
%0 = mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
@ -103,7 +103,7 @@ func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> ten
|
||||
}
|
||||
|
||||
func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> {
|
||||
%0 = xla_hlo.and %arg0, %arg0 : tensor<2xi1>
|
||||
%0 = mhlo.and %arg0, %arg0 : tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -118,7 +118,7 @@ func @and_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
|
||||
}
|
||||
|
||||
func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> {
|
||||
%0 = xla_hlo.or %arg0, %arg0 : tensor<2xi1>
|
||||
%0 = mhlo.or %arg0, %arg0 : tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -133,7 +133,7 @@ func @or_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
|
||||
}
|
||||
|
||||
func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
%0 = xla_hlo.or %arg0, %arg1 : tensor<4xi32>
|
||||
%0 = mhlo.or %arg0, %arg1 : tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
@ -148,7 +148,7 @@ func @bitwise_or_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?
|
||||
}
|
||||
|
||||
func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
%0 = xla_hlo.and %arg0, %arg1 : tensor<4xi32>
|
||||
%0 = mhlo.and %arg0, %arg1 : tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
@ -163,69 +163,69 @@ func @bitwise_and_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<
|
||||
}
|
||||
|
||||
func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = xla_hlo.power %arg0, %arg0 : tensor<2xf32>
|
||||
%0 = mhlo.power %arg0, %arg0 : tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = xla_hlo.power %arg0, %arg0 : tensor<?xf32>
|
||||
%0 = mhlo.power %arg0, %arg0 : tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<2x3xi32>
|
||||
%0 = mhlo.constant dense<0> : tensor<2x3xi32>
|
||||
%1 = "xla_chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
|
||||
%2 = xla_hlo.constant dense<0> : tensor<3xi32>
|
||||
%2 = mhlo.constant dense<0> : tensor<3xi32>
|
||||
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
||||
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
|
||||
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%6 = "xla_hlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%7 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%8 = xla_hlo.constant dense<1> : tensor<3xi32>
|
||||
%9 = xla_hlo.subtract %7, %8 : tensor<3xi32>
|
||||
%6 = "mhlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%7 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%8 = mhlo.constant dense<1> : tensor<3xi32>
|
||||
%9 = mhlo.subtract %7, %8 : tensor<3xi32>
|
||||
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%12 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%12 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%13 = "xla_chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%14 = "mhlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
return %14 : tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<3xi32>
|
||||
%1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
||||
%2 = xla_hlo.constant dense<0> : tensor<2x3xi32>
|
||||
%0 = mhlo.constant dense<0> : tensor<3xi32>
|
||||
%1 = "mhlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
||||
%2 = mhlo.constant dense<0> : tensor<2x3xi32>
|
||||
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
|
||||
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
|
||||
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%6 = "xla_hlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%7 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%8 = xla_hlo.constant dense<1> : tensor<2x3xi32>
|
||||
%9 = xla_hlo.subtract %7, %8 : tensor<2x3xi32>
|
||||
%6 = "mhlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%7 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%8 = mhlo.constant dense<1> : tensor<2x3xi32>
|
||||
%9 = mhlo.subtract %7, %8 : tensor<2x3xi32>
|
||||
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%12 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%13 = xla_hlo.divide %11, %12 : tensor<2x3xi32>
|
||||
%14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%12 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%13 = mhlo.divide %11, %12 : tensor<2x3xi32>
|
||||
%14 = "mhlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
return %14 : tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = xla_hlo.divide %arg0, %arg0 : tensor<2xf32>
|
||||
%1 = xla_hlo.divide %arg0, %arg0 : tensor<2xf32>
|
||||
%2 = "xla_hlo.floor"(%1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = mhlo.divide %arg0, %arg0 : tensor<2xf32>
|
||||
%1 = mhlo.divide %arg0, %arg0 : tensor<2xf32>
|
||||
%2 = "mhlo.floor"(%1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %2 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> {
|
||||
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
|
||||
%1 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
|
||||
%2 = "xla_hlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16>
|
||||
%2 = "mhlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16>
|
||||
return %2 : tensor<2x3xf16>
|
||||
}
|
||||
|
||||
func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -250,7 +250,7 @@ func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor
|
||||
}
|
||||
|
||||
func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -270,7 +270,7 @@ func @notequal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: ten
|
||||
}
|
||||
|
||||
func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -280,7 +280,7 @@ func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<
|
||||
}
|
||||
|
||||
func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -290,7 +290,7 @@ func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> t
|
||||
}
|
||||
|
||||
func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -300,7 +300,7 @@ func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2
|
||||
}
|
||||
|
||||
func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -310,426 +310,426 @@ func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tens
|
||||
}
|
||||
|
||||
func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> {
|
||||
%2 = "xla_hlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32>
|
||||
%2 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32>
|
||||
return %2 : tensor<6x3xf32>
|
||||
}
|
||||
|
||||
func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> {
|
||||
%2 = "xla_hlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32>
|
||||
%2 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32>
|
||||
return %2 : tensor<3x6xf32>
|
||||
}
|
||||
|
||||
func @const() -> tensor<2xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<2xi32>
|
||||
%0 = mhlo.constant dense<0> : tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
return %1 : tensor<1xi32>
|
||||
}
|
||||
|
||||
func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%1 = xla_hlo.constant dense<6> : tensor<i32>
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = mhlo.constant dense<6> : tensor<i32>
|
||||
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
|
||||
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
|
||||
return %3 : tensor<1xi32>
|
||||
}
|
||||
|
||||
func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%1 = xla_hlo.constant dense<6> : tensor<i32>
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = mhlo.constant dense<6> : tensor<i32>
|
||||
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
return %3 : tensor<?xi32>
|
||||
}
|
||||
|
||||
func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor<?x?xf32>) -> tensor<4x8xf32> {
|
||||
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = "xla_chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32>
|
||||
%3 = "xla_hlo.select"(%1, %arg0, %2) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
%2 = mhlo.constant dense<0.000000e+00> : tensor<4x8xf32>
|
||||
%3 = "mhlo.select"(%1, %arg0, %2) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
return %3 : tensor<4x8xf32>
|
||||
}
|
||||
|
||||
func @select(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
func @select_float(%arg0: tensor<2xi1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @select_multidimensional(%arg0: tensor<3x2xi1>, %arg1: tensor<3x2xi32>, %arg2: tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
return %0 : tensor<3x2xi32>
|
||||
}
|
||||
|
||||
func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
func @selectv2_pred_scalar(%arg0: tensor<i1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
|
||||
%0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
|
||||
%0 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%1 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%2 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
|
||||
return %2 : tensor<3x2xf32>
|
||||
}
|
||||
|
||||
func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||
%0 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi32>
|
||||
%1 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64>
|
||||
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
%0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi32>
|
||||
%1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64>
|
||||
%2 = "mhlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
return %2 : tensor<3x2x1xf32>
|
||||
}
|
||||
|
||||
func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||
%0 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64>
|
||||
%1 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64>
|
||||
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
%0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64>
|
||||
%1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64>
|
||||
%2 = "mhlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
return %2 : tensor<3x2x1xf32>
|
||||
}
|
||||
|
||||
func @transpose_dynamic_2d(%arg0: tensor<?x4xf32>) -> tensor<4x?xf32> {
|
||||
%0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<?x4xf32>) -> tensor<4x?xf32>
|
||||
%0 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%1 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%2 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<?x4xf32>) -> tensor<4x?xf32>
|
||||
return %2 : tensor<4x?xf32>
|
||||
}
|
||||
|
||||
func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%1 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%2 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %2 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @ceil_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @complex_abs(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
|
||||
%0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.cosine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.cosine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> {
|
||||
%0 = "xla_hlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1>
|
||||
%0 = "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
func @is_finite_dynamic(%arg0: tensor<?xf32>) -> tensor<?xi1> {
|
||||
%0 = "xla_hlo.is_finite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1>
|
||||
%0 = "mhlo.is_finite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
}
|
||||
|
||||
func @is_finite_unranked(%arg0: tensor<*xf32>) -> tensor<*xi1> {
|
||||
%0 = "xla_hlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1>
|
||||
%0 = "mhlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1>
|
||||
return %0 : tensor<*xi1>
|
||||
}
|
||||
|
||||
func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @log_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @log1p_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.log_plus_one"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.negate"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.negate"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
|
||||
%1 = xla_hlo.constant dense<2> : tensor<1xi64>
|
||||
%2 = xla_hlo.constant dense<5.000000e-01> : tensor<2xf32>
|
||||
%3 = xla_hlo.multiply %arg0, %2 : tensor<2xf32>
|
||||
%4 = "xla_hlo.tanh"(%3) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%5 = xla_hlo.multiply %4, %2 : tensor<2xf32>
|
||||
%6 = xla_hlo.add %5, %2 : tensor<2xf32>
|
||||
%0 = mhlo.constant dense<5.000000e-01> : tensor<f32>
|
||||
%1 = mhlo.constant dense<2> : tensor<1xi64>
|
||||
%2 = mhlo.constant dense<5.000000e-01> : tensor<2xf32>
|
||||
%3 = mhlo.multiply %arg0, %2 : tensor<2xf32>
|
||||
%4 = "mhlo.tanh"(%3) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%5 = mhlo.multiply %4, %2 : tensor<2xf32>
|
||||
%6 = mhlo.add %5, %2 : tensor<2xf32>
|
||||
return %6 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @sin_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.sine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.sine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @rsqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @sqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @sqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @bitcast_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.bitcast_convert"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @bitcast_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> {
|
||||
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32>
|
||||
%0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
|
||||
%1 = xla_hlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
|
||||
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
|
||||
%3 = xla_hlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
|
||||
%4 = "xla_hlo.sign"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
%5 = "xla_hlo.select"(%2, %3, %4) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
%6 = "xla_hlo.select"(%0, %1, %5) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
|
||||
%1 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
|
||||
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
|
||||
%3 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
|
||||
%4 = "mhlo.sign"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
%5 = "mhlo.select"(%2, %3, %4) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
%6 = "mhlo.select"(%0, %1, %5) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
return %6 : tensor<1x2x3x4xf32>
|
||||
}
|
||||
|
||||
func @size_rank_one_i32(%arg0: tensor<f32>) -> tensor<i32> {
|
||||
%0 = xla_hlo.constant dense<1> : tensor<i32>
|
||||
%0 = mhlo.constant dense<1> : tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
func @size_rank_one_i64(%arg0: tensor<f32>) -> tensor<i64> {
|
||||
%0 = xla_hlo.constant dense<1> : tensor<i64>
|
||||
%0 = mhlo.constant dense<1> : tensor<i64>
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex<f32>> {
|
||||
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
|
||||
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
|
||||
return %0 : tensor<3xcomplex<f32>>
|
||||
}
|
||||
|
||||
func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @convert_slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> {
|
||||
%0 = "xla_hlo.slice"(%arg0) {limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x4672xf32>) -> tensor<1x519xf32>
|
||||
%0 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x4672xf32>) -> tensor<1x519xf32>
|
||||
return %0 : tensor<1x519xf32>
|
||||
}
|
||||
|
||||
func @reshape(%arg0: tensor<4x6xf32>) -> tensor<2x2x6xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<4x6xf32>) -> tensor<2x2x6xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<4x6xf32>) -> tensor<2x2x6xf32>
|
||||
return %0 : tensor<2x2x6xf32>
|
||||
|
||||
}
|
||||
|
||||
func @convert_dot_1d_2d(%arg0: tensor<256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1xf32> {
|
||||
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256x1xf32>) -> tensor<1xf32>
|
||||
%0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256x1xf32>) -> tensor<1xf32>
|
||||
return %0 : tensor<1xf32>
|
||||
}
|
||||
|
||||
func @convert_dot_2d_1d(%arg0: tensor<1x256xf32>, %arg1: tensor<256xf32>) -> tensor<1xf32> {
|
||||
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256xf32>) -> tensor<1xf32>
|
||||
%0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256xf32>) -> tensor<1xf32>
|
||||
return %0 : tensor<1xf32>
|
||||
}
|
||||
|
||||
func @convert_dot_1d_1d(%arg0: tensor<256xf32>, %arg1: tensor<256xf32>) -> tensor<f32> {
|
||||
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
|
||||
%0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
func @convert_dot_2d_2d(%arg0: tensor<1x256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1x1xf32> {
|
||||
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
|
||||
%0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
|
||||
return %0 : tensor<1x1xf32>
|
||||
}
|
||||
|
||||
func @broadcast_in_dim_tf_style(%arg0: tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> {
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32>
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32>
|
||||
return %0 : tensor<3x8x8x16xf32>
|
||||
}
|
||||
|
||||
func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> {
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32>
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32>
|
||||
return %0 : tensor<3x8x8x16xf32>
|
||||
}
|
||||
|
||||
func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> {
|
||||
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32>
|
||||
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32>
|
||||
return %0 : tensor<3x5x1x4xf32>
|
||||
}
|
||||
|
||||
func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
|
||||
%0 = "xla_hlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
|
||||
{input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
|
||||
feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
|
||||
(tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
|
||||
@ -737,7 +737,7 @@ func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>
|
||||
}
|
||||
|
||||
func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
|
||||
%0 = "xla_hlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
|
||||
{input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
|
||||
feature_group_count = 207 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
|
||||
(tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
|
||||
@ -745,7 +745,7 @@ func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x2
|
||||
}
|
||||
|
||||
func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
|
||||
%0 = "xla_hlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
|
||||
{input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
|
||||
feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
|
||||
(tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
|
||||
@ -753,22 +753,22 @@ func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3
|
||||
}
|
||||
|
||||
func @convert_reduce_to_sum(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
|
||||
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = "xla_hlo.reduce"(%arg0, %0) ( {
|
||||
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = "mhlo.reduce"(%arg0, %0) ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%2 = xla_hlo.add %arg1, %arg2 : tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
|
||||
return %1 : tensor<1xf32>
|
||||
}
|
||||
|
||||
func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
|
||||
// "0xFF800000" represents -INF for f32.
|
||||
%0 = xla_hlo.constant dense<0xFF800000> : tensor<f32>
|
||||
%1 = "xla_hlo.reduce"(%arg0, %0) ( {
|
||||
%0 = mhlo.constant dense<0xFF800000> : tensor<f32>
|
||||
%1 = "mhlo.reduce"(%arg0, %0) ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%2 = xla_hlo.maximum %arg1, %arg2 : tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = mhlo.maximum %arg1, %arg2 : tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
|
||||
return %1 : tensor<1xf32>
|
||||
}
|
||||
@ -776,11 +776,11 @@ func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
|
||||
|
||||
func @convert_reduce_to_min(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
|
||||
// "0x7F800000" represents INF for f32.
|
||||
%0 = xla_hlo.constant dense<0x7F800000> : tensor<f32>
|
||||
%1 = "xla_hlo.reduce"(%arg0, %0) ( {
|
||||
%0 = mhlo.constant dense<0x7F800000> : tensor<f32>
|
||||
%1 = "mhlo.reduce"(%arg0, %0) ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%2 = xla_hlo.minimum %arg1, %arg2 : tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = mhlo.minimum %arg1, %arg2 : tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
|
||||
return %1 : tensor<1xf32>
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ func @single_arg_single_shape(%arg0: tensor<i1>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func0
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [1 : i32], shape_indices = [2 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>)
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1> {mhlo.padding_map = {padding_arg_indices = [1 : i32], shape_indices = [2 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>)
|
||||
func @func0(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
@ -44,7 +44,7 @@ func @single_arg_multiple_shapes(%arg0: tensor<i1>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func1
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [1 : i32, 2 : i32], shape_indices = [2 : i32, 3 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>)
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1> {mhlo.padding_map = {padding_arg_indices = [1 : i32, 2 : i32], shape_indices = [2 : i32, 3 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>)
|
||||
func @func1(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
@ -76,7 +76,7 @@ func @multiple_args(%arg0: tensor<i1>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func2
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [1 : i32, 2 : i32], shape_indices = [2 : i32, 3 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [3 : i32], shape_indices = [1 : i32]}})
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1> {mhlo.padding_map = {padding_arg_indices = [1 : i32, 2 : i32], shape_indices = [2 : i32, 3 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {mhlo.padding_map = {padding_arg_indices = [3 : i32], shape_indices = [1 : i32]}})
|
||||
func @func2(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>, %arg3: tensor<i1>, %arg4: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
@ -97,7 +97,7 @@ func @remap_indices(%arg0: tensor<i1>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func3
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [0 : i32], shape_indices = [2 : i32]}})
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {mhlo.padding_map = {padding_arg_indices = [0 : i32], shape_indices = [2 : i32]}})
|
||||
func @func3(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
@ -196,7 +196,7 @@ func @missing_padding_arg(%arg0: tensor<i1>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func8
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [2 : i32], shape_indices = [2 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>)
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {mhlo.padding_map = {padding_arg_indices = [2 : i32], shape_indices = [2 : i32]}}, %{{[a-z0-9]+}}: tensor<i1>)
|
||||
func @func8(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
@ -218,7 +218,7 @@ func @missing_replicated_input_indices(%arg0: tensor<i1>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func9
|
||||
// CHECK-NOT: xla_hlo.padding_map
|
||||
// CHECK-NOT: mhlo.padding_map
|
||||
func @func9(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
@ -240,7 +240,7 @@ func @non_contigous_indices(%arg0: tensor<i1>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func10
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [0 : i32], shape_indices = [6 : i32]}})
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {mhlo.padding_map = {padding_arg_indices = [0 : i32], shape_indices = [6 : i32]}})
|
||||
func @func10(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
@ -30,8 +30,8 @@ func @check_default_sharding_for_block_arg_inputs_outputs(%arg0: tensor<*xi32>)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func_without_sharding
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"})
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"})
|
||||
func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||
return %arg0 : tensor<*xi32>
|
||||
}
|
||||
@ -51,8 +51,8 @@ func @check_default_sharding_for_inputs_outputs(%arg0: tensor<*xi32>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func_without_sharding
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"})
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"})
|
||||
func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
|
||||
return %0 : tensor<*xi32>
|
||||
@ -72,8 +72,8 @@ func @check_sharding_for_input_correctly_identified(%arg0: tensor<*xi32>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @inputs_with_sharding_func
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"})
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"})
|
||||
func @inputs_with_sharding_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.A"(%0) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
@ -94,8 +94,8 @@ return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func_with_sharding
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {xla_hlo.sharding = "\04\05\06"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"})
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {mhlo.sharding = "\04\05\06"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"})
|
||||
func @func_with_sharding(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
@ -119,8 +119,8 @@ func @check_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func_with_sharding_after_identity
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {xla_hlo.sharding = "\04\05\06"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"})
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {mhlo.sharding = "\04\05\06"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"})
|
||||
func @func_with_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
@ -145,8 +145,8 @@ func @check_sharding_after_read_variable(%arg0: tensor<*xi32>, %arg1: tensor<*xi
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func_with_sharding_after_read_variable
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*x!tf.resource<tensor<32xf32>>> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*x!tf.resource<tensor<32xf32>>> {xla_hlo.sharding = "\04\05\06"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"})
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*x!tf.resource<tensor<32xf32>>> {mhlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*x!tf.resource<tensor<32xf32>>> {mhlo.sharding = "\04\05\06"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"})
|
||||
func @func_with_sharding_after_read_variable(%arg0: tensor<*x!tf.resource<tensor<32xf32>>>, %arg1: tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<32xf32>) -> tensor<32xf32>
|
||||
@ -173,8 +173,8 @@ func @check_sharding_after_cast_op(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func_with_sharding_after_cast
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {xla_hlo.sharding = "\04\05\06"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"})
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {mhlo.sharding = "\04\05\06"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"})
|
||||
func @func_with_sharding_after_cast(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.Cast"(%0) : (tensor<*xi32>) -> tensor<*xi1>
|
||||
@ -200,8 +200,8 @@ func @check_sharding_inside_functional_op(%arg0: tensor<*xi32>, %arg1: tensor<*x
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func_with_device_training_loop
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {xla_hlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {xla_hlo.sharding = "\04\05\06"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {xla_hlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {xla_hlo.sharding = "\0D\0E\0F"})
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<*xi32> {mhlo.sharding = "\01\02\03"}, %{{[a-z0-9]+}}: tensor<*xi1> {mhlo.sharding = "\04\05\06"})
|
||||
// CHECK-SAME: -> (tensor<*xi32> {mhlo.sharding = "\0A\0B\0C"}, tensor<*xi1> {mhlo.sharding = "\0D\0E\0F"})
|
||||
func @func_with_device_training_loop(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%1:2 = "tf.StatefulPartitionedCall"(%arg0){f= @func_body, config="", config_proto="", executor_type=""}
|
||||
: (tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
|
@ -45,8 +45,8 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
return %10 : tensor<i1>
|
||||
}
|
||||
// CHECK-LABEL: func @_func
|
||||
// CHECK-SAME: [[FUNCINPUT0:.*]]: tensor<2x112x112x12xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT1:%.*]]: tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT2:%.*]]: tensor<f32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[VAL_59:%.*]]: tensor<i64> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
|
||||
func @_func(%arg0: tensor<2x224x224x3xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<f32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<i64> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
|
||||
// CHECK-SAME: [[FUNCINPUT0:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT1:%.*]]: tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT2:%.*]]: tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[VAL_59:%.*]]: tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
|
||||
func @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
|
||||
%1 = "tf.Const"() {value = dense<0> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
|
||||
%2 = "tf.Const"() {value = dense<[7, 7, 3, 64]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
|
@ -33,7 +33,7 @@ namespace TFDevice {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kReplicationAttr[] = "xla_hlo.is_same_data_across_replicas";
|
||||
constexpr char kReplicationAttr[] = "mhlo.is_same_data_across_replicas";
|
||||
constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
|
||||
|
||||
// Analyzes the inputs to ClusterFuncOps in the module, and annotates their
|
||||
|
@ -47,14 +47,14 @@ namespace mlir {
|
||||
namespace TF {
|
||||
namespace {
|
||||
|
||||
using xla_hlo::DotDimensionNumbers;
|
||||
using mhlo::DotDimensionNumbers;
|
||||
|
||||
class ConvertConvOp : public OpConversionPattern<xla_hlo::ConvOp> {
|
||||
class ConvertConvOp : public OpConversionPattern<mhlo::ConvOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_hlo::ConvOp conv_op, ArrayRef<Value> args,
|
||||
mhlo::ConvOp conv_op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
if (!IsSupportedConvOp(conv_op)) {
|
||||
return failure();
|
||||
@ -120,7 +120,7 @@ class ConvertConvOp : public OpConversionPattern<xla_hlo::ConvOp> {
|
||||
};
|
||||
|
||||
private:
|
||||
bool IsSamePadding(xla_hlo::ConvOp conv_op, int num_spatial_dims,
|
||||
bool IsSamePadding(mhlo::ConvOp conv_op, int num_spatial_dims,
|
||||
ArrayRef<int64_t> strides, ArrayRef<int64_t> dilation,
|
||||
ArrayRef<int64_t> padding_array) const {
|
||||
for (auto i : llvm::seq<int>(0, num_spatial_dims)) {
|
||||
@ -142,7 +142,7 @@ class ConvertConvOp : public OpConversionPattern<xla_hlo::ConvOp> {
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateConvOp(xla_hlo::ConvOp conv_op, ArrayRef<int64_t> strides,
|
||||
void CreateConvOp(mhlo::ConvOp conv_op, ArrayRef<int64_t> strides,
|
||||
StringRef padding, ArrayRef<int64_t> dilation,
|
||||
bool is_depthwise_conv,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
@ -167,13 +167,13 @@ class ConvertConvOp : public OpConversionPattern<xla_hlo::ConvOp> {
|
||||
}
|
||||
}
|
||||
|
||||
bool IsSupportedConvOp(xla_hlo::ConvOp conv_op) const {
|
||||
bool IsSupportedConvOp(mhlo::ConvOp conv_op) const {
|
||||
if (!conv_op.lhs().getType().cast<ShapedType>().hasStaticShape() ||
|
||||
!conv_op.rhs().getType().cast<ShapedType>().hasStaticShape() ||
|
||||
!conv_op.getType().cast<ShapedType>().hasStaticShape())
|
||||
return false;
|
||||
|
||||
// All ones in "lhs_dilation" means this "xla_hlo.conv" op should be
|
||||
// All ones in "lhs_dilation" means this "mhlo.conv" op should be
|
||||
// converted to "tf.Conv2D" or "tf.DepthwiseConv2dNativeOp".
|
||||
if (conv_op.lhs_dilation().hasValue()) {
|
||||
auto lhs_dilation = conv_op.lhs_dilation().getValue();
|
||||
@ -236,15 +236,15 @@ class ConvertConvOp : public OpConversionPattern<xla_hlo::ConvOp> {
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertSliceOp : public OpConversionPattern<xla_hlo::SliceOp> {
|
||||
class ConvertSliceOp : public OpConversionPattern<mhlo::SliceOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_hlo::SliceOp slice_op, ArrayRef<Value> args,
|
||||
mhlo::SliceOp slice_op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
DenseIntElementsAttr strides = slice_op.strides();
|
||||
// Strides must be 1 otherwise we cannot legalize this `xla_hlo.slice` op.
|
||||
// Strides must be 1 otherwise we cannot legalize this `mhlo.slice` op.
|
||||
if (!strides.isSplat() ||
|
||||
strides.getSplatValue().cast<IntegerAttr>().getInt() != 1)
|
||||
return failure();
|
||||
@ -374,10 +374,10 @@ class DotDimensionsInfo {
|
||||
DimensionSetVector out_dimensions_;
|
||||
};
|
||||
|
||||
// Converts xla_hlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be
|
||||
// Converts mhlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be
|
||||
// inserted to convert to well-formed matrix multiply.
|
||||
Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
|
||||
auto dot_general_op = cast<xla_hlo::DotGeneralOp>(old_op);
|
||||
auto dot_general_op = cast<mhlo::DotGeneralOp>(old_op);
|
||||
auto lhs_type = dot_general_op.lhs().getType().cast<ShapedType>();
|
||||
auto rhs_type = dot_general_op.rhs().getType().cast<ShapedType>();
|
||||
auto result_type = dot_general_op.getResult().getType().cast<ShapedType>();
|
||||
@ -405,7 +405,7 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
|
||||
lhs_dot_dimensions_info.batch_dimensions().SizesArray(),
|
||||
lhs_dot_dimensions_info.out_dimensions().SizesArray(),
|
||||
lhs_dot_dimensions_info.contracting_dimensions().SizesArray());
|
||||
auto lhs_transposed = rewriter.create<xla_hlo::TransposeOp>(
|
||||
auto lhs_transposed = rewriter.create<mhlo::TransposeOp>(
|
||||
loc,
|
||||
RankedTensorType::get(lhs_transposed_shape, lhs_type.getElementType()),
|
||||
dot_general_op.lhs(),
|
||||
@ -423,7 +423,7 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
|
||||
rhs_dot_dimensions_info.batch_dimensions().SizesArray(),
|
||||
rhs_dot_dimensions_info.contracting_dimensions().SizesArray(),
|
||||
rhs_dot_dimensions_info.out_dimensions().SizesArray());
|
||||
auto rhs_transposed = rewriter.create<xla_hlo::TransposeOp>(
|
||||
auto rhs_transposed = rewriter.create<mhlo::TransposeOp>(
|
||||
loc,
|
||||
RankedTensorType::get(rhs_transposed_shape, rhs_type.getElementType()),
|
||||
dot_general_op.rhs(),
|
||||
@ -438,7 +438,7 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
|
||||
lhs_dot_dimensions_info.FlattenedOutDimensionSize()},
|
||||
llvm::ArrayRef<int64_t>{
|
||||
lhs_dot_dimensions_info.FlattenedContractingDimensionSize()});
|
||||
auto lhs_flattend = rewriter.create<xla_hlo::ReshapeOp>(
|
||||
auto lhs_flattend = rewriter.create<mhlo::ReshapeOp>(
|
||||
loc,
|
||||
RankedTensorType::get(lhs_flattened_shape, lhs_type.getElementType()),
|
||||
lhs_transposed.getResult());
|
||||
@ -450,7 +450,7 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
|
||||
rhs_dot_dimensions_info.FlattenedContractingDimensionSize()},
|
||||
llvm::ArrayRef<int64_t>{
|
||||
rhs_dot_dimensions_info.FlattenedOutDimensionSize()});
|
||||
auto rhs_flattend = rewriter.create<xla_hlo::ReshapeOp>(
|
||||
auto rhs_flattend = rewriter.create<mhlo::ReshapeOp>(
|
||||
loc,
|
||||
RankedTensorType::get(rhs_flattened_shape, rhs_type.getElementType()),
|
||||
rhs_transposed.getResult());
|
||||
@ -466,14 +466,14 @@ Value ConvertDotGeneralOp(PatternRewriter &rewriter, Operation *old_op) {
|
||||
loc, RankedTensorType::get(matmul_shape, result_type.getElementType()),
|
||||
lhs_flattend.getResult(), rhs_flattend.getResult());
|
||||
auto reshaped =
|
||||
rewriter.create<xla_hlo::ReshapeOp>(loc, result_type, matmul.getResult());
|
||||
rewriter.create<mhlo::ReshapeOp>(loc, result_type, matmul.getResult());
|
||||
return reshaped.getResult();
|
||||
}
|
||||
|
||||
// This function tries to match that the "xla_hlo::ReduceOp" only has one
|
||||
// input, one init_value and one result. Also "xla_hlo::ReduceOp" has two ops
|
||||
// This function tries to match that the "mhlo::ReduceOp" only has one
|
||||
// input, one init_value and one result. Also "mhlo::ReduceOp" has two ops
|
||||
// in the region, and the last one is return op.
|
||||
LogicalResult MatchReduceOpInput(xla_hlo::ReduceOp reduce_op) {
|
||||
LogicalResult MatchReduceOpInput(mhlo::ReduceOp reduce_op) {
|
||||
if (reduce_op.operands().size() != 1 || reduce_op.init_values().size() != 1 ||
|
||||
reduce_op.getResults().size() != 1)
|
||||
return failure();
|
||||
@ -489,23 +489,23 @@ LogicalResult MatchReduceOpInput(xla_hlo::ReduceOp reduce_op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO(jingpu): This "xla_hlo::ReduceOp" can corresponds to many TF ops
|
||||
// TODO(jingpu): This "mhlo::ReduceOp" can corresponds to many TF ops
|
||||
// with different ops in reduce_op.body. Now we only match to "tf.Max", "tf.Min"
|
||||
// and "tf.Sum".
|
||||
class ConvertReduceOpToTfSum : public OpConversionPattern<xla_hlo::ReduceOp> {
|
||||
class ConvertReduceOpToTfSum : public OpConversionPattern<mhlo::ReduceOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_hlo::ReduceOp reduce_op, ArrayRef<Value> args,
|
||||
mhlo::ReduceOp reduce_op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
if (failed(MatchReduceOpInput(reduce_op))) return failure();
|
||||
|
||||
Operation *first_op = &reduce_op.body().front().front();
|
||||
if (!llvm::isa<xla_hlo::AddOp>(first_op)) return failure();
|
||||
if (!llvm::isa<mhlo::AddOp>(first_op)) return failure();
|
||||
|
||||
// In `MatchReduceOpInput` function, we already match that the
|
||||
// "xla_hlo::ReduceOp" only has one input, one init_value and one result.
|
||||
// "mhlo::ReduceOp" only has one input, one init_value and one result.
|
||||
auto input = reduce_op.operands()[0];
|
||||
// Get reduction dimension.
|
||||
DenseIntElementsAttr dimension = reduce_op.dimensions();
|
||||
@ -531,20 +531,20 @@ class ConvertReduceOpToTfSum : public OpConversionPattern<xla_hlo::ReduceOp> {
|
||||
};
|
||||
};
|
||||
|
||||
class ConvertReduceOpToTfMax : public OpConversionPattern<xla_hlo::ReduceOp> {
|
||||
class ConvertReduceOpToTfMax : public OpConversionPattern<mhlo::ReduceOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_hlo::ReduceOp reduce_op, ArrayRef<Value> args,
|
||||
mhlo::ReduceOp reduce_op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
if (failed(MatchReduceOpInput(reduce_op))) return failure();
|
||||
|
||||
Operation *first_op = &reduce_op.body().front().front();
|
||||
if (!llvm::isa<xla_hlo::MaxOp>(first_op)) return failure();
|
||||
if (!llvm::isa<mhlo::MaxOp>(first_op)) return failure();
|
||||
|
||||
// In `MatchReduceOpInput` function, we already match that the
|
||||
// "xla_hlo::ReduceOp" only has one input, one init_value and one result.
|
||||
// "mhlo::ReduceOp" only has one input, one init_value and one result.
|
||||
auto input = reduce_op.operands()[0];
|
||||
// Get reduction dimension.
|
||||
DenseIntElementsAttr dimension = reduce_op.dimensions();
|
||||
@ -572,20 +572,20 @@ class ConvertReduceOpToTfMax : public OpConversionPattern<xla_hlo::ReduceOp> {
|
||||
};
|
||||
};
|
||||
|
||||
class ConvertReduceOpToTfMin : public OpConversionPattern<xla_hlo::ReduceOp> {
|
||||
class ConvertReduceOpToTfMin : public OpConversionPattern<mhlo::ReduceOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_hlo::ReduceOp reduce_op, ArrayRef<Value> args,
|
||||
mhlo::ReduceOp reduce_op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
if (failed(MatchReduceOpInput(reduce_op))) return failure();
|
||||
|
||||
Operation *first_op = &reduce_op.body().front().front();
|
||||
if (!llvm::isa<xla_hlo::MinOp>(first_op)) return failure();
|
||||
if (!llvm::isa<mhlo::MinOp>(first_op)) return failure();
|
||||
|
||||
// In `MatchReduceOpInput` function, we already match that the
|
||||
// "xla_hlo::ReduceOp" only has one input, one init_value and one result.
|
||||
// "mhlo::ReduceOp" only has one input, one init_value and one result.
|
||||
Value input = reduce_op.operands()[0];
|
||||
// Get reduction dimension.
|
||||
DenseIntElementsAttr dimension = reduce_op.dimensions();
|
||||
@ -645,10 +645,10 @@ ConstantOp ShapeToConst(PatternRewriter &rewriter, Value value) {
|
||||
return rewriter.create<ConstantOp>(value.getLoc(), attr_type, attr);
|
||||
}
|
||||
|
||||
// Converts xla_hlo.dot to tf.MatMul. Reshape ops will be inserted when
|
||||
// Converts mhlo.dot to tf.MatMul. Reshape ops will be inserted when
|
||||
// necessary.
|
||||
Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) {
|
||||
auto dot_op = cast<xla_hlo::DotOp>(old_op);
|
||||
auto dot_op = cast<mhlo::DotOp>(old_op);
|
||||
const mlir::Location loc = dot_op.getLoc();
|
||||
// Normalizes a ShapedType to 2d if the ShapedType is less than 2d by
|
||||
// inserting dummy 1-element dimensions in the begining. Does nothing if the
|
||||
@ -677,7 +677,7 @@ Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) {
|
||||
return input;
|
||||
}
|
||||
|
||||
auto reshape = rewriter.create<xla_hlo::ReshapeOp>(
|
||||
auto reshape = rewriter.create<mhlo::ReshapeOp>(
|
||||
loc, normalize_rank(input_type), input);
|
||||
return reshape.getResult();
|
||||
};
|
||||
@ -694,7 +694,7 @@ Value ConvertDotOp(PatternRewriter &rewriter, Operation *old_op) {
|
||||
loc, normalize_rank(output_type), a, b,
|
||||
/*transpose_a=*/rewriter.getBoolAttr(false), transpose_b);
|
||||
auto reshape =
|
||||
rewriter.create<xla_hlo::ReshapeOp>(loc, output_type, matmul.product());
|
||||
rewriter.create<mhlo::ReshapeOp>(loc, output_type, matmul.product());
|
||||
return reshape.getResult();
|
||||
}
|
||||
|
||||
@ -752,7 +752,7 @@ void LegalizeHloToTf::runOnFunction() {
|
||||
target.addLegalDialect<TensorFlowDialect>();
|
||||
target.addLegalOp<CallOp, ConstantOp>();
|
||||
if (failed(applyPartialConversion(getFunction(), target, patterns))) {
|
||||
getFunction().emitError("xla_hlo to TF legalization failed.");
|
||||
getFunction().emitError("mhlo to TF legalization failed.");
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
@ -151,7 +151,7 @@ LogicalResult GetRemappedPaddings(
|
||||
|
||||
// Inserts padding maps for relevant arguments as argument attributes on the
|
||||
// encapsulated function. The padding maps will be in the form of:
|
||||
// %arg0 : type {xla_hlo.padding_map = {shape_indices = [...],
|
||||
// %arg0 : type {mhlo.padding_map = {shape_indices = [...],
|
||||
// padding_arg_indices = [...]}}
|
||||
void AnnotateFunctionArgumentsWithPaddings(
|
||||
FuncOp func,
|
||||
@ -174,7 +174,7 @@ void AnnotateFunctionArgumentsWithPaddings(
|
||||
"padding_arg_indices",
|
||||
builder.getI32ArrayAttr(padding.getSecond().second));
|
||||
func.setArgAttr(
|
||||
padding.getFirst(), "xla_hlo.padding_map",
|
||||
padding.getFirst(), "mhlo.padding_map",
|
||||
builder.getDictionaryAttr({shape_indices, padding_arg_indices}));
|
||||
}
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ namespace mlir {
|
||||
namespace TFTPU {
|
||||
namespace {
|
||||
|
||||
constexpr char kShardingAttr[] = "xla_hlo.sharding";
|
||||
constexpr char kShardingAttr[] = "mhlo.sharding";
|
||||
|
||||
struct TPUShardingIdentificationPass
|
||||
: public PassWrapper<TPUShardingIdentificationPass,
|
||||
|
@ -108,7 +108,7 @@ Status GetXlaInputShapes(
|
||||
|
||||
// Rewrite layout with sharding, if sharding is set.
|
||||
auto sharding =
|
||||
main_func.getArgAttrOfType<mlir::StringAttr>(i, "xla_hlo.sharding");
|
||||
main_func.getArgAttrOfType<mlir::StringAttr>(i, "mhlo.sharding");
|
||||
if (!sharding) continue;
|
||||
|
||||
absl::optional<xla::HloSharding> arg_sharding;
|
||||
@ -253,7 +253,7 @@ static void RegisterDialects() {
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
mlir::registerDialect<mlir::shape::ShapeDialect>();
|
||||
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
|
||||
mlir::registerDialect<mlir::mhlo::XlaHloDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
@ -279,9 +279,9 @@ Status ConvertMLIRToXlaComputation(
|
||||
// LegalizeTFControlFlow encapsulates arguments for control flow operations
|
||||
// with a tuple argument which break the assumption of resource lifting
|
||||
// inside PromoteResourcesToArgs.
|
||||
tf2xla.addPass(mlir::xla_hlo::createLegalizeTFControlFlowPass());
|
||||
tf2xla.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
|
||||
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::xla_hlo::createLegalizeTFPass(true));
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(true));
|
||||
for (auto& target_pass : custom_legalization_passes) {
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(std::move(target_pass));
|
||||
}
|
||||
@ -290,7 +290,7 @@ Status ConvertMLIRToXlaComputation(
|
||||
|
||||
// Leverage tf2xla kernels for ops that didn't get lowered in the previous
|
||||
// legalization pass.
|
||||
tf2xla.addPass(mlir::xla_hlo::createLegalizeTfWithTf2XlaPass(device_type));
|
||||
tf2xla.addPass(mlir::mhlo::createLegalizeTfWithTf2XlaPass(device_type));
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
|
||||
// Run shape inference pass to propagate shapes through tensor_cast operations
|
||||
@ -303,12 +303,11 @@ Status ConvertMLIRToXlaComputation(
|
||||
// expose more graph pruning and canonicalization opportunities that are
|
||||
// necessary for the second LegalizeTFPass(allow_partial_conversion=false)
|
||||
// invocation.
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(
|
||||
mlir::xla_hlo::createLegalizeTFPass(false));
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(false));
|
||||
// In order to export to XLA, we must sink constants to control flow regions,
|
||||
// since XLA uses functional control flow.
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(
|
||||
mlir::xla_hlo::createSinkConstantsToControlFlowPass());
|
||||
mlir::mhlo::createSinkConstantsToControlFlowPass());
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
// Print the whole module after each pass which requires disabling
|
||||
|
@ -184,7 +184,7 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
|
||||
// only be lowered when tf.Shape is folded into a constant.
|
||||
constexpr char mlir_module[] = R"(
|
||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {xla_hlo.is_same_data_across_replicas}) -> tensor<10x19xf32> {
|
||||
func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {mhlo.is_same_data_across_replicas}) -> tensor<10x19xf32> {
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64>
|
||||
%1 = "tf.Reshape"(%arg1, %0) : (tensor<19x10xf32>, tensor<2xi64>) -> tensor<10x19xf32>
|
||||
return %1 : tensor<10x19xf32>
|
||||
@ -344,7 +344,7 @@ ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) {
|
||||
TEST(CompileSerializedMlirToXlaHloTest, ArgumentSharding) {
|
||||
constexpr char mlir_module[] = R"(
|
||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
func @main(%arg0: tensor<128x10xf32> {xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {xla_hlo.sharding = ""}) {
|
||||
func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {mhlo.sharding = ""}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -383,7 +383,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> () {
|
||||
TEST(CompileSerializedMlirToXlaHloTest, BadArgumentSharding) {
|
||||
constexpr char mlir_module[] = R"(
|
||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
func @main(%arg0: tensor<128x10xf32> {xla_hlo.sharding = "bad_sharding"}) {
|
||||
func @main(%arg0: tensor<128x10xf32> {mhlo.sharding = "bad_sharding"}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -403,7 +403,7 @@ module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
TEST(CompileSerializedMlirToXlaHloTest, ResultSharding) {
|
||||
constexpr char mlir_module[] = R"(
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 351 : i32}} {
|
||||
func @main(%arg0: tensor<128x10xf32>, %arg1: tensor<10x1024xf32>, %arg2: tensor<128x1024xf32>) -> (tensor<128x10xf32> {xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, tensor<10x1024xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<128x1024xf32> {xla_hlo.sharding = ""}) {
|
||||
func @main(%arg0: tensor<128x10xf32>, %arg1: tensor<10x1024xf32>, %arg2: tensor<128x1024xf32>) -> (tensor<128x10xf32> {mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, tensor<10x1024xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<128x1024xf32> {mhlo.sharding = ""}) {
|
||||
return %arg0, %arg1, %arg2 : tensor<128x10xf32>, tensor<10x1024xf32>, tensor<128x1024xf32>
|
||||
}
|
||||
}
|
||||
|
@ -87,15 +87,15 @@ struct MaterializeBroadcastsPass
|
||||
mlir::ConversionTarget conversionTarget(getContext());
|
||||
mlir::OwningRewritePatternList conversionPatterns;
|
||||
|
||||
// Consider the xla_hlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<mlir::xla_hlo::XlaHloDialect>();
|
||||
// Consider the mhlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<mlir::mhlo::XlaHloDialect>();
|
||||
// The conversion uses helpers from the Standard dialect.
|
||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||
|
||||
mlir::xla_hlo::SetupMaterializeBroadcastsLegality(&getContext(),
|
||||
&conversionTarget);
|
||||
mlir::xla_hlo::PopulateMaterializeBroadcastsPatterns(&getContext(),
|
||||
&conversionPatterns);
|
||||
mlir::mhlo::SetupMaterializeBroadcastsLegality(&getContext(),
|
||||
&conversionTarget);
|
||||
mlir::mhlo::PopulateMaterializeBroadcastsPatterns(&getContext(),
|
||||
&conversionPatterns);
|
||||
|
||||
if (failed(applyPartialConversion(getFunction(), conversionTarget,
|
||||
conversionPatterns))) {
|
||||
@ -108,7 +108,7 @@ struct UnfuseBatchNormPass
|
||||
: public mlir::PassWrapper<UnfuseBatchNormPass, mlir::FunctionPass> {
|
||||
void runOnFunction() override {
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
mlir::xla_hlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
|
||||
mlir::mhlo::PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
|
||||
mlir::applyPatternsAndFoldGreedily(getOperation(), patterns);
|
||||
}
|
||||
};
|
||||
@ -122,11 +122,11 @@ Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) {
|
||||
/*shouldPrintAfterPass=*/enable_if_vlog_is_on,
|
||||
/*printModuleScope=*/false,
|
||||
/*printAfterOnlyOnChange=*/false, llvm::dbgs());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::xla_hlo::createLegalizeTFPass(false));
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(false));
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
absl::make_unique<MaterializeBroadcastsPass>());
|
||||
pm.addNestedPass<mlir::FuncOp>(absl::make_unique<UnfuseBatchNormPass>());
|
||||
pm.addPass(mlir::xla_hlo::createLegalizeToLhloPass(
|
||||
pm.addPass(mlir::mhlo::createLegalizeToLhloPass(
|
||||
/*results_escape_functions=*/true));
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::xla_lhlo::createLhloCopyRemovalPass());
|
||||
|
||||
|
@ -115,9 +115,9 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_hlo_to_lhlo_with_xla",
|
||||
srcs = ["transforms/xla_hlo_to_lhlo_with_xla.cc"],
|
||||
hdrs = ["transforms/xla_hlo_to_lhlo_with_xla.h"],
|
||||
name = "mhlo_to_lhlo_with_xla",
|
||||
srcs = ["transforms/mhlo_to_lhlo_with_xla.cc"],
|
||||
hdrs = ["transforms/mhlo_to_lhlo_with_xla.h"],
|
||||
deps = [
|
||||
":hlo_module_importer",
|
||||
":hlo_utils",
|
||||
@ -363,7 +363,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
":xla_hlo_to_lhlo_with_xla",
|
||||
":mhlo_to_lhlo_with_xla",
|
||||
":xla_legalize_tf",
|
||||
":xla_legalize_tf_with_tf2xla",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
@ -376,7 +376,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_affine",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_gpu",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_parallel_loops",
|
||||
"//tensorflow/compiler/mlir/hlo:xla_hlo_fusion",
|
||||
"//tensorflow/compiler/mlir/hlo:mhlo_fusion",
|
||||
"//tensorflow/compiler/mlir/hlo:xla_legalize_control_flow",
|
||||
"//tensorflow/compiler/mlir/hlo:xla_legalize_tanh_to_approximation",
|
||||
"//tensorflow/compiler/mlir/hlo:xla_legalize_to_linalg",
|
||||
|
@ -42,7 +42,7 @@ mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config,
|
||||
}
|
||||
|
||||
// Converts the gather dimensions to attributes.
|
||||
mlir::xla_hlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
|
||||
mlir::mhlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
|
||||
const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder) {
|
||||
std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
|
||||
dnums.offset_dims().end());
|
||||
@ -50,14 +50,14 @@ mlir::xla_hlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
|
||||
dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
|
||||
std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
|
||||
dnums.start_index_map().end());
|
||||
return mlir::xla_hlo::GatherDimensionNumbers::get(
|
||||
return mlir::mhlo::GatherDimensionNumbers::get(
|
||||
Convert(offset_dims, builder), Convert(collapsed_slice_dims, builder),
|
||||
Convert(start_index_map, builder),
|
||||
builder->getI64IntegerAttr(dnums.index_vector_dim()),
|
||||
builder->getContext());
|
||||
}
|
||||
|
||||
mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
|
||||
mlir::mhlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
|
||||
const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder) {
|
||||
std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
|
||||
dnums.update_window_dims().end());
|
||||
@ -66,7 +66,7 @@ mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
|
||||
std::vector<int64_t> scatter_dims_to_operand_dims(
|
||||
dnums.scatter_dims_to_operand_dims().begin(),
|
||||
dnums.scatter_dims_to_operand_dims().end());
|
||||
return mlir::xla_hlo::ScatterDimensionNumbers::get(
|
||||
return mlir::mhlo::ScatterDimensionNumbers::get(
|
||||
Convert(update_window_dims, builder),
|
||||
Convert(inserted_window_dims, builder),
|
||||
Convert(scatter_dims_to_operand_dims, builder),
|
||||
@ -74,7 +74,7 @@ mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
|
||||
builder->getContext());
|
||||
}
|
||||
|
||||
mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers(
|
||||
mlir::mhlo::DotDimensionNumbers ConvertDotDimensionNumbers(
|
||||
const DotDimensionNumbers& dnums, mlir::Builder* builder) {
|
||||
std::vector<int64_t> rhs_contracting_dimensions(
|
||||
dnums.rhs_contracting_dimensions().begin(),
|
||||
@ -93,12 +93,12 @@ mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers(
|
||||
auto lhs_contracting_dims_attr = Convert(lhs_contracting_dimensions, builder);
|
||||
auto rhs_contracting_dims_attr = Convert(rhs_contracting_dimensions, builder);
|
||||
|
||||
return mlir::xla_hlo::DotDimensionNumbers::get(
|
||||
return mlir::mhlo::DotDimensionNumbers::get(
|
||||
lhs_batch_dims_attr, rhs_batch_dims_attr, lhs_contracting_dims_attr,
|
||||
rhs_contracting_dims_attr, builder->getContext());
|
||||
}
|
||||
|
||||
mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
|
||||
mlir::mhlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
|
||||
const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder) {
|
||||
llvm::SmallVector<int64_t, 4> input_spatial_dims(
|
||||
dnums.input_spatial_dimensions().begin(),
|
||||
@ -109,7 +109,7 @@ mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
|
||||
llvm::SmallVector<int64_t, 4> output_spatial_dims(
|
||||
dnums.output_spatial_dimensions().begin(),
|
||||
dnums.output_spatial_dimensions().end());
|
||||
return mlir::xla_hlo::ConvDimensionNumbers::get(
|
||||
return mlir::mhlo::ConvDimensionNumbers::get(
|
||||
builder->getI64IntegerAttr(dnums.input_batch_dimension()),
|
||||
builder->getI64IntegerAttr(dnums.input_feature_dimension()),
|
||||
Convert(input_spatial_dims, builder),
|
||||
|
@ -29,19 +29,19 @@ mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config,
|
||||
mlir::Builder* builder);
|
||||
|
||||
// Converts the gather dimensions to attributes.
|
||||
mlir::xla_hlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
|
||||
mlir::mhlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
|
||||
const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder);
|
||||
|
||||
// Converts the scatter dimensions to attributes.
|
||||
mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
|
||||
mlir::mhlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
|
||||
const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder);
|
||||
|
||||
// Converts the dot dimensions to attributes.
|
||||
mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers(
|
||||
mlir::mhlo::DotDimensionNumbers ConvertDotDimensionNumbers(
|
||||
const DotDimensionNumbers& dnums, mlir::Builder* builder);
|
||||
|
||||
// Converts the conv dimensions to attributes.
|
||||
mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
|
||||
mlir::mhlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
|
||||
const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder);
|
||||
|
||||
} // namespace xla
|
||||
|
@ -171,7 +171,7 @@ tensorflow::Status HloFunctionImporter::ImportInstructions(
|
||||
if (llvm::isa<FuncOp>(block->getParentOp())) {
|
||||
builder.create<mlir::ReturnOp>(loc, result);
|
||||
} else {
|
||||
builder.create<mlir::xla_hlo::ReturnOp>(loc, result);
|
||||
builder.create<mlir::mhlo::ReturnOp>(loc, result);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
@ -202,18 +202,18 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
}
|
||||
case HloOpcode::kIota: {
|
||||
return func_builder
|
||||
->create<mlir::xla_hlo::IotaOp>(
|
||||
->create<mlir::mhlo::IotaOp>(
|
||||
loc, result_type,
|
||||
func_builder->getI64IntegerAttr(
|
||||
Cast<HloIotaInstruction>(instruction)->iota_dimension()))
|
||||
.getOperation();
|
||||
}
|
||||
#define MakeAndReturn(mlir_op) \
|
||||
{ \
|
||||
mlir::Operation* new_operation = \
|
||||
func_builder->create<mlir::xla_hlo::mlir_op>(loc, result_type, \
|
||||
operands, attributes); \
|
||||
return new_operation; \
|
||||
#define MakeAndReturn(mlir_op) \
|
||||
{ \
|
||||
mlir::Operation* new_operation = \
|
||||
func_builder->create<mlir::mhlo::mlir_op>(loc, result_type, operands, \
|
||||
attributes); \
|
||||
return new_operation; \
|
||||
}
|
||||
case HloOpcode::kBroadcast: {
|
||||
// Note that the HLO broadcast is more powerful than the XLA broadcast op.
|
||||
@ -314,14 +314,14 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
instruction->dynamic_slice_sizes().begin(),
|
||||
instruction->dynamic_slice_sizes().end());
|
||||
return func_builder
|
||||
->create<mlir::xla_hlo::DynamicSliceOp>(
|
||||
->create<mlir::mhlo::DynamicSliceOp>(
|
||||
loc, result_type, operands[0],
|
||||
makeArrayRef(operands).drop_front(), Convert(slice_sizes))
|
||||
.getOperation();
|
||||
}
|
||||
case HloOpcode::kDynamicUpdateSlice: {
|
||||
return func_builder
|
||||
->create<mlir::xla_hlo::DynamicUpdateSliceOp>(
|
||||
->create<mlir::mhlo::DynamicUpdateSliceOp>(
|
||||
loc, result_type, operands[0], operands[1],
|
||||
llvm::ArrayRef<Value>(operands.begin() + 2, operands.end()))
|
||||
.getOperation();
|
||||
@ -354,10 +354,10 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
}
|
||||
|
||||
return func_builder
|
||||
->create<mlir::xla_hlo::PadOp>(loc, result_type, operands[0],
|
||||
operands[1], Convert(edge_padding_low),
|
||||
Convert(edge_padding_high),
|
||||
Convert(interior_padding))
|
||||
->create<mlir::mhlo::PadOp>(loc, result_type, operands[0],
|
||||
operands[1], Convert(edge_padding_low),
|
||||
Convert(edge_padding_high),
|
||||
Convert(interior_padding))
|
||||
.getOperation();
|
||||
}
|
||||
case HloOpcode::kScatter: {
|
||||
@ -372,7 +372,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"unique_indices", builder_->getBoolAttr(scatter->unique_indices())));
|
||||
|
||||
auto scatter_op = func_builder->create<mlir::xla_hlo::ScatterOp>(
|
||||
auto scatter_op = func_builder->create<mlir::mhlo::ScatterOp>(
|
||||
loc, result_type, operands, attributes);
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(),
|
||||
&scatter_op.update_computation()));
|
||||
@ -394,7 +394,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
Convert(window_dimensions)));
|
||||
attributes.push_back(ConvertPadding(padding));
|
||||
auto select_scatter_op =
|
||||
func_builder->create<mlir::xla_hlo::SelectAndScatterOp>(
|
||||
func_builder->create<mlir::mhlo::SelectAndScatterOp>(
|
||||
loc, result_type, operands, attributes);
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(),
|
||||
&select_scatter_op.select()));
|
||||
@ -410,7 +410,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
}
|
||||
case HloOpcode::kSlice: {
|
||||
return func_builder
|
||||
->create<mlir::xla_hlo::SliceOp>(
|
||||
->create<mlir::mhlo::SliceOp>(
|
||||
loc, result_type, operands[0],
|
||||
ConvertDimensions(instruction->slice_starts()),
|
||||
ConvertDimensions(instruction->slice_limits()),
|
||||
@ -419,7 +419,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
}
|
||||
case HloOpcode::kSort: {
|
||||
auto sort_instruction = Cast<HloSortInstruction>(instruction);
|
||||
auto sort_op = func_builder->create<mlir::xla_hlo::SortOp>(
|
||||
auto sort_op = func_builder->create<mlir::mhlo::SortOp>(
|
||||
loc, result_type, operands,
|
||||
builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
|
||||
builder_->getBoolAttr(sort_instruction->is_stable()));
|
||||
@ -437,8 +437,8 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
TF_RETURN_IF_ERROR(GetMlirTypes(
|
||||
{instruction->true_computation()->root_instruction()}, &rets));
|
||||
|
||||
auto op = func_builder->create<mlir::xla_hlo::IfOp>(loc, rets, operands,
|
||||
attributes);
|
||||
auto op = func_builder->create<mlir::mhlo::IfOp>(loc, rets, operands,
|
||||
attributes);
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(),
|
||||
&op.true_branch()));
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->false_computation(),
|
||||
@ -451,7 +451,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
{instruction->branch_computation(0)->root_instruction()}, &rets));
|
||||
|
||||
int num_branches = instruction->branch_count();
|
||||
auto op = func_builder->create<mlir::xla_hlo::CaseOp>(
|
||||
auto op = func_builder->create<mlir::mhlo::CaseOp>(
|
||||
loc, rets, operands, attributes, num_branches);
|
||||
for (auto index_and_computation :
|
||||
llvm::enumerate(instruction->branch_computations())) {
|
||||
@ -465,7 +465,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
// TODO(b/132057942): Support taking an uint64_t instead of an IntegerAttr
|
||||
// for concatenate dimension.
|
||||
return func_builder
|
||||
->create<mlir::xla_hlo::ConcatenateOp>(
|
||||
->create<mlir::mhlo::ConcatenateOp>(
|
||||
loc, result_type, operands,
|
||||
builder_->getI64IntegerAttr(instruction->concatenate_dimension()))
|
||||
.getOperation();
|
||||
@ -474,7 +474,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
|
||||
attributes.push_back(ConvertReplicaGroups(all_reduce->replica_groups()));
|
||||
attributes.push_back(ConvertChannelHandle(all_reduce->channel_id()));
|
||||
auto all_reduce_op = func_builder->create<mlir::xla_hlo::AllReduceOp>(
|
||||
auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
|
||||
loc, result_type, operands, attributes);
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
|
||||
&all_reduce_op.computation()));
|
||||
@ -484,7 +484,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
// Operands in the first half are reduction inputs and the remaining
|
||||
// operands are corresponding initial values.
|
||||
size_t num_inputs = operands.size() / 2;
|
||||
auto reduce = func_builder->create<mlir::xla_hlo::ReduceOp>(
|
||||
auto reduce = func_builder->create<mlir::mhlo::ReduceOp>(
|
||||
loc, result_type, llvm::makeArrayRef(operands).take_front(num_inputs),
|
||||
llvm::makeArrayRef(operands).drop_front(num_inputs),
|
||||
ConvertDimensions(instruction->dimensions()));
|
||||
@ -494,7 +494,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
}
|
||||
case HloOpcode::kReverse: {
|
||||
return func_builder
|
||||
->create<mlir::xla_hlo::ReverseOp>(
|
||||
->create<mlir::mhlo::ReverseOp>(
|
||||
loc, result_type, operands[0],
|
||||
ConvertDimensions(instruction->dimensions()))
|
||||
.getOperation();
|
||||
@ -505,14 +505,14 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
switch (instruction->random_distribution()) {
|
||||
case xla::RNG_UNIFORM:
|
||||
return func_builder
|
||||
->create<mlir::xla_hlo::RngUniformOp>(
|
||||
loc, result_type, operands[0], operands[1], shape)
|
||||
->create<mlir::mhlo::RngUniformOp>(loc, result_type, operands[0],
|
||||
operands[1], shape)
|
||||
.getOperation();
|
||||
|
||||
case xla::RNG_NORMAL:
|
||||
return func_builder
|
||||
->create<mlir::xla_hlo::RngNormalOp>(
|
||||
loc, result_type, operands[0], operands[1], shape)
|
||||
->create<mlir::mhlo::RngNormalOp>(loc, result_type, operands[0],
|
||||
operands[1], shape)
|
||||
.getOperation();
|
||||
|
||||
default:
|
||||
@ -522,7 +522,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
}
|
||||
}
|
||||
case HloOpcode::kWhile: {
|
||||
auto op = func_builder->create<mlir::xla_hlo::WhileOp>(
|
||||
auto op = func_builder->create<mlir::mhlo::WhileOp>(
|
||||
loc, operands[0].getType(), operands[0]);
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportAsRegion(*instruction->while_condition(), &op.cond()));
|
||||
@ -585,14 +585,14 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"window_dilations", ConvertDimensions(win_dilations)));
|
||||
attributes.push_back(ConvertPadding(padding));
|
||||
auto reduce = func_builder->create<mlir::xla_hlo::ReduceWindowOp>(
|
||||
auto reduce = func_builder->create<mlir::mhlo::ReduceWindowOp>(
|
||||
loc, result_type, operands, attributes);
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportAsRegion(*instruction->to_apply(), &reduce.body()));
|
||||
return reduce.getOperation();
|
||||
}
|
||||
case HloOpcode::kMap: {
|
||||
auto op = func_builder->create<mlir::xla_hlo::MapOp>(
|
||||
auto op = func_builder->create<mlir::mhlo::MapOp>(
|
||||
loc, result_type, operands,
|
||||
ConvertDimensions(instruction->dimensions()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -714,7 +714,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
// is not mentioned in xla client anywhere or in the hlo of our sample
|
||||
// models.
|
||||
default: {
|
||||
mlir::OperationState result(loc, "xla_hlo.unknown");
|
||||
mlir::OperationState result(loc, "mhlo.unknown");
|
||||
result.addOperands(operands);
|
||||
result.addTypes(result_type);
|
||||
for (auto attr : attributes) {
|
||||
@ -840,7 +840,7 @@ mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
|
||||
const xla::ChannelHandle& channel) {
|
||||
return builder_->getNamedAttr(
|
||||
"channel_handle",
|
||||
mlir::xla_hlo::ChannelHandle::get(
|
||||
mlir::mhlo::ChannelHandle::get(
|
||||
builder_->getI64IntegerAttr(channel.handle()),
|
||||
builder_->getI64IntegerAttr(channel.type()), context_));
|
||||
}
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
@ -143,4 +143,4 @@ class HloFunctionImporter {
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_FUNCTION_IMPORTER_H_
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_MODULE_IMPORTER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_MODULE_IMPORTER_H_
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
@ -59,4 +59,4 @@ class HloModuleImporter {
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_MODULE_IMPORTER_H_
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TO_MLIR_HLO_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_TO_MLIR_HLO_H_
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
@ -36,4 +36,4 @@ Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModule* hlo_module);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TO_MLIR_HLO_H_
|
||||
|
@ -197,7 +197,7 @@ StatusOr<mlir::Type> ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,
|
||||
}
|
||||
}
|
||||
|
||||
mlir::xla_hlo::GatherDimensionNumbers CreateGatherDimensionNumbers(
|
||||
mlir::mhlo::GatherDimensionNumbers CreateGatherDimensionNumbers(
|
||||
const GatherDimensionNumbers& input, mlir::Builder builder) {
|
||||
auto offset_dims = CreateDenseIntElementsAttrFromVector(
|
||||
llvm::SmallVector<int64, 4>{input.offset_dims().begin(),
|
||||
@ -215,7 +215,7 @@ mlir::xla_hlo::GatherDimensionNumbers CreateGatherDimensionNumbers(
|
||||
mlir::IntegerAttr index_vector_dim =
|
||||
builder.getI64IntegerAttr(input.index_vector_dim());
|
||||
|
||||
return mlir::xla_hlo::GatherDimensionNumbers::get(
|
||||
return mlir::mhlo::GatherDimensionNumbers::get(
|
||||
offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim,
|
||||
builder.getContext());
|
||||
}
|
||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
||||
|
||||
// This file defines helpers useful when creating or manipulating lhlo/hlo.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_UTILS_H_
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_UTILS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_UTILS_H_
|
||||
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
@ -39,7 +39,7 @@ mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
|
||||
StatusOr<mlir::Type> ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,
|
||||
mlir::Builder builder);
|
||||
|
||||
mlir::xla_hlo::GatherDimensionNumbers CreateGatherDimensionNumbers(
|
||||
mlir::mhlo::GatherDimensionNumbers CreateGatherDimensionNumbers(
|
||||
const GatherDimensionNumbers& input, mlir::Builder builder);
|
||||
|
||||
template <typename TypeT>
|
||||
@ -77,11 +77,11 @@ static StatusOr<mlir::Type> ConvertShapeToType(const Shape& shape,
|
||||
return builder.getTupleType(contents);
|
||||
}
|
||||
if (shape.IsToken()) {
|
||||
return mlir::xla_hlo::TokenType::get(builder.getContext());
|
||||
return mlir::mhlo::TokenType::get(builder.getContext());
|
||||
}
|
||||
return ConvertTensorShapeToType<TypeT>(shape, builder);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_UTILS_H_
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_UTILS_H_
|
||||
|
@ -33,8 +33,7 @@ namespace xla {
|
||||
static std::string GetMlirOpName(HloOpcode opcode) {
|
||||
std::string op_name = HloOpcodeString(opcode);
|
||||
absl::c_replace(op_name, '-', '_');
|
||||
return mlir::xla_hlo::XlaHloDialect::getDialectNamespace().str() + "." +
|
||||
op_name;
|
||||
return mlir::mhlo::XlaHloDialect::getDialectNamespace().str() + "." + op_name;
|
||||
}
|
||||
|
||||
static std::string ToString(mlir::Type ty) {
|
||||
@ -90,7 +89,7 @@ XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(mlir::DenseElementsAttr attr,
|
||||
CreateDenseElementsAttrFromLiteral(literal, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::ConstOp>(loc_, attr);
|
||||
auto op = builder_.create<mlir::mhlo::ConstOp>(loc_, attr);
|
||||
return MakeXlaOp(op);
|
||||
});
|
||||
}
|
||||
@ -108,7 +107,7 @@ StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal(
|
||||
mlir::ArrayAttr config_attr;
|
||||
if (precision_config)
|
||||
config_attr = ConvertPrecisionConfig(precision_config, &builder_);
|
||||
auto op = builder_.create<mlir::xla_hlo::ConvOp>(
|
||||
auto op = builder_.create<mlir::mhlo::ConvOp>(
|
||||
loc_, ty, GetValue(lhs), GetValue(rhs),
|
||||
GetI64ElementsAttr(window_strides, &builder_),
|
||||
ConvertPadding(padding, &builder_),
|
||||
@ -125,7 +124,7 @@ StatusOr<XlaOp> MlirHloBuilder::FftInternal(
|
||||
absl::Span<const int64> fft_length) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::FftOp>(
|
||||
auto op = builder_.create<mlir::mhlo::FftOp>(
|
||||
loc_, ty, GetValue(operand),
|
||||
builder_.getStringAttr(FftType_Name(fft_type)),
|
||||
GetI64ElementsAttr(fft_length, &builder_));
|
||||
@ -141,7 +140,7 @@ StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
|
||||
"CustomCall doesn't support operands shapes with layout");
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::CustomCallOp>(
|
||||
auto op = builder_.create<mlir::mhlo::CustomCallOp>(
|
||||
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
|
||||
/*has_side_effect=*/builder_.getBoolAttr(false),
|
||||
builder_.getStringAttr(opaque));
|
||||
@ -155,13 +154,13 @@ StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
|
||||
// Reduce takes two set of variadic operands inputs and init_values.
|
||||
// all_operands contains both of these so split operands into two parts.
|
||||
int64_t num_args = all_operands.size() / 2;
|
||||
auto op = builder_.create<mlir::xla_hlo::ReduceOp>(
|
||||
auto op = builder_.create<mlir::mhlo::ReduceOp>(
|
||||
loc_, GetValues(all_operands.first(num_args)),
|
||||
GetValues(all_operands.subspan(num_args)),
|
||||
GetI64ElementsAttr(dimensions_to_reduce, &builder_));
|
||||
TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
|
||||
if (op.getNumResults() == 1) return MakeXlaOp(op.getResult(0));
|
||||
auto tuple = builder_.create<mlir::xla_hlo::TupleOp>(loc_, op.getResults());
|
||||
auto tuple = builder_.create<mlir::mhlo::TupleOp>(loc_, op.getResults());
|
||||
return MakeXlaOp(tuple);
|
||||
}
|
||||
|
||||
@ -183,7 +182,7 @@ StatusOr<XlaOp> MlirHloBuilder::ReduceWindowInternal(
|
||||
auto padding_ty =
|
||||
mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
|
||||
builder_.getIntegerType(64));
|
||||
auto op = builder_.create<mlir::xla_hlo::ReduceWindowOp>(
|
||||
auto op = builder_.create<mlir::mhlo::ReduceWindowOp>(
|
||||
loc_, ty, GetValue(operand), GetValue(init_value),
|
||||
GetI64ElementsAttr(sizes, &builder_),
|
||||
GetI64ElementsAttr(strides, &builder_),
|
||||
@ -199,7 +198,7 @@ XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
mlir::Type ty,
|
||||
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::IotaOp>(
|
||||
auto op = builder_.create<mlir::mhlo::IotaOp>(
|
||||
loc_, ty,
|
||||
builder_.getIntegerAttr(builder_.getI64Type(), iota_dimension));
|
||||
return MakeXlaOp(op);
|
||||
@ -210,7 +209,7 @@ StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
|
||||
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::TransposeOp>(
|
||||
auto op = builder_.create<mlir::mhlo::TransposeOp>(
|
||||
loc_, ty, GetValue(operand), GetI64ElementsAttr(permutation, &builder_));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
@ -219,7 +218,7 @@ StatusOr<XlaOp> MlirHloBuilder::RevInternal(
|
||||
const Shape& shape, XlaOp operand, absl::Span<const int64> dimensions) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::ReverseOp>(
|
||||
auto op = builder_.create<mlir::mhlo::ReverseOp>(
|
||||
loc_, ty, GetValue(operand), GetI64ElementsAttr(dimensions, &builder_));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
@ -230,7 +229,7 @@ StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
|
||||
absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::GatherOp>(
|
||||
auto op = builder_.create<mlir::mhlo::GatherOp>(
|
||||
loc_, ty, GetValue(input), GetValue(start_indices),
|
||||
ConvertGatherDimensionNumbers(dimension_numbers, &builder_),
|
||||
GetI64ElementsAttr(slice_sizes, &builder_));
|
||||
@ -244,7 +243,7 @@ StatusOr<XlaOp> MlirHloBuilder::ScatterInternal(
|
||||
bool unique_indices) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::ScatterOp>(
|
||||
auto op = builder_.create<mlir::mhlo::ScatterOp>(
|
||||
loc_, ty, GetValue(input), GetValue(scatter_indices), GetValue(updates),
|
||||
ConvertScatterDimensionNumbers(dimension_numbers, &builder_),
|
||||
builder_.getBoolAttr(indices_are_sorted),
|
||||
@ -262,11 +261,11 @@ StatusOr<XlaOp> MlirHloBuilder::RngOpInternal(
|
||||
// and RngNormal can be mapped to the new op.
|
||||
std::string op_name;
|
||||
if (distribution == xla::RandomDistribution::RNG_UNIFORM) {
|
||||
op_name = "xla_hlo.rng_uniform";
|
||||
op_name = "mhlo.rng_uniform";
|
||||
} else {
|
||||
TF_RET_CHECK(distribution == xla::RandomDistribution::RNG_NORMAL)
|
||||
<< "Unexpected distribution: " << distribution;
|
||||
op_name = "xla_hlo.rng_normal";
|
||||
op_name = "mhlo.rng_normal";
|
||||
}
|
||||
|
||||
if (shape.is_dynamic())
|
||||
@ -288,7 +287,7 @@ StatusOr<XlaOp> MlirHloBuilder::ReshapeInternal(const Shape& shape,
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
mlir::Value value = GetValue(operand);
|
||||
auto op = builder_.create<mlir::xla_hlo::ReshapeOp>(loc_, ty, value);
|
||||
auto op = builder_.create<mlir::mhlo::ReshapeOp>(loc_, ty, value);
|
||||
return MakeXlaOp(op.getResult());
|
||||
}
|
||||
|
||||
@ -298,7 +297,7 @@ StatusOr<XlaOp> MlirHloBuilder::DotGeneralInternal(
|
||||
const PrecisionConfig* precision_config) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::DotGeneralOp>(
|
||||
auto op = builder_.create<mlir::mhlo::DotGeneralOp>(
|
||||
loc_, ty, GetValue(lhs), GetValue(rhs),
|
||||
ConvertDotDimensionNumbers(dimension_number, &builder_),
|
||||
ConvertPrecisionConfig(precision_config, &builder_));
|
||||
@ -312,7 +311,7 @@ StatusOr<XlaOp> MlirHloBuilder::InDimBroadcast(
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
mlir::Value value = GetValue(operand);
|
||||
auto op = builder_.create<mlir::xla_hlo::BroadcastInDimOp>(
|
||||
auto op = builder_.create<mlir::mhlo::BroadcastInDimOp>(
|
||||
loc_, ty, value, GetI64ElementsAttr(broadcast_dimensions, &builder_));
|
||||
return MakeXlaOp(op.getResult());
|
||||
}
|
||||
@ -322,7 +321,7 @@ StatusOr<XlaOp> MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs,
|
||||
ComparisonDirection direction) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::CompareOp>(
|
||||
auto op = builder_.create<mlir::mhlo::CompareOp>(
|
||||
loc_, ty, GetValue(lhs), GetValue(rhs),
|
||||
builder_.getStringAttr(ComparisonDirectionToString(direction)));
|
||||
return MakeXlaOp(op.getResult());
|
||||
@ -343,8 +342,8 @@ StatusOr<XlaOp> MlirHloBuilder::AddOpWithShape(
|
||||
|
||||
XlaOp MlirHloBuilder::CreateToken() {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::CreateTokenOp>(
|
||||
loc_, mlir::xla_hlo::TokenType::get(builder_.getContext())));
|
||||
return MakeXlaOp(builder_.create<mlir::mhlo::CreateTokenOp>(
|
||||
loc_, mlir::mhlo::TokenType::get(builder_.getContext())));
|
||||
});
|
||||
}
|
||||
|
||||
@ -353,16 +352,16 @@ StatusOr<XlaOp> MlirHloBuilder::InfeedWithTokenInternal(
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type result_type,
|
||||
ConvertShapeToType<mlir::RankedTensorType>(
|
||||
infeed_instruction_shape, builder_));
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::InfeedOp>(
|
||||
loc_, result_type, GetValue(token),
|
||||
/*infeed_config=*/config));
|
||||
return MakeXlaOp(
|
||||
builder_.create<mlir::mhlo::InfeedOp>(loc_, result_type, GetValue(token),
|
||||
/*infeed_config=*/config));
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::OutfeedWithTokenInternal(
|
||||
XlaOp operand, XlaOp token, const Shape& shape_with_layout,
|
||||
const string& outfeed_config) {
|
||||
auto token_type = mlir::xla_hlo::TokenType::get(builder_.getContext());
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::OutfeedOp>(
|
||||
auto token_type = mlir::mhlo::TokenType::get(builder_.getContext());
|
||||
return MakeXlaOp(builder_.create<mlir::mhlo::OutfeedOp>(
|
||||
loc_, token_type, GetValue(operand), GetValue(token), outfeed_config));
|
||||
}
|
||||
|
||||
@ -372,7 +371,7 @@ StatusOr<XlaOp> MlirHloBuilder::ConcatInDimInternal(
|
||||
mlir::Type result_type,
|
||||
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
|
||||
auto mlir_operands = GetValues(operands);
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::ConcatenateOp>(
|
||||
return MakeXlaOp(builder_.create<mlir::mhlo::ConcatenateOp>(
|
||||
loc_, result_type, mlir_operands, builder_.getI64IntegerAttr(dimension)));
|
||||
}
|
||||
|
||||
@ -382,7 +381,7 @@ StatusOr<XlaOp> MlirHloBuilder::GetTupleElementInternal(const Shape& shape,
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
mlir::Type result_type,
|
||||
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::GetTupleElementOp>(
|
||||
return MakeXlaOp(builder_.create<mlir::mhlo::GetTupleElementOp>(
|
||||
loc_, result_type, GetValue(tuple_data),
|
||||
builder_.getI32IntegerAttr(index)));
|
||||
}
|
||||
@ -390,7 +389,7 @@ StatusOr<XlaOp> MlirHloBuilder::GetTupleElementInternal(const Shape& shape,
|
||||
StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
|
||||
const Shape& shape, XlaOp operand, absl::Span<const int64> start_indices,
|
||||
absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::SliceOp>(
|
||||
return MakeXlaOp(builder_.create<mlir::mhlo::SliceOp>(
|
||||
loc_, GetValue(operand), GetI64ElementsAttr(start_indices, &builder_),
|
||||
GetI64ElementsAttr(limit_indices, &builder_),
|
||||
GetI64ElementsAttr(strides, &builder_)));
|
||||
@ -402,7 +401,7 @@ StatusOr<XlaOp> MlirHloBuilder::DynamicSliceInternal(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
mlir::Type result_ty,
|
||||
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::DynamicSliceOp>(
|
||||
return MakeXlaOp(builder_.create<mlir::mhlo::DynamicSliceOp>(
|
||||
loc_, result_ty, GetValue(operand), GetValues(start_indices),
|
||||
GetI64ElementsAttr(slice_sizes, &builder_)));
|
||||
}
|
||||
@ -413,7 +412,7 @@ StatusOr<XlaOp> MlirHloBuilder::DynamicUpdateSliceInternal(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
mlir::Type result_ty,
|
||||
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::DynamicUpdateSliceOp>(
|
||||
return MakeXlaOp(builder_.create<mlir::mhlo::DynamicUpdateSliceOp>(
|
||||
loc_, result_ty, GetValue(operand), GetValue(update),
|
||||
GetValues(start_indices)));
|
||||
}
|
||||
@ -432,7 +431,7 @@ StatusOr<XlaOp> MlirHloBuilder::PadInternal(
|
||||
high.push_back(dimension.edge_padding_high());
|
||||
internal.push_back(dimension.interior_padding());
|
||||
}
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::PadOp>(
|
||||
return MakeXlaOp(builder_.create<mlir::mhlo::PadOp>(
|
||||
loc_, result_type, GetValue(operand), GetValue(padding_value),
|
||||
GetI64ElementsAttr(low, &builder_), GetI64ElementsAttr(high, &builder_),
|
||||
GetI64ElementsAttr(internal, &builder_)));
|
||||
@ -444,7 +443,7 @@ StatusOr<XlaOp> MlirHloBuilder::TupleInternal(
|
||||
for (auto& element : elements) {
|
||||
operands.push_back(GetValue(element));
|
||||
}
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::TupleOp>(loc_, operands));
|
||||
return MakeXlaOp(builder_.create<mlir::mhlo::TupleOp>(loc_, operands));
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::CreateOp(
|
||||
|
@ -34,7 +34,7 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Provides a way to construct xla_hlo dialect ops in MLIR using XlaBuilder
|
||||
// Provides a way to construct mhlo dialect ops in MLIR using XlaBuilder
|
||||
// interface.
|
||||
//
|
||||
// Requires that all XlaOp arguments are either returned by any of the builder
|
||||
|
@ -69,12 +69,12 @@ using ::tensorflow::uint32;
|
||||
using ::tensorflow::uint64;
|
||||
using ::tensorflow::uint8;
|
||||
|
||||
constexpr char kPaddingMapAttr[] = "xla_hlo.padding_map";
|
||||
constexpr char kPaddingMapAttr[] = "mhlo.padding_map";
|
||||
constexpr char kShapeIndicesAttr[] = "shape_indices";
|
||||
constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
|
||||
constexpr char kShardingAttr[] = "xla_hlo.sharding";
|
||||
constexpr char kFrontendAttributesAttr[] = "xla_hlo.frontend_attributes";
|
||||
constexpr char kRepicationAttr[] = "xla_hlo.is_same_data_across_replicas";
|
||||
constexpr char kShardingAttr[] = "mhlo.sharding";
|
||||
constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
|
||||
constexpr char kRepicationAttr[] = "mhlo.is_same_data_across_replicas";
|
||||
|
||||
// Passes through everything except for unique_ptr, on which it calls get().
|
||||
// This exists to allow the generated code to call XLA functions that take a raw
|
||||
@ -247,7 +247,7 @@ static std::unique_ptr<xla::PrecisionConfig> Convert_precision_config(
|
||||
}
|
||||
|
||||
static xla::DotDimensionNumbers Convert_dot_dimension_numbers(
|
||||
mlir::xla_hlo::DotDimensionNumbers dot_dimension_numbers_attr) {
|
||||
mlir::mhlo::DotDimensionNumbers dot_dimension_numbers_attr) {
|
||||
xla::DotDimensionNumbers dot_dimension_numbers;
|
||||
|
||||
auto rhs_contracting_dimensions =
|
||||
@ -282,7 +282,7 @@ static xla::DotDimensionNumbers Convert_dot_dimension_numbers(
|
||||
}
|
||||
|
||||
static xla::ConvolutionDimensionNumbers Convert_dimension_numbers(
|
||||
mlir::xla_hlo::ConvDimensionNumbers input) {
|
||||
mlir::mhlo::ConvDimensionNumbers input) {
|
||||
xla::ConvolutionDimensionNumbers output;
|
||||
|
||||
output.set_input_batch_dimension(
|
||||
@ -315,7 +315,7 @@ static xla::ConvolutionDimensionNumbers Convert_dimension_numbers(
|
||||
return output;
|
||||
}
|
||||
|
||||
xla::ChannelHandle Convert_channel_handle(mlir::xla_hlo::ChannelHandle attr) {
|
||||
xla::ChannelHandle Convert_channel_handle(mlir::mhlo::ChannelHandle attr) {
|
||||
xla::ChannelHandle channel_handle;
|
||||
channel_handle.set_handle(ConvertAPInt(attr.handle().getValue()));
|
||||
channel_handle.set_type(static_cast<xla::ChannelHandle::ChannelType>(
|
||||
@ -333,7 +333,7 @@ static xla::ComparisonDirection Convert_comparison_direction(
|
||||
}
|
||||
|
||||
static xla::GatherDimensionNumbers Convert_dimension_numbers(
|
||||
mlir::xla_hlo::GatherDimensionNumbers input) {
|
||||
mlir::mhlo::GatherDimensionNumbers input) {
|
||||
xla::GatherDimensionNumbers output;
|
||||
|
||||
auto offset_dims = ConvertDenseIntAttr(input.offset_dims());
|
||||
@ -357,7 +357,7 @@ static xla::GatherDimensionNumbers Convert_dimension_numbers(
|
||||
}
|
||||
|
||||
static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
|
||||
mlir::xla_hlo::ScatterDimensionNumbers input) {
|
||||
mlir::mhlo::ScatterDimensionNumbers input) {
|
||||
xla::ScatterDimensionNumbers output;
|
||||
|
||||
auto update_window_dims = ConvertDenseIntAttr(input.update_window_dims());
|
||||
@ -574,7 +574,7 @@ llvm::SmallVector<xla::XlaOp, 4> GetTuple(mlir::Operation::operand_range values,
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) {
|
||||
@ -829,7 +829,7 @@ LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) {
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(ReturnOp op, OpLoweringContext ctx) {
|
||||
// Failure on purpose because `xla_hlo::ReturnOp` will be handled by
|
||||
// Failure on purpose because `mhlo::ReturnOp` will be handled by
|
||||
// special purpose logic in `ConvertToHloModule::Lower`.
|
||||
return failure();
|
||||
}
|
||||
@ -943,7 +943,7 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/operator_writers.inc"
|
||||
@ -1060,7 +1060,7 @@ LogicalResult ConvertToHloModule::Lower(
|
||||
return success();
|
||||
}
|
||||
|
||||
if (isa<xla_hlo::ReturnOp, mlir::ReturnOp>(inst)) {
|
||||
if (isa<mhlo::ReturnOp, mlir::ReturnOp>(inst)) {
|
||||
// Construct the return value for the function. If there are multiple
|
||||
// values returned, then create a tuple, else return value directly.
|
||||
xla::XlaOp return_value;
|
||||
@ -1405,7 +1405,7 @@ void AddDynamicParameterBindingEntry(xla::DynamicParameterBindingProto* binding,
|
||||
}
|
||||
|
||||
// Validates and populates dynamic parameter bindings from a module's entry
|
||||
// function `xla_hlo.padding_map` argument attributes to a `xla::HloModuleProto`
|
||||
// function `mhlo.padding_map` argument attributes to a `xla::HloModuleProto`
|
||||
// `DynamicParameterBindingProto`.
|
||||
LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
|
||||
xla::HloModuleProto* hlo_module_proto,
|
||||
|
@ -73,8 +73,8 @@ static StringRef GetClientBuilder(const Operator& op) {
|
||||
}
|
||||
|
||||
static void BuildOperator(const Operator& op, raw_ostream& os) {
|
||||
os << "mlir::LogicalResult ExportXlaOp(mlir::xla_hlo::"
|
||||
<< op.getCppClassName() << " op, OpLoweringContext ctx) {\n"
|
||||
os << "mlir::LogicalResult ExportXlaOp(mlir::mhlo::" << op.getCppClassName()
|
||||
<< " op, OpLoweringContext ctx) {\n"
|
||||
<< " auto& value_map = *ctx.values;\n"
|
||||
<< " auto result = op.getResult();\n";
|
||||
|
||||
@ -164,12 +164,12 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
|
||||
Operator op(def);
|
||||
|
||||
// Cast to the current operation and build the exporter.
|
||||
os << " if (auto xla_op = llvm::dyn_cast<mlir::xla_hlo::"
|
||||
os << " if (auto xla_op = llvm::dyn_cast<mlir::mhlo::"
|
||||
<< op.getCppClassName() << ">(op)) {\n";
|
||||
os << " return ";
|
||||
// The autogenerated converters aren't in the same namespace.
|
||||
// TODO(jpienaar): Reconsider this.
|
||||
if (def->getValueAsBit("hasCustomHLOConverter")) os << "mlir::xla_hlo::";
|
||||
if (def->getValueAsBit("hasCustomHLOConverter")) os << "mlir::mhlo::";
|
||||
os << "ExportXlaOp(xla_op, lowering_context);\n";
|
||||
os << " }\n";
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.abs
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%abs = "xla_hlo.abs"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%abs = "mhlo.abs"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %abs : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -22,7 +22,7 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
|
||||
// CHECK: lhlo.add
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "xla_hlo.add"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%res = "mhlo.add"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -37,7 +37,7 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32
|
||||
// CHECK: lhlo.and
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "xla_hlo.and"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
%res = "mhlo.and"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %res : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
@ -50,7 +50,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.ceil
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "xla_hlo.ceil"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%res = "mhlo.ceil"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -65,7 +65,7 @@ func @main(%value0: tensor<1x2xf32>, %value1: tensor<1x2xf32>) -> tensor<1x2xcom
|
||||
// CHECK: lhlo.complex
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "xla_hlo.complex"(%value0, %value1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
|
||||
%res = "mhlo.complex"(%value0, %value1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
|
||||
return %res : tensor<1x2xcomplex<f32>>
|
||||
}
|
||||
|
||||
@ -79,7 +79,7 @@ func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xcomplex<f32>> {
|
||||
// CHECK: lhlo.cosine
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "xla_hlo.cosine"(%value0) : (tensor<1x2xcomplex<f32>>) -> tensor<1x2xcomplex<f32>>
|
||||
%res = "mhlo.cosine"(%value0) : (tensor<1x2xcomplex<f32>>) -> tensor<1x2xcomplex<f32>>
|
||||
return %res : tensor<1x2xcomplex<f32>>
|
||||
}
|
||||
|
||||
@ -94,7 +94,7 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
|
||||
// CHECK: lhlo.divide
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "xla_hlo.divide"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%res = "mhlo.divide"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -107,7 +107,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.exponential
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "xla_hlo.exponential"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%res = "mhlo.exponential"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -120,7 +120,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.log
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "xla_hlo.log"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%res = "mhlo.log"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -135,7 +135,7 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
|
||||
// CHECK: lhlo.maximum
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "xla_hlo.maximum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%res = "mhlo.maximum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -150,7 +150,7 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
|
||||
// CHECK: lhlo.minimum
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "xla_hlo.minimum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%res = "mhlo.minimum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -165,7 +165,7 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
|
||||
// CHECK: lhlo.multiply
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "xla_hlo.multiply"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%res = "mhlo.multiply"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -178,7 +178,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.negate
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "xla_hlo.negate"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%res = "mhlo.negate"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -191,7 +191,7 @@ func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32>
|
||||
// CHECK: lhlo.real
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "xla_hlo.real"(%value0) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
%res = "mhlo.real"(%value0) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
return %res : tensor<1x2xf32>
|
||||
}
|
||||
|
||||
@ -204,7 +204,7 @@ func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32>
|
||||
// CHECK: lhlo.imag
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "xla_hlo.imag"(%value0) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
%res = "mhlo.imag"(%value0) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
return %res : tensor<1x2xf32>
|
||||
}
|
||||
|
||||
@ -219,7 +219,7 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32
|
||||
// CHECK: lhlo.remainder
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "xla_hlo.remainder"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
%res = "mhlo.remainder"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %res : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
@ -232,7 +232,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.rsqrt
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "xla_hlo.rsqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%res = "mhlo.rsqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -248,7 +248,7 @@ func @main(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>)
|
||||
// CHECK: lhlo.select
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%0 = "xla_hlo.select"(%pred, %lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
|
||||
%0 = "mhlo.select"(%pred, %lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -261,7 +261,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.sign
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "xla_hlo.sign"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%res = "mhlo.sign"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -274,7 +274,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.sqrt
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "xla_hlo.sqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%res = "mhlo.sqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -289,7 +289,7 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32
|
||||
// CHECK: lhlo.subtract
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "xla_hlo.subtract"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
%res = "mhlo.subtract"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %res : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
@ -302,7 +302,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.tanh
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "xla_hlo.tanh"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%res = "mhlo.tanh"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -317,10 +317,10 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW1:.*]] = std.view %[[ARG3]]{{.*}} : memref<100xi8> to memref<5x5xf32>
|
||||
// CHECK: "xla_lhlo.sort"(%[[ARG0]], %[[ARG1]], %[[VIEW0]], %[[VIEW1]])
|
||||
func @main(%key: tensor<5x5xi32>, %value: tensor<5x5xf32>) -> tuple<tensor<5x5xi32>, tensor<5x5xf32>> {
|
||||
%res = "xla_hlo.sort"(%key, %value) ({
|
||||
%res = "mhlo.sort"(%key, %value) ({
|
||||
^bb0(%a: tensor<i32>, %b: tensor<i32>, %c: tensor<f32>, %d: tensor<f32>):
|
||||
%ret = "xla_hlo.compare"(%c, %d) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%ret) : (tensor<i1>) -> ()
|
||||
%ret = "mhlo.compare"(%c, %d) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%ret) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true}: (tensor<5x5xi32>, tensor<5x5xf32>) -> tuple<tensor<5x5xi32>, tensor<5x5xf32>>
|
||||
|
||||
return %res : tuple<tensor<5x5xi32>, tensor<5x5xf32>>
|
||||
|
@ -15,11 +15,11 @@ func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) ->
|
||||
// CHECK: [[BCASTHEAD:%.*]] = "shape.broadcast"([[LHSHEAD]], [[RHSHEAD]]) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
// CHECK: [[LHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[LHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
// CHECK: [[LHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[LHSBCASTSHAPE]] : tensor<3xindex>
|
||||
// CHECK: [[LHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32>
|
||||
// CHECK: [[LHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32>
|
||||
// CHECK: [[RHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[RHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
// CHECK: [[RHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[RHSBCASTSHAPE]] : tensor<3xindex>
|
||||
// CHECK: [[RHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32>
|
||||
// CHECK: [[RESULT:%.*]] = "xla_hlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
|
||||
// CHECK: [[RHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32>
|
||||
// CHECK: [[RESULT:%.*]] = "mhlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
|
||||
// CHECK: return [[RESULT]] : tensor<3x4x4xf32>
|
||||
// CHECK: }
|
||||
|
||||
@ -29,9 +29,9 @@ func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) ->
|
||||
|
||||
func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> {
|
||||
// CHECK-LABEL: func @batchmatmulv2_lhs_batch
|
||||
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
|
||||
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
||||
// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
|
||||
// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
|
||||
// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
||||
// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
@ -42,9 +42,9 @@ func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>)
|
||||
|
||||
func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> {
|
||||
// CHECK-LABEL: func @batchmatmulv2_rhs_batch
|
||||
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
||||
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
|
||||
// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
|
||||
// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
||||
// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
|
||||
// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
@ -55,7 +55,7 @@ func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>)
|
||||
|
||||
func @batchmatmulv2_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
|
||||
// CHECK-LABEL: func @batchmatmulv2_dynamic
|
||||
// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
|
||||
// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
@ -66,7 +66,7 @@ func @batchmatmulv2_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>)
|
||||
|
||||
func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> {
|
||||
// CHECK-LABEL: func @batchmatmulv2_adj_real
|
||||
// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
|
||||
// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
|
||||
@ -78,14 +78,14 @@ func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) ->
|
||||
func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex<f32>>, %arg1: tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
|
||||
// CHECK-LABEL: func @batchmatmulv2_adj_complex(
|
||||
// CHECK-SAME: [[LHS:%.*]]: tensor<5x2xcomplex<f32>>, [[RHS:%.*]]: tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
|
||||
// CHECK: [[LHSRE:%.*]] = "xla_hlo.real"([[LHS]])
|
||||
// CHECK: [[LHSIM:%.*]] = "xla_hlo.imag"([[LHS]])
|
||||
// CHECK: [[LHSIMNEG:%.*]] = "xla_hlo.negate"([[LHSIM]])
|
||||
// CHECK: [[LHSCONJ:%.*]] = "xla_hlo.complex"([[LHSRE]], [[LHSIMNEG]])
|
||||
// CHECK: [[RHSRE:%.*]] = "xla_hlo.real"([[RHS]])
|
||||
// CHECK: [[RHSIM:%.*]] = "xla_hlo.imag"([[RHS]])
|
||||
// CHECK: [[RHSIMNEG:%.*]] = "xla_hlo.negate"([[RHSIM]])
|
||||
// CHECK: [[RHSCONJ:%.*]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]])
|
||||
// CHECK: [[LHSRE:%.*]] = "mhlo.real"([[LHS]])
|
||||
// CHECK: [[LHSIM:%.*]] = "mhlo.imag"([[LHS]])
|
||||
// CHECK: [[LHSIMNEG:%.*]] = "mhlo.negate"([[LHSIM]])
|
||||
// CHECK: [[LHSCONJ:%.*]] = "mhlo.complex"([[LHSRE]], [[LHSIMNEG]])
|
||||
// CHECK: [[RHSRE:%.*]] = "mhlo.real"([[RHS]])
|
||||
// CHECK: [[RHSIM:%.*]] = "mhlo.imag"([[RHS]])
|
||||
// CHECK: [[RHSIMNEG:%.*]] = "mhlo.negate"([[RHSIM]])
|
||||
// CHECK: [[RHSCONJ:%.*]] = "mhlo.complex"([[RHSRE]], [[RHSIMNEG]])
|
||||
// CHECK: shape.shape_of [[LHSCONJ]]
|
||||
// CHECK: shape.shape_of [[RHSCONJ]]
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex<f32>>, tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>
|
||||
|
@ -11,8 +11,8 @@
|
||||
|
||||
// CHECK-LABEL: func @add
|
||||
func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// CHECK-NEXT: %[[SUM0:.*]] = xla_hlo.add %arg0, %arg0 : tensor<2xi32>
|
||||
// CHECK-NEXT: %[[SUM1:.*]] = xla_hlo.add %[[SUM0]], %arg0 : tensor<2xi32>
|
||||
// CHECK-NEXT: %[[SUM0:.*]] = mhlo.add %arg0, %arg0 : tensor<2xi32>
|
||||
// CHECK-NEXT: %[[SUM1:.*]] = mhlo.add %[[SUM0]], %arg0 : tensor<2xi32>
|
||||
// CHECK-NEXT: return %[[SUM1]] : tensor<2xi32>
|
||||
%0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
%1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
@ -24,8 +24,8 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// patterns unambiguous and more interesting (once broadcastable trait is
|
||||
// fixed upstream).
|
||||
func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-NEXT: xla_hlo.add %[[LHS_BCAST]], %arg1
|
||||
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0: tensor<1x2xi32>
|
||||
}
|
||||
@ -34,8 +34,8 @@ func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
|
||||
// TODO(laurenzo): Change this to a (4x1x1 + 1x4x4x4) shaped add once upstream
|
||||
// broadcastable bug is fixed (helps make the CHECK matching unambiguous)
|
||||
func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
|
||||
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}
|
||||
// CHECK-NEXT: xla_hlo.add %[[LHS_BCAST]], %arg1
|
||||
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}
|
||||
// CHECK-NEXT: mhlo.add %[[LHS_BCAST]], %arg1
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
|
||||
return %0: tensor<4x4x4x4xi32>
|
||||
}
|
||||
@ -50,9 +50,9 @@ func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1
|
||||
// CHECK-NEXT: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
|
||||
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-NEXT: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] : tensor<?x?xi32>
|
||||
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-NEXT: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] : tensor<?x?xi32>
|
||||
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
return %0: tensor<?x?xi32>
|
||||
@ -60,7 +60,7 @@ func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
|
||||
|
||||
// CHECK-LABEL: func @div
|
||||
func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32>
|
||||
// CHECK-NEXT: %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32>
|
||||
// CHECK-NEXT: return %0 : tensor<2xi32>
|
||||
%0 = "tf.Div"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0: tensor<2xi32>
|
||||
@ -68,7 +68,7 @@ func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
|
||||
// CHECK-LABEL: func @shift_left
|
||||
func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK: xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
|
||||
// CHECK: mhlo.shift_left %arg0, %arg1 : tensor<4xi32>
|
||||
%0 = "tf.LeftShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
@ -82,21 +82,21 @@ func @div_unranked(%arg0: tensor<*xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi
|
||||
|
||||
// CHECK-LABEL: func @maximum
|
||||
func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
// CHECK-NEXT: mhlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
%0 = "tf.Maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @minimum
|
||||
func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
|
||||
// CHECK-NEXT: mhlo.minimum %arg0, %arg1 : tensor<4xf32>
|
||||
%0 = "tf.Minimum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mul
|
||||
func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32>
|
||||
// CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32>
|
||||
// CHECK-NEXT: return %0 : tensor<2xi32>
|
||||
%0 = "tf.Mul"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0: tensor<2xi32>
|
||||
@ -104,14 +104,14 @@ func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
|
||||
// CHECK-LABEL: func @real_div
|
||||
func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// CHECK-NEXT: %0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32>
|
||||
// CHECK-NEXT: %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32>
|
||||
%0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0: tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sub
|
||||
func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// CHECK-NEXT: %0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32>
|
||||
// CHECK-NEXT: %0 = mhlo.subtract %arg0, %arg0 : tensor<2xi32>
|
||||
// CHECK-NEXT: return %0 : tensor<2xi32>
|
||||
%0 = "tf.Sub"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0: tensor<2xi32>
|
||||
@ -119,7 +119,7 @@ func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
|
||||
// CHECK-LABEL: func @shift_right
|
||||
func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
|
||||
// CHECK: mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
|
||||
%0 = "tf.RightShift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
@ -140,7 +140,7 @@ func @broadcast_shift_right_unsigned(%arg0: tensor<4xui8>, %arg1: tensor<2x4xui8
|
||||
|
||||
// CHECK-LABEL: func @and
|
||||
func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: xla_hlo.and
|
||||
// CHECK-NEXT: mhlo.and
|
||||
%0 = "tf.LogicalAnd"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
@ -154,28 +154,28 @@ func @and_unranked(%arg0: tensor<*xi1>, %arg1: tensor<*xi1>) -> tensor<*xi1> {
|
||||
|
||||
// CHECK-LABEL: func @or
|
||||
func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: xla_hlo.or
|
||||
// CHECK-NEXT: mhlo.or
|
||||
%0 = "tf.LogicalOr"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @bitwise_or
|
||||
func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK-NEXT: xla_hlo.or
|
||||
// CHECK-NEXT: mhlo.or
|
||||
%0 = "tf.BitwiseOr"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %0: tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @bitwise_and
|
||||
func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK-NEXT: xla_hlo.and
|
||||
// CHECK-NEXT: mhlo.and
|
||||
%0 = "tf.BitwiseAnd"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %0: tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @pow
|
||||
func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK-NEXT: xla_hlo.power
|
||||
// CHECK-NEXT: mhlo.power
|
||||
%0 = "tf.Pow"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0: tensor<2xf32>
|
||||
}
|
||||
@ -188,7 +188,7 @@ func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @equal
|
||||
func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"}
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"}
|
||||
%0 = "tf.Equal"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
@ -202,9 +202,9 @@ func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1>
|
||||
// CHECK-DAG: %[[LHS_SHAPE1:.+]] = shape.shape_of %arg0
|
||||
// CHECK-NEXT: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE1]], %[[RHS_SHAPE]])
|
||||
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"}
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = "mhlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"}
|
||||
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
||||
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
return %0: tensor<?xi1>
|
||||
@ -212,8 +212,8 @@ func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1>
|
||||
|
||||
// CHECK-LABEL: func @equal_broadcast
|
||||
func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%[[LHS_BCAST]], %arg1) {comparison_direction = "EQ"}
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-NEXT: "mhlo.compare"(%[[LHS_BCAST]], %arg1) {comparison_direction = "EQ"}
|
||||
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0: tensor<1x2xi1>
|
||||
}
|
||||
@ -255,7 +255,7 @@ func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1>
|
||||
|
||||
// CHECK-LABEL: func @notequal
|
||||
func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"}
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"}
|
||||
%0 = "tf.NotEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
@ -268,15 +268,15 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
|
||||
// CHECK-LABEL: func @greater
|
||||
func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"}
|
||||
// CHECK: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"}
|
||||
%0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @broadcast_greater
|
||||
func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%[[LHS_BCAST]], %arg1) {comparison_direction = "GT"}
|
||||
// CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-NEXT: "mhlo.compare"(%[[LHS_BCAST]], %arg1) {comparison_direction = "GT"}
|
||||
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0: tensor<1x2xi1>
|
||||
}
|
||||
@ -291,9 +291,9 @@ func @greater_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi1
|
||||
// CHECK-DAG: %[[RHS_SHAPE1:.+]] = shape.shape_of %arg1
|
||||
// CHECK-NEXT: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE1]], %[[RHS_SHAPE1]])
|
||||
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"}
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-NEXT: "mhlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"}
|
||||
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi1>
|
||||
return %0: tensor<?xi1>
|
||||
}
|
||||
@ -307,21 +307,21 @@ func @greater_uranked(%arg0: tensor<*xi32>) -> tensor<*xi1> {
|
||||
|
||||
// CHECK-LABEL: func @greater_equal
|
||||
func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"}
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"}
|
||||
%0 = "tf.GreaterEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @less
|
||||
func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"}
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"}
|
||||
%0 = "tf.Less"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @less_equal
|
||||
func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"}
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"}
|
||||
%0 = "tf.LessEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
@ -3,40 +3,40 @@
|
||||
// CHECK-LABEL: @if
|
||||
func @if(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>)
|
||||
attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} {
|
||||
// CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: [[VAL1:%.+]] = "xla_hlo.tuple"(%arg0, %arg1)
|
||||
// CHECK: [[VAL2:%.+]] = "xla_hlo.if"([[VAL0]], [[VAL1]], [[VAL1]]) ( {
|
||||
// CHECK: [[VAL0:%.+]] = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: [[VAL1:%.+]] = "mhlo.tuple"(%arg0, %arg1)
|
||||
// CHECK: [[VAL2:%.+]] = "mhlo.if"([[VAL0]], [[VAL1]], [[VAL1]]) ( {
|
||||
// CHECK: ^bb0(%arg2: tuple<tensor<f32>, tensor<f32>>):
|
||||
// CHECK: [[VAL4:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 0 : i32}
|
||||
// CHECK: [[VAL5:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 1 : i32}
|
||||
// CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32}
|
||||
// CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"(%arg2) {index = 1 : i32}
|
||||
// CHECK: [[VAL6:%.+]] = call @cond_true([[VAL4]], [[VAL5]])
|
||||
// CHECK: [[VAL7:%.+]] = "xla_hlo.tuple"([[VAL6]])
|
||||
// CHECK: "xla_hlo.return"([[VAL7]]) : (tuple<tensor<f32>>) -> ()
|
||||
// CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]])
|
||||
// CHECK: "mhlo.return"([[VAL7]]) : (tuple<tensor<f32>>) -> ()
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0(%arg2: tuple<tensor<f32>, tensor<f32>>)
|
||||
// CHECK: [[VAL4:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 0 : i32}
|
||||
// CHECK: [[VAL5:%.+]] = "xla_hlo.get_tuple_element"(%arg2) {index = 1 : i32}
|
||||
// CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32}
|
||||
// CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"(%arg2) {index = 1 : i32}
|
||||
// CHECK: [[VAL6:%.+]] = call @cond_false([[VAL4]], [[VAL5]])
|
||||
// CHECK: [[VAL7:%.+]] = "xla_hlo.tuple"([[VAL6]])
|
||||
// CHECK: "xla_hlo.return"([[VAL7]]) : (tuple<tensor<f32>>) -> ()
|
||||
// CHECK: [[VAL7:%.+]] = "mhlo.tuple"([[VAL6]])
|
||||
// CHECK: "mhlo.return"([[VAL7]]) : (tuple<tensor<f32>>) -> ()
|
||||
// CHECK: })
|
||||
%1 = "tf.If"(%0, %arg0, %arg1) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _lower_using_switch_merge = true, _output_shapes = ["tfshape$"], device = "", else_branch = @cond_false, is_stateless = true, name = "cond", output_shapes = [#tf.shape<>], then_branch = @cond_true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
|
||||
// CHECK: [[VAL3:%.+]] = "xla_hlo.get_tuple_element"([[VAL2]]) {index = 0 : i32}
|
||||
// CHECK: [[VAL3:%.+]] = "mhlo.get_tuple_element"([[VAL2]]) {index = 0 : i32}
|
||||
// CHECK: return [[VAL3]]
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
||||
func @cond_false(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32>
|
||||
attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} {
|
||||
%0 = "xla_hlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
%0 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
func @cond_true(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32>
|
||||
attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} {
|
||||
%0 = "xla_hlo.log"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
%0 = "mhlo.log"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
@ -45,42 +45,42 @@ attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} {
|
||||
// CHECK-SAME: %[[BRANCH_INDEX:.*]]: tensor<i32>, %[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
||||
func @case(%index: tensor<i32>, %arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
|
||||
%0:2 = "tf.Case"(%index, %arg0, %arg1) {branches = [@exponential, @log, @floor]} : (tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
||||
// CHECK: %[[TUPLE_INPUT:.*]] = "xla_hlo.tuple"(%[[ARG0]], %[[ARG1]]) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
|
||||
// CHECK: %[[CASE:.*]]:2 = "xla_hlo.case"(%[[BRANCH_INDEX]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]]) ( {
|
||||
// CHECK: %[[TUPLE_INPUT:.*]] = "mhlo.tuple"(%[[ARG0]], %[[ARG1]]) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
|
||||
// CHECK: %[[CASE:.*]]:2 = "mhlo.case"(%[[BRANCH_INDEX]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]]) ( {
|
||||
// CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple<tensor<f32>, tensor<f32>>):
|
||||
// CHECK: %[[TUPLE_ELEMENT_0:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
|
||||
// CHECK: %[[TUPLE_ELEMENT_1:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
|
||||
// CHECK: %[[TUPLE_ELEMENT_0:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
|
||||
// CHECK: %[[TUPLE_ELEMENT_1:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
|
||||
// CHECK: %[[CALL_EXP:.*]]:2 = call @exponential(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
||||
// CHECK: "xla_hlo.return"(%[[CALL_EXP]]#0, %[[CALL_EXP]]#1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
// CHECK: "mhlo.return"(%[[CALL_EXP]]#0, %[[CALL_EXP]]#1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple<tensor<f32>, tensor<f32>>):
|
||||
// CHECK: %[[TUPLE_ELEMENT_0:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
|
||||
// CHECK: %[[TUPLE_ELEMENT_1:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
|
||||
// CHECK: %[[TUPLE_ELEMENT_0:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
|
||||
// CHECK: %[[TUPLE_ELEMENT_1:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
|
||||
// CHECK: %[[CALL_LOG:.*]]:2 = call @log(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
||||
// CHECK: "xla_hlo.return"(%[[CALL_LOG]]#0, %[[CALL_LOG]]#1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
// CHECK: "mhlo.return"(%[[CALL_LOG]]#0, %[[CALL_LOG]]#1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple<tensor<f32>, tensor<f32>>):
|
||||
// CHECK: %[[TUPLE_ELEMENT_0:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
|
||||
// CHECK: %[[TUPLE_ELEMENT_1:.*]] = "xla_hlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
|
||||
// CHECK: %[[TUPLE_ELEMENT_0:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 0 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
|
||||
// CHECK: %[[TUPLE_ELEMENT_1:.*]] = "mhlo.get_tuple_element"(%[[TUPLE_ARG]]) {index = 1 : i32} : (tuple<tensor<f32>, tensor<f32>>) -> tensor<f32>
|
||||
// CHECK: %[[CALL_FLOOR:.*]]:2 = call @floor(%[[TUPLE_ELEMENT_0]], %[[TUPLE_ELEMENT_1]]) : (tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
||||
// CHECK: "xla_hlo.return"(%[[CALL_FLOOR]]#0, %[[CALL_FLOOR]]#1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
// CHECK: "mhlo.return"(%[[CALL_FLOOR]]#0, %[[CALL_FLOOR]]#1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
// CHECK: }) : (tensor<i32>, tuple<tensor<f32>, tensor<f32>>, tuple<tensor<f32>, tensor<f32>>, tuple<tensor<f32>, tensor<f32>>) -> (tensor<f32>, tensor<f32>)
|
||||
return %0#0, %0#1 : tensor<f32>, tensor<f32>
|
||||
// CHECK: return %[[CASE]]#0, %[[CASE]]#1 : tensor<f32>, tensor<f32>
|
||||
}
|
||||
|
||||
func @exponential(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
|
||||
%0 = "xla_hlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
%0 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
return %0, %arg1 : tensor<f32>, tensor<f32>
|
||||
}
|
||||
|
||||
func @log(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
|
||||
%0 = "xla_hlo.log"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
%0 = "mhlo.log"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
return %0, %arg1 : tensor<f32>, tensor<f32>
|
||||
}
|
||||
|
||||
func @floor(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
|
||||
%0 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
%0 = "mhlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
return %0, %arg1 : tensor<f32>, tensor<f32>
|
||||
}
|
||||
|
||||
@ -88,44 +88,44 @@ func @floor(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>
|
||||
// CHECK-LABEL: func @while
|
||||
func @while(%arg0: tensor<f32> {tf_saved_model.index_path = [0]}) -> (tensor<i32> {tf_saved_model.index_path = []})
|
||||
attributes {tf._input_shapes = ["tfshape$"]} {
|
||||
// CHECK: [[VAL0:%.+]] = xla_hlo.constant dense<0>
|
||||
// CHECK: [[VAL1:%.+]] = xla_hlo.constant dense<-1>
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%1 = xla_hlo.constant dense<-1> : tensor<i32>
|
||||
// CHECK: [[VAL2:%.+]] = "xla_hlo.tuple"([[VAL0]], [[VAL1]], [[VAL0]])
|
||||
// CHECK: [[VAL3:%.+]] = "xla_hlo.while"([[VAL2]]) ( {
|
||||
// CHECK: [[VAL0:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1>
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = mhlo.constant dense<-1> : tensor<i32>
|
||||
// CHECK: [[VAL2:%.+]] = "mhlo.tuple"([[VAL0]], [[VAL1]], [[VAL0]])
|
||||
// CHECK: [[VAL3:%.+]] = "mhlo.while"([[VAL2]]) ( {
|
||||
// CHECK: ^bb0(%arg1: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
|
||||
// CHECK: [[VAL7:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32}
|
||||
// CHECK: [[VAL8:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 1 : i32}
|
||||
// CHECK: [[VAL9:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 2 : i32}
|
||||
// CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32}
|
||||
// CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 1 : i32}
|
||||
// CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 2 : i32}
|
||||
// CHECK: [[VAL10:%.+]] = call @while_cond([[VAL7]], [[VAL8]], [[VAL9]])
|
||||
// CHECK: "xla_hlo.return"([[VAL10]])
|
||||
// CHECK: "mhlo.return"([[VAL10]])
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0(%arg1: tuple<tensor<i32>, tensor<i32>, tensor<i32>>):
|
||||
// CHECK: [[VAL7:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32}
|
||||
// CHECK: [[VAL8:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 1 : i32}
|
||||
// CHECK: [[VAL9:%.+]] = "xla_hlo.get_tuple_element"(%arg1) {index = 2 : i32}
|
||||
// CHECK: [[VAL7:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32}
|
||||
// CHECK: [[VAL8:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 1 : i32}
|
||||
// CHECK: [[VAL9:%.+]] = "mhlo.get_tuple_element"(%arg1) {index = 2 : i32}
|
||||
// CHECK: [[VAL10:%.+]]:3 = call @while_body([[VAL7]], [[VAL8]], [[VAL9]])
|
||||
// CHECK: [[VAL11:%.+]] = "xla_hlo.tuple"([[VAL10]]#0, [[VAL10]]#1, [[VAL10]]#2)
|
||||
// CHECK: "xla_hlo.return"([[VAL11]])
|
||||
// CHECK: [[VAL11:%.+]] = "mhlo.tuple"([[VAL10]]#0, [[VAL10]]#1, [[VAL10]]#2)
|
||||
// CHECK: "mhlo.return"([[VAL11]])
|
||||
// CHECK: }) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
|
||||
// CHECK: [[VAL4:%.+]] = "xla_hlo.get_tuple_element"([[VAL3]]) {index = 0 : i32}
|
||||
// CHECK: [[VAL5:%.+]] = "xla_hlo.get_tuple_element"([[VAL3]]) {index = 1 : i32}
|
||||
// CHECK: [[VAL6:%.+]] = "xla_hlo.get_tuple_element"([[VAL3]]) {index = 2 : i32}
|
||||
// CHECK: [[VAL4:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 0 : i32}
|
||||
// CHECK: [[VAL5:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 1 : i32}
|
||||
// CHECK: [[VAL6:%.+]] = "mhlo.get_tuple_element"([[VAL3]]) {index = 2 : i32}
|
||||
// CHECK: return [[VAL6]]
|
||||
%2:3 = "tf.While"(%0, %1, %0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_INT32", "tfdtype$DT_INT32"], _lower_using_switch_merge = true, _num_original_outputs = 3 : i64, _output_shapes = ["tfshape$", "tfshape$", "tfshape$"], body = @while_body, cond = @while_cond, device = "", is_stateless = true, name = "while", output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], parallel_iterations = 10 : i64} : (tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
|
||||
return %2#2 : tensor<i32>
|
||||
}
|
||||
func @while_cond(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<i1>
|
||||
attributes {tf._input_shapes = ["tfshape$", "tfshape$", "tfshape$"]} {
|
||||
%0 = xla_hlo.constant dense<10> : tensor<i32>
|
||||
%1 = "xla_hlo.compare"(%arg2, %0) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
%0 = mhlo.constant dense<10> : tensor<i32>
|
||||
%1 = "mhlo.compare"(%arg2, %0) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
func @while_body(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>)
|
||||
attributes {tf._input_shapes = ["tfshape$", "tfshape$", "tfshape$"]} {
|
||||
%0 = xla_hlo.constant dense<1> : tensor<i32>
|
||||
%1 = xla_hlo.add %arg2, %0 : tensor<i32>
|
||||
%2 = xla_hlo.add %arg0, %0 : tensor<i32>
|
||||
%0 = mhlo.constant dense<1> : tensor<i32>
|
||||
%1 = mhlo.add %arg2, %0 : tensor<i32>
|
||||
%2 = mhlo.add %arg0, %0 : tensor<i32>
|
||||
return %2, %arg1, %1 : tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
// CHECK-LABEL: abs
|
||||
// expected-error@+1 {{unsupported device}}
|
||||
func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
// return %[[RESULT]]
|
||||
@ -54,7 +54,7 @@ func @dynamic_operand(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
func @tuple_type(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
|
||||
// Verifies that the pass can handle operands of non-tensor type like tuple
|
||||
// from non TensorFlow ops.
|
||||
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
|
||||
%0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
@ -69,9 +69,9 @@ func @unsupported_dtype(%arg0: tensor<2x!tf.variant>) -> tensor<2x!tf.variant> {
|
||||
|
||||
// CHECK-LABEL: multiple_dialect_ops
|
||||
func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: xla_hlo.negate
|
||||
%0 = "xla_hlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: xla_hlo.abs
|
||||
// CHECK: mhlo.negate
|
||||
%0 = "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: mhlo.abs
|
||||
%1 = "tf.Abs"(%0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
return %1 : tensor<2xf32>
|
||||
@ -79,21 +79,21 @@ func @multiple_dialect_ops(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
|
||||
// CHECK-LABEL: binary_op
|
||||
func @binary_op(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: xla_hlo.atan2 %arg0, %arg1 : tensor<2xf32>
|
||||
// CHECK: mhlo.atan2 %arg0, %arg1 : tensor<2xf32>
|
||||
%0 = "tf.Atan2"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: binary_op_broadcast
|
||||
func @binary_op_broadcast(%arg0: tensor<4x1xf32>, %arg1: tensor<4x1x4xf32>) -> tensor<4x4x4xf32> {
|
||||
// CHECK: %[[BROADCAST0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x1xf32>) -> tensor<4x4x1xf32>
|
||||
// CHECK: %[[RESHAPE0:.*]] = "xla_hlo.reshape"(%[[BROADCAST0]]) : (tensor<4x4x1xf32>) -> tensor<4x4xf32>
|
||||
// CHECK: %[[UPDATED_ARG0:.*]] = "xla_hlo.broadcast_in_dim"(%[[RESHAPE0]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32>
|
||||
// CHECK: %[[BROADCAST0:.*]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x1xf32>) -> tensor<4x4x1xf32>
|
||||
// CHECK: %[[RESHAPE0:.*]] = "mhlo.reshape"(%[[BROADCAST0]]) : (tensor<4x4x1xf32>) -> tensor<4x4xf32>
|
||||
// CHECK: %[[UPDATED_ARG0:.*]] = "mhlo.broadcast_in_dim"(%[[RESHAPE0]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32>
|
||||
|
||||
// CHECK: %[[RESHAPE1:.*]] = "xla_hlo.reshape"(%arg1) : (tensor<4x1x4xf32>) -> tensor<4x4xf32>
|
||||
// CHECK: %[[UPDATED_ARG1:.*]] = "xla_hlo.broadcast_in_dim"(%[[RESHAPE1]]) {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32>
|
||||
// CHECK: %[[RESHAPE1:.*]] = "mhlo.reshape"(%arg1) : (tensor<4x1x4xf32>) -> tensor<4x4xf32>
|
||||
// CHECK: %[[UPDATED_ARG1:.*]] = "mhlo.broadcast_in_dim"(%[[RESHAPE1]]) {broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4x4xf32>
|
||||
|
||||
// CHECK: %[[RESULT:.*]] = xla_hlo.atan2 %[[UPDATED_ARG0]], %[[UPDATED_ARG1]] : tensor<4x4x4xf32>
|
||||
// CHECK: %[[RESULT:.*]] = mhlo.atan2 %[[UPDATED_ARG0]], %[[UPDATED_ARG1]] : tensor<4x4x4xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<4x4x4xf32>
|
||||
|
||||
%0 = "tf.Atan2"(%arg0, %arg1) : (tensor<4x1xf32>, tensor<4x1x4xf32>) -> tensor<4x4x4xf32>
|
||||
@ -102,23 +102,23 @@ func @binary_op_broadcast(%arg0: tensor<4x1xf32>, %arg1: tensor<4x1x4xf32>) -> t
|
||||
|
||||
// CHECK-LABEL: func @ternary_op
|
||||
func @ternary_op(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
|
||||
// CHECK: "xla_hlo.select"(%arg0, %arg1, %arg2)
|
||||
// CHECK: "mhlo.select"(%arg0, %arg1, %arg2)
|
||||
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0: tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @convert
|
||||
func @convert(%arg0: tensor<2xi32>) -> tensor<2xf32> {
|
||||
// CHECK: "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
|
||||
// CHECK: "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
|
||||
%0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<2xi32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @constant
|
||||
func @constant(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK: %[[SCALAR_ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[ONE:.*]] = "xla_hlo.broadcast_in_dim"(%[[SCALAR_ONE]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<2xf32>
|
||||
// CHECK: %[[RESULT:.*]] = xla_hlo.divide %[[ONE]], %arg0 : tensor<2xf32>
|
||||
// CHECK: %[[SCALAR_ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[ONE:.*]] = "mhlo.broadcast_in_dim"(%[[SCALAR_ONE]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<2xf32>
|
||||
// CHECK: %[[RESULT:.*]] = mhlo.divide %[[ONE]], %arg0 : tensor<2xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
%0 = "tf.Inv"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
@ -127,7 +127,7 @@ func @constant(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @greater
|
||||
func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"}
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"}
|
||||
%0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
@ -136,14 +136,14 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x2xf64>, %[[ARG1:.*]]: tensor<f64>,
|
||||
func @const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor<f64>, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>, %arg4: tensor<2xi32>) -> tensor<6x5xf64> {
|
||||
|
||||
// CHECK: "xla_hlo.pad"(%[[ARG0]], %[[ARG1]])
|
||||
// CHECK: "mhlo.pad"(%[[ARG0]], %[[ARG1]])
|
||||
// CHECK-SAME-DAG: edge_padding_high = dense<[1, 2]> : tensor<2xi64>
|
||||
// CHECK-SAME-DAG: edge_padding_low = dense<[2, 1]> : tensor<2xi64>
|
||||
// CHECK-SAME-DAG: interior_padding = dense<[1, 0]> : tensor<2xi64>
|
||||
|
||||
%0 = xla_hlo.constant dense<[2, 1]> : tensor<2xi32>
|
||||
%1 = xla_hlo.constant dense<[1, 2]> : tensor<2xi32>
|
||||
%2 = xla_hlo.constant dense<[1, 0]> : tensor<2xi32>
|
||||
%0 = mhlo.constant dense<[2, 1]> : tensor<2xi32>
|
||||
%1 = mhlo.constant dense<[1, 2]> : tensor<2xi32>
|
||||
%2 = mhlo.constant dense<[1, 0]> : tensor<2xi32>
|
||||
%3 = "tf.XlaPad"(%arg0, %arg1, %0, %1, %2) : (tensor<2x2xf64>, tensor<f64>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<6x5xf64>
|
||||
return %3 : tensor<6x5xf64>
|
||||
}
|
||||
@ -156,7 +156,7 @@ func @non_const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor<f64>, %arg2: tensor
|
||||
|
||||
// CHECK-LABEL: dynamic_result_type
|
||||
func @dynamic_result_type(%arg0: tensor<2xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: tensor_cast %0 : tensor<2xf32> to tensor<*xf32>
|
||||
%0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<*xf32>
|
||||
|
||||
@ -166,7 +166,7 @@ func @dynamic_result_type(%arg0: tensor<2xf32>) -> tensor<*xf32> {
|
||||
|
||||
func @truncated_normal() -> tensor<2x2xf32> {
|
||||
// CHECK-NOT: tf.TruncatedNormal
|
||||
%0 = xla_hlo.constant dense<[2, 2]> : tensor<2xi32>
|
||||
%0 = mhlo.constant dense<[2, 2]> : tensor<2xi32>
|
||||
%1 = "tf.TruncatedNormal"(%0) {T = i32, device = "", dtype = f32, seed = 0 : i64, seed2 = 1950157571 : i64} : (tensor<2xi32>) -> tensor<2x2xf32>
|
||||
return %1 : tensor<2x2xf32>
|
||||
}
|
||||
@ -175,21 +175,21 @@ func @truncated_normal() -> tensor<2x2xf32> {
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xi32>, %[[ARG1:.*]]: tensor<2x2xi32>, %[[ARG2:.*]]: tensor<2xi32>
|
||||
func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2: tensor<2xi32>) -> tensor<3x4xi32> {
|
||||
|
||||
// CHECK: %[[SLICE0:.*]] = "xla_hlo.slice"(%[[ARG2]])
|
||||
// CHECK: %[[SLICE0:.*]] = "mhlo.slice"(%[[ARG2]])
|
||||
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>
|
||||
// CHECK-DAG-SAME: limit_indices = dense<1> : tensor<1xi64>
|
||||
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32>
|
||||
// CHECK: %[[DIM0:.*]] = "xla_hlo.reshape"(%[[SLICE0]]) : (tensor<1xi32>) -> tensor<i32>
|
||||
// CHECK: %[[DIM0:.*]] = "mhlo.reshape"(%[[SLICE0]]) : (tensor<1xi32>) -> tensor<i32>
|
||||
|
||||
// CHECK: %[[SLICE1:.*]] = "xla_hlo.slice"(%[[ARG2]])
|
||||
// CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%[[ARG2]])
|
||||
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>
|
||||
// CHECK-DAG-SAME: limit_indices = dense<2> : tensor<1xi64>
|
||||
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: (tensor<2xi32>) -> tensor<1xi32>
|
||||
// CHECK: %[[DIM1:.*]] = "xla_hlo.reshape"(%[[SLICE1]]) : (tensor<1xi32>) -> tensor<i32>
|
||||
// CHECK: %[[DIM1:.*]] = "mhlo.reshape"(%[[SLICE1]]) : (tensor<1xi32>) -> tensor<i32>
|
||||
|
||||
// CHECK: "xla_hlo.dynamic-update-slice"(%[[ARG0]], %[[ARG1]], %[[DIM0]], %[[DIM1]])
|
||||
// CHECK: "mhlo.dynamic-update-slice"(%[[ARG0]], %[[ARG1]], %[[DIM0]], %[[DIM1]])
|
||||
|
||||
%0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<3x4xi32>, tensor<2x2xi32>, tensor<2xi32>) -> tensor<3x4xi32>
|
||||
return %0: tensor<3x4xi32>
|
||||
@ -199,12 +199,12 @@ func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xi32>, %[[ARG1:.*]]: tensor<3xf32>, %[[ARG2:.*]]: tensor<f32>)
|
||||
func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>) -> tensor<3x3xf32> {
|
||||
|
||||
// CHECK: %[[CST:.*]] = xla_hlo.constant dense<3> : tensor<2xi32>
|
||||
// CHECK: %[[DEFAULT:.*]] = "xla_hlo.broadcast_in_dim"(%[[ARG2]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<3x3xf32>
|
||||
// CHECK: %[[CST:.*]] = mhlo.constant dense<3> : tensor<2xi32>
|
||||
// CHECK: %[[DEFAULT:.*]] = "mhlo.broadcast_in_dim"(%[[ARG2]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<3x3xf32>
|
||||
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.scatter"(%[[DEFAULT]], %[[ARG0]], %[[ARG1]]) ( {
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.scatter"(%[[DEFAULT]], %[[ARG0]], %[[ARG1]]) ( {
|
||||
// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): // no predecessors
|
||||
// CHECK: "xla_hlo.return"(%[[ARG4]]) : (tensor<f32>) -> ()
|
||||
// CHECK: "mhlo.return"(%[[ARG4]]) : (tensor<f32>) -> ()
|
||||
// CHECK: })
|
||||
// CHECK-SAME: indices_are_sorted = false
|
||||
// CHECK-SAME: scatter_dimension_numbers
|
||||
@ -217,14 +217,14 @@ func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tenso
|
||||
|
||||
// return %[[RESULT]] : tensor<3x3xf32>
|
||||
|
||||
%cst = xla_hlo.constant dense<3> : tensor<2xi32>
|
||||
%cst = mhlo.constant dense<3> : tensor<2xi32>
|
||||
%0 = "tf.SparseToDense"(%arg0, %cst, %arg1, %arg2) {validate_indices = true}: (tensor<3x2xi32>, tensor<2xi32>, tensor<3xf32>, tensor<f32>) -> tensor<3x3xf32>
|
||||
return %0 : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fft
|
||||
func @fft(%arg0: tensor<3x5x8xcomplex<f32>>) -> tensor<3x5x8xcomplex<f32>> {
|
||||
// CHECK: "xla_hlo.fft"(%arg0)
|
||||
// CHECK: "mhlo.fft"(%arg0)
|
||||
%0 = "tf.FFT"(%arg0) : (tensor<3x5x8xcomplex<f32>>) -> tensor<3x5x8xcomplex<f32>>
|
||||
return %0 : tensor<3x5x8xcomplex<f32>>
|
||||
}
|
||||
@ -238,7 +238,7 @@ func @reverse_sequence(%arg0: tensor<4x2x3x1x1xi32>, %arg1: tensor<3xi32>) -> te
|
||||
|
||||
// CHECK-LABEL: mirror_pad
|
||||
func @mirror_pad(%arg0: tensor<2x3xcomplex<f64>>) -> tensor<4x7xcomplex<f64>> {
|
||||
%0 = xla_hlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>
|
||||
%0 = mhlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>
|
||||
// CHECK-NOT: tf.MirrorPad
|
||||
%1 = "tf.MirrorPad"(%arg0, %0) {mode = "SYMMETRIC"} : (tensor<2x3xcomplex<f64>>, tensor<2x2xi32>) -> tensor<4x7xcomplex<f64>>
|
||||
return %1 : tensor<4x7xcomplex<f64>>
|
||||
@ -254,7 +254,7 @@ func @bucketize(%arg0: tensor<2x5xf32>) -> tensor<2x5xi32> {
|
||||
// CHECK-LABEL: arg_min
|
||||
func @arg_min(%arg0: tensor<6xf64>) -> tensor<i32> {
|
||||
// CHECK-NOT: ArgMin
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = "tf.ArgMin"(%arg0, %0) : (tensor<6xf64>, tensor<i32>) -> tensor<i32>
|
||||
return %1 : tensor<i32>
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -48,7 +48,7 @@ class XlaBuilderTest : public ::testing::Test {
|
||||
xla_builder_(name_, builder_, module_->getLoc()) {}
|
||||
|
||||
string SetupTest() {
|
||||
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
|
||||
mlir::registerDialect<mlir::mhlo::XlaHloDialect>();
|
||||
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
|
||||
}
|
||||
|
||||
@ -75,7 +75,7 @@ TEST_F(XlaBuilderTest, CreateToken) {
|
||||
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
|
||||
|
||||
ExpectHasSubstr(GetMlirOpString(token),
|
||||
R"("xla_hlo.create_token"() : () -> !xla_hlo.token)");
|
||||
R"("mhlo.create_token"() : () -> !mhlo.token)");
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, Infeed) {
|
||||
@ -85,7 +85,7 @@ TEST_F(XlaBuilderTest, Infeed) {
|
||||
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
|
||||
ExpectHasSubstr(
|
||||
GetMlirOpString(infeed),
|
||||
R"("xla_hlo.infeed"(%0) {infeed_config = ""} : (!xla_hlo.token) -> tuple<tensor<4x8xf32>, !xla_hlo.token>)");
|
||||
R"("mhlo.infeed"(%0) {infeed_config = ""} : (!mhlo.token) -> tuple<tensor<4x8xf32>, !mhlo.token>)");
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, Outfeed) {
|
||||
@ -99,7 +99,7 @@ TEST_F(XlaBuilderTest, Outfeed) {
|
||||
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
|
||||
ExpectHasSubstr(
|
||||
GetMlirOpString(outfeed),
|
||||
R"("xla_hlo.outfeed"(%0, %1) {outfeed_config = ""} : (tensor<4x8xf32>, !xla_hlo.token) -> !xla_hlo.token)");
|
||||
R"("mhlo.outfeed"(%0, %1) {outfeed_config = ""} : (tensor<4x8xf32>, !mhlo.token) -> !mhlo.token)");
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, ConcatInDim) {
|
||||
@ -112,7 +112,7 @@ TEST_F(XlaBuilderTest, ConcatInDim) {
|
||||
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
|
||||
ExpectHasSubstr(
|
||||
GetMlirOpString(concat),
|
||||
R"("xla_hlo.concatenate"(%0, %1) {dimension = 1 : i64} : (tensor<2x4x5xf32>, tensor<2x6x5xf32>) -> tensor<2x10x5xf32>)");
|
||||
R"("mhlo.concatenate"(%0, %1) {dimension = 1 : i64} : (tensor<2x4x5xf32>, tensor<2x6x5xf32>) -> tensor<2x10x5xf32>)");
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, Tuple) {
|
||||
@ -125,7 +125,7 @@ TEST_F(XlaBuilderTest, Tuple) {
|
||||
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
|
||||
ExpectHasSubstr(
|
||||
GetMlirOpString(tuple),
|
||||
R"("xla_hlo.tuple"(%0, %1) : (tensor<3x7xf32>, tensor<f32>) -> tuple<tensor<3x7xf32>, tensor<f32>>)");
|
||||
R"("mhlo.tuple"(%0, %1) : (tensor<3x7xf32>, tensor<f32>) -> tuple<tensor<3x7xf32>, tensor<f32>>)");
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, GetTupleElement) {
|
||||
@ -139,7 +139,7 @@ TEST_F(XlaBuilderTest, GetTupleElement) {
|
||||
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
|
||||
ExpectHasSubstr(
|
||||
GetMlirOpString(gte),
|
||||
R"("xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : (tuple<tensor<3x7xf32>, tensor<f32>>) -> tensor<f32>)");
|
||||
R"("mhlo.get_tuple_element"(%2) {index = 1 : i32} : (tuple<tensor<3x7xf32>, tensor<f32>>) -> tensor<f32>)");
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, Slice) {
|
||||
@ -150,7 +150,7 @@ TEST_F(XlaBuilderTest, Slice) {
|
||||
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
|
||||
ExpectHasSubstr(
|
||||
GetMlirOpString(slice),
|
||||
R"("xla_hlo.slice"(%0) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x7xf32>) -> tensor<2x4xf32>)");
|
||||
R"("mhlo.slice"(%0) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x7xf32>) -> tensor<2x4xf32>)");
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, Pad) {
|
||||
@ -172,7 +172,7 @@ TEST_F(XlaBuilderTest, Pad) {
|
||||
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
|
||||
ExpectHasSubstr(
|
||||
GetMlirOpString(pad),
|
||||
R"("xla_hlo.pad"(%0, %1) {edge_padding_high = dense<[2, 0]> : tensor<2xi64>, edge_padding_low = dense<[1, 3]> : tensor<2xi64>, interior_padding = dense<[0, 1]> : tensor<2xi64>} : (tensor<3x7xf32>, tensor<f32>) -> tensor<6x16xf32>)");
|
||||
R"("mhlo.pad"(%0, %1) {edge_padding_high = dense<[2, 0]> : tensor<2xi64>, edge_padding_low = dense<[1, 3]> : tensor<2xi64>, interior_padding = dense<[0, 1]> : tensor<2xi64>} : (tensor<3x7xf32>, tensor<f32>) -> tensor<6x16xf32>)");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -12,9 +12,9 @@ func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: ROOT %add.4 = f32[4] add(f32[4] %add.3, f32[4] %Arg_1.2)
|
||||
%1 = "xla_hlo.add"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%1 = "mhlo.add"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %1 : tensor<4xf32>
|
||||
}
|
||||
|
@ -5,18 +5,18 @@ func @main() -> tensor<f32> {
|
||||
%cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor<f32>
|
||||
%cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor<f32>
|
||||
%cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor<f32>
|
||||
%0 = "xla_hlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( {
|
||||
%0 = "mhlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
%1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.copy"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
%1 = "mhlo.copy"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
%1 = "mhlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}) {name = "conditional"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
@ -52,18 +52,18 @@ func @main() -> (tensor<f32>, tensor<f32>) {
|
||||
%cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor<f32>
|
||||
%cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor<f32>
|
||||
%cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor<f32>
|
||||
%0:2 = "xla_hlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( {
|
||||
%0:2 = "mhlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.negate"(%arg0) {name = "negate"} : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
%1 = "mhlo.negate"(%arg0) {name = "negate"} : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.copy"(%arg0) {name = "copy"} : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
%1 = "mhlo.copy"(%arg0) {name = "copy"} : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "xla_hlo.floor"(%arg0) {name = "floor"} : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
%1 = "mhlo.floor"(%arg0) {name = "floor"} : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
}) {name = "conditional"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
||||
return %0#0, %0#1 : tensor<f32>, tensor<f32>
|
||||
}
|
||||
|
@ -30,17 +30,17 @@ ENTRY %indexed_conditional () -> f32[] {
|
||||
// CHECK: %[[OPERAND_1:.*]] = constant {name = "{{.*}}"} dense<5.600000e+01> : tensor<f32>
|
||||
// CHECK: %[[OPERAND_2:.*]] = constant {name = "{{.*}}"} dense<1.200000e+01> : tensor<f32>
|
||||
// CHECK: %[[OPERAND_3:.*]] = constant {name = "{{.*}}"} dense<1.300000e+01> : tensor<f32>
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( {
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( {
|
||||
// CHECK: ^bb0(%[[ARG_1:.*]]: tensor<f32>):
|
||||
// CHECK: %[[RES_1:.*]] = "xla_hlo.negate"(%[[ARG_1]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: "xla_hlo.return"(%[[RES_1]]) : (tensor<f32>) -> ()
|
||||
// CHECK: %[[RES_1:.*]] = "mhlo.negate"(%[[ARG_1]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: "mhlo.return"(%[[RES_1]]) : (tensor<f32>) -> ()
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0(%[[ARG_2:.*]]: tensor<f32>):
|
||||
// CHECK: %[[RES_2:.*]] = "xla_hlo.copy"(%[[ARG_2]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: "xla_hlo.return"(%[[RES_2]]) : (tensor<f32>) -> ()
|
||||
// CHECK: %[[RES_2:.*]] = "mhlo.copy"(%[[ARG_2]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: "mhlo.return"(%[[RES_2]]) : (tensor<f32>) -> ()
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0(%[[ARG_3:.*]]: tensor<f32>):
|
||||
// CHECK: %[[RES_3:.*]] = "xla_hlo.floor"(%[[ARG_3]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: "xla_hlo.return"(%[[RES_3]]) : (tensor<f32>) -> ()
|
||||
// CHECK: %[[RES_3:.*]] = "mhlo.floor"(%[[ARG_3]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: "mhlo.return"(%[[RES_3]]) : (tensor<f32>) -> ()
|
||||
// CHECK: }) {name = "{{.*}}"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: return %[[RESULT]] : tensor<f32>
|
||||
|
@ -19,7 +19,7 @@ func @main(%arg0: tensor<10xf32>, %arg1: tensor<i32>) {
|
||||
|
||||
// Test entry function with single dynamic parameter binding on an argument.
|
||||
|
||||
func @main(%arg0: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [0 : i32], padding_arg_indices = [1 : i32]}}, %arg1: tensor<i32>) {
|
||||
func @main(%arg0: tensor<10xf32> {mhlo.padding_map = {shape_indices = [0 : i32], padding_arg_indices = [1 : i32]}}, %arg1: tensor<i32>) {
|
||||
return
|
||||
}
|
||||
|
||||
@ -42,7 +42,7 @@ func @main(%arg0: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [0 : i3
|
||||
|
||||
// Test entry function with multiple dynamic parameter bindings on an argument.
|
||||
|
||||
func @main(%arg0: tensor<8x10xf32> {xla_hlo.padding_map = {shape_indices = [0 : i32, 1 : i32], padding_arg_indices = [1 : i32, 2 : i32]}}, %arg1: tensor<i32>, %arg2: tensor<i32>) {
|
||||
func @main(%arg0: tensor<8x10xf32> {mhlo.padding_map = {shape_indices = [0 : i32, 1 : i32], padding_arg_indices = [1 : i32, 2 : i32]}}, %arg1: tensor<i32>, %arg2: tensor<i32>) {
|
||||
return
|
||||
}
|
||||
|
||||
@ -75,7 +75,7 @@ func @main(%arg0: tensor<8x10xf32> {xla_hlo.padding_map = {shape_indices = [0 :
|
||||
// Test entry function with multiple dynamic parameter bindings on multiple
|
||||
// arguments.
|
||||
|
||||
func @main(%arg0: tensor<8x10xf32> {xla_hlo.padding_map = {shape_indices = [0 : i32, 1 : i32], padding_arg_indices = [1 : i32, 2 : i32]}}, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<10x8x6xi32> {xla_hlo.padding_map = {shape_indices = [2 : i32], padding_arg_indices = [4 : i32]}}, %arg4: tensor<i32>) {
|
||||
func @main(%arg0: tensor<8x10xf32> {mhlo.padding_map = {shape_indices = [0 : i32, 1 : i32], padding_arg_indices = [1 : i32, 2 : i32]}}, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<10x8x6xi32> {mhlo.padding_map = {shape_indices = [2 : i32], padding_arg_indices = [4 : i32]}}, %arg4: tensor<i32>) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -1,148 +1,148 @@
|
||||
// RUN: not tf-mlir-translate -split-input-file -mlir-hlo-to-hlo %s -o - 2>&1 | FileCheck %s
|
||||
|
||||
// Test bad `xla_hlo.padding_map` attribute type.
|
||||
// Test bad `mhlo.padding_map` attribute type.
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = ""}) {
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = ""}) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: requires 'xla_hlo.padding_map' dict attribute at arg 1
|
||||
// CHECK: requires 'mhlo.padding_map' dict attribute at arg 1
|
||||
|
||||
// -----
|
||||
|
||||
// Test missing `shape_indices` attribute in `xla_hlo.padding_map`.
|
||||
// Test missing `shape_indices` attribute in `mhlo.padding_map`.
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {}}) {
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {}}) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: requires 'shape_indices' array attribute in 'xla_hlo.padding_map' dict at arg 1
|
||||
// CHECK: requires 'shape_indices' array attribute in 'mhlo.padding_map' dict at arg 1
|
||||
|
||||
// -----
|
||||
|
||||
// Test bad `shape_indices` attribute type in `xla_hlo.padding_map`.
|
||||
// Test bad `shape_indices` attribute type in `mhlo.padding_map`.
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = ""}}) {
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = ""}}) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: requires 'shape_indices' array attribute in 'xla_hlo.padding_map' dict at arg 1
|
||||
// CHECK: requires 'shape_indices' array attribute in 'mhlo.padding_map' dict at arg 1
|
||||
|
||||
// -----
|
||||
|
||||
// Test missing `padding_arg_indices` attribute in `xla_hlo.padding_map`.
|
||||
// Test missing `padding_arg_indices` attribute in `mhlo.padding_map`.
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = []}}) {
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = []}}) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: requires 'padding_arg_indices' array attribute in 'xla_hlo.padding_map' dict at arg 1
|
||||
// CHECK: requires 'padding_arg_indices' array attribute in 'mhlo.padding_map' dict at arg 1
|
||||
|
||||
// -----
|
||||
|
||||
// Test bad `padding_arg_indices` attribute type in `xla_hlo.padding_map`.
|
||||
// Test bad `padding_arg_indices` attribute type in `mhlo.padding_map`.
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [], padding_arg_indices = ""}}) {
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [], padding_arg_indices = ""}}) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: requires 'padding_arg_indices' array attribute in 'xla_hlo.padding_map' dict at arg 1
|
||||
// CHECK: requires 'padding_arg_indices' array attribute in 'mhlo.padding_map' dict at arg 1
|
||||
|
||||
// -----
|
||||
|
||||
// Test mismatched `shape_indices` and `padding_arg_indices` lengths.
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32, 0 : i32 ]}}) {
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32, 0 : i32 ]}}) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: requires 'shape_indices' and 'padding_arg_indices' array attributes in 'xla_hlo.padding_map' dic at arg 1 to be of the same size, got sizes 1 and 2
|
||||
// CHECK: requires 'shape_indices' and 'padding_arg_indices' array attributes in 'mhlo.padding_map' dic at arg 1 to be of the same size, got sizes 1 and 2
|
||||
|
||||
// -----
|
||||
|
||||
// Test non integer attribute in `shape_indices`.
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32, 0.0: f32 ], padding_arg_indices = [ 0: i32, 0: i32 ]}}) {
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32, 0.0: f32 ], padding_arg_indices = [ 0: i32, 0: i32 ]}}) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: requires element 1 in 'shape_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be an int attribute
|
||||
// CHECK: requires element 1 in 'shape_indices' array of 'mhlo.padding_map' dict at arg 1 to be an int attribute
|
||||
|
||||
// -----
|
||||
|
||||
// Test non integer attribute in `padding_arg_indices`.
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32, 0: i32 ], padding_arg_indices = [ 0: i32, 0.0: f32 ]}}) {
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32, 0: i32 ], padding_arg_indices = [ 0: i32, 0.0: f32 ]}}) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: requires element 1 in 'padding_arg_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be an int attribute
|
||||
// CHECK: requires element 1 in 'padding_arg_indices' array of 'mhlo.padding_map' dict at arg 1 to be an int attribute
|
||||
|
||||
// -----
|
||||
|
||||
// Test negative out of range shape index in `shape_indices`.
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ -1: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ -1: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: requires element 0 in 'shape_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be in range [0, 1), got -1
|
||||
// CHECK: requires element 0 in 'shape_indices' array of 'mhlo.padding_map' dict at arg 1 to be in range [0, 1), got -1
|
||||
|
||||
// -----
|
||||
|
||||
// Test positive out of range shape index in `shape_indices`.
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 1: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 1: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: requires element 0 in 'shape_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be in range [0, 1), got 1
|
||||
// CHECK: requires element 0 in 'shape_indices' array of 'mhlo.padding_map' dict at arg 1 to be in range [0, 1), got 1
|
||||
|
||||
// -----
|
||||
|
||||
// Test negative shape index in `shape_indices` for unranked argument.
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<*xf32> {xla_hlo.padding_map = {shape_indices = [ -1: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<*xf32> {mhlo.padding_map = {shape_indices = [ -1: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: requires element 0 in 'shape_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be non-negative, got -1
|
||||
// CHECK: requires element 0 in 'shape_indices' array of 'mhlo.padding_map' dict at arg 1 to be non-negative, got -1
|
||||
|
||||
// -----
|
||||
|
||||
// Test duplicate shape indices in `shape_indices`.
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32, 0: i32 ], padding_arg_indices = [ 0: i32, 0: i32 ]}}) {
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32, 0: i32 ], padding_arg_indices = [ 0: i32, 0: i32 ]}}) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: requires elements in 'shape_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be unique, got duplicate element 0 at index 1
|
||||
// CHECK: requires elements in 'shape_indices' array of 'mhlo.padding_map' dict at arg 1 to be unique, got duplicate element 0 at index 1
|
||||
|
||||
// -----
|
||||
|
||||
// Test negative out of range shape index in `padding_arg_indices`.
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ -1: i32 ]}}) {
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ -1: i32 ]}}) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: requires element 0 in 'padding_arg_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be in range [0, 2), got -1
|
||||
// CHECK: requires element 0 in 'padding_arg_indices' array of 'mhlo.padding_map' dict at arg 1 to be in range [0, 2), got -1
|
||||
|
||||
// -----
|
||||
|
||||
// Test positive out of range shape index in `padding_arg_indices`.
|
||||
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 2: i32 ]}}) {
|
||||
func @main(%arg0: tensor<i32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 2: i32 ]}}) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: requires element 0 in 'padding_arg_indices' array of 'xla_hlo.padding_map' dict at arg 1 to be in range [0, 2), got 2
|
||||
// CHECK: requires element 0 in 'padding_arg_indices' array of 'mhlo.padding_map' dict at arg 1 to be in range [0, 2), got 2
|
||||
|
||||
// -----
|
||||
|
||||
// Test non scalar padding argument.
|
||||
|
||||
func @main(%arg0: tensor<8xi32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
|
||||
func @main(%arg0: tensor<8xi32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
|
||||
return
|
||||
}
|
||||
|
||||
@ -152,7 +152,7 @@ func @main(%arg0: tensor<8xi32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {s
|
||||
|
||||
// Test non integer type padding argument.
|
||||
|
||||
func @main(%arg0: tensor<f32>, %arg1: tensor<10xf32> {xla_hlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
|
||||
func @main(%arg0: tensor<f32>, %arg1: tensor<10xf32> {mhlo.padding_map = {shape_indices = [ 0: i32 ], padding_arg_indices = [ 0: i32 ]}}) {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: !xla_hlo.token, %arg1: !xla_hlo.token) -> !xla_hlo.token {
|
||||
%0 = "xla_hlo.after_all"(%arg0, %arg1) : (!xla_hlo.token, !xla_hlo.token) -> !xla_hlo.token
|
||||
return %0 : !xla_hlo.token
|
||||
func @main(%arg0: !mhlo.token, %arg1: !mhlo.token) -> !mhlo.token {
|
||||
%0 = "mhlo.after_all"(%arg0, %arg1) : (!mhlo.token, !mhlo.token) -> !mhlo.token
|
||||
return %0 : !mhlo.token
|
||||
}
|
||||
|
||||
// CHECK: ENTRY
|
||||
@ -15,11 +15,11 @@ func @main(%arg0: !xla_hlo.token, %arg1: !xla_hlo.token) -> !xla_hlo.token {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
|
||||
%0 = "xla_hlo.all_reduce"(%arg0) ({
|
||||
%0 = "mhlo.all_reduce"(%arg0) ({
|
||||
// Perform max reduction inside the region
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
|
||||
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"mhlo.return"(%max) : (tensor<f32>) -> ()
|
||||
})
|
||||
{
|
||||
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
|
||||
@ -43,7 +43,7 @@ func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf32>, %variance: tensor<2xf32>, %grad_output: tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>> {
|
||||
%0 = "xla_hlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||
%0 = "mhlo.batch_norm_grad" (%input, %scale, %mean, %variance, %grad_output) {epsilon = 0.001 : f32, feature_index = 0 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2x2x2x2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||
return %0 : tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||
}
|
||||
|
||||
@ -60,7 +60,7 @@ func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tensor<2xf
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>> {
|
||||
%0 = "xla_hlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 3 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||
%0 = "mhlo.batch_norm_training" (%input, %scale, %offset) {epsilon = 0.001 : f32, feature_index = 3 : i64} : (tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||
return %0 : tuple<tensor<2x2x2x2xf32>, tensor<2xf32>, tensor<2xf32>>
|
||||
}
|
||||
|
||||
@ -78,16 +78,16 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor
|
||||
// CHECK: [[VAL_1:%.*]] = s32[4] parameter(0)
|
||||
// CHECK: [[VAL_2:%.*]] = s32[4] parameter(1)
|
||||
// CHECK: [[ATAN2:%.*]] = s32[4] atan2(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||
%0 = xla_hlo.atan2 %arg0, %arg1 : tensor<4xi32>
|
||||
%0 = mhlo.atan2 %arg0, %arg1 : tensor<4xi32>
|
||||
|
||||
// CHECK: [[SHL:%.*]] = s32[4] shift-left(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||
%1 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
|
||||
%1 = mhlo.shift_left %arg0, %arg1 : tensor<4xi32>
|
||||
|
||||
// CHECK: [[SHRA:%.*]] = s32[4] shift-right-arithmetic(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||
%2 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
|
||||
%2 = mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
|
||||
|
||||
// CHECK: [[SHRL:%.*]] = s32[4] shift-right-logical(s32[4] [[VAL_1]], s32[4] [[VAL_2]])
|
||||
%3 = xla_hlo.shift_right_logical %arg0, %arg1 : tensor<4xi32>
|
||||
%3 = mhlo.shift_right_logical %arg0, %arg1 : tensor<4xi32>
|
||||
|
||||
// CHECK: ROOT
|
||||
// CHECK-SAME: [[VAL_7:%.*]] = (s32[4], s32[4], s32[4], s32[4]) tuple(s32[4] [[ATAN2]], s32[4] [[SHL]], s32[4] [[SHRA]], s32[4] [[SHRL]])
|
||||
@ -98,7 +98,7 @@ func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
@ -112,7 +112,7 @@ func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
|
||||
func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> {
|
||||
// CHECK: [[ARG:%.*]] = s32[4] parameter(0)
|
||||
// CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] [[ARG]]), dimensions={3}
|
||||
%0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
|
||||
%0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
|
||||
return %0 : tensor<1x2x3x4xi32>
|
||||
}
|
||||
|
||||
@ -120,7 +120,7 @@ func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
|
||||
%result = "xla_hlo.broadcast_in_dim"(%arg0) {
|
||||
%result = "mhlo.broadcast_in_dim"(%arg0) {
|
||||
broadcast_dimensions = dense<0> : tensor<1xi64>
|
||||
} : (tensor<1xf32>) -> tensor<1x10xf32>
|
||||
return %result : tensor<1x10xf32>
|
||||
@ -133,9 +133,9 @@ func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> {
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main() -> !xla_hlo.token {
|
||||
%0 = "xla_hlo.create_token"() : () -> !xla_hlo.token
|
||||
return %0 : !xla_hlo.token
|
||||
func @main() -> !mhlo.token {
|
||||
%0 = "mhlo.create_token"() : () -> !mhlo.token
|
||||
return %0 : !mhlo.token
|
||||
}
|
||||
|
||||
// CHECK: ROOT [[TOKEN:%.*]] = token[] after-all()
|
||||
@ -150,7 +150,7 @@ func @main(%arg0: tensor<4xi32>) -> tensor<4xi32> {
|
||||
}
|
||||
|
||||
func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
@ -181,8 +181,8 @@ func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
|
||||
}
|
||||
|
||||
func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>) {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%1 = "xla_hlo.multiply"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%1 = "mhlo.multiply"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
return %0, %1 : tensor<4xi32>, tensor<4xi32>
|
||||
}
|
||||
|
||||
@ -202,7 +202,7 @@ func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tens
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> {
|
||||
%0 = "xla_hlo.collective_permute"(%arg0) {
|
||||
%0 = "mhlo.collective_permute"(%arg0) {
|
||||
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
|
||||
} : (tensor<128x32xf32>) -> tensor<128x32xf32>
|
||||
return %0 : tensor<128x32xf32>
|
||||
@ -217,7 +217,7 @@ func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> {
|
||||
func @main(%arg0 : tensor<5x2xf32>,
|
||||
%arg1 : tensor<5x5xf32>,
|
||||
%arg2 : tensor<5x7xf32>) -> tensor<5x14xf32> {
|
||||
%result = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) {
|
||||
%result = "mhlo.concatenate"(%arg0, %arg1, %arg2) {
|
||||
dimension = 1 : i64
|
||||
} : (tensor<5x2xf32>, tensor<5x5xf32>, tensor<5x7xf32>) -> tensor<5x14xf32>
|
||||
return %result : tensor<5x14xf32>
|
||||
@ -279,7 +279,7 @@ func @main() {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> tensor<100x28x28x1xf32> {
|
||||
%result = "xla_hlo.convolution"(%arg0, %arg1) {
|
||||
%result = "mhlo.convolution"(%arg0, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
input_batch_dimension = 0 : i64,
|
||||
@ -312,7 +312,7 @@ func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) -> te
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
@ -324,7 +324,7 @@ func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = "xla_hlo.copy"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
|
||||
%0 = "mhlo.copy"(%arg0) : (tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
@ -336,8 +336,8 @@ func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
|
||||
%0 = xla_hlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
|
||||
%1 = "xla_hlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32>
|
||||
%0 = mhlo.constant dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
|
||||
%1 = "mhlo.cross-replica-sum"(%arg0) {replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>} : (tensor<10xf32>) -> tensor<10xf32>
|
||||
return %1 : tensor<10xf32>
|
||||
}
|
||||
|
||||
@ -354,7 +354,7 @@ func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3xf32> {
|
||||
%0 = "xla_hlo.custom_call"(%arg0, %arg1) {backend_config = "bar", call_target_name = "foo"} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
|
||||
%0 = "mhlo.custom_call"(%arg0, %arg1) {backend_config = "bar", call_target_name = "foo"} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
|
||||
return %0 : tensor<1x2x3xf32>
|
||||
}
|
||||
|
||||
@ -369,7 +369,7 @@ func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3xf32>
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> {
|
||||
|
||||
%0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x64xbf16>
|
||||
%0 = "mhlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x64xbf16>
|
||||
return %0 : tensor<16x64xbf16>
|
||||
}
|
||||
|
||||
@ -388,7 +388,7 @@ func @main(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> {
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<16x16xi32>) -> tensor<16x32xbf16> {
|
||||
|
||||
%0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false, is_16bits = true} : (tensor<16x16xi32>) -> tensor<16x32xbf16>
|
||||
%0 = "mhlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false, is_16bits = true} : (tensor<16x16xi32>) -> tensor<16x32xbf16>
|
||||
return %0 : tensor<16x32xbf16>
|
||||
}
|
||||
|
||||
@ -408,7 +408,7 @@ func @main(%arg: tensor<16x16xi32>) -> tensor<16x32xbf16> {
|
||||
func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> {
|
||||
// Simple einsum is lowered to HLO dot op.
|
||||
// CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
%0 = "xla_hlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32>
|
||||
%0 = "mhlo.einsum"(%arg0, %arg1) {einsum_config = "ab,bc->ac"} : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32>
|
||||
return %0 : tensor<3x5xi32>
|
||||
}
|
||||
|
||||
@ -416,7 +416,7 @@ func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex<f32>> {
|
||||
%0 = "xla_hlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<3x9xf32>) -> tensor<3x5xcomplex<f32>>
|
||||
%0 = "mhlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (tensor<3x9xf32>) -> tensor<3x5xcomplex<f32>>
|
||||
return %0 : tensor<3x5xcomplex<f32>>
|
||||
}
|
||||
|
||||
@ -437,7 +437,7 @@ func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10
|
||||
// CHECK-SAME: index_vector_dim=1
|
||||
// CHECK-SAME: slice_sizes={1,1,300}
|
||||
// CHECK-SAME: indices_are_sorted=true
|
||||
%0 = "xla_hlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = true, name = "gather", slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>} : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32>
|
||||
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = true, name = "gather", slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>} : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32>
|
||||
return %0 : tensor<10x300xf32>
|
||||
}
|
||||
|
||||
@ -445,8 +445,8 @@ func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<4x2xf32>, %size: tensor<i32>) -> tensor<i32> {
|
||||
%0 = "xla_hlo.set_dimension_size"(%arg, %size) {dimension = 1 : i32} : (tensor<4x2xf32>, tensor<i32>) -> tensor<4x2xf32>
|
||||
%1 = "xla_hlo.get_dimension_size"(%0) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
|
||||
%0 = "mhlo.set_dimension_size"(%arg, %size) {dimension = 1 : i32} : (tensor<4x2xf32>, tensor<i32>) -> tensor<4x2xf32>
|
||||
%1 = "mhlo.get_dimension_size"(%0) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
|
||||
return %1 : tensor<i32>
|
||||
}
|
||||
|
||||
@ -461,7 +461,7 @@ func @main(%arg: tensor<4x2xf32>, %size: tensor<i32>) -> tensor<i32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
|
||||
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
|
||||
%0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
@ -472,9 +472,9 @@ func @main(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: !xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token> {
|
||||
%0 = "xla_hlo.infeed"(%arg0) {infeed_config = "foobar"} : (!xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token>
|
||||
return %0 : tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token>
|
||||
func @main(%arg0: !mhlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !mhlo.token> {
|
||||
%0 = "mhlo.infeed"(%arg0) {infeed_config = "foobar"} : (!mhlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !mhlo.token>
|
||||
return %0 : tuple<tuple<tensor<3xi32>, tensor<i1>>, !mhlo.token>
|
||||
}
|
||||
|
||||
// CHECK: ENTRY
|
||||
@ -485,7 +485,7 @@ func @main(%arg0: !xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !xl
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main() -> tensor<1x10xf32> {
|
||||
%result = "xla_hlo.iota"() {
|
||||
%result = "mhlo.iota"() {
|
||||
iota_dimension = 1 : i64
|
||||
} : () -> tensor<1x10xf32>
|
||||
return %result : tensor<1x10xf32>
|
||||
@ -498,10 +498,10 @@ func @main() -> tensor<1x10xf32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%0 = "xla_hlo.map"(%arg0, %arg1) ( {
|
||||
%0 = "mhlo.map"(%arg0, %arg1) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
|
||||
%1 = xla_hlo.add %arg2, %arg3 {name = "add"} : tensor<f32>
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
%1 = mhlo.add %arg2, %arg3 {name = "add"} : tensor<f32>
|
||||
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -522,9 +522,9 @@ func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%data: tensor<3xi32>, %token: !xla_hlo.token) -> !xla_hlo.token {
|
||||
%0 = "xla_hlo.outfeed"(%data, %token) {outfeed_config = "foobar"} : (tensor<3xi32>, !xla_hlo.token) -> !xla_hlo.token
|
||||
return %0 : !xla_hlo.token
|
||||
func @main(%data: tensor<3xi32>, %token: !mhlo.token) -> !mhlo.token {
|
||||
%0 = "mhlo.outfeed"(%data, %token) {outfeed_config = "foobar"} : (tensor<3xi32>, !mhlo.token) -> !mhlo.token
|
||||
return %0 : !mhlo.token
|
||||
}
|
||||
|
||||
// CHECK: ENTRY
|
||||
@ -536,7 +536,7 @@ func @main(%data: tensor<3xi32>, %token: !xla_hlo.token) -> !xla_hlo.token {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<4x6xf32>, %pad: tensor<f32>) -> tensor<13x19xf32> {
|
||||
%0 = "xla_hlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<13x19xf32>
|
||||
%0 = "mhlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<13x19xf32>
|
||||
return %0 : tensor<13x19xf32>
|
||||
}
|
||||
|
||||
@ -549,15 +549,15 @@ func @main(%arg: tensor<4x6xf32>, %pad: tensor<f32>) -> tensor<13x19xf32> {
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%token: !xla_hlo.token) -> tuple<tensor<3x4xi32>, !xla_hlo.token> {
|
||||
%0 = "xla_hlo.recv"(%token) {
|
||||
func @main(%token: !mhlo.token) -> tuple<tensor<3x4xi32>, !mhlo.token> {
|
||||
%0 = "mhlo.recv"(%token) {
|
||||
channel_id = {
|
||||
handle = 5 : i64,
|
||||
type = 3 : i64 // Host to device channel
|
||||
},
|
||||
is_host_transfer = true
|
||||
} : (!xla_hlo.token) -> tuple<tensor<3x4xi32>, !xla_hlo.token>
|
||||
return %0 : tuple<tensor<3x4xi32>, !xla_hlo.token>
|
||||
} : (!mhlo.token) -> tuple<tensor<3x4xi32>, !mhlo.token>
|
||||
return %0 : tuple<tensor<3x4xi32>, !mhlo.token>
|
||||
}
|
||||
|
||||
// CHECK: ENTRY
|
||||
@ -569,15 +569,15 @@ func @main(%token: !xla_hlo.token) -> tuple<tensor<3x4xi32>, !xla_hlo.token> {
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%token: !xla_hlo.token) -> tuple<tensor<3x4xi32>, !xla_hlo.token> {
|
||||
%0 = "xla_hlo.recv"(%token) {
|
||||
func @main(%token: !mhlo.token) -> tuple<tensor<3x4xi32>, !mhlo.token> {
|
||||
%0 = "mhlo.recv"(%token) {
|
||||
channel_id = {
|
||||
handle = 5 : i64,
|
||||
type = 1 : i64 // Device to device channel
|
||||
},
|
||||
is_host_transfer = false
|
||||
} : (!xla_hlo.token) -> tuple<tensor<3x4xi32>, !xla_hlo.token>
|
||||
return %0 : tuple<tensor<3x4xi32>, !xla_hlo.token>
|
||||
} : (!mhlo.token) -> tuple<tensor<3x4xi32>, !mhlo.token>
|
||||
return %0 : tuple<tensor<3x4xi32>, !mhlo.token>
|
||||
}
|
||||
|
||||
// CHECK: ENTRY
|
||||
@ -591,11 +591,11 @@ func @main(%token: !xla_hlo.token) -> tuple<tensor<3x4xi32>, !xla_hlo.token> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f32>, %arg3 : tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>) {
|
||||
%result0, %result1 = "xla_hlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
|
||||
%result0, %result1 = "mhlo.reduce"(%arg0, %arg1, %arg2, %arg3) ( {
|
||||
^bb0(%fa: tensor<f32>, %ia : tensor<i32>, %fb: tensor<f32>, %ib: tensor<i32>): // no predecessors
|
||||
%fmax = "xla_hlo.maximum"(%fa, %fb) {} : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%imax = "xla_hlo.maximum"(%ia, %ib) {} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
"xla_hlo.return"(%fmax, %imax) : (tensor<f32>, tensor<i32>) -> ()
|
||||
%fmax = "mhlo.maximum"(%fa, %fb) {} : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%imax = "mhlo.maximum"(%ia, %ib) {} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
"mhlo.return"(%fmax, %imax) : (tensor<f32>, tensor<i32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x10xf32>, tensor<1x10xi32>, tensor<f32>, tensor<i32>) -> (tensor<1xf32>, tensor<1xi32>)
|
||||
return %result0, %result1 : tensor<1xf32>, tensor<1xi32>
|
||||
}
|
||||
@ -617,11 +617,11 @@ func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor<f3
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<2x17x31x7xi32>) -> tensor<2x3x5x7xi32> {
|
||||
%0 = xla_hlo.constant dense<-2147483648> : tensor<i32>
|
||||
%1 = "xla_hlo.reduce_window"(%arg0, %0) ( {
|
||||
%0 = mhlo.constant dense<-2147483648> : tensor<i32>
|
||||
%1 = "mhlo.reduce_window"(%arg0, %0) ( {
|
||||
^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>): // no predecessors
|
||||
%2 = xla_hlo.maximum %arg1, %arg2 : tensor<i32>
|
||||
"xla_hlo.return"(%2) : (tensor<i32>) -> ()
|
||||
%2 = mhlo.maximum %arg1, %arg2 : tensor<i32>
|
||||
"mhlo.return"(%2) : (tensor<i32>) -> ()
|
||||
}) {
|
||||
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
|
||||
window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>,
|
||||
@ -646,7 +646,7 @@ func @main(%arg0: tensor<2x17x31x7xi32>) -> tensor<2x3x5x7xi32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<2xf32>) -> tensor<1x2xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x2xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x2xf32>
|
||||
return %0 : tensor<1x2xf32>
|
||||
}
|
||||
|
||||
@ -658,7 +658,7 @@ func @main(%arg0: tensor<2xf32>) -> tensor<1x2xf32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0 : tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
|
||||
%result = "xla_hlo.reverse"(%arg0) {
|
||||
%result = "mhlo.reverse"(%arg0) {
|
||||
dimensions = dense<[1,2]> : tensor<2xi64>
|
||||
} : (tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32>
|
||||
return %result : tensor<10x11x12x13xf32>
|
||||
@ -672,8 +672,8 @@ func @main(%arg0 : tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%mu: tensor<f32>, %sigma: tensor<f32>) -> tensor<2x3x5xf32> {
|
||||
%shape = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64>
|
||||
%0 = "xla_hlo.rng_normal"(%mu, %sigma, %shape) : (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
|
||||
%shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64>
|
||||
%0 = "mhlo.rng_normal"(%mu, %sigma, %shape) : (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
|
||||
return %0 : tensor<2x3x5xf32>
|
||||
}
|
||||
|
||||
@ -686,10 +686,10 @@ func @main(%mu: tensor<f32>, %sigma: tensor<f32>) -> tensor<2x3x5xf32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main() -> tensor<2x3x5xf32> {
|
||||
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
%2 = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64>
|
||||
%3 = "xla_hlo.rng_uniform"(%0, %1, %2) : (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
|
||||
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
%2 = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64>
|
||||
%3 = "mhlo.rng_uniform"(%0, %1, %2) : (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<2x3x5xf32>
|
||||
return %3 : tensor<2x3x5xf32>
|
||||
}
|
||||
|
||||
@ -702,10 +702,10 @@ func @main() -> tensor<2x3x5xf32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xf32> {
|
||||
%0 = "xla_hlo.scatter" (%input_tensor, %scatter_indices, %updates) ({
|
||||
%0 = "mhlo.scatter" (%input_tensor, %scatter_indices, %updates) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
|
||||
%add = xla_hlo.add %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%add) : (tensor<f32>) -> ()
|
||||
%add = mhlo.add %lhs, %rhs : tensor<f32>
|
||||
"mhlo.return"(%add) : (tensor<f32>) -> ()
|
||||
}) {
|
||||
scatter_dimension_numbers = {
|
||||
update_window_dims = dense<[1]> : tensor<1xi64>,
|
||||
@ -737,7 +737,7 @@ func @main(%arg0: tensor<i1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) ->
|
||||
// CHECK: %[[ARG2:.*]] = s32[2,3] parameter(2)
|
||||
|
||||
// CHECK: ROOT %[[RES:.*]] = s32[2,3] select(pred[2,3] %[[COND]], s32[2,3] %[[ARG1]], s32[2,3] %[[ARG2]])
|
||||
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<i1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%0 = "mhlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<i1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
return %0 : tensor<2x3xi32>
|
||||
}
|
||||
|
||||
@ -745,15 +745,15 @@ func @main(%arg0: tensor<i1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) ->
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
|
||||
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = "xla_hlo.select_and_scatter"(%arg0, %arg1, %0) ( {
|
||||
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ( {
|
||||
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
|
||||
%2 = "xla_hlo.compare"(%arg3, %arg4) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%2) : (tensor<i1>) -> ()
|
||||
%2 = "mhlo.compare"(%arg3, %arg4) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%2) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
|
||||
%2 = xla_hlo.add %arg3, %arg4 : tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}) {
|
||||
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
|
||||
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
|
||||
@ -780,15 +780,15 @@ func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> te
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<3x4xi32>, %token: !xla_hlo.token) -> !xla_hlo.token {
|
||||
%0 = "xla_hlo.send"(%arg, %token) {
|
||||
func @main(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token {
|
||||
%0 = "mhlo.send"(%arg, %token) {
|
||||
channel_id = {
|
||||
handle = 5 : i64,
|
||||
type = 2 : i64 // Device to host channel
|
||||
},
|
||||
is_host_transfer = true
|
||||
} : (tensor<3x4xi32>, !xla_hlo.token) -> !xla_hlo.token
|
||||
return %0 : !xla_hlo.token
|
||||
} : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.token
|
||||
return %0 : !mhlo.token
|
||||
}
|
||||
|
||||
// CHECK: ENTRY
|
||||
@ -801,15 +801,15 @@ func @main(%arg: tensor<3x4xi32>, %token: !xla_hlo.token) -> !xla_hlo.token {
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<3x4xi32>, %token: !xla_hlo.token) -> !xla_hlo.token {
|
||||
%0 = "xla_hlo.send"(%arg, %token) {
|
||||
func @main(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token {
|
||||
%0 = "mhlo.send"(%arg, %token) {
|
||||
channel_id = {
|
||||
handle = 5 : i64,
|
||||
type = 1 : i64 // Device to device channel
|
||||
},
|
||||
is_host_transfer = false
|
||||
} : (tensor<3x4xi32>, !xla_hlo.token) -> !xla_hlo.token
|
||||
return %0 : !xla_hlo.token
|
||||
} : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.token
|
||||
return %0 : !mhlo.token
|
||||
}
|
||||
|
||||
// CHECK: ENTRY
|
||||
@ -823,7 +823,7 @@ func @main(%arg: tensor<3x4xi32>, %token: !xla_hlo.token) -> !xla_hlo.token {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<4x4xf32>, %size: tensor<i32>) -> tensor<4x4xf32> {
|
||||
%0 = "xla_hlo.set_dimension_size"(%arg, %size) {dimension = 1 : i32} : (tensor<4x4xf32>, tensor<i32>) -> tensor<4x4xf32>
|
||||
%0 = "mhlo.set_dimension_size"(%arg, %size) {dimension = 1 : i32} : (tensor<4x4xf32>, tensor<i32>) -> tensor<4x4xf32>
|
||||
return %0 : tensor<4x4xf32>
|
||||
}
|
||||
|
||||
@ -837,7 +837,7 @@ func @main(%arg: tensor<4x4xf32>, %size: tensor<i32>) -> tensor<4x4xf32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "xla_hlo.slice"(%arg) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
|
||||
%0 = "mhlo.slice"(%arg) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
@ -850,7 +850,7 @@ func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<3x4xi32>, %start1: tensor<i64>, %start2: tensor<i64>) -> tensor<1x4xi32> {
|
||||
%0 = "xla_hlo.dynamic-slice"(%arg, %start1, %start2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
|
||||
%0 = "mhlo.dynamic-slice"(%arg, %start1, %start2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
|
||||
return %0 : tensor<1x4xi32>
|
||||
}
|
||||
|
||||
@ -865,7 +865,7 @@ func @main(%arg: tensor<3x4xi32>, %start1: tensor<i64>, %start2: tensor<i64>) ->
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
"xla_hlo.trace"(%arg0) {tag = "This is a random test"} : (tensor<2xi32>) -> ()
|
||||
"mhlo.trace"(%arg0) {tag = "This is a random test"} : (tensor<2xi32>) -> ()
|
||||
return %arg0: tensor<2xi32>
|
||||
}
|
||||
|
||||
@ -880,7 +880,7 @@ func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
|
||||
// CHECK: [[ARG:%.*]] = s32[1,2,3,4] parameter(0)
|
||||
|
||||
// CHECK-NEXT: ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] [[ARG]]), dimensions={1,0,3,2}
|
||||
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
||||
%0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
||||
return %0 : tensor<2x1x4x3xi32>
|
||||
}
|
||||
|
||||
@ -888,7 +888,7 @@ func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> {
|
||||
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32>
|
||||
%0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32>
|
||||
return %0 : tensor<4x3xf32>
|
||||
}
|
||||
|
||||
@ -901,7 +901,7 @@ func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<f32>, %arg1 : tensor<i32>) -> tuple<tensor<f32>, tensor<i32>> {
|
||||
%result = "xla_hlo.tuple"(%arg0, %arg1) {} : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>>
|
||||
%result = "mhlo.tuple"(%arg0, %arg1) {} : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>>
|
||||
return %result : tuple<tensor<f32>, tensor<i32>>
|
||||
}
|
||||
|
||||
@ -916,17 +916,17 @@ func @main(%arg0: tensor<f32>, %arg1 : tensor<i32>) -> tuple<tensor<f32>, tensor
|
||||
func @main(%arg_f32: tensor<4xf32>, %arg_i32: tensor<4xi32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>) {
|
||||
// CHECK: [[ARG_F32:%.*]] = f32[4] parameter(0)
|
||||
// CHECK: [[EXPM1:%.*]] = f32[4] exponential-minus-one(f32[4] [[ARG_F32]])
|
||||
%expm1 = "xla_hlo.exponential_minus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
%expm1 = "mhlo.exponential_minus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK: [[LOG1P:%.*]] = f32[4] log-plus-one(f32[4] [[ARG_F32]])
|
||||
%log1p = "xla_hlo.log_plus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
%log1p = "mhlo.log_plus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK: [[ARG_I32:%.*]] = s32[4] parameter(1)
|
||||
// CHECK: [[NOT:%.*]] = s32[4] not(s32[4] [[ARG_I32]])
|
||||
%not = "xla_hlo.not"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
|
||||
%not = "mhlo.not"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK: [[POPCNT:%.*]] = s32[4] popcnt(s32[4] [[ARG_I32]])
|
||||
%popcnt = "xla_hlo.popcnt"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
|
||||
%popcnt = "mhlo.popcnt"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
return %expm1, %log1p, %not, %popcnt : tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>
|
||||
}
|
||||
@ -937,7 +937,7 @@ func @main(%arg_f32: tensor<4xf32>, %arg_i32: tensor<4xi32>) -> (tensor<4xf32>,
|
||||
func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: [[VAL_1:%.*]] = pred[4] parameter(0)
|
||||
// CHECK: [[VAL_2:%.*]] = pred[4] parameter(1)
|
||||
%0 = xla_hlo.xor %arg0, %arg1 : tensor<4xi1>
|
||||
%0 = mhlo.xor %arg0, %arg1 : tensor<4xi1>
|
||||
// CHECK: ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]])
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
@ -946,10 +946,10 @@ func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||
%0 = "xla_hlo.sort"(%input0, %input1) ( {
|
||||
%0 = "mhlo.sort"(%input0, %input1) ( {
|
||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||
%7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
||||
return
|
||||
}
|
||||
@ -975,7 +975,7 @@ func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> {
|
||||
%0 = "xla_hlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"} : (tensor<16x16xf32>) -> tensor<16x16xf32>
|
||||
%0 = "mhlo.custom_call"(%arg0) {backend_config = "", call_target_name = "Sharding", mhlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"} : (tensor<16x16xf32>) -> tensor<16x16xf32>
|
||||
return %0 : tensor<16x16xf32>
|
||||
}
|
||||
|
||||
@ -988,8 +988,8 @@ func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> {
|
||||
// Tests that the exported HLO module keeps parameter replication annotation.
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {xla_hlo.is_same_data_across_replicas}) -> tensor<16x16xf32> {
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32>
|
||||
func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {mhlo.is_same_data_across_replicas}) -> tensor<16x16xf32> {
|
||||
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32>
|
||||
return %0 : tensor<16x16xf32>
|
||||
}
|
||||
|
||||
@ -1003,8 +1003,8 @@ func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {xla_hlo.is_same_d
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f64>>) -> (tensor<2xf32>, tensor<2xf64>) {
|
||||
%0 = "xla_hlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%1 = "xla_hlo.abs"(%arg1) : (tensor<2xcomplex<f64>>) -> (tensor<2xf64>)
|
||||
%0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%1 = "mhlo.abs"(%arg1) : (tensor<2xcomplex<f64>>) -> (tensor<2xf64>)
|
||||
return %0, %1 : tensor<2xf32>, tensor<2xf64>
|
||||
}
|
||||
|
||||
@ -1019,7 +1019,7 @@ func @main(%arg0: tensor<2xcomplex<f32>>, %arg1: tensor<2xcomplex<f64>>) -> (ten
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<4xui8>) -> tensor<4xui8> {
|
||||
%0 = "xla_hlo.not"(%arg0) : (tensor<4xui8>) -> tensor<4xui8>
|
||||
%0 = "mhlo.not"(%arg0) : (tensor<4xui8>) -> tensor<4xui8>
|
||||
return %0 : tensor<4xui8>
|
||||
}
|
||||
|
||||
@ -1031,7 +1031,7 @@ func @main(%arg0: tensor<4xui8>) -> tensor<4xui8> {
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<4xi32>) -> tensor<*xi32> {
|
||||
%0 = "xla_hlo.not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
|
||||
%0 = "mhlo.not"(%arg0) : (tensor<4xi32>) -> tensor<4xi32>
|
||||
%1 = tensor_cast %0 : tensor<4xi32> to tensor<*xi32>
|
||||
return %1 : tensor<*xi32>
|
||||
}
|
||||
@ -1046,10 +1046,10 @@ func @main(%arg0: tensor<4xi32>) -> tensor<*xi32> {
|
||||
// correctly in HloModule as frontend_attributes.
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> tuple<tensor<3x4xf32>, !xla_hlo.token> {
|
||||
%0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_dtoh_0"}} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token
|
||||
%1 = "xla_hlo.recv"(%0) {channel_id = {handle = 2 : i64, type = 3 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_htod_0"}} : (!xla_hlo.token) -> tuple<tensor<3x4xf32>, !xla_hlo.token>
|
||||
return %1 : tuple<tensor<3x4xf32>, !xla_hlo.token>
|
||||
func @main(%arg: tensor<3x4xf32>, %token: !mhlo.token) -> tuple<tensor<3x4xf32>, !mhlo.token> {
|
||||
%0 = "mhlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_dtoh_0"}} : (tensor<3x4xf32>, !mhlo.token) -> !mhlo.token
|
||||
%1 = "mhlo.recv"(%0) {channel_id = {handle = 2 : i64, type = 3 : i64}, is_host_transfer = true, mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "channel_htod_0"}} : (!mhlo.token) -> tuple<tensor<3x4xf32>, !mhlo.token>
|
||||
return %1 : tuple<tensor<3x4xf32>, !mhlo.token>
|
||||
}
|
||||
|
||||
// CHECK: ENTRY
|
||||
@ -1068,9 +1068,9 @@ func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> tuple<tensor<3x4xf3
|
||||
// populated in HloModule.
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> !xla_hlo.token {
|
||||
%0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, xla_hlo.frontend_attributes = {}} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token
|
||||
return %0 : !xla_hlo.token
|
||||
func @main(%arg: tensor<3x4xf32>, %token: !mhlo.token) -> !mhlo.token {
|
||||
%0 = "mhlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true, mhlo.frontend_attributes = {}} : (tensor<3x4xf32>, !mhlo.token) -> !mhlo.token
|
||||
return %0 : !mhlo.token
|
||||
}
|
||||
|
||||
// CHECK-NOT: frontend_attributes
|
||||
@ -1081,9 +1081,9 @@ func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> !xla_hlo.token {
|
||||
// populated in HloModule.
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<3x4xf32>, %token: !xla_hlo.token) -> !xla_hlo.token {
|
||||
%0 = "xla_hlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true} : (tensor<3x4xf32>, !xla_hlo.token) -> !xla_hlo.token
|
||||
return %0 : !xla_hlo.token
|
||||
func @main(%arg: tensor<3x4xf32>, %token: !mhlo.token) -> !mhlo.token {
|
||||
%0 = "mhlo.send"(%arg, %token) {channel_id = {handle = 1 : i64, type = 2 : i64}, is_host_transfer = true} : (tensor<3x4xf32>, !mhlo.token) -> !mhlo.token
|
||||
return %0 : !mhlo.token
|
||||
}
|
||||
|
||||
// CHECK-NOT: frontend_attributes
|
||||
|
@ -9,95 +9,95 @@ ENTRY %tfcompile.48 {
|
||||
%arg0.1 = f32[1,300] parameter(0)
|
||||
%arg1.2 = f32[1,300,3,1] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla_hlo.reshape"(%arg0) {name = "reshape.3"} : (tensor<1x300xf32>) -> tensor<1x300xf32>
|
||||
// CHECK-NEXT: %0 = "mhlo.reshape"(%arg0) {name = "reshape.3"} : (tensor<1x300xf32>) -> tensor<1x300xf32>
|
||||
%reshape.3 = f32[1,300] reshape(%arg0.1)
|
||||
|
||||
// CHECK-NEXT: %1 = "xla_hlo.transpose"(%0) {name = "transpose.27", permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32>
|
||||
// CHECK-NEXT: %1 = "mhlo.transpose"(%0) {name = "transpose.27", permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32>
|
||||
%transpose.27 = f32[300,1] transpose(%reshape.3), dimensions={1,0}
|
||||
|
||||
// CHECK-NEXT: %2 = "xla_hlo.reshape"(%1) {name = "reshape.28"} : (tensor<300x1xf32>) -> tensor<300x1x1xf32>
|
||||
// CHECK-NEXT: %2 = "mhlo.reshape"(%1) {name = "reshape.28"} : (tensor<300x1xf32>) -> tensor<300x1x1xf32>
|
||||
%reshape.28 = f32[300,1,1] reshape(%transpose.27)
|
||||
|
||||
// CHECK-NEXT: %3 = "xla_hlo.reshape"(%2) {name = "reshape.29"} : (tensor<300x1x1xf32>) -> tensor<300x1xf32>
|
||||
// CHECK-NEXT: %3 = "mhlo.reshape"(%2) {name = "reshape.29"} : (tensor<300x1x1xf32>) -> tensor<300x1xf32>
|
||||
%reshape.29 = f32[300,1] reshape(%reshape.28)
|
||||
|
||||
// CHECK-NEXT: %4 = "xla_hlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.30"} : (tensor<300x1xf32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.30"} : (tensor<300x1xf32>) -> tensor<300x1x5xf32>
|
||||
%broadcast.30 = f32[300,1,5] broadcast(%reshape.29), dimensions={0,1}
|
||||
|
||||
// CHECK-NEXT: %cst = constant {name = "constant.8"} dense<1.000000e+00> : tensor<f32>
|
||||
%constant.8 = f32[] constant(1)
|
||||
|
||||
// CHECK-NEXT: %5 = "xla_hlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.9"} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %5 = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.9"} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
%broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={}
|
||||
|
||||
// CHECK-NEXT: %6 = xla_hlo.multiply %4, %5 {name = "multiply.31"} : tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %6 = mhlo.multiply %4, %5 {name = "multiply.31"} : tensor<300x1x5xf32>
|
||||
%multiply.31 = f32[300,1,5] multiply(%broadcast.30, %broadcast.9)
|
||||
|
||||
// CHECK-NEXT: %cst_0 = constant {name = "constant.32"} dense<0.000000e+00> : tensor<f32>
|
||||
%constant.32 = f32[] constant(0)
|
||||
|
||||
// CHECK-NEXT: %7 = "xla_hlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.33"} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %7 = "mhlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.33"} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
%broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={}
|
||||
|
||||
// CHECK-NEXT: %8 = "xla_hlo.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1>
|
||||
// CHECK-NEXT: %8 = "mhlo.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1>
|
||||
%compare.34 = pred[300,1,5] compare(%multiply.31, %broadcast.33), direction=GT
|
||||
|
||||
// CHECK-NEXT: %cst_1 = constant {name = "constant.10"} dense<0.000000e+00> : tensor<f32>
|
||||
%constant.10 = f32[] constant(0)
|
||||
|
||||
// CHECK-NEXT: %9 = "xla_hlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.11"} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %9 = "mhlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.11"} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
%broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={}
|
||||
|
||||
// CHECK-NEXT: %cst_2 = constant {name = "constant.40"} dense<0.000000e+00> : tensor<f32>
|
||||
%constant.40 = f32[] constant(0)
|
||||
|
||||
// CHECK-NEXT: %10 = "xla_hlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.41"} : (tensor<f32>) -> tensor<300x5xf32>
|
||||
// CHECK-NEXT: %10 = "mhlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<[]> : tensor<0xi64>, name = "broadcast.41"} : (tensor<f32>) -> tensor<300x5xf32>
|
||||
%broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={}
|
||||
|
||||
// CHECK-NEXT: %11 = "xla_hlo.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
|
||||
// CHECK-NEXT: %11 = "mhlo.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
|
||||
%copy.1 = f32[1,300,3,1] copy(%arg1.2)
|
||||
|
||||
// CHECK-NEXT: %12 = "xla_hlo.reshape"(%11) {name = "reshape.4"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
|
||||
// CHECK-NEXT: %12 = "mhlo.reshape"(%11) {name = "reshape.4"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
|
||||
%reshape.4 = f32[1,300,3,1] reshape(%copy.1)
|
||||
|
||||
// CHECK-NEXT: %13 = "xla_hlo.reshape"(%12) {name = "reshape.24"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32>
|
||||
// CHECK-NEXT: %13 = "mhlo.reshape"(%12) {name = "reshape.24"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32>
|
||||
%reshape.24 = f32[1,300,3] reshape(%reshape.4)
|
||||
|
||||
// CHECK-NEXT: %14 = "xla_hlo.transpose"(%13) {name = "transpose.25", permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32>
|
||||
// CHECK-NEXT: %14 = "mhlo.transpose"(%13) {name = "transpose.25", permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32>
|
||||
%transpose.25 = f32[300,1,3] transpose(%reshape.24), dimensions={1,0,2}
|
||||
|
||||
// CHECK-NEXT: %15 = "xla_hlo.reshape"(%14) {name = "reshape.26"} : (tensor<300x1x3xf32>) -> tensor<300x3xf32>
|
||||
// CHECK-NEXT: %15 = "mhlo.reshape"(%14) {name = "reshape.26"} : (tensor<300x1x3xf32>) -> tensor<300x3xf32>
|
||||
%reshape.26 = f32[300,3] reshape(%transpose.25)
|
||||
|
||||
// CHECK-NEXT: %cst_3 = constant {name = "constant.35"} dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
|
||||
%constant.35 = f32[3,5] constant({ { -0.106023, 0.121505, 0.800239, -0.768885, 0.0966113 }, { 0.689014, -0.407056, -0.797853, 0.00378925, -0.208881 }, { -0.608529, 0.0276617, 0.268557, 0.577401, -0.428437 } })
|
||||
|
||||
// TODO(b/129709049) consider making this default precision config implied.
|
||||
// CHECK-NEXT: %16 = "xla_hlo.dot"(%15, %cst_3) {name = "dot.36", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
|
||||
// CHECK-NEXT: %16 = "mhlo.dot"(%15, %cst_3) {name = "dot.36", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
|
||||
%dot.36 = f32[300,5] dot(%reshape.26, %constant.35), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
|
||||
// CHECK-NEXT: %cst_4 = constant {name = "constant.37"} dense<0.000000e+00> : tensor<5xf32>
|
||||
%constant.37 = f32[5]{0} constant({0, 0, 0, 0, 0})
|
||||
|
||||
// CHECK-NEXT: %17 = "xla_hlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.38"} : (tensor<5xf32>) -> tensor<300x5xf32>
|
||||
// CHECK-NEXT: %17 = "mhlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.38"} : (tensor<5xf32>) -> tensor<300x5xf32>
|
||||
%broadcast.38 = f32[300,5] broadcast(%constant.37), dimensions={1}
|
||||
|
||||
// CHECK-NEXT: %18 = xla_hlo.add %16, %17 {name = "add.39"} : tensor<300x5xf32>
|
||||
// CHECK-NEXT: %18 = mhlo.add %16, %17 {name = "add.39"} : tensor<300x5xf32>
|
||||
%add.39 = f32[300,5] add(%dot.36, %broadcast.38)
|
||||
|
||||
// CHECK-NEXT: %19 = xla_hlo.maximum %10, %18 {name = "maximum.42"} : tensor<300x5xf32>
|
||||
// CHECK-NEXT: %19 = mhlo.maximum %10, %18 {name = "maximum.42"} : tensor<300x5xf32>
|
||||
%maximum.42 = f32[300,5] maximum(%broadcast.41, %add.39)
|
||||
|
||||
// CHECK-NEXT: %20 = "xla_hlo.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %20 = "mhlo.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32>
|
||||
%reshape.44 = f32[300,1,5] reshape(%maximum.42)
|
||||
|
||||
// CHECK-NEXT: %21 = "xla_hlo.select"(%8, %9, %20) {name = "select.45"} : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %21 = "mhlo.select"(%8, %9, %20) {name = "select.45"} : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
|
||||
%select.45 = f32[300,1,5] select(%compare.34, %broadcast.11, %reshape.44)
|
||||
|
||||
// CHECK-NEXT: %22 = "xla_hlo.reshape"(%21) {name = "reshape.46"} : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %22 = "mhlo.reshape"(%21) {name = "reshape.46"} : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
|
||||
%reshape.46 = f32[300,1,5] reshape(%select.45)
|
||||
|
||||
// CHECK-NEXT: %23 = "xla_hlo.tuple"(%22) {name = "tuple.47"} : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>>
|
||||
// CHECK-NEXT: %23 = "mhlo.tuple"(%22) {name = "tuple.47"} : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>>
|
||||
// CHECK-NEXT: return %23 : tuple<tensor<300x1x5xf32>>
|
||||
ROOT %tuple.47 = (f32[300,1,5]) tuple(%reshape.46)
|
||||
}
|
||||
|
@ -4,13 +4,13 @@
|
||||
// CHECK: %[[A0]] = (f32[]) parameter(0)
|
||||
func @then_branch(%arg0: tuple<tensor<f32>>) -> tuple<tensor<f32>> {
|
||||
// CHECK: %[[VAL0:.+]] = f32[] get-tuple-element((f32[]) %[[A0]]), index=0
|
||||
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
%0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
|
||||
// CHECK: %[[VAL1:.+]] = f32[] log(f32[] %[[VAL0]])
|
||||
%1 = "xla_hlo.log"(%0) : (tensor<f32>) -> tensor<f32>
|
||||
%1 = "mhlo.log"(%0) : (tensor<f32>) -> tensor<f32>
|
||||
|
||||
// CHECK: ROOT %[[VAl2:.+]] = (f32[]) tuple(f32[] %[[VAL1]])
|
||||
%2 = "xla_hlo.tuple"(%1) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
%2 = "mhlo.tuple"(%1) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
return %2 : tuple<tensor<f32>>
|
||||
}
|
||||
|
||||
@ -18,13 +18,13 @@ func @then_branch(%arg0: tuple<tensor<f32>>) -> tuple<tensor<f32>> {
|
||||
// CHECK: %[[A0]] = (f32[]) parameter(0)
|
||||
func @else_branch(%arg0: tuple<tensor<f32>>) -> tuple<tensor<f32>> {
|
||||
// CHECK: %[[VAL0:.+]] = f32[] get-tuple-element((f32[]) %[[A0]]), index=0
|
||||
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
%0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
|
||||
// CHECK: %[[VAL1:.+]] = f32[] exponential(f32[] %[[VAL0]])
|
||||
%1 = "xla_hlo.exponential"(%0) : (tensor<f32>) -> tensor<f32>
|
||||
%1 = "mhlo.exponential"(%0) : (tensor<f32>) -> tensor<f32>
|
||||
|
||||
// CHECK: ROOT %[[VAL2:.+]] = (f32[]) tuple(f32[] %[[VAL1]])
|
||||
%2 = "xla_hlo.tuple"(%1) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
%2 = "mhlo.tuple"(%1) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
return %2 : tuple<tensor<f32>>
|
||||
}
|
||||
|
||||
@ -35,30 +35,30 @@ func @main(%arg0: tensor<f32>) -> tuple<tensor<f32>> {
|
||||
%cst = constant dense<1.000000e+01> : tensor<f32>
|
||||
|
||||
// CHECK: %[[VAL1:.+]] = pred[] compare(f32[] %[[A0]], f32[] %[[VAL0]]), direction=LT
|
||||
%0 = "xla_hlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
%0 = "mhlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
|
||||
// CHECK: %[[VAL2:.+]] = (f32[]) tuple(f32[] %[[A0]])
|
||||
%1 = "xla_hlo.tuple"(%arg0) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
%1 = "mhlo.tuple"(%arg0) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
|
||||
// CHECK: %[[VAL3:.+]] = (f32[]) conditional(pred[] %[[VAL1]], (f32[]) %[[VAL2]], (f32[]) %[[VAL2]]), true_computation=[[R0]], false_computation=[[R1]]
|
||||
%2 = "xla_hlo.if"(%0, %1, %1) ( {
|
||||
%2 = "mhlo.if"(%0, %1, %1) ( {
|
||||
^bb0(%arg1: tuple<tensor<f32>>):
|
||||
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
%7 = "xla_hlo.log"(%6) : (tensor<f32>) -> tensor<f32>
|
||||
%8 = "xla_hlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
"xla_hlo.return"(%8) : (tuple<tensor<f32>>) -> ()
|
||||
%6 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
%7 = "mhlo.log"(%6) : (tensor<f32>) -> tensor<f32>
|
||||
%8 = "mhlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
"mhlo.return"(%8) : (tuple<tensor<f32>>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tuple<tensor<f32>>):
|
||||
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
%7 = "xla_hlo.exponential"(%6) : (tensor<f32>) -> tensor<f32>
|
||||
%8 = "xla_hlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
"xla_hlo.return"(%8) : (tuple<tensor<f32>>) -> ()
|
||||
%6 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
%7 = "mhlo.exponential"(%6) : (tensor<f32>) -> tensor<f32>
|
||||
%8 = "mhlo.tuple"(%7) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
"mhlo.return"(%8) : (tuple<tensor<f32>>) -> ()
|
||||
}) : (tensor<i1>, tuple<tensor<f32>>, tuple<tensor<f32>>) -> tuple<tensor<f32>>
|
||||
|
||||
// CHECK: %[[VAL4:.+]] = f32[] get-tuple-element((f32[]) %[[VAL3]]), index=0
|
||||
%3 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
%3 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<f32>>) -> tensor<f32>
|
||||
|
||||
// CHECK: ROOT %[[VAL5:.+]] = (f32[]) tuple(f32[] %[[VAL4]])
|
||||
%4 = "xla_hlo.tuple"(%3) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
%4 = "mhlo.tuple"(%3) : (tensor<f32>) -> tuple<tensor<f32>>
|
||||
return %4 : tuple<tensor<f32>>
|
||||
}
|
||||
|
@ -23,31 +23,31 @@ ENTRY %tfcompile.20 {
|
||||
// CHECK: [[C0:%.+]] = constant
|
||||
%constant.3 = f32[] constant(10), metadata={op_type="Less" op_name="Less"}
|
||||
|
||||
// CHECK: [[R1:%.+]] = "xla_hlo.compare"([[A0]], [[C0]])
|
||||
// CHECK: [[R1:%.+]] = "mhlo.compare"([[A0]], [[C0]])
|
||||
%compare.4 = pred[] compare(%arg0.1, %constant.3), direction=LT, metadata={op_type="Less" op_name="Less"}
|
||||
|
||||
// CHECK: [[R2:%.+]] = "xla_hlo.tuple"([[A0]])
|
||||
// CHECK: [[R2:%.+]] = "mhlo.tuple"([[A0]])
|
||||
%tuple.5 = (f32[]) tuple(%arg0.1), metadata={op_type="If" op_name="cond/Merge_if"}
|
||||
|
||||
// CHECK: [[R3:%.+]] = "xla_hlo.if"([[R1]], [[R2]], [[R2]]) ( {
|
||||
// CHECK: [[R3:%.+]] = "mhlo.if"([[R1]], [[R2]], [[R2]]) ( {
|
||||
// CHECK: ^bb0([[A1:%.+]]: tuple<tensor<f32>>):
|
||||
// CHECK: [[R7:%.+]] = "xla_hlo.get_tuple_element"([[A1]])
|
||||
// CHECK: [[R8:%.+]] = "xla_hlo.log"([[R7]])
|
||||
// CHECK: [[R9:%.+]] = "xla_hlo.tuple"([[R8]])
|
||||
// CHECK: "xla_hlo.return"([[R9]])
|
||||
// CHECK: [[R7:%.+]] = "mhlo.get_tuple_element"([[A1]])
|
||||
// CHECK: [[R8:%.+]] = "mhlo.log"([[R7]])
|
||||
// CHECK: [[R9:%.+]] = "mhlo.tuple"([[R8]])
|
||||
// CHECK: "mhlo.return"([[R9]])
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0([[A1:%.+]]: tuple<tensor<f32>>):
|
||||
// CHECK: [[R7:%.+]] = "xla_hlo.get_tuple_element"([[A1]])
|
||||
// CHECK: [[R8:%.+]] = "xla_hlo.exponential"([[R7]])
|
||||
// CHECK: [[R9:%.+]] = "xla_hlo.tuple"([[R8]])
|
||||
// CHECK: "xla_hlo.return"([[R9]])
|
||||
// CHECK: [[R7:%.+]] = "mhlo.get_tuple_element"([[A1]])
|
||||
// CHECK: [[R8:%.+]] = "mhlo.exponential"([[R7]])
|
||||
// CHECK: [[R9:%.+]] = "mhlo.tuple"([[R8]])
|
||||
// CHECK: "mhlo.return"([[R9]])
|
||||
// CHECK: })
|
||||
%conditional.16 = (f32[]) conditional(%compare.4, %tuple.5, %tuple.5), true_computation=%then_branch, false_computation=%else_branch, metadata={op_type="If" op_name="cond/Merge_if"}
|
||||
|
||||
// CHECK: [[R4:%.+]] = "xla_hlo.get_tuple_element"([[R3]])
|
||||
// CHECK: [[R4:%.+]] = "mhlo.get_tuple_element"([[R3]])
|
||||
%get-tuple-element.17 = f32[] get-tuple-element(%conditional.16), index=0, metadata={op_type="If" op_name="cond/Merge_if"}
|
||||
|
||||
// CHECK: [[R5:%.+]] = "xla_hlo.tuple"([[R4]])
|
||||
// CHECK: [[R5:%.+]] = "mhlo.tuple"([[R4]])
|
||||
// CHECK: return [[R5]]
|
||||
ROOT %tuple.19 = (f32[]) tuple(%get-tuple-element.17), metadata={op_name="XLA_Retvals"}
|
||||
}
|
||||
|
@ -13,20 +13,20 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
|
||||
%Arg_0.1 = f32[4]{0} parameter(0)
|
||||
%Arg_1.2 = f32[4]{0} parameter(1)
|
||||
|
||||
// CHECK-NEXT: xla_hlo.add %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
// CHECK-NEXT: mhlo.add %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
%add.3 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2)
|
||||
|
||||
// TODO(b/129709049) consider making this default precision config inferred.
|
||||
// CHECK-NEXT: "xla_hlo.dot"(%0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
|
||||
// CHECK-NEXT: "mhlo.dot"(%0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
|
||||
ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_after_all
|
||||
// CHECK-SAME: ([[VAL_0:%.*]]: !xla_hlo.token, [[VAL_1:%.*]]: !xla_hlo.token) -> !xla_hlo.token [[PRIVATE]]
|
||||
// CHECK-SAME: ([[VAL_0:%.*]]: !mhlo.token, [[VAL_1:%.*]]: !mhlo.token) -> !mhlo.token [[PRIVATE]]
|
||||
%test_after_all (token0: token[], token1: token[] ) -> token[] {
|
||||
token0 = token[] parameter(0)
|
||||
token1 = token[] parameter(1)
|
||||
// CHECK-NEXT: "xla_hlo.after_all"([[VAL_0]], [[VAL_1]]) {name = "{{.*}}"} : (!xla_hlo.token, !xla_hlo.token) -> !xla_hlo.token
|
||||
// CHECK-NEXT: "mhlo.after_all"([[VAL_0]], [[VAL_1]]) {name = "{{.*}}"} : (!mhlo.token, !mhlo.token) -> !mhlo.token
|
||||
ROOT after-all = token[] after-all(token0, token1)
|
||||
}
|
||||
|
||||
@ -41,10 +41,10 @@ add {
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<8xf32>)
|
||||
%test_all_reduce {
|
||||
input = f32[8] parameter(0)
|
||||
// CHECK-NEXT: "xla_hlo.all_reduce"([[INPUT]])
|
||||
// CHECK-NEXT: "mhlo.all_reduce"([[INPUT]])
|
||||
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
|
||||
// CHECK: [[ADD:%.*]] = xla_hlo.add [[ARG0]], [[ARG1]]
|
||||
// CHECK: "xla_hlo.return"([[ADD]]) : (tensor<f32>) -> ()
|
||||
// CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
|
||||
// CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
|
||||
// CHECK: }) {
|
||||
// CHECK-SAME: channel_handle = {handle = 1 : i64, type = 0 : i64}
|
||||
// CHECK-SAME: replica_groups = dense<{{\[\[}}0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>
|
||||
@ -57,7 +57,7 @@ add {
|
||||
%Arg_0.1 = pred[4] parameter(0)
|
||||
%Arg_1.2 = pred[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: xla_hlo.and %arg0, %arg1
|
||||
// CHECK-NEXT: mhlo.and %arg0, %arg1
|
||||
ROOT %and.3 = pred[4] and(pred[4] %Arg_0.1, pred[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -67,7 +67,7 @@ add {
|
||||
%Arg_0.1 = s32[4] parameter(0)
|
||||
%Arg_1.2 = s32[4] parameter(1)
|
||||
|
||||
// CHECK: xla_hlo.atan2 [[VAL_0]], [[VAL_1]]
|
||||
// CHECK: mhlo.atan2 [[VAL_0]], [[VAL_1]]
|
||||
ROOT %atan2 = s32[4] atan2(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -75,10 +75,10 @@ add {
|
||||
%test_broadcast_in_dim {
|
||||
%Arg_0.1 = f32[1, 2] parameter(0)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<1x2x3xf32>
|
||||
// CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<1x2x3xf32>
|
||||
%broadcast.2 = f32[1,2,3] broadcast(%Arg_0.1), dimensions={0,1}
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<3x1x2xf32>
|
||||
// CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<3x1x2xf32>
|
||||
ROOT broadcast.4 = f32[3,1,2] broadcast(%Arg_0.1), dimensions={1, 2}
|
||||
}
|
||||
|
||||
@ -90,7 +90,7 @@ add {
|
||||
%variance = f32[2] parameter(3)
|
||||
%grad_output = f32[2,2,2,2] parameter(4)
|
||||
|
||||
// CHECK: "xla_hlo.batch_norm_grad"
|
||||
// CHECK: "mhlo.batch_norm_grad"
|
||||
// CHECK-SAME: epsilon = 1.000000e-03 : f32
|
||||
// CHECK-SAME: feature_index = 1 : i64
|
||||
ROOT %batch-norm-grad = (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(f32[2,2,2,2] %input, f32[2] %scale, f32[2] %mean, f32[2] %variance, f32[2,2,2,2] %grad_output), epsilon=0.001, feature_index=1
|
||||
@ -113,7 +113,7 @@ add {
|
||||
// CHECK-SAME: ([[ARG:%.*]]: tensor<1x291x291xf32>) -> tensor<1x291x291xf32>
|
||||
%test_cholesky (a: f32[1,291,291]) -> f32[1,291,291] {
|
||||
%a = f32[1,291,291] parameter(0)
|
||||
// CHECK-NEXT: "xla_hlo.cholesky"([[ARG]]) {lower = true, name = {{.*}}} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32>
|
||||
// CHECK-NEXT: "mhlo.cholesky"([[ARG]]) {lower = true, name = {{.*}}} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32>
|
||||
ROOT %out = f32[1,291,291] cholesky(f32[1,291,291] %a), lower=true
|
||||
}
|
||||
|
||||
@ -124,7 +124,7 @@ add {
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
%Arg_2.3 = f32[] parameter(2)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.clamp"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: "mhlo.clamp"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
ROOT %clamp.3 = f32[4] clamp(f32[] %Arg_0.1, f32[4] %Arg_1.2, f32[] %Arg_2.3)
|
||||
}
|
||||
|
||||
@ -132,7 +132,7 @@ add {
|
||||
// CHECK-SAME: ([[ARG:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32>
|
||||
%test_collective_permute (input: f32[128,32]) -> f32[128,32] {
|
||||
%input = f32[128,32]{0,1} parameter(0)
|
||||
// CHECK-NEXT: "xla_hlo.collective_permute"([[ARG]]) {name = {{.*}}, source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>} : (tensor<128x32xf32>) -> tensor<128x32xf32>
|
||||
// CHECK-NEXT: "mhlo.collective_permute"([[ARG]]) {name = {{.*}}, source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>} : (tensor<128x32xf32>) -> tensor<128x32xf32>
|
||||
ROOT root = f32[128,32]{0,1} collective-permute(%input), source_target_pairs={{0,1},{1,2},{2,3}}
|
||||
}
|
||||
|
||||
@ -143,14 +143,14 @@ add {
|
||||
%Arg_1.2 = f32[3] parameter(1)
|
||||
%Arg_2.3 = f32[3] parameter(2)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
|
||||
%compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "LE", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
|
||||
%compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE
|
||||
|
||||
// Requires broadcast of compatible tensors.
|
||||
// CHECK-NEXT: "xla_hlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
|
||||
ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT
|
||||
}
|
||||
|
||||
@ -159,7 +159,7 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.complex"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
// CHECK-NEXT: "mhlo.complex"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
ROOT %complex.3 = c64[4] complex(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -168,7 +168,7 @@ add {
|
||||
%Arg_0.1 = f32[4, 1] parameter(0)
|
||||
%Arg_1.2 = f32[4, 2] parameter(1)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<4x1xf32>, tensor<4x2xf32>) -> tensor<4x3xf32>
|
||||
// CHECK-NEXT: "mhlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<4x1xf32>, tensor<4x2xf32>) -> tensor<4x3xf32>
|
||||
ROOT %concatenate.3 = f32[4, 3] concatenate(f32[4, 1] %Arg_0.1, f32[4, 2] %Arg_1.2), dimensions={1}
|
||||
}
|
||||
|
||||
@ -206,10 +206,10 @@ add {
|
||||
%test_conv {
|
||||
%arg0.1 = f32[256,32,32,6]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
|
||||
|
||||
// CHECK-NEXT: %0 = "xla_hlo.copy"(%arg0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
|
||||
// CHECK-NEXT: %0 = "mhlo.copy"(%arg0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
|
||||
%copy.1 = f32[256,32,32,6]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"}
|
||||
|
||||
// CHECK-NEXT: %1 = "xla_hlo.reshape"(%0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
|
||||
// CHECK-NEXT: %1 = "mhlo.reshape"(%0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
|
||||
%reshape.2 = f32[256,32,32,6]{2,1,3,0} reshape(%copy.1)
|
||||
|
||||
// Note that double brackets "[[" have to be escaped as they denote variables
|
||||
@ -217,7 +217,7 @@ add {
|
||||
// CHECK-NEXT: %cst = constant {name = "{{.*}}"} dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32>
|
||||
%constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
|
||||
|
||||
// CHECK-NEXT: %2 = "xla_hlo.convolution"(%1, %cst) {
|
||||
// CHECK-NEXT: %2 = "mhlo.convolution"(%1, %cst) {
|
||||
// CHECK-SAME: batch_group_count = 1 : i64
|
||||
// CHECK-SAME: dimension_numbers = {
|
||||
// CHECK-SAME: input_batch_dimension = 0 : i64
|
||||
@ -241,10 +241,10 @@ add {
|
||||
|
||||
%convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
|
||||
|
||||
// CHECK-NEXT: %3 = "xla_hlo.reshape"(%2) {name = "{{.*}}"} : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32>
|
||||
// CHECK-NEXT: %3 = "mhlo.reshape"(%2) {name = "{{.*}}"} : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32>
|
||||
%reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"}
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.tuple"(%3) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tuple<tensor<256x30x30x16xf32>>
|
||||
// CHECK-NEXT: "mhlo.tuple"(%3) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tuple<tensor<256x30x30x16xf32>>
|
||||
ROOT %tuple.6 = (f32[256,30,30,16]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"}
|
||||
}
|
||||
|
||||
@ -253,7 +253,7 @@ add {
|
||||
%test_convolve1D_padding (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,5,1] {
|
||||
%input = f32[1,2,1] parameter(0)
|
||||
%filter = f32[1,1,1] parameter(1)
|
||||
// CHECK: "xla_hlo.convolution"
|
||||
// CHECK: "mhlo.convolution"
|
||||
// CHECK-SAME: padding = dense<{{\[\[}}1, 2]]> : tensor<1x2xi64>
|
||||
ROOT %convolution = f32[1,5,1] convolution(f32[1,2,1] %input, f32[1,1,1] %filter), feature_group_count=1, dim_labels=b0f_0io->b0f, window={pad=1_2 size=1}
|
||||
}
|
||||
@ -263,13 +263,13 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla_hlo.convert"(%arg0) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64>
|
||||
// CHECK-NEXT: %0 = "mhlo.convert"(%arg0) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64>
|
||||
%convert.3 = f64[4] convert(f32[4] %Arg_0.1)
|
||||
|
||||
// CHECK-NEXT: %1 = "xla_hlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64>
|
||||
// CHECK-NEXT: %1 = "mhlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64>
|
||||
%convert.4 = f64[4] convert(f32[4] %Arg_1.2)
|
||||
|
||||
// CHECK-NEXT: xla_hlo.add %0, %1
|
||||
// CHECK-NEXT: mhlo.add %0, %1
|
||||
ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[4] %convert.4)
|
||||
}
|
||||
|
||||
@ -277,7 +277,7 @@ add {
|
||||
%test_cosine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
|
||||
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.cosine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
|
||||
// CHECK-NEXT: "mhlo.cosine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
|
||||
ROOT %cosine.3 = f32[1,16,16,3]{3,2,1,0} cosine(f32[1,16,16,3]{3,2,1,0} %arg0.1)
|
||||
}
|
||||
|
||||
@ -286,7 +286,7 @@ add {
|
||||
%test_custom_call (arg1: f32[2,3], arg2: f32[5,5]) -> f32[1,2,3] {
|
||||
%arg1 = f32[2,3] parameter(0)
|
||||
%arg2 = f32[5,5] parameter(1)
|
||||
// CHECK: "xla_hlo.custom_call"([[ARG_0]], [[ARG_1]]) {backend_config = "bar", call_target_name = "foo", has_side_effect = true, name = {{.*}}} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
|
||||
// CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {backend_config = "bar", call_target_name = "foo", has_side_effect = true, name = {{.*}}} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
|
||||
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[2,3] %arg1, f32[5,5] %arg2), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true
|
||||
}
|
||||
|
||||
@ -295,7 +295,7 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: xla_hlo.divide %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
// CHECK-NEXT: mhlo.divide %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
ROOT %divide.3 = f32[4] divide(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -304,17 +304,17 @@ add {
|
||||
%Arg_0.1 = f32[1, 4] parameter(0)
|
||||
%Arg_1.2 = f32[4, 1] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla_hlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
// CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={high,highest}
|
||||
|
||||
// CHECK-NEXT: %1 = "xla_hlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
// CHECK-NEXT: %1 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,default}
|
||||
|
||||
// CHECK-NEXT: %2 = "xla_hlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
// CHECK-NEXT: %2 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
%dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default,default}
|
||||
|
||||
// TODO(b/129709049) consider making this default precision config inferred.
|
||||
// CHECK-NEXT: "xla_hlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
// CHECK-NEXT: "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
}
|
||||
|
||||
@ -325,17 +325,17 @@ add {
|
||||
%Arg_0.1 = f32[4, 1] parameter(0)
|
||||
%Arg_1.2 = f32[1, 4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: [[R0:%.+]] = "xla_hlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]}
|
||||
// CHECK-NEXT: [[R0:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]}
|
||||
dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={high,highest}
|
||||
|
||||
// CHECK-NEXT: [[R1:%.+]] = "xla_hlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]}
|
||||
// CHECK-NEXT: [[R1:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]}
|
||||
dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={highest,default}
|
||||
|
||||
// CHECK-NEXT: [[R2:%.+]] = "xla_hlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]}
|
||||
// CHECK-NEXT: [[R2:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]}
|
||||
%dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={default,default}
|
||||
|
||||
// TODO(b/129709049) consider making this default precision config inferred.
|
||||
// CHECK-NEXT: "xla_hlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]}
|
||||
// CHECK-NEXT: "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]}
|
||||
ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}
|
||||
}
|
||||
|
||||
@ -346,7 +346,7 @@ add {
|
||||
%start_idx_1 = s32[] parameter(1)
|
||||
%start_idx_2 = s32[] parameter(2)
|
||||
%start_idx_3 = s32[] parameter(3)
|
||||
// CHECK: "xla_hlo.dynamic-slice"([[OPERAND]], [[START_IDX_1]], [[START_IDX_2]], [[START_IDX_3]])
|
||||
// CHECK: "mhlo.dynamic-slice"([[OPERAND]], [[START_IDX_1]], [[START_IDX_2]], [[START_IDX_3]])
|
||||
// CHECK-SAME: slice_sizes = dense<[1, 1, 32]> : tensor<3xi64>
|
||||
ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[] %start_idx_1, s32[] %start_idx_2, s32[] %start_idx_3), dynamic_slice_sizes={1,1,32}
|
||||
}
|
||||
@ -358,7 +358,7 @@ add {
|
||||
%Arg_2.3 = s32[] parameter(2)
|
||||
%Arg_3.4 = s32[] parameter(3)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor<i32>, tensor<i32>) -> tensor<4x4xf32>
|
||||
// CHECK-NEXT: "mhlo.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor<i32>, tensor<i32>) -> tensor<4x4xf32>
|
||||
ROOT %dynamic-update-slice.5 = f32[4, 4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4)
|
||||
}
|
||||
|
||||
@ -368,7 +368,7 @@ add {
|
||||
%Arg_1.2 = f32[2] parameter(1)
|
||||
%Arg_2.3 = s32[] parameter(2)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<i32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: "mhlo.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<i32>) -> tensor<4xf32>
|
||||
ROOT %dynamic-update-slice.5 = f32[4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3)
|
||||
}
|
||||
|
||||
@ -376,7 +376,7 @@ add {
|
||||
%test_exponential (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.exponential"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
// CHECK-NEXT: "mhlo.exponential"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
ROOT %exp.2 = f32[16] exponential(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
@ -384,14 +384,14 @@ add {
|
||||
%test_expm1 (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
// CHECK: "xla_hlo.exponential_minus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
// CHECK: "mhlo.exponential_minus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
ROOT %expm1.2 = f32[16] exponential-minus-one(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_fft(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex<f32>>
|
||||
%test_fft {
|
||||
%arg0.1 = f32[3,9]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
|
||||
// CHECK: "xla_hlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"
|
||||
// CHECK: "mhlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"
|
||||
ROOT %fft.2 = c64[3,5]{1,0} fft(%arg0.1), fft_type=RFFT, fft_length={9}, metadata={op_type="RFFT" op_name="rfft"}
|
||||
}
|
||||
|
||||
@ -400,7 +400,7 @@ add {
|
||||
%test_floor (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.floor"([[A0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
// CHECK-NEXT: "mhlo.floor"([[A0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
ROOT %floor.2 = f32[16] floor(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
@ -409,7 +409,7 @@ add {
|
||||
%test_gather (arg.0: f32[200,100,300], arg.1: s32[10,2]) -> f32[10,300] {
|
||||
%arg.0 = f32[200,100,300] parameter(0)
|
||||
%arg.1 = s32[10,2] parameter(1)
|
||||
// CHECK: "xla_hlo.gather"([[ARG0]], [[ARG1]])
|
||||
// CHECK: "mhlo.gather"([[ARG0]], [[ARG1]])
|
||||
// CHECK-SAME: dimension_numbers
|
||||
// CHECK-SAME: collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>
|
||||
// CHECK-SAME: index_vector_dim = 1 : i64
|
||||
@ -430,7 +430,7 @@ add {
|
||||
// CHECK-SAME: ([[ARG:%.*]]: tensor<4x2xf32>)
|
||||
%test_get_dimension_size (Arg_0.1: f32[4,2]) -> s32[] {
|
||||
%Arg_0.1 = f32[4,2] parameter(0)
|
||||
// CHECK-NEXT: "xla_hlo.get_dimension_size"([[ARG]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x2xf32>) -> tensor<i32>
|
||||
// CHECK-NEXT: "mhlo.get_dimension_size"([[ARG]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x2xf32>) -> tensor<i32>
|
||||
ROOT %get-dimension-size.2 = s32[] get-dimension-size(f32[4,2] %Arg_0.1), dimensions={1}
|
||||
}
|
||||
|
||||
@ -438,15 +438,15 @@ add {
|
||||
%test_imag (Arg_0.1: c64[4]) -> f32[4] {
|
||||
%Arg_0.1 = c64[4] parameter(0)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.imag"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: "mhlo.imag"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
|
||||
ROOT %imag.3 = f32[4] imag(c64[4] %Arg_0.1)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_infeed
|
||||
// CHECK-SAME: ([[TOKEN:%.*]]: !xla_hlo.token) -> tuple<tensor<3xi32>, !xla_hlo.token>
|
||||
// CHECK-SAME: ([[TOKEN:%.*]]: !mhlo.token) -> tuple<tensor<3xi32>, !mhlo.token>
|
||||
%test_infeed (token0: token[]) -> (s32[3], token[]) {
|
||||
%token0 = token[] parameter(0)
|
||||
// CHECK-NEXT: "xla_hlo.infeed"([[TOKEN]])
|
||||
// CHECK-NEXT: "mhlo.infeed"([[TOKEN]])
|
||||
// CHECK-SAME: infeed_config = "foobar"
|
||||
ROOT %infeed = (s32[3], token[]) infeed(token[] %token0), infeed_config="foobar"
|
||||
}
|
||||
@ -454,13 +454,13 @@ add {
|
||||
|
||||
// CHECK-LABEL: func @test_iota_1() -> tensor<4xf32>
|
||||
%test_iota_1 () -> f32[4] {
|
||||
// CHECK-NEXT: "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
|
||||
// CHECK-NEXT: "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
|
||||
ROOT %iota.0 = f32[4] iota(), iota_dimension=0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_iota_2() -> tensor<4x5xf32>
|
||||
%test_iota_2 () -> f32[4, 5] {
|
||||
// CHECK-NEXT: "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4x5xf32>
|
||||
// CHECK-NEXT: "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4x5xf32>
|
||||
ROOT %iota.0 = f32[4, 5] iota(), iota_dimension=1
|
||||
}
|
||||
|
||||
@ -468,7 +468,7 @@ add {
|
||||
%test_log (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.log"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
// CHECK-NEXT: "mhlo.log"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
ROOT %log.2 = f32[16] log(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
@ -476,11 +476,11 @@ add {
|
||||
%test_log1p (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
// CHECK: "xla_hlo.log_plus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
// CHECK: "mhlo.log_plus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
ROOT %log1p.2 = f32[16] log-plus-one(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
// Test xla_hlo.map
|
||||
// Test mhlo.map
|
||||
%map_computation {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
@ -492,10 +492,10 @@ add {
|
||||
%test_map {
|
||||
param0 = f32[4]{0} parameter(0)
|
||||
param1 = f32[4]{0} parameter(1)
|
||||
// CHECK: "xla_hlo.map"([[ARG_0]], [[ARG_1]]) ( {
|
||||
// CHECK: "mhlo.map"([[ARG_0]], [[ARG_1]]) ( {
|
||||
// CHECK: ^bb0([[ARG_2:%.*]]: tensor<f32>, [[ARG_3:%.*]]: tensor<f32>):
|
||||
// CHECK: [[ADD:%.*]] = xla_hlo.add [[ARG_2]], [[ARG_3]]
|
||||
// CHECK: "xla_hlo.return"([[ADD]]) : (tensor<f32>) -> ()
|
||||
// CHECK: [[ADD:%.*]] = mhlo.add [[ARG_2]], [[ARG_3]]
|
||||
// CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
|
||||
// CHECK: }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=%map_computation
|
||||
}
|
||||
@ -507,7 +507,7 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: xla_hlo.maximum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
// CHECK-NEXT: mhlo.maximum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
ROOT %maximum.3 = f32[4] maximum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -516,7 +516,7 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: xla_hlo.minimum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
// CHECK-NEXT: mhlo.minimum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
ROOT %minimum.3 = f32[4] minimum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -525,7 +525,7 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = xla_hlo.multiply %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
// CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
ROOT %multiply.3 = f32[4] multiply(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -533,7 +533,7 @@ add {
|
||||
%test_negate (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.negate"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
// CHECK-NEXT: "mhlo.negate"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
ROOT %negate.2 = f32[16] negate(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
@ -541,7 +541,7 @@ add {
|
||||
%test_not (arg0.1: pred[16]) -> pred[16] {
|
||||
%arg0.1 = pred[16] parameter(0)
|
||||
|
||||
// CHECK: "xla_hlo.not"(%arg0) {name = "{{.*}}"} : (tensor<16xi1>) -> tensor<16xi1>
|
||||
// CHECK: "mhlo.not"(%arg0) {name = "{{.*}}"} : (tensor<16xi1>) -> tensor<16xi1>
|
||||
ROOT %not.2 = pred[16] not(pred[16] %arg0.1)
|
||||
}
|
||||
|
||||
@ -550,16 +550,16 @@ add {
|
||||
%Arg_0.1 = pred[4] parameter(0)
|
||||
%Arg_1.2 = pred[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: xla_hlo.or %arg0, %arg1
|
||||
// CHECK-NEXT: mhlo.or %arg0, %arg1
|
||||
ROOT %or.3 = pred[4] or(pred[4] %Arg_0.1, pred[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_outfeed
|
||||
// CHECK-SAME: ([[DATA:%.*]]: tensor<3xi32>, [[TOKEN:%.*]]: !xla_hlo.token) -> !xla_hlo.token
|
||||
// CHECK-SAME: ([[DATA:%.*]]: tensor<3xi32>, [[TOKEN:%.*]]: !mhlo.token) -> !mhlo.token
|
||||
%test_outfeed (Arg_0.1: s32[3], Arg_1.2: token[]) -> token[] {
|
||||
%Arg_0.1 = s32[3] parameter(0)
|
||||
%Arg_1.2 = token[] parameter(1)
|
||||
// CHECK-NEXT: "xla_hlo.outfeed"([[DATA]], [[TOKEN]])
|
||||
// CHECK-NEXT: "mhlo.outfeed"([[DATA]], [[TOKEN]])
|
||||
// CHECK-SAME: outfeed_config = "foobar"
|
||||
ROOT %outfeed.3 = token[] outfeed(s32[3] %Arg_0.1, token[] %Arg_1.2), outfeed_config="foobar"
|
||||
}
|
||||
@ -569,7 +569,7 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
ROOT %pad.3 = f32[4] pad(%Arg_0.1, %Arg_1.2), padding=0_0_0
|
||||
}
|
||||
|
||||
@ -578,7 +578,7 @@ add {
|
||||
%Arg_0.1 = f32[4, 4, 4] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 4, 6]> : tensor<3xi64>, edge_padding_low = dense<[1, 3, 5]> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<4x4x4xf32>, tensor<f32>) -> tensor<7x11x15xf32>
|
||||
// CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 4, 6]> : tensor<3xi64>, edge_padding_low = dense<[1, 3, 5]> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<4x4x4xf32>, tensor<f32>) -> tensor<7x11x15xf32>
|
||||
ROOT %pad.3 = f32[7, 11, 15] pad(%Arg_0.1, %Arg_1.2), padding=1_2x3_4x5_6
|
||||
}
|
||||
|
||||
@ -587,7 +587,7 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<2> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<10xf32>
|
||||
// CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<2> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<10xf32>
|
||||
ROOT %pad.3 = f32[10] pad(%Arg_0.1, %Arg_1.2), padding=0_0_2
|
||||
}
|
||||
|
||||
@ -595,7 +595,7 @@ add {
|
||||
%test_popcnt (arg0.1: s32[16]) -> s32[16] {
|
||||
%arg0.1 = s32[16] parameter(0)
|
||||
|
||||
// CHECK: "xla_hlo.popcnt"(%arg0) {name = "{{.*}}"} : (tensor<16xi32>) -> tensor<16xi32>
|
||||
// CHECK: "mhlo.popcnt"(%arg0) {name = "{{.*}}"} : (tensor<16xi32>) -> tensor<16xi32>
|
||||
ROOT %popcnt.2 = s32[16] popcnt(s32[16] %arg0.1)
|
||||
}
|
||||
|
||||
@ -604,7 +604,7 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: xla_hlo.power %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
// CHECK-NEXT: mhlo.power %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
ROOT %power.3 = f32[4] power(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -614,7 +614,7 @@ add {
|
||||
%Arg_0.1 = f32[] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
// CHECK: [[CST:%.*]] = constant dense<[2, 3, 5]> : tensor<3xi64>
|
||||
// CHECK: "xla_hlo.rng_normal"([[ARG0]], [[ARG1]], [[CST]])
|
||||
// CHECK: "mhlo.rng_normal"([[ARG0]], [[ARG1]], [[CST]])
|
||||
ROOT %rng.4 = f32[2,3,5] rng(f32[] %Arg_0.1, f32[] %Arg_1.2), distribution=rng_normal
|
||||
}
|
||||
|
||||
@ -624,7 +624,7 @@ add {
|
||||
%Arg_0.1 = f32[] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
// CHECK: [[CST:%.*]] = constant dense<[2, 3, 5]> : tensor<3xi64>
|
||||
// CHECK: "xla_hlo.rng_uniform"([[ARG0]], [[ARG1]], [[CST]])
|
||||
// CHECK: "mhlo.rng_uniform"([[ARG0]], [[ARG1]], [[CST]])
|
||||
ROOT %rng.4 = f32[2,3,5] rng(f32[] %Arg_0.1, f32[] %Arg_1.2), distribution=rng_uniform
|
||||
}
|
||||
|
||||
@ -632,7 +632,7 @@ add {
|
||||
%test_real (Arg_0.1: c64[4]) -> f32[4] {
|
||||
%Arg_0.1 = c64[4] parameter(0)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.real"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: "mhlo.real"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
|
||||
ROOT %real.3 = f32[4] real(c64[4] %Arg_0.1)
|
||||
}
|
||||
|
||||
@ -666,28 +666,28 @@ add {
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
%Arg_2.3 = f32[] parameter(2)
|
||||
|
||||
// CHECK: "xla_hlo.reduce"([[ARG0]], [[ARG0]], [[ARG2]], [[ARG2]])
|
||||
// CHECK: xla_hlo.add{{.*}} : tensor<f32>
|
||||
// CHECK: xla_hlo.add{{.*}} : tensor<f32>
|
||||
// CHECK: "mhlo.reduce"([[ARG0]], [[ARG0]], [[ARG2]], [[ARG2]])
|
||||
// CHECK: mhlo.add{{.*}} : tensor<f32>
|
||||
// CHECK: mhlo.add{{.*}} : tensor<f32>
|
||||
// CHECK: {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
|
||||
%reduce.1 = (f32[], f32[]) reduce(%Arg_0.1, %Arg_0.1, %Arg_2.3, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.1
|
||||
|
||||
// CHECK: [[VAL2:%.*]] = "xla_hlo.reduce"([[ARG0]], [[ARG2]])
|
||||
// CHECK: xla_hlo.add{{.*}} : tensor<f32>
|
||||
// CHECK: [[VAL2:%.*]] = "mhlo.reduce"([[ARG0]], [[ARG2]])
|
||||
// CHECK: mhlo.add{{.*}} : tensor<f32>
|
||||
// CHECK: {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<f32>) -> tensor<f32>
|
||||
%reduce.3 = f32[] reduce(%Arg_0.1, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.3
|
||||
|
||||
// CHECK: [[VAL3:%.*]] = "xla_hlo.reduce"([[ARG0]], [[ARG1]])
|
||||
// CHECK: xla_hlo.add{{.*}} : tensor<4xf32>
|
||||
// CHECK: [[VAL3:%.*]] = "mhlo.reduce"([[ARG0]], [[ARG1]])
|
||||
// CHECK: mhlo.add{{.*}} : tensor<4xf32>
|
||||
// CHECK: {dimensions = dense<0> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%reduce.2 = f32[4] reduce(%Arg_0.1, %Arg_1.2), dimensions={0}, to_apply=%reduce_helper.2
|
||||
|
||||
// CHECK: [[VAL4:%.*]] = "xla_hlo.reduce"([[VAL3]], [[ARG2]])
|
||||
// CHECK: xla_hlo.add{{.*}} : tensor<f32>
|
||||
// CHECK: [[VAL4:%.*]] = "mhlo.reduce"([[VAL3]], [[ARG2]])
|
||||
// CHECK: mhlo.add{{.*}} : tensor<f32>
|
||||
// CHECK: {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<f32>
|
||||
%reduce.4 = f32[] reduce(%reduce.2, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.3
|
||||
|
||||
// CHECK: %4 = xla_hlo.subtract [[VAL2]], [[VAL4]] {name = "{{.*}}"} : tensor<f32>
|
||||
// CHECK: %4 = mhlo.subtract [[VAL2]], [[VAL4]] {name = "{{.*}}"} : tensor<f32>
|
||||
%sub.5 = f32[] subtract(%reduce.3, %reduce.4)
|
||||
|
||||
ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.1, %sub.5)
|
||||
@ -699,8 +699,8 @@ add {
|
||||
%Arg_0.1 = f32[2,17,31,7] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
|
||||
// CHECK: "xla_hlo.reduce_window"([[ARG0]], [[ARG1]]) ( {
|
||||
// CHECK: xla_hlo.add {{.*}} : tensor<f32>
|
||||
// CHECK: "mhlo.reduce_window"([[ARG0]], [[ARG1]]) ( {
|
||||
// CHECK: mhlo.add {{.*}} : tensor<f32>
|
||||
// CHECK: }) {
|
||||
// CHECK-SAME: base_dilations = dense<1> : tensor<4xi64>
|
||||
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64>
|
||||
@ -716,7 +716,7 @@ add {
|
||||
%test_remainder (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
// CHECK: xla_hlo.remainder [[VAL_0]], [[VAL_1]]
|
||||
// CHECK: mhlo.remainder [[VAL_0]], [[VAL_1]]
|
||||
ROOT %remainder.3 = f32[4] remainder(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -724,7 +724,7 @@ add {
|
||||
%test_reverse_1d (Arg_0.1: f32[4]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
ROOT reverse.2 = f32[4] reverse(%Arg_0.1), dimensions={0}
|
||||
}
|
||||
|
||||
@ -732,7 +732,7 @@ add {
|
||||
%test_reverse_2d (Arg_0.1: f32[4, 4]) -> f32[4, 4] {
|
||||
%Arg_0.1 = f32[4, 4] parameter(0)
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
// CHECK-NEXT: "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
ROOT reverse.2 = f32[4, 4] reverse(%Arg_0.1), dimensions={0, 1}
|
||||
}
|
||||
|
||||
@ -741,7 +741,7 @@ add {
|
||||
%test_rsqrt (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
// CHECK: "xla_hlo.rsqrt"([[ARG0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
// CHECK: "mhlo.rsqrt"([[ARG0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
ROOT %rsqrt.2 = f32[16] rsqrt(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
@ -767,10 +767,10 @@ add {
|
||||
|
||||
// CHECK-LABEL: func @test_scatter
|
||||
// CHECK-SAME: [[ARG_0:%.*]]: tensor<200x100x300xf32>, [[ARG_1:%.*]]: tensor<10x2xi64>, [[ARG_2:%.*]]: tensor<10x300xf32>) -> tensor<200x100x300xf32>
|
||||
// CHECK: "xla_hlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) ( {
|
||||
// CHECK: "mhlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) ( {
|
||||
// CHECK: ^bb0([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
|
||||
// CHECK: [[ADD:%.*]] = xla_hlo.add [[LHS]], [[RHS]]
|
||||
// CHECK: "xla_hlo.return"([[ADD]]) : (tensor<f32>) -> ()
|
||||
// CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]]
|
||||
// CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
|
||||
// CHECK: })
|
||||
// CHECK-SAME: indices_are_sorted = false
|
||||
// CHECK-SAME: scatter_dimension_numbers = {
|
||||
@ -788,7 +788,7 @@ add {
|
||||
%Arg_1.2 = s32[2,3] parameter(1)
|
||||
%Arg_2.3 = s32[2,3] parameter(2)
|
||||
|
||||
// CHECK: "xla_hlo.select"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
// CHECK: "mhlo.select"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3)
|
||||
}
|
||||
|
||||
@ -814,14 +814,14 @@ add {
|
||||
ROOT %select-and-scatter = f32[4,5] select-and-scatter(f32[4,5] %input, f32[2,2] %source, f32[] %init_value), window={size=2x3 stride=2x3 pad=0_0x0_1}, select=%ge_select, scatter=%add_gather
|
||||
}
|
||||
|
||||
// CHECK: [[RESULT:%.*]] = "xla_hlo.select_and_scatter"([[INPUT]], [[SOURCE]], [[INIT_VAL]]) ( {
|
||||
// CHECK: [[RESULT:%.*]] = "mhlo.select_and_scatter"([[INPUT]], [[SOURCE]], [[INIT_VAL]]) ( {
|
||||
// CHECK: ^bb0([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
|
||||
// CHECK: [[CMP:%.*]] = "xla_hlo.compare"([[LHS]], [[RHS]])
|
||||
// CHECK: "xla_hlo.return"([[CMP]]) : (tensor<i1>) -> ()
|
||||
// CHECK: [[CMP:%.*]] = "mhlo.compare"([[LHS]], [[RHS]])
|
||||
// CHECK: "mhlo.return"([[CMP]]) : (tensor<i1>) -> ()
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
|
||||
// CHECK: [[ADD:%.*]] = xla_hlo.add [[LHS]], [[RHS]]
|
||||
// CHECK: "xla_hlo.return"([[ADD]]) : (tensor<f32>) -> ()
|
||||
// CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]]
|
||||
// CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
|
||||
// CHECK: }) {
|
||||
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 1]]> : tensor<2x2xi64>
|
||||
// CHECK-SAME: window_dimensions = dense<[2, 3]> : tensor<2xi64>
|
||||
@ -835,7 +835,7 @@ add {
|
||||
%test_set_dimension_size (Arg_0.1: f32[4,4], Arg_1.2: s32[]) -> f32[4,<=4] {
|
||||
%Arg_0.1 = f32[4,4] parameter(0)
|
||||
%Arg_1.2 = s32[] parameter(1)
|
||||
// CHECK-NEXT: "xla_hlo.set_dimension_size"([[ARG]], [[SIZE]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x4xf32>, tensor<i32>) -> tensor<4x4xf32>
|
||||
// CHECK-NEXT: "mhlo.set_dimension_size"([[ARG]], [[SIZE]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x4xf32>, tensor<i32>) -> tensor<4x4xf32>
|
||||
ROOT %set-dimension-size.2 = f32[4,<=4] set-dimension-size(f32[4,4] %Arg_0.1, s32[] %Arg_1.2), dimensions={1}
|
||||
}
|
||||
|
||||
@ -843,7 +843,7 @@ add {
|
||||
%test_sine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
|
||||
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.sine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
|
||||
// CHECK-NEXT: "mhlo.sine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
|
||||
ROOT %sine.3 = f32[1,16,16,3]{3,2,1,0} sine(f32[1,16,16,3]{3,2,1,0} %arg0.1)
|
||||
}
|
||||
|
||||
@ -860,10 +860,10 @@ add {
|
||||
}
|
||||
// CHECK-LABEL: func @test_sort
|
||||
// CHECK-SAME: [[ARG:%.*]]: tensor<1024xf32>) -> tensor<1024xf32>
|
||||
// CHECK: "xla_hlo.sort"([[ARG]]) ( {
|
||||
// CHECK: "mhlo.sort"([[ARG]]) ( {
|
||||
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
|
||||
// CHECK: [[CMP:%.*]] = "xla_hlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "LT", name = "lt"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: "xla_hlo.return"([[CMP]]) : (tensor<i1>) -> ()
|
||||
// CHECK: [[CMP:%.*]] = "mhlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "LT", name = "lt"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: "mhlo.return"([[CMP]]) : (tensor<i1>) -> ()
|
||||
// CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<1024xf32>) -> tensor<1024xf32>
|
||||
|
||||
// CHECK-LABEL: func @test_subtract
|
||||
@ -871,7 +871,7 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: xla_hlo.subtract %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
// CHECK-NEXT: mhlo.subtract %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
ROOT %subtract.3 = f32[4] subtract(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -879,7 +879,7 @@ add {
|
||||
%test_tanh (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
|
||||
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
|
||||
|
||||
// CHECK-NEXT: "xla_hlo.tanh"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
|
||||
// CHECK-NEXT: "mhlo.tanh"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
|
||||
ROOT %tanh.3 = f32[1,16,16,3]{3,2,1,0} tanh(f32[1,16,16,3]{3,2,1,0} %arg0.1), metadata={op_type="Tanh" op_name="embedded_inference/tanh_model/Tanh"}
|
||||
}
|
||||
|
||||
@ -887,7 +887,7 @@ add {
|
||||
%test_transpose {
|
||||
%Arg_0.1 = s32[1,2,3,4] parameter(0)
|
||||
|
||||
// CHECK: "xla_hlo.transpose"(%arg0) {name = "{{.*}}", permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
||||
// CHECK: "mhlo.transpose"(%arg0) {name = "{{.*}}", permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
||||
ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2}
|
||||
}
|
||||
|
||||
@ -896,7 +896,7 @@ add {
|
||||
%test_triangular_solve (Arg_0.1: f32[4,4], Arg_1.2: f32[4,3]) -> f32[4,3] {
|
||||
%Arg_0.1 = f32[4,4] parameter(0)
|
||||
%Arg_1.2 = f32[4,3] parameter(1)
|
||||
// CHECK-NEXT: "xla_hlo.triangular_solve"([[ARG_A]], [[ARG_B]])
|
||||
// CHECK-NEXT: "mhlo.triangular_solve"([[ARG_A]], [[ARG_B]])
|
||||
// CHECK-SAME: left_side = true
|
||||
// CHECK-SAME: lower = true
|
||||
// CHECK-SAME: transpose_a = "NO_TRANSPOSE"
|
||||
@ -909,10 +909,10 @@ add {
|
||||
%Arg_0.1 = s32[1] parameter(0)
|
||||
%Arg_1.2 = f32[1, 2] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla_hlo.tuple"(%arg0) {name = "{{.*}}"} : (tensor<1xi32>) -> tuple<tensor<1xi32>>
|
||||
// CHECK-NEXT: %0 = "mhlo.tuple"(%arg0) {name = "{{.*}}"} : (tensor<1xi32>) -> tuple<tensor<1xi32>>
|
||||
%tuple.3 = (s32[1]) tuple(%Arg_0.1)
|
||||
|
||||
// CHECK: "xla_hlo.tuple"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>
|
||||
// CHECK: "mhlo.tuple"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>
|
||||
ROOT %tuple.4 = (s32[1], f32[1,2]) tuple(%Arg_0.1, %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -932,14 +932,14 @@ add {
|
||||
// CHECK-LABEL: func @test_while(%arg0: tensor<i64>) -> tensor<i64>
|
||||
%test_while (arg0.1: s64[]) -> s64[] {
|
||||
%arg0.1 = s64[] parameter(0), metadata={op_name="HLO_Args"}
|
||||
// CHECK-NEXT: "xla_hlo.while"(%arg0) ( {
|
||||
// CHECK-NEXT: "mhlo.while"(%arg0) ( {
|
||||
// CHECK-NEXT: ^bb0(%arg1: tensor<i64>): // no predecessors
|
||||
// CHECK-NEXT: [[CMP:%.*]] = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "{{.*}}"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK-NEXT: "xla_hlo.return"([[CMP]]) : (tensor<i1>) -> ()
|
||||
// CHECK-NEXT: [[CMP:%.*]] = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "{{.*}}"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK-NEXT: "mhlo.return"([[CMP]]) : (tensor<i1>) -> ()
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: ^bb0(%arg1: tensor<i64>): // no predecessors
|
||||
// CHECK-NEXT: [[ADD:%.*]] = xla_hlo.add %arg1, %arg1 {name = "{{.*}}"} : tensor<i64>
|
||||
// CHECK-NEXT: "xla_hlo.return"([[ADD]]) : (tensor<i64>) -> ()
|
||||
// CHECK-NEXT: [[ADD:%.*]] = mhlo.add %arg1, %arg1 {name = "{{.*}}"} : tensor<i64>
|
||||
// CHECK-NEXT: "mhlo.return"([[ADD]]) : (tensor<i64>) -> ()
|
||||
// CHECK-NEXT: }) : (tensor<i64>) -> tensor<i64>
|
||||
ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond
|
||||
}
|
||||
@ -950,7 +950,7 @@ add {
|
||||
%Arg_0.1 = pred[4] parameter(0)
|
||||
%Arg_1.2 = pred[4] parameter(1)
|
||||
|
||||
// CHECK: xla_hlo.xor [[VAL_0]], [[VAL_1]]
|
||||
// CHECK: mhlo.xor [[VAL_0]], [[VAL_1]]
|
||||
ROOT %xor.3 = pred[4] xor(pred[4] %Arg_0.1, pred[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -960,7 +960,7 @@ add {
|
||||
%Arg_0.1 = s32[4] parameter(0)
|
||||
%Arg_1.2 = s32[4] parameter(1)
|
||||
|
||||
// CHECK: xla_hlo.shift_left [[VAL_0]], [[VAL_1]]
|
||||
// CHECK: mhlo.shift_left [[VAL_0]], [[VAL_1]]
|
||||
ROOT %shiftleft = s32[4] shift-left(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -970,7 +970,7 @@ add {
|
||||
%Arg_0.1 = s32[4] parameter(0)
|
||||
%Arg_1.2 = s32[4] parameter(1)
|
||||
|
||||
// CHECK: xla_hlo.shift_right_arithmetic [[VAL_0]], [[VAL_1]]
|
||||
// CHECK: mhlo.shift_right_arithmetic [[VAL_0]], [[VAL_1]]
|
||||
ROOT %shiftright.arithmetic = s32[4] shift-right-arithmetic(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -980,7 +980,7 @@ add {
|
||||
%Arg_0.1 = s32[4] parameter(0)
|
||||
%Arg_1.2 = s32[4] parameter(1)
|
||||
|
||||
// CHECK: xla_hlo.shift_right_logical [[VAL_0]], [[VAL_1]]
|
||||
// CHECK: mhlo.shift_right_logical [[VAL_0]], [[VAL_1]]
|
||||
ROOT %shiftright.logical = s32[4] shift-right-logical(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -992,8 +992,8 @@ add {
|
||||
%Arg_1.2 = c128[2] parameter(1)
|
||||
%abs.4 = f64[2] abs(c128[2] %Arg_1.2)
|
||||
|
||||
// CHECK: "xla_hlo.abs"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
|
||||
// CHECK: "xla_hlo.abs"(%[[ARG1]]) {name = "{{.*}}"} : (tensor<2xcomplex<f64>>) -> tensor<2xf64>
|
||||
// CHECK: "mhlo.abs"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
|
||||
// CHECK: "mhlo.abs"(%[[ARG1]]) {name = "{{.*}}"} : (tensor<2xcomplex<f64>>) -> tensor<2xf64>
|
||||
ROOT %tuple.5 = (f32[2], f64[2]) tuple(f32[2] %abs.3, f64[2] %abs.4)
|
||||
}
|
||||
|
||||
@ -1002,6 +1002,6 @@ add {
|
||||
%unsigned_int(Arg_0.1: u16[4]) -> u16[4] {
|
||||
%Arg_0.1 = u16[4] parameter(0)
|
||||
|
||||
// CHECK: "xla_hlo.not"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<4xui16>) -> tensor<4xui16>
|
||||
// CHECK: "mhlo.not"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<4xui16>) -> tensor<4xui16>
|
||||
ROOT %not.2 = u16[4] not(u16[4] %Arg_0.1)
|
||||
}
|
||||
|
@ -6,7 +6,7 @@
|
||||
// TUPLE-ARG-LABEL: ENTRY %main
|
||||
// TUPLE-ARG: // OutputIndex {0} aliases with input 0 at {0}
|
||||
func @main(%arg0: tensor<1xf32> {tf.aliasing_output = 0 : i64}) -> (tensor<1xf32>) {
|
||||
%0 = xla_hlo.constant dense<4.200000e+01> : tensor<1xf32>
|
||||
%1 = xla_hlo.add %arg0, %0 : tensor<1xf32>
|
||||
%0 = mhlo.constant dense<4.200000e+01> : tensor<1xf32>
|
||||
%1 = mhlo.add %arg0, %0 : tensor<1xf32>
|
||||
return %1 : tensor<1xf32>
|
||||
}
|
||||
|
@ -9,6 +9,6 @@
|
||||
func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<1x2x3x4xi32>) {
|
||||
// CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0)
|
||||
// CHECK-NEXT: %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3}
|
||||
%0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
|
||||
%0 = "mhlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
|
||||
return %arg0, %0 : tensor<4xi32>, tensor<1x2x3x4xi32>
|
||||
}
|
||||
|
@ -139,8 +139,8 @@ dynamic_parameter_binding {
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<f32> {
|
||||
# CHECK-NEXT: %0 = xla_hlo.add %arg0, %arg1 {name = "add.3"} : tensor<4xf32>
|
||||
# CHECK-NEXT: %0 = mhlo.add %arg0, %arg1 {name = "add.3"} : tensor<4xf32>
|
||||
# TODO(b/129709049) consider making this default precision config inferred.
|
||||
# CHECK-NEXT: %1 = "xla_hlo.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
|
||||
# CHECK-NEXT: %1 = "mhlo.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
|
||||
# CHECK-NEXT: return %1 : tensor<f32>
|
||||
# CHECK-NEXT: }
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%1 = "xla_hlo.dot"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%1 = "mhlo.dot"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %1 : tensor<4xf32>
|
||||
}
|
||||
|
||||
|
@ -16,14 +16,14 @@ HloModule foo
|
||||
ENTRY %foo (arg0.1: s64[]) -> s64[] {
|
||||
%arg0.1 = s64[] parameter(0), metadata={op_name="HLO_Args"}
|
||||
|
||||
// CHECK: "xla_hlo.while"(%arg0) ( {
|
||||
// CHECK: "mhlo.while"(%arg0) ( {
|
||||
// CHECK: ^bb0
|
||||
// CHECK: "xla_hlo.compare"
|
||||
// CHECK: "xla_hlo.return"
|
||||
// CHECK: "mhlo.compare"
|
||||
// CHECK: "mhlo.return"
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0
|
||||
// CHECK: xla_hlo.add
|
||||
// CHECK: "xla_hlo.return"
|
||||
// CHECK: mhlo.add
|
||||
// CHECK: "mhlo.return"
|
||||
// CHECK: }) : (tensor<i64>) -> tensor<i64>
|
||||
ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond
|
||||
}
|
@ -2,7 +2,7 @@
|
||||
|
||||
module {
|
||||
func @main(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
// CHECK: [[R0:%.+]] ([[A0:.+]]: s64[]) -> s64[] {
|
||||
// CHECK: %[[A0]] = s64[] parameter(0)
|
||||
// CHECK: ROOT %add.4 = s64[] add(s64[] %[[A0]], s64[] %[[A0]])
|
||||
@ -10,12 +10,12 @@ module {
|
||||
// CHECK: %[[A0]] = s64[] parameter(0)
|
||||
// CHECK: ROOT %compare.7 = pred[] compare(s64[] %[[A0]], s64[] %[[A0]]), direction=LT
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
|
||||
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = xla_hlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64>
|
||||
"xla_hlo.return"(%1) : (tensor<i64>) -> ()
|
||||
%1 = mhlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64>
|
||||
"mhlo.return"(%1) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
// CHECK: ENTRY %main.9 ([[A0:.+]]: s64[]) -> s64[] {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -48,7 +48,7 @@ limitations under the License.
|
||||
using mlir::PassRegistration;
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
class LegalizeTFControlFlow
|
||||
: public PassWrapper<LegalizeTFControlFlow, OperationPass<ModuleOp>> {
|
||||
@ -67,7 +67,7 @@ namespace {
|
||||
void Detuple(Value tuple, Operation::result_range replace, OpBuilder* builder) {
|
||||
// De-tuple the results of the xla hlo if result.
|
||||
for (auto result_it : llvm::enumerate(replace)) {
|
||||
auto get_tuple_value = builder->create<xla_hlo::GetTupleElementOp>(
|
||||
auto get_tuple_value = builder->create<mhlo::GetTupleElementOp>(
|
||||
result_it.value().getLoc(), tuple, result_it.index());
|
||||
result_it.value().replaceAllUsesWith(get_tuple_value);
|
||||
}
|
||||
@ -95,10 +95,10 @@ void ImportXlaRegion(mlir::FuncOp func, Region* dest_region, Location loc,
|
||||
|
||||
auto result = builder.create<CallOp>(loc, func, detupled_args).getResults();
|
||||
if (!tuple_return) {
|
||||
builder.create<xla_hlo::ReturnOp>(loc, result);
|
||||
builder.create<mhlo::ReturnOp>(loc, result);
|
||||
} else {
|
||||
auto tuple_op = builder.create<TupleOp>(loc, result);
|
||||
builder.create<xla_hlo::ReturnOp>(loc, tuple_op.getResult());
|
||||
builder.create<mhlo::ReturnOp>(loc, tuple_op.getResult());
|
||||
}
|
||||
}
|
||||
|
||||
@ -109,12 +109,12 @@ void LowerIf(TF::IfOp op, ModuleOp module) {
|
||||
// XLA prefers tuple arguments for control flow due to XLA not supporting
|
||||
// multiple return values.
|
||||
SmallVector<Value, 3> inputs(op.input());
|
||||
auto tuple_input = builder.create<xla_hlo::TupleOp>(loc, inputs);
|
||||
auto tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
|
||||
|
||||
// Create the new if op with tuple inputs.
|
||||
auto result_type = builder.getTupleType(op.getResultTypes());
|
||||
auto if_op = builder.create<xla_hlo::IfOp>(loc, result_type, op.cond(),
|
||||
tuple_input, tuple_input);
|
||||
auto if_op = builder.create<mhlo::IfOp>(loc, result_type, op.cond(),
|
||||
tuple_input, tuple_input);
|
||||
|
||||
// Import the regions for both the true and false cases. These regions
|
||||
// must be updated to tuple the return results together and use the xla hlo
|
||||
@ -136,15 +136,15 @@ void LowerCase(TF::CaseOp op, ModuleOp module) {
|
||||
// XLA requires one argument per branch so we create a tuple of inputs to pass
|
||||
// to each branch.
|
||||
SmallVector<Value, 4> inputs(op.input());
|
||||
auto tuple_input = builder.create<xla_hlo::TupleOp>(loc, inputs);
|
||||
auto tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
|
||||
|
||||
// Create replica of input tuple for each branch
|
||||
SmallVector<Value, 4> n_tuple_inputs(op.branches().size(), tuple_input);
|
||||
|
||||
// Create the new case op with tuple inputs.
|
||||
auto case_op = builder.create<xla_hlo::CaseOp>(
|
||||
loc, op.getResultTypes(), op.branch_index(), n_tuple_inputs,
|
||||
op.branches().size());
|
||||
auto case_op =
|
||||
builder.create<mhlo::CaseOp>(loc, op.getResultTypes(), op.branch_index(),
|
||||
n_tuple_inputs, op.branches().size());
|
||||
|
||||
// Import the regions for all branches.
|
||||
for (unsigned i = 0; i < op.branches().size(); ++i) {
|
||||
@ -166,10 +166,10 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) {
|
||||
// multiple return values.
|
||||
SmallVector<Value, 3> inputs(op.input());
|
||||
builder.setInsertionPoint(op);
|
||||
Value tuple_input = builder.create<xla_hlo::TupleOp>(loc, inputs);
|
||||
Value tuple_input = builder.create<mhlo::TupleOp>(loc, inputs);
|
||||
|
||||
// Create the new while op with tuple inputs.
|
||||
auto while_op = builder.create<xla_hlo::WhileOp>(
|
||||
auto while_op = builder.create<mhlo::WhileOp>(
|
||||
loc, builder.getTupleType(op.getResultTypes()), tuple_input);
|
||||
|
||||
// Import the regions for both the cond and body. These regions must be
|
||||
@ -204,9 +204,9 @@ void LegalizeTFControlFlow::runOnOperation() {
|
||||
}
|
||||
});
|
||||
}
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
static PassRegistration<mlir::xla_hlo::LegalizeTFControlFlow> cfpass(
|
||||
static PassRegistration<mlir::mhlo::LegalizeTFControlFlow> cfpass(
|
||||
"xla-legalize-tf-control-flow",
|
||||
"Legalize TensorFlow control flow to the XLA dialect");
|
||||
|
@ -366,7 +366,7 @@ class GetDimensionSizeFromEnd<string dimFromEnd>: NativeCodeCall<
|
||||
// For now, this op needs to be created in C++ because the expected output type
|
||||
// cannot be inferred.
|
||||
class createIotaOp<string dim>: NativeCodeCall<
|
||||
"$_builder.create<xla_hlo::IotaOp>($0.getOwner()->getLoc(), "
|
||||
"$_builder.create<mhlo::IotaOp>($0.getOwner()->getLoc(), "
|
||||
"Get2DTensorType($1), $_builder.getI64IntegerAttr(" # dim # "))">;
|
||||
|
||||
// This op needs to be created in C++ because the generated Convert Op has no
|
||||
|
@ -69,7 +69,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/stream_executor.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
template <typename T, size_t N>
|
||||
@ -544,5 +544,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
|
||||
return std::make_unique<LegalizeTF>(device_type);
|
||||
}
|
||||
|
||||
} // end namespace xla_hlo
|
||||
} // end namespace mhlo
|
||||
} // end namespace mlir
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/xla_hlo_to_lhlo_with_xla.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
|
||||
|
||||
#include <memory>
|
||||
#include <tuple>
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user