[XLA] Fix bug in ShapeUtil::ShapeIs that would lead to type inference errors.
PiperOrigin-RevId: 168323589
This commit is contained in:
parent
9f848734fc
commit
c0a4c7ffc2
@ -299,21 +299,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
|
||||
/* static */ bool ShapeUtil::ShapeIs(const Shape& shape,
|
||||
PrimitiveType element_type,
|
||||
std::initializer_list<int64> dimensions) {
|
||||
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
|
||||
if (shape.element_type() != element_type) {
|
||||
return false;
|
||||
}
|
||||
if (shape.dimensions_size() != Rank(shape)) {
|
||||
return false;
|
||||
}
|
||||
int64 i = 0;
|
||||
for (int64 dimension : dimensions) {
|
||||
if (shape.dimensions(i) != dimension) {
|
||||
return false;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
return true;
|
||||
return Equal(shape, MakeShape(element_type, dimensions));
|
||||
}
|
||||
|
||||
/* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
|
||||
|
@ -296,6 +296,8 @@ class ShapeUtil {
|
||||
|
||||
// Shorthand for testing whether a shape is of a given element type and
|
||||
// sequence of dimensions.
|
||||
//
|
||||
// DEPRECATED: Use Equal() instead.
|
||||
static bool ShapeIs(const Shape& shape, PrimitiveType element_type,
|
||||
std::initializer_list<int64> dimensions);
|
||||
|
||||
|
@ -491,6 +491,10 @@ TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) {
|
||||
ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2)));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ShapeIs) {
|
||||
EXPECT_FALSE(ShapeUtil::ShapeIs(ShapeUtil::MakeShape(PRED, {2}), PRED, {}));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ForEachIndex) {
|
||||
struct ShapeDimensionAndNumberInvocations {
|
||||
std::vector<int64> dimensions;
|
||||
|
Loading…
Reference in New Issue
Block a user