[MLIR/XLA] Fix two bugs:
* Correctly handle trivial dimensions in TypeToShape. * Correctly generate default comparison type in the XLA builder. PiperOrigin-RevId: 336957760 Change-Id: I5e23797584cf670a994c63e7f3885e356e6bf053
This commit is contained in:
parent
1edb83449e
commit
c0b57b23be
@ -139,7 +139,8 @@ Shape TypeToShape(mlir::Type type) {
|
||||
for (const auto& e : llvm::enumerate(strides)) {
|
||||
strides_with_indices.push_back({e.value(), e.index()});
|
||||
}
|
||||
std::sort(strides_with_indices.begin(), strides_with_indices.end());
|
||||
std::stable_sort(strides_with_indices.begin(),
|
||||
strides_with_indices.end());
|
||||
|
||||
llvm::SmallVector<int64, 4> minor_to_major;
|
||||
int64_t stride = 1;
|
||||
@ -148,7 +149,7 @@ Shape TypeToShape(mlir::Type type) {
|
||||
|
||||
// Either the affine map is not perfectly strided, or the dimensions
|
||||
// recovered from strides don't match the actual dimensions in shapes.
|
||||
if (stride != pr.first) return {};
|
||||
if (stride != pr.first && m.getShape()[pr.second] != 1) return {};
|
||||
|
||||
stride *= m.getShape()[pr.second];
|
||||
}
|
||||
|
@ -196,5 +196,22 @@ TEST(TypeToShapeTest, ConvertMemRefToShape) {
|
||||
EXPECT_TRUE(ShapeUtil::Equal(converted, shape));
|
||||
}
|
||||
|
||||
TEST(TypeToShapeTest, ConvertMemRefToShape2) {
|
||||
Shape shape = ShapeUtil::MakeShapeWithLayout(PrimitiveType::C64, {2, 4, 3, 3},
|
||||
{2, 3, 1, 0});
|
||||
MLIRContext context;
|
||||
mlir::Builder builder(&context);
|
||||
|
||||
StatusOr<mlir::Type> mlir_type =
|
||||
ConvertShapeToType<MemRefType>(shape, builder);
|
||||
ASSERT_TRUE(mlir_type.ok());
|
||||
mlir::Type type = mlir_type.ConsumeValueOrDie();
|
||||
Shape converted = TypeToShape(type);
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
converted, ShapeUtil::MakeShapeWithLayout(PrimitiveType::C64,
|
||||
{2, 4, 3, 3}, {2, 3, 1, 0})));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(converted, shape));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -257,6 +257,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -691,8 +691,10 @@ XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
|
||||
ComparisonDirection direction) {
|
||||
return Compare(shape, lhs, rhs, direction,
|
||||
Comparison::DefaultComparisonType(shape.element_type()));
|
||||
TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(lhs));
|
||||
return Compare(
|
||||
shape, lhs, rhs, direction,
|
||||
Comparison::DefaultComparisonType(operand_shape.element_type()));
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
|
||||
|
@ -19,6 +19,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -1203,5 +1205,16 @@ TEST_F(XlaBuilderTest, AddFrontendAttribute) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||
ExpectInstructionsAttributesMatch(*module, expected);
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, ComparisonType) {
|
||||
XlaBuilder b(TestName());
|
||||
(void)Le(ConstantR0<int32>(&b, 1), ConstantR0<int32>(&b, 2));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
ASSERT_THAT(root, op::Compare(op::Constant(), op::Constant()));
|
||||
EXPECT_EQ(Comparison::Type::kSigned,
|
||||
DynCast<HloCompareInstruction>(root)->type());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user