[XLA] Enhancement to source tensor indexing.
Change ElementalIrEmitter::ElementwiseSourceIndex to use the target index as a source index for the case where the two tensors have the same shape but different element types. This improves the implementation of fusion kernels by avoiding the calculation of the dimensional indices from the linear index for the source tensors. PiperOrigin-RevId: 177036769
This commit is contained in:
parent
4fbf63a8ba
commit
b115a9fc73
@ -905,7 +905,7 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
|
||||
|
||||
// If no implicit broadcast is needed for this operand, returns the target
|
||||
// index as the source index.
|
||||
if (ShapeUtil::Compatible(operand_shape, hlo.shape())) {
|
||||
if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape())) {
|
||||
return target_index;
|
||||
}
|
||||
|
||||
|
@ -553,6 +553,16 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
|
||||
return SameDimensions(lhs, rhs) && SameElementType(lhs, rhs);
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
|
||||
const Shape& rhs) {
|
||||
if (lhs.element_type() == TUPLE) {
|
||||
return rhs.element_type() == TUPLE &&
|
||||
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
|
||||
CompatibleIgnoringElementType);
|
||||
}
|
||||
return SameDimensions(lhs, rhs);
|
||||
}
|
||||
|
||||
/* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
|
||||
int64 dimension_number) {
|
||||
return shape.dimensions(GetDimensionNumber(shape, dimension_number));
|
||||
|
@ -190,6 +190,11 @@ class ShapeUtil {
|
||||
// compatibility.
|
||||
static bool Compatible(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
// Returns true if the rank and dimension sizes are identical. Element type
|
||||
// and layout are ignored. Tuple elements are compared recursively for
|
||||
// compatibility.
|
||||
static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
// Returns whether the lhs and rhs shapes are identical protobufs.
|
||||
static bool Equal(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
|
@ -145,6 +145,7 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) {
|
||||
Shape tuple2 = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})});
|
||||
EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2));
|
||||
EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) {
|
||||
@ -153,6 +154,7 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) {
|
||||
Shape tuple2 = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(S32, {3, 2})});
|
||||
EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2));
|
||||
EXPECT_TRUE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) {
|
||||
|
Loading…
Reference in New Issue
Block a user