[XLA] Enhancement to source tensor indexing.

Change ElementalIrEmitter::ElementwiseSourceIndex to use the target index as
a source index for the case where the two tensors have the same shape but
different element types.
This improves the implementation of fusion kernels by avoiding the calculation
of the dimensional indices from the linear index for the source tensors.

PiperOrigin-RevId: 177036769
This commit is contained in:
A. Unique TensorFlower 2017-11-27 10:25:14 -08:00 committed by TensorFlower Gardener
parent 4fbf63a8ba
commit b115a9fc73
4 changed files with 18 additions and 1 deletions

View File

@ -905,7 +905,7 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
// If no implicit broadcast is needed for this operand, returns the target
// index as the source index.
if (ShapeUtil::Compatible(operand_shape, hlo.shape())) {
if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape())) {
return target_index;
}

View File

@ -553,6 +553,16 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
return SameDimensions(lhs, rhs) && SameElementType(lhs, rhs);
}
/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
const Shape& rhs) {
if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE &&
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
CompatibleIgnoringElementType);
}
return SameDimensions(lhs, rhs);
}
/* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
int64 dimension_number) {
return shape.dimensions(GetDimensionNumber(shape, dimension_number));

View File

@ -190,6 +190,11 @@ class ShapeUtil {
// compatibility.
static bool Compatible(const Shape& lhs, const Shape& rhs);
// Returns true if the rank and dimension sizes are identical. Element type
// and layout are ignored. Tuple elements are compared recursively for
// compatibility.
static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
// Returns whether the lhs and rhs shapes are identical protobufs.
static bool Equal(const Shape& lhs, const Shape& rhs);

View File

@ -145,6 +145,7 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) {
Shape tuple2 = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})});
EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2));
EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2));
}
TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) {
@ -153,6 +154,7 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) {
Shape tuple2 = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(S32, {3, 2})});
EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2));
EXPECT_TRUE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2));
}
TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) {