Avoid integer overflow in InferConvolveShape().
PiperOrigin-RevId: 218183332
This commit is contained in:
parent
2db1f885bd
commit
f0d7172a30
@ -1571,6 +1571,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
|||||||
TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
|
TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
|
||||||
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
|
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
|
||||||
|
|
||||||
|
if (feature_group_count <= 0) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"feature_group_count must be a positive number, got %d",
|
||||||
|
feature_group_count);
|
||||||
|
}
|
||||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
|
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Convolution with different element types: %s and %s.",
|
"Convolution with different element types: %s and %s.",
|
||||||
@ -1684,14 +1689,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
|||||||
const int64 kernel_output_features =
|
const int64 kernel_output_features =
|
||||||
rhs.dimensions(dnums.kernel_output_feature_dimension());
|
rhs.dimensions(dnums.kernel_output_feature_dimension());
|
||||||
|
|
||||||
if (input_features != kernel_input_features * feature_group_count) {
|
if (input_features % feature_group_count != 0 ||
|
||||||
|
input_features / feature_group_count != kernel_input_features) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Expected LHS feature dimension (value %d) to match RHS "
|
"Expected LHS feature dimension (value %d) to be a multiple of "
|
||||||
"input feature dimension * feature_group_count (value %d * %d = %d); "
|
"feature_group_count (value %d), and LHS feature dimension / "
|
||||||
|
"feature_group_count = RHS feature dimension (value %d); "
|
||||||
"got <conv>(%s, %s)\n"
|
"got <conv>(%s, %s)\n"
|
||||||
"Dimension numbers: {%s}.",
|
"Dimension numbers: {%s}.",
|
||||||
input_features, kernel_input_features, feature_group_count,
|
input_features, feature_group_count, kernel_input_features,
|
||||||
kernel_input_features * feature_group_count,
|
|
||||||
ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
|
ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
|
||||||
dnums.DebugString());
|
dnums.DebugString());
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user