diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index f2c16c9dce2..2f8f092303e 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1571,6 +1571,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); + if (feature_group_count <= 0) { + return InvalidArgument( + "feature_group_count must be a positive number, got %d", + feature_group_count); + } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Convolution with different element types: %s and %s.", @@ -1684,14 +1689,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 kernel_output_features = rhs.dimensions(dnums.kernel_output_feature_dimension()); - if (input_features != kernel_input_features * feature_group_count) { + if (input_features % feature_group_count != 0 || + input_features / feature_group_count != kernel_input_features) { return InvalidArgument( - "Expected LHS feature dimension (value %d) to match RHS " - "input feature dimension * feature_group_count (value %d * %d = %d); " + "Expected LHS feature dimension (value %d) to be a multiple of " + "feature_group_count (value %d), and LHS feature dimension / " + "feature_group_count = RHS feature dimension (value %d); " "got (%s, %s)\n" "Dimension numbers: {%s}.", - input_features, kernel_input_features, feature_group_count, - kernel_input_features * feature_group_count, + input_features, feature_group_count, kernel_input_features, ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), dnums.DebugString()); }