[MLIR/XLA] Fix two bugs:

* Correctly handle trivial dimensions in TypeToShape.
* Correctly generate default comparison type in the XLA builder.

PiperOrigin-RevId: 336957760
Change-Id: I5e23797584cf670a994c63e7f3885e356e6bf053
This commit is contained in:
Tim Shen 2020-10-13 14:35:49 -07:00 committed by TensorFlower Gardener
parent 1edb83449e
commit c0b57b23be
5 changed files with 38 additions and 4 deletions

View File

@ -139,7 +139,8 @@ Shape TypeToShape(mlir::Type type) {
for (const auto& e : llvm::enumerate(strides)) {
strides_with_indices.push_back({e.value(), e.index()});
}
std::sort(strides_with_indices.begin(), strides_with_indices.end());
std::stable_sort(strides_with_indices.begin(),
strides_with_indices.end());
llvm::SmallVector<int64, 4> minor_to_major;
int64_t stride = 1;
@ -148,7 +149,7 @@ Shape TypeToShape(mlir::Type type) {
// Either the affine map is not perfectly strided, or the dimensions
// recovered from strides don't match the actual dimensions in shapes.
if (stride != pr.first) return {};
if (stride != pr.first && m.getShape()[pr.second] != 1) return {};
stride *= m.getShape()[pr.second];
}

View File

@ -196,5 +196,22 @@ TEST(TypeToShapeTest, ConvertMemRefToShape) {
EXPECT_TRUE(ShapeUtil::Equal(converted, shape));
}
TEST(TypeToShapeTest, ConvertMemRefToShape2) {
Shape shape = ShapeUtil::MakeShapeWithLayout(PrimitiveType::C64, {2, 4, 3, 3},
{2, 3, 1, 0});
MLIRContext context;
mlir::Builder builder(&context);
StatusOr<mlir::Type> mlir_type =
ConvertShapeToType<MemRefType>(shape, builder);
ASSERT_TRUE(mlir_type.ok());
mlir::Type type = mlir_type.ConsumeValueOrDie();
Shape converted = TypeToShape(type);
EXPECT_TRUE(ShapeUtil::Equal(
converted, ShapeUtil::MakeShapeWithLayout(PrimitiveType::C64,
{2, 4, 3, 3}, {2, 3, 1, 0})));
EXPECT_TRUE(ShapeUtil::Equal(converted, shape));
}
} // namespace
} // namespace xla

View File

@ -257,6 +257,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",

View File

@ -691,8 +691,10 @@ XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
ComparisonDirection direction) {
return Compare(shape, lhs, rhs, direction,
Comparison::DefaultComparisonType(shape.element_type()));
TF_ASSIGN_OR_RETURN(auto operand_shape, GetShape(lhs));
return Compare(
shape, lhs, rhs, direction,
Comparison::DefaultComparisonType(operand_shape.element_type()));
}
StatusOr<XlaOp> XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,

View File

@ -19,6 +19,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -1203,5 +1205,16 @@ TEST_F(XlaBuilderTest, AddFrontendAttribute) {
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
ExpectInstructionsAttributesMatch(*module, expected);
}
TEST_F(XlaBuilderTest, ComparisonType) {
XlaBuilder b(TestName());
(void)Le(ConstantR0<int32>(&b, 1), ConstantR0<int32>(&b, 2));
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
ASSERT_THAT(root, op::Compare(op::Constant(), op::Constant()));
EXPECT_EQ(Comparison::Type::kSigned,
DynCast<HloCompareInstruction>(root)->type());
}
} // namespace
} // namespace xla