[XLA] Fix bug in ShapeUtil::ShapeIs that would lead to type inference errors.

PiperOrigin-RevId: 168323589
This commit is contained in:
Chris Leary 2017-09-11 20:10:28 -07:00 committed by TensorFlower Gardener
parent 9f848734fc
commit c0a4c7ffc2
3 changed files with 7 additions and 15 deletions

View File

@ -299,21 +299,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
/* static */ bool ShapeUtil::ShapeIs(const Shape& shape,
PrimitiveType element_type,
std::initializer_list<int64> dimensions) {
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
if (shape.element_type() != element_type) {
return false;
}
if (shape.dimensions_size() != Rank(shape)) {
return false;
}
int64 i = 0;
for (int64 dimension : dimensions) {
if (shape.dimensions(i) != dimension) {
return false;
}
i += 1;
}
return true;
return Equal(shape, MakeShape(element_type, dimensions));
}
/* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {

View File

@ -296,6 +296,8 @@ class ShapeUtil {
// Shorthand for testing whether a shape is of a given element type and
// sequence of dimensions.
//
// DEPRECATED: Use Equal() instead.
static bool ShapeIs(const Shape& shape, PrimitiveType element_type,
std::initializer_list<int64> dimensions);

View File

@ -491,6 +491,10 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) {
ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2)));
}
TEST(ShapeUtilTest, ShapeIs) {
EXPECT_FALSE(ShapeUtil::ShapeIs(ShapeUtil::MakeShape(PRED, {2}), PRED, {}));
}
TEST(ShapeUtilTest, ForEachIndex) {
struct ShapeDimensionAndNumberInvocations {
std::vector<int64> dimensions;