[XLA] Add an optional preferred_element_type argument to Dot/Conv builder methods to enable generating HLOs that has wider accumulation type than default shape inference result.

PiperOrigin-RevId: 343546767
Change-Id: Iba19e8c9a9748d0eb0e738432f714f43b4575a38
This commit is contained in:
Ce Zheng 2020-11-20 12:48:04 -08:00 committed by TensorFlower Gardener
parent 7fa0d80d06
commit 06f7a3cb81
24 changed files with 752 additions and 305 deletions

View File

@ -1266,7 +1266,8 @@ StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape,
}
XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
const PrecisionConfig* precision_config) {
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
@ -1278,15 +1279,17 @@ XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
});
}
XlaOp XlaBuilder::DotGeneral(XlaOp lhs, XlaOp rhs,
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig* precision_config) {
XlaOp XlaBuilder::DotGeneral(
XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferDotOpShape(*lhs_shape, *rhs_shape,
dimension_numbers));
TF_ASSIGN_OR_RETURN(
Shape shape,
ShapeInference::InferDotOpShape(
*lhs_shape, *rhs_shape, dimension_numbers, preferred_element_type));
return DotGeneralInternal(shape, lhs, rhs, dimension_numbers,
precision_config);
});
@ -1353,28 +1356,33 @@ Status XlaBuilder::VerifyConvolution(
XlaOp XlaBuilder::Conv(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) {
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ConvWithGeneralDimensions(
lhs, rhs, window_strides, padding,
CreateDefaultConvDimensionNumbers(window_strides.size()),
feature_group_count, batch_group_count, precision_config);
feature_group_count, batch_group_count, precision_config,
preferred_element_type);
}
XlaOp XlaBuilder::ConvWithGeneralPadding(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) {
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ConvGeneral(lhs, rhs, window_strides, padding,
CreateDefaultConvDimensionNumbers(window_strides.size()),
feature_group_count, batch_group_count, precision_config);
feature_group_count, batch_group_count, precision_config,
preferred_element_type);
}
XlaOp XlaBuilder::ConvWithGeneralDimensions(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) {
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
@ -1402,7 +1410,8 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions(
MakePadding(base_area_dimensions, window_dimensions,
window_strides, padding),
dimension_numbers, feature_group_count,
batch_group_count, precision_config);
batch_group_count, precision_config,
preferred_element_type);
});
}
@ -1411,10 +1420,12 @@ XlaOp XlaBuilder::ConvGeneral(
absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) {
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
dimension_numbers, feature_group_count,
batch_group_count, precision_config);
batch_group_count, precision_config,
preferred_element_type);
}
XlaOp XlaBuilder::ConvGeneralDilated(
@ -1423,7 +1434,8 @@ XlaOp XlaBuilder::ConvGeneralDilated(
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) {
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
@ -1442,10 +1454,11 @@ XlaOp XlaBuilder::ConvGeneralDilated(
ShapeInference::InferWindowFromDimensions(
window_dimensions, window_strides, padding,
lhs_dilation, rhs_dilation));
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferConvolveShape(
*lhs_shape, *rhs_shape, feature_group_count,
batch_group_count, window, dimension_numbers));
TF_ASSIGN_OR_RETURN(
Shape shape,
ShapeInference::InferConvolveShape(
*lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
window, dimension_numbers, preferred_element_type));
return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides,
padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count,
@ -1459,7 +1472,8 @@ StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type) {
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
std::vector<int64> window_dimensions(
@ -1472,10 +1486,11 @@ StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions(
window_dimensions, window_strides,
padding, lhs_dilation, rhs_dilation));
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferConvolveShape(
*lhs_shape, *rhs_shape, feature_group_count,
batch_group_count, window, dimension_numbers));
TF_ASSIGN_OR_RETURN(
Shape shape,
ShapeInference::InferConvolveShape(
*lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
window, dimension_numbers, preferred_element_type));
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
@ -1499,14 +1514,15 @@ XlaOp XlaBuilder::DynamicConvInputGrad(
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type) {
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(
HloInstructionProto instr,
DynamicConvInstruction(lhs, rhs, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers,
feature_group_count, batch_group_count,
precision_config, padding_type));
DynamicConvInstruction(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type, preferred_element_type));
instr.set_custom_call_target("DynamicConvolutionInputGrad");
@ -1521,14 +1537,16 @@ XlaOp XlaBuilder::DynamicConvKernelGrad(
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type) {
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(
HloInstructionProto instr,
DynamicConvInstruction(activations, gradients, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers,
feature_group_count, batch_group_count,
precision_config, padding_type));
precision_config, padding_type,
preferred_element_type));
instr.set_custom_call_target("DynamicConvolutionKernelGrad");
// The gradient of kernel has kernel shape and shouldn't have any dynamic
@ -1545,14 +1563,15 @@ XlaOp XlaBuilder::DynamicConvForward(
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type) {
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(
HloInstructionProto instr,
DynamicConvInstruction(lhs, rhs, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers,
feature_group_count, batch_group_count,
precision_config, padding_type));
DynamicConvInstruction(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type, preferred_element_type));
instr.set_custom_call_target("DynamicConvolutionForward");
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, {lhs, rhs});
@ -4074,44 +4093,49 @@ XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
}
XlaOp Dot(const XlaOp lhs, const XlaOp rhs,
const PrecisionConfig* precision_config) {
return lhs.builder()->Dot(lhs, rhs, precision_config);
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->Dot(lhs, rhs, precision_config, preferred_element_type);
}
XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs,
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig* precision_config) {
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
precision_config);
precision_config, preferred_element_type);
}
XlaOp Conv(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) {
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
feature_group_count, batch_group_count,
precision_config);
precision_config, preferred_element_type);
}
XlaOp ConvWithGeneralPadding(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) {
XlaOp ConvWithGeneralPadding(
const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->ConvWithGeneralPadding(
lhs, rhs, window_strides, padding, feature_group_count, batch_group_count,
precision_config);
precision_config, preferred_element_type);
}
XlaOp ConvWithGeneralDimensions(
const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) {
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->ConvWithGeneralDimensions(
lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
batch_group_count, precision_config);
batch_group_count, precision_config, preferred_element_type);
}
XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
@ -4119,10 +4143,11 @@ XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) {
return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding,
dimension_numbers, feature_group_count,
batch_group_count, precision_config);
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->ConvGeneral(
lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
batch_group_count, precision_config, preferred_element_type);
}
XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
@ -4132,26 +4157,27 @@ XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) {
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->ConvGeneralDilated(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count,
precision_config);
precision_config, preferred_element_type);
}
XlaOp DynamicConvInputGrad(XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config,
PaddingType padding_type) {
XlaOp DynamicConvInputGrad(
XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->DynamicConvInputGrad(
input_sizes, lhs, rhs, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type);
precision_config, padding_type, preferred_element_type);
}
XlaOp DynamicConvKernelGrad(
@ -4160,11 +4186,12 @@ XlaOp DynamicConvKernelGrad(
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type) {
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
return activations.builder()->DynamicConvKernelGrad(
activations, gradients, window_strides, padding, lhs_dilation,
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type);
precision_config, padding_type, preferred_element_type);
}
XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
@ -4175,11 +4202,12 @@ XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config,
PaddingType padding_type) {
PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type) {
return lhs.builder()->DynamicConvForward(
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
dimension_numbers, feature_group_count, batch_group_count,
precision_config, padding_type);
precision_config, padding_type, preferred_element_type);
}
XlaOp Fft(const XlaOp operand, FftType fft_type,

View File

@ -521,56 +521,63 @@ class XlaBuilder {
XlaOp tuple_data,
int64 index);
XlaOp Dot(XlaOp lhs, XlaOp rhs,
const PrecisionConfig* precision_config = nullptr);
XlaOp Dot(
XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig* precision_config = nullptr);
XlaOp DotGeneral(
XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
Padding padding, int64 feature_group_count = 1,
int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
XlaOp Conv(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
Padding padding, int64 feature_group_count = 1,
int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp ConvWithGeneralPadding(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count = 1, int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp ConvWithGeneralDimensions(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1, int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1, int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
XlaOp ConvGeneral(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1, int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1,
int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
XlaOp ConvGeneralDilated(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1, int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp DynamicConvForward(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config,
PaddingType padding_type);
XlaOp DynamicConvForward(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp DynamicConvInputGrad(
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
@ -580,7 +587,8 @@ class XlaBuilder {
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type);
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp DynamicConvKernelGrad(
XlaOp activations, XlaOp gradients,
@ -590,7 +598,8 @@ class XlaBuilder {
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type);
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
StatusOr<HloInstructionProto> DynamicConvInstruction(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
@ -599,7 +608,8 @@ class XlaBuilder {
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type);
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
virtual StatusOr<XlaOp> ConvGeneralDilatedInternal(
const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
@ -1098,10 +1108,12 @@ class XlaBuilder {
ComparisonDirection direction,
Comparison::Type compare_type);
friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
const PrecisionConfig* precision_config);
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
const DotDimensionNumbers& dimension_number,
const PrecisionConfig* precision_config);
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type);
virtual StatusOr<XlaOp> DotGeneralInternal(
const Shape& shape, XlaOp lhs, XlaOp rhs,
const DotDimensionNumbers& dimension_number,
@ -1109,23 +1121,27 @@ class XlaBuilder {
friend XlaOp Conv(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides, Padding padding,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config);
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp ConvWithGeneralPadding(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config);
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp ConvWithGeneralDimensions(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_confige);
friend XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config);
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp ConvGeneral(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp DynamicConvForward(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
@ -1133,7 +1149,8 @@ class XlaBuilder {
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type);
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp DynamicConvKernelGrad(
XlaOp activations, XlaOp gradients,
absl::Span<const int64> window_strides,
@ -1142,7 +1159,8 @@ class XlaBuilder {
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type);
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp DynamicConvInputGrad(
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides,
@ -1151,7 +1169,8 @@ class XlaBuilder {
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type);
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp ConvKernelGrad(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
@ -1160,7 +1179,8 @@ class XlaBuilder {
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config);
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp ConvGeneralDilated(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
@ -1169,7 +1189,8 @@ class XlaBuilder {
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config);
const PrecisionConfig* precision_config,
absl::optional<PrimitiveType> preferred_element_type);
friend XlaOp Fft(XlaOp operand, FftType fft_type,
absl::Span<const int64> fft_length);
friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
@ -1813,28 +1834,31 @@ XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
// Enqueues a dot instruction onto the computation.
XlaOp Dot(XlaOp lhs, XlaOp rhs,
const PrecisionConfig* precision_config = nullptr);
const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
// Enqueues a general dot instruction onto the computation.
XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig* precision_config = nullptr);
XlaOp DotGeneral(
XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
Padding padding, int64 feature_group_count = 1,
int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
XlaOp Conv(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
Padding padding, int64 feature_group_count = 1, int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration in the format returned by MakePadding().
XlaOp ConvWithGeneralPadding(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count = 1,
int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
XlaOp ConvWithGeneralPadding(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
int64 feature_group_count = 1, int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
// Enqueues a convolution instruction onto the computation, with the caller
// provided dimension numbers configuration.
@ -1842,47 +1866,48 @@ XlaOp ConvWithGeneralDimensions(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1, int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration as well as the dimension numbers.
XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1, int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
XlaOp ConvGeneral(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1, int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
// Enqueues a convolution instruction onto the computation, with the caller
// provided padding configuration, dilation factors and dimension numbers.
XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1,
int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr);
XlaOp ConvGeneralDilated(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count = 1, int64 batch_group_count = 1,
const PrecisionConfig* precision_config = nullptr,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp DynamicConvForward(XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config,
PaddingType padding_type);
XlaOp DynamicConvForward(
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp DynamicConvInputGrad(XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config,
PaddingType padding_type);
XlaOp DynamicConvInputGrad(
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
XlaOp DynamicConvKernelGrad(
XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
@ -1890,7 +1915,8 @@ XlaOp DynamicConvKernelGrad(
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers,
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config, PaddingType padding_type);
const PrecisionConfig* precision_config, PaddingType padding_type,
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
// Enqueues an FFT instruction onto the computation, of the given type and
// with the given FFT length.

View File

@ -1066,6 +1066,56 @@ TEST_F(XlaBuilderTest, DynamicTranspose) {
<< result_shape;
}
TEST_F(XlaBuilderTest, DotWithPreferredElementType) {
XlaBuilder b(TestName());
Shape p0_shape = ShapeUtil::MakeShape(U8, {2, 3});
Shape p1_shape = ShapeUtil::MakeShape(U16, {3, 2});
auto p0 = Parameter(&b, 0, p0_shape, "p0");
auto p1 = Parameter(&b, 1, p1_shape, "p1");
DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(1);
dnums.add_rhs_contracting_dimensions(0);
DotGeneral(p0, p1, dnums, /*precision_config=*/nullptr,
/*preferred_element_type=*/U32);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
const Shape& result_shape =
module->entry_computation()->root_instruction()->shape();
ASSERT_TRUE(
ShapeUtil::Equal(ShapeUtil::MakeShape(U32, {2, 2}), result_shape));
}
TEST_F(XlaBuilderTest, ConvolutionWithPreferredElementType) {
XlaBuilder b(TestName());
Shape p0_shape = ShapeUtil::MakeShape(S16, {1, 2, 2, 128});
Shape p1_shape = ShapeUtil::MakeShape(S8, {2, 2, 128, 8});
auto p0 = Parameter(&b, 0, p0_shape, "p0");
auto p1 = Parameter(&b, 1, p1_shape, "p1");
ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(0);
dnums.set_output_batch_dimension(0);
dnums.add_input_spatial_dimensions(1);
dnums.add_output_spatial_dimensions(1);
dnums.add_input_spatial_dimensions(2);
dnums.add_output_spatial_dimensions(2);
dnums.set_input_feature_dimension(3);
dnums.set_output_feature_dimension(3);
dnums.add_kernel_spatial_dimensions(0);
dnums.add_kernel_spatial_dimensions(1);
dnums.set_kernel_input_feature_dimension(2);
dnums.set_kernel_output_feature_dimension(3);
ConvWithGeneralDimensions(p0, p1, {1, 1}, Padding::kValid, dnums,
/*feature_group_count=*/1, /*batch_group_count=*/1,
/*precision_config=*/nullptr,
/*preferred_element_type=*/S32);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
const Shape& result_shape =
module->entry_computation()->root_instruction()->shape();
ASSERT_TRUE(
ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {1, 1, 1, 8}), result_shape));
}
TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) {
XlaBuilder b(TestName());
AfterAll(&b, {CreateToken(&b), ConstantR0<float>(&b, 1.0)});

View File

@ -108,7 +108,8 @@ void BuildOpsSubmodule(py::module* m) {
py::arg("lhs_dilation"), py::arg("rhs_dilation"),
py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
py::arg("batch_group_count") = 1,
py::arg("precision_config") = nullptr);
py::arg("precision_config") = nullptr,
py::arg("preferred_element_type") = absl::nullopt);
ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
py::arg("new_element_type"));
ops.def(
@ -136,9 +137,11 @@ void BuildOpsSubmodule(py::module* m) {
py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false);
ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
py::arg("precision_config") = nullptr);
py::arg("precision_config") = nullptr,
py::arg("preferred_element_type") = absl::nullopt);
ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
py::arg("dimension_numbers"), py::arg("precision_config") = nullptr);
py::arg("dimension_numbers"), py::arg("precision_config") = nullptr,
py::arg("preferred_element_type") = absl::nullopt);
ops.def("DynamicSlice",
static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
absl::Span<const int64>)>(&DynamicSlice),

View File

@ -465,7 +465,8 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
const Shape& shape =
ShapeInference::InferConvolveShape(
lhs_literal.shape(), rhs_literal.shape(),
/*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums)
/*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
/*preferred_element_type=*/absl::nullopt)
.ConsumeValueOrDie();
HloInstruction* lhs_instruction =

View File

@ -1798,10 +1798,10 @@ StatusOr<bool> AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot(
ShapeUtil::DropDegenerateDimensions(rhs_shape),
dot->mutable_operand(1)))
: dot->mutable_operand(1);
TF_ASSIGN_OR_RETURN(auto new_dot, MakeDotHlo(new_lhs, new_rhs, new_dnums,
dot->precision_config()));
// TODO(b/165824019): Add an optional preferred element type to MakeDotHlo.
new_dot->mutable_shape()->set_element_type(dot->shape().element_type());
TF_ASSIGN_OR_RETURN(
auto new_dot,
MakeDotHlo(new_lhs, new_rhs, new_dnums, dot->precision_config(),
/*preferred_element_type=*/dot->shape().element_type()));
if (ShapeUtil::Compatible(dot->shape(), new_dot->shape())) {
TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot));
} else {
@ -4678,10 +4678,10 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
}
}
TF_ASSIGN_OR_RETURN(
auto new_dot, MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config()));
auto new_dot,
MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config(),
/*preferred_element_type=*/dot->shape().element_type()));
dot->SetupDerivedInstruction(new_dot);
// TODO(b/165824019): Add an optional preferred element type to MakeDotHlo.
new_dot->mutable_shape()->set_element_type(dot->shape().element_type());
if (reduce_dims.empty()) {
return ReplaceInstruction(hlo, new_dot);
}
@ -5312,10 +5312,13 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
if (!reverse_dimensions.empty()) {
TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions));
}
TF_ASSIGN_OR_RETURN(HloInstruction * new_convolution,
MakeConvolveHlo(kernel, input, /*feature_group_count=*/1,
/*batch_group_count=*/1, swapped_window,
swapped_dnums, precision_config));
TF_ASSIGN_OR_RETURN(
HloInstruction * new_convolution,
MakeConvolveHlo(
kernel, input, /*feature_group_count=*/1,
/*batch_group_count=*/1, swapped_window, swapped_dnums,
precision_config,
/*preferred_element_type=*/convolution->shape().element_type()));
convolution->SetupDerivedInstruction(new_convolution);
TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution));

View File

@ -3785,9 +3785,11 @@ TEST_P(ConvInputPaddingTest, DoTest) {
ParseWindow(absl::StrCat("size=3x3 ", testcase.orig_conv_window))
.ValueOrDie();
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(),
/*feature_group_count=*/1,
/*batch_group_count=*/1, window, dnums)
ShapeInference::InferConvolveShape(
lhs_pad->shape(), filter->shape(),
/*feature_group_count=*/1,
/*batch_group_count=*/1, window, dnums,
/*preferred_element_type=*/absl::nullopt)
.ValueOrDie(),
lhs_pad, filter, /*feature_group_count=*/1, /*batch_group_count=*/1,
window, dnums, DefaultPrecisionConfig(2)));
@ -3902,9 +3904,11 @@ TEST_P(ConvFilterPaddingTest, DoIt) {
precision_config.add_operand_precision(PrecisionConfig::HIGHEST);
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
/*feature_group_count=*/1,
/*batch_group_count=*/1, window, dnums)
ShapeInference::InferConvolveShape(
input->shape(), rhs_pad->shape(),
/*feature_group_count=*/1,
/*batch_group_count=*/1, window, dnums,
/*preferred_element_type=*/absl::nullopt)
.ValueOrDie(),
input, rhs_pad, /*feature_group_count=*/1, /*batch_group_count=*/1,
window, dnums, precision_config));
@ -4050,7 +4054,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
Shape out_shape = ShapeInference::InferConvolveShape(
in_shape, f_shape, /*feature_group_count=*/1,
/*batch_group_count=*/1, window, dnums)
/*batch_group_count=*/1, window, dnums,
/*preferred_element_type=*/absl::nullopt)
.ValueOrDie();
if (options.output_minor_to_major_layout) {
out_shape = ShapeUtil::MakeShapeWithLayout(F32, out_shape.dimensions(),

View File

@ -87,9 +87,11 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
0,
new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size());
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
MakeDotHlo(new_lhs, new_rhs, new_dim_numbers,
batch_dot->precision_config()));
TF_ASSIGN_OR_RETURN(
HloInstruction * new_dot,
MakeDotHlo(new_lhs, new_rhs, new_dim_numbers,
batch_dot->precision_config(),
/*preferred_element_type=*/batch_dot->shape().element_type()));
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped,
MakeReshapeHlo(batch_dot->shape(), new_dot));

View File

@ -299,9 +299,11 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
window_dim->set_window_reversal(false);
window_dim->set_window_dilation(1);
HloInstruction* new_convolution =
MakeConvolveHlo(activation, filter, convolution->feature_group_count(),
/*batch_group_count=*/1, window, dim_numbers,
convolution->precision_config())
MakeConvolveHlo(
activation, filter, convolution->feature_group_count(),
/*batch_group_count=*/1, window, dim_numbers,
convolution->precision_config(),
/*preferred_element_type=*/convolution->shape().element_type())
.ValueOrDie();
convolution->SetupDerivedInstruction(new_convolution);
TF_CHECK_OK(computation_->ReplaceInstruction(
@ -650,9 +652,11 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
window_dim->set_window_reversal(false);
window_dim->set_window_dilation(1);
HloInstruction* new_convolution =
MakeConvolveHlo(activation, filter, /*feature_group_count=*/1,
/*batch_group_count=*/1, window, dim_numbers,
convolution->precision_config())
MakeConvolveHlo(
activation, filter, /*feature_group_count=*/1,
/*batch_group_count=*/1, window, dim_numbers,
convolution->precision_config(),
/*preferred_element_type=*/convolution->shape().element_type())
.ValueOrDie();
convolution->SetupDerivedInstruction(new_convolution);
changed_ = true;

View File

@ -142,7 +142,8 @@ CreateShardedConvForDotGeneralConvolution(
ShapeInference::InferConvolveShape(
sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
/*feature_group_count=*/conv.feature_group_count(),
/*batch_group_count=*/conv.batch_group_count(), window, conv_dnums));
/*batch_group_count=*/conv.batch_group_count(), window, conv_dnums,
/*preferred_element_type=*/conv.shape().element_type()));
*sharded_conv_shape.mutable_layout() = conv.shape().layout();
return HloInstruction::CreateConvolve(
sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo,

View File

@ -110,7 +110,8 @@ TEST_F(GpuConvRewriterTest, BackwardFilterConvolve) {
ShapeInference::InferConvolveShape(
activations->shape(), gradients->shape(), /*feature_group_count=*/1,
/*batch_group_count=*/1, conv_window,
tf_default_dnums_for_backward_filter_)
tf_default_dnums_for_backward_filter_,
/*preferred_element_type=*/absl::nullopt)
.ConsumeValueOrDie(),
activations, gradients, /*feature_group_count=*/1,
/*batch_group_count=*/1, conv_window,
@ -150,7 +151,8 @@ TEST_F(GpuConvRewriterTest,
ShapeInference::InferConvolveShape(
activations->shape(), gradients->shape(), /*feature_group_count=*/1,
/*batch_group_count=*/1, conv_window,
tf_default_dnums_for_backward_filter_)
tf_default_dnums_for_backward_filter_,
/*preferred_element_type=*/absl::nullopt)
.ConsumeValueOrDie(),
activations, gradients, /*feature_group_count=*/1,
/*batch_group_count=*/1, conv_window,
@ -292,11 +294,12 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveEvenPadding) {
DefaultPrecisionConfig(2)));
// Verify the convolution's shape is consistent with ShapeInference.
CHECK(ShapeUtil::Compatible(
conv->shape(), ShapeInference::InferConvolveShape(
output->shape(), reverse_kernel->shape(),
/*feature_group_count=*/1, /*batch_group_count=*/1,
conv_window, conv_dnums)
.ValueOrDie()));
conv->shape(),
ShapeInference::InferConvolveShape(
output->shape(), reverse_kernel->shape(),
/*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
conv_dnums, /*preferred_element_type=*/absl::nullopt)
.ValueOrDie()));
auto module = CreateNewVerifiedModule();
HloComputation* entry_computation =
@ -337,10 +340,12 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolve1x1Filter) {
conv_window.mutable_dimensions(1)->set_base_dilation(2);
builder.AddInstruction(HloInstruction::CreateConvolve(
ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
/*feature_group_count=*/1,
/*batch_group_count=*/1, conv_window,
tf_default_dnums_for_backward_input_)
ShapeInference::InferConvolveShape(
output->shape(), kernel->shape(),
/*feature_group_count=*/1,
/*batch_group_count=*/1, conv_window,
tf_default_dnums_for_backward_input_,
/*preferred_element_type=*/absl::nullopt)
.ConsumeValueOrDie(),
/*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
/*batch_group_count=*/1, conv_window,
@ -374,7 +379,8 @@ TEST_F(GpuConvRewriterTest,
ShapeInference::InferConvolveShape(
output->shape(), kernel->shape(), /*feature_group_count=*/1,
/*batch_group_count=*/1, default_conv_window_,
tf_default_dnums_for_backward_input_)
tf_default_dnums_for_backward_input_,
/*preferred_element_type=*/absl::nullopt)
.ConsumeValueOrDie(),
/*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
/*batch_group_count=*/1, default_conv_window_,
@ -431,7 +437,8 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) {
conv->shape(), ShapeInference::InferConvolveShape(
output->shape(), reverse_kernel->shape(),
/*feature_group_count=*/1, /*batch_group_count=*/1,
conv_window, tf_default_dnums_for_backward_input_)
conv_window, tf_default_dnums_for_backward_input_,
/*preferred_element_type=*/absl::nullopt)
.ValueOrDie()));
auto module = CreateNewVerifiedModule();
@ -481,7 +488,8 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
conv->shape(), ShapeInference::InferConvolveShape(
output->shape(), reverse_kernel->shape(),
/*feature_group_count=*/1, /*batch_group_count=*/1,
conv_window, tf_default_dnums_for_backward_input_)
conv_window, tf_default_dnums_for_backward_input_,
/*preferred_element_type=*/absl::nullopt)
.ValueOrDie()));
auto module = CreateNewVerifiedModule();
@ -535,7 +543,8 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) {
conv->shape(), ShapeInference::InferConvolveShape(
output->shape(), reverse_kernel->shape(),
/*feature_group_count=*/1, /*batch_group_count=*/1,
conv_window, tf_default_dnums_for_backward_input_)
conv_window, tf_default_dnums_for_backward_input_,
/*preferred_element_type=*/absl::nullopt)
.ValueOrDie()));
auto module = CreateNewVerifiedModule();
@ -590,7 +599,8 @@ TEST_F(GpuConvRewriterTest,
conv->shape(), ShapeInference::InferConvolveShape(
output->shape(), reverse_kernel->shape(),
/*feature_group_count=*/1, /*batch_group_count=*/1,
conv_window, tf_default_dnums_for_backward_input_)
conv_window, tf_default_dnums_for_backward_input_,
/*preferred_element_type=*/absl::nullopt)
.ValueOrDie()));
auto module = CreateNewVerifiedModule();

View File

@ -94,13 +94,15 @@ StatusOr<HloInstruction*> MakeConvolveHlo(
HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
int64 batch_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config) {
const PrecisionConfig& precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
HloComputation* computation = lhs->parent();
CHECK_EQ(computation, rhs->parent());
TF_ASSIGN_OR_RETURN(Shape convolve_shape,
ShapeInference::InferConvolveShape(
lhs->shape(), rhs->shape(), feature_group_count,
batch_group_count, window, dimension_numbers));
TF_ASSIGN_OR_RETURN(
Shape convolve_shape,
ShapeInference::InferConvolveShape(
lhs->shape(), rhs->shape(), feature_group_count, batch_group_count,
window, dimension_numbers, preferred_element_type));
return computation->AddInstruction(HloInstruction::CreateConvolve(
convolve_shape, lhs, rhs, feature_group_count, batch_group_count, window,
dimension_numbers, precision_config));
@ -281,14 +283,17 @@ HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
HloInstruction::CreateIota(shape, iota_dimension));
}
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dim_numbers,
const PrecisionConfig& precision_config) {
StatusOr<HloInstruction*> MakeDotHlo(
HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dim_numbers,
const PrecisionConfig& precision_config,
absl::optional<PrimitiveType> preferred_element_type) {
HloComputation* computation = lhs->parent();
CHECK_EQ(computation, rhs->parent());
TF_ASSIGN_OR_RETURN(
Shape dot_shape,
ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers));
ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers,
preferred_element_type));
return computation->AddInstruction(HloInstruction::CreateDot(
dot_shape, lhs, rhs, dim_numbers, precision_config));
}

View File

@ -59,11 +59,14 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
// Creates a convolution HLO instruction and adds it to the computation
// containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
// If the result shape has integral element type, an optional
// preferred_element_type can be specified to override the element type.
StatusOr<HloInstruction*> MakeConvolveHlo(
HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
int64 batch_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config);
const PrecisionConfig& precision_config,
absl::optional<PrimitiveType> preferred_element_type);
// Creates a transpose HLO instruction and adds it to the computation containing
// `operand`.
@ -128,10 +131,14 @@ HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
int64 iota_dimension);
// Creates a Dot HLO instruction and adds it to the computation containing `lhs`
// and `rhs` (both must be in the same computation).
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dim_numbers,
const PrecisionConfig& precision_config);
// and `rhs` (both must be in the same computation). If the result shape has
// integral element type, an optional preferred_element_type can be specified to
// override the element type.
StatusOr<HloInstruction*> MakeDotHlo(
HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dim_numbers,
const PrecisionConfig& precision_config,
absl::optional<PrimitiveType> preferred_element_type);
// Creates a Map HLO instruction and adds it to the computation containing the
// operands. All operands must be in the same computation.

View File

@ -388,9 +388,10 @@ StatusOr<Literal> HloEvaluator::EvaluateDotOp(
std::unique_ptr<HloInstruction> rhs_instr =
HloInstruction::CreateConstant(rhs.Clone());
TF_ASSIGN_OR_RETURN(
Shape dot_shape,
ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers));
TF_ASSIGN_OR_RETURN(Shape dot_shape,
ShapeInference::InferDotOpShape(
lhs.shape(), rhs.shape(), dim_numbers,
/*preferred_element_type=*/absl::nullopt));
std::unique_ptr<HloInstruction> cloned_instruction =
HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),

View File

@ -1290,7 +1290,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
ShapeInference::InferConvolveShape(
lhs_shape, rhs_shape, conv->feature_group_count(),
conv->batch_group_count(), window, dnums));
conv->batch_group_count(), window, dnums,
/*preferred_element_type=*/absl::nullopt));
CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
<< "return shape set to: " << ShapeUtil::HumanString(result_shape)
<< " but is inferred to be: "

View File

@ -136,13 +136,12 @@ Status ShapeVerifier::HandleCopy(HloInstruction* copy) {
}
Status ShapeVerifier::HandleDot(HloInstruction* dot) {
TF_ASSIGN_OR_RETURN(Shape expected,
ShapeInference::InferDotOpShape(
dot->operand(0)->shape(), dot->operand(1)->shape(),
dot->dot_dimension_numbers()));
if (ShapeUtil::CanUpcastIntegral(expected, dot->shape())) {
expected.set_element_type(dot->shape().element_type());
}
TF_ASSIGN_OR_RETURN(
const Shape expected,
ShapeInference::InferDotOpShape(
dot->operand(0)->shape(), dot->operand(1)->shape(),
dot->dot_dimension_numbers(),
/*preferred_element_type=*/dot->shape().element_type()));
return CheckShape(dot, expected);
}
@ -152,10 +151,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
ShapeInference::InferConvolveShape(
convolution->operand(0)->shape(), convolution->operand(1)->shape(),
convolution->feature_group_count(), convolution->batch_group_count(),
convolution->window(), convolution->convolution_dimension_numbers()));
if (ShapeUtil::CanUpcastIntegral(expected, convolution->shape())) {
expected.set_element_type(convolution->shape().element_type());
}
convolution->window(), convolution->convolution_dimension_numbers(),
/*preferred_element_type=*/convolution->shape().element_type()));
return CheckShape(convolution, expected);
}

View File

@ -26,12 +26,14 @@ StatusOr<absl::optional<Shape>> MaybeInferShape(
case HloOpcode::kDot:
return ShapeInference::InferDotOpShape(
instruction->operand(0)->shape(), instruction->operand(1)->shape(),
instruction->dot_dimension_numbers());
instruction->dot_dimension_numbers(),
/*preferred_element_type=*/absl::nullopt);
case HloOpcode::kConvolution:
return ShapeInference::InferConvolveShape(
instruction->operand(0)->shape(), instruction->operand(1)->shape(),
instruction->feature_group_count(), instruction->batch_group_count(),
instruction->window(), instruction->convolution_dimension_numbers());
instruction->window(), instruction->convolution_dimension_numbers(),
/*preferred_element_type=*/absl::nullopt);
default:
return absl::make_optional<Shape>();
}

View File

@ -214,6 +214,37 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
output_is_dynamic);
}
StatusOr<PrimitiveType> MaybeUpcast(
PrimitiveType from_type,
absl::optional<PrimitiveType> preferred_element_type) {
if (!preferred_element_type.has_value() ||
*preferred_element_type == from_type) {
return from_type;
}
if (primitive_util::IsIntegralType(from_type) !=
primitive_util::IsIntegralType(*preferred_element_type)) {
return InvalidArgument(
"`preferred_element_type` and the original type must both be integral "
"or both be floating point.");
}
if (!primitive_util::IsSignedIntegralType(from_type) !=
!primitive_util::IsSignedIntegralType(*preferred_element_type)) {
return InvalidArgument(
"`preferred_element_type` must have the same signedness as the "
"original type.");
}
if (primitive_util::BitWidth(*preferred_element_type) <
primitive_util::BitWidth(from_type)) {
if (primitive_util::IsFloatingPointType(from_type)) {
return from_type;
}
return InvalidArgument(
"`preferred_element_type` must not be narrower than the original "
"type.");
}
return *preferred_element_type;
}
} // namespace
/* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
@ -622,7 +653,8 @@ Status ValidateDotDimensionNumbers(
/* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
const Shape& lhs, const Shape& rhs,
const DotDimensionNumbers& dimension_numbers) {
const DotDimensionNumbers& dimension_numbers,
absl::optional<PrimitiveType> preferred_element_type) {
TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot"));
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
@ -700,8 +732,11 @@ Status ValidateDotDimensionNumbers(
is_dynamic.push_back(rhs.is_dynamic_dimension(i));
}
}
Shape result = ShapeUtil::MakeShape(
ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions, is_dynamic);
TF_ASSIGN_OR_RETURN(
PrimitiveType type,
MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
preferred_element_type));
Shape result = ShapeUtil::MakeShape(type, dimensions, is_dynamic);
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
@ -1586,7 +1621,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
const Shape& lhs, const Shape& rhs, int64 feature_group_count,
int64 batch_group_count, const Window& window,
const ConvolutionDimensionNumbers& dnums) {
const ConvolutionDimensionNumbers& dnums,
absl::optional<PrimitiveType> preferred_element_type) {
TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
@ -1833,8 +1869,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
}
}
return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
dimensions, is_dynamic);
TF_ASSIGN_OR_RETURN(
PrimitiveType type,
MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
preferred_element_type));
return ShapeUtil::MakeShape(type, dimensions, is_dynamic);
}
/* static */ StatusOr<Shape> ShapeInference::InferFftShape(

View File

@ -105,12 +105,14 @@ class ShapeInference {
const Shape& output_grad_shape,
int64 feature_index);
// Infers the shape produced by applying the given convolutional
// filter (rhs) to lhs in the way specified by the fields on window.
// Infers the shape produced by applying the given convolutional filter (rhs)
// to lhs in the way specified by the fields on window. An optional
// preferred_element_type can be specified to upcast the element type.
static StatusOr<Shape> InferConvolveShape(
const Shape& lhs, const Shape& rhs, int64 feature_group_count,
int64 batch_group_count, const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers);
const ConvolutionDimensionNumbers& dimension_numbers,
absl::optional<PrimitiveType> preferred_element_type);
// Infers the shape produced by the given FFT type on the given operand.
static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type,
@ -298,10 +300,12 @@ class ShapeInference {
absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply);
// Helper that infers the shape produced by performing a dot operation with
// the given LHS and RHS shapes.
// the given LHS and RHS shapes. An optional preferred_element_type can be
// specified to upcast the element type.
static StatusOr<Shape> InferDotOpShape(
const Shape& lhs, const Shape& rhs,
const DotDimensionNumbers& dimension_numbers);
const DotDimensionNumbers& dimension_numbers,
absl::optional<PrimitiveType> preferred_element_type);
// Helper that infers the shape of the tensor produced by a gather operation
// with the given input shape, gather indices shape and gather dimension

View File

@ -437,7 +437,7 @@ TEST_F(ShapeInferenceTest, Convolve) {
dim1->set_base_dilation(1);
auto inferred_status = ShapeInference::InferConvolveShape(
lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
window, dnums);
window, dnums, /*preferred_element_type=*/absl::nullopt);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
@ -483,7 +483,7 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
dim1->set_base_dilation(1);
auto inferred_status = ShapeInference::InferConvolveShape(
lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
window, dnums);
window, dnums, /*preferred_element_type=*/absl::nullopt);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
@ -529,7 +529,7 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
dim1->set_base_dilation(2);
auto inferred_status = ShapeInference::InferConvolveShape(
lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
window, dnums);
window, dnums, /*preferred_element_type=*/absl::nullopt);
ASSERT_IS_OK(inferred_status.status());
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
@ -568,7 +568,7 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
dim1->set_padding_high(1);
auto inferred_status = ShapeInference::InferConvolveShape(
lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
window, dnums);
window, dnums, /*preferred_element_type=*/absl::nullopt);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
HasSubstr("each dimension exactly once"));
@ -605,12 +605,150 @@ TEST_F(ShapeInferenceTest, ConvolveBatchGroupCountUnequalOutputFeature) {
dim1->set_window_dilation(2);
auto inferred_status = ShapeInference::InferConvolveShape(
lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6,
window, dnums);
window, dnums, /*preferred_element_type=*/absl::nullopt);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
HasSubstr("to be a multiple of batch group count"));
}
struct ConvolveArgs {
Shape lhs_shape;
Shape rhs_shape;
ConvolutionDimensionNumbers dnums;
Window window;
};
ConvolveArgs MakeConvolveArgs(PrimitiveType lhs_type, PrimitiveType rhs_type) {
ConvolveArgs args;
ConvolutionDimensionNumbers& dnums = args.dnums;
// Dimension order: batch, feature, x0, x1
args.lhs_shape = ShapeUtil::MakeShape(lhs_type, {10, 11, 3, 4});
dnums.set_input_batch_dimension(0);
dnums.set_output_batch_dimension(0);
dnums.set_input_feature_dimension(1);
dnums.set_output_feature_dimension(1);
dnums.add_input_spatial_dimensions(2);
dnums.add_output_spatial_dimensions(2);
dnums.add_input_spatial_dimensions(3);
dnums.add_output_spatial_dimensions(3);
// Dimension order: x1, batch, feature, x0
args.rhs_shape = ShapeUtil::MakeShape(rhs_type, {2, 12, 11, 3});
dnums.set_kernel_input_feature_dimension(2);
dnums.set_kernel_output_feature_dimension(1);
dnums.add_kernel_spatial_dimensions(3);
dnums.add_kernel_spatial_dimensions(0);
auto dim0 = args.window.add_dimensions();
auto dim1 = args.window.add_dimensions();
dim0->set_size(3);
dim0->set_stride(2);
dim0->set_padding_low(1);
dim0->set_padding_high(1);
dim0->set_window_dilation(1);
dim0->set_base_dilation(1);
dim1->set_size(2);
dim1->set_stride(1);
dim1->set_padding_low(0);
dim1->set_padding_high(0);
dim1->set_window_dilation(1);
dim1->set_base_dilation(1);
return args;
}
TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementType) {
ConvolveArgs args = MakeConvolveArgs(S8, S16);
TF_ASSERT_OK_AND_ASSIGN(
Shape inferred_shape,
ShapeInference::InferConvolveShape(
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
/*batch_group_count=*/1, args.window, args.dnums,
/*preferred_element_type=*/S16))
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S16, {10, 12, 2, 3}),
inferred_shape));
}
TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementTypeSameAsInferredType) {
ConvolveArgs args = MakeConvolveArgs(S8, S16);
TF_ASSERT_OK_AND_ASSIGN(
Shape inferred_shape,
ShapeInference::InferConvolveShape(
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
/*batch_group_count=*/1, args.window, args.dnums,
/*preferred_element_type=*/S32))
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
inferred_shape));
}
TEST_F(ShapeInferenceTest,
FloatingPointConvolveWithNarrowerPreferredElementType) {
ConvolveArgs args = MakeConvolveArgs(F32, F32);
TF_ASSERT_OK_AND_ASSIGN(
Shape inferred_shape,
ShapeInference::InferConvolveShape(
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
/*batch_group_count=*/1, args.window, args.dnums,
/*preferred_element_type=*/BF16))
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
inferred_shape));
}
TEST_F(ShapeInferenceTest,
FloatingPointConvolveWithInvalidPreferredElementType) {
ConvolveArgs args = MakeConvolveArgs(BF16, BF16);
auto inferred_status =
ShapeInference::InferConvolveShape(
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
/*batch_group_count=*/1, args.window, args.dnums,
/*preferred_element_type=*/S32)
.status();
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.error_message(),
HasSubstr("must both be integral or both be floating point"));
}
TEST_F(ShapeInferenceTest,
IntegralConvolveWithFloatingPointPreferredElementType) {
ConvolveArgs args = MakeConvolveArgs(S8, S16);
auto inferred_status =
ShapeInference::InferConvolveShape(
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
/*batch_group_count=*/1, args.window, args.dnums,
/*preferred_element_type=*/F32)
.status();
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.error_message(),
HasSubstr("must both be integral or both be floating point"));
}
TEST_F(ShapeInferenceTest,
ConvolveWithPreferredElementTypeWithDifferentSignedness) {
ConvolveArgs args = MakeConvolveArgs(S8, S16);
auto inferred_status =
ShapeInference::InferConvolveShape(
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
/*batch_group_count=*/1, args.window, args.dnums,
/*preferred_element_type=*/U32)
.status();
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.error_message(),
HasSubstr("must have the same signedness as the original type"));
}
TEST_F(ShapeInferenceTest, ConvolveWithNarrowerPreferredElementType) {
ConvolveArgs args = MakeConvolveArgs(S8, S16);
auto inferred_status =
ShapeInference::InferConvolveShape(
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
/*batch_group_count=*/1, args.window, args.dnums,
/*preferred_element_type=*/S8)
.status();
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.error_message(),
HasSubstr("must not be narrower than the original type"));
}
namespace fft {
static const char* unsupported_rank = "only supports ranks 1-3";
@ -1282,8 +1420,8 @@ TEST_F(ShapeInferenceTest, BroadcastScalar) {
// scalar <dot> vector: ok
TEST_F(ShapeInferenceTest, ScalarDotVector) {
DotDimensionNumbers dot_dnums;
auto inferred_status =
ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums);
auto inferred_status = ShapeInference::InferDotOpShape(
f32_, vector_32_, dot_dnums, /*preferred_element_type=*/absl::nullopt);
EXPECT_TRUE(inferred_status.ok());
EXPECT_EQ(inferred_status.ValueOrDie(), vector_32_);
}
@ -1294,7 +1432,8 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
auto inferred_status = ShapeInference::InferDotOpShape(
ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums);
ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums,
/*preferred_element_type=*/absl::nullopt);
EXPECT_TRUE(inferred_status.ok());
EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
ShapeUtil::MakeShape(F32, {32, 32, 64})));
@ -1306,11 +1445,13 @@ TEST_F(ShapeInferenceTest, VectorDotVector) {
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(0);
auto inferred_status =
ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums);
ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums,
/*preferred_element_type=*/absl::nullopt);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
auto inferred_status_mismatch =
ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums);
ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums,
/*preferred_element_type=*/absl::nullopt);
ASSERT_FALSE(inferred_status_mismatch.ok());
}
@ -1320,11 +1461,13 @@ TEST_F(ShapeInferenceTest, MatrixDotVector) {
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
auto inferred_status =
ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums);
ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums,
/*preferred_element_type=*/absl::nullopt);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_));
auto inferred_status_mismatch =
ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums);
ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums,
/*preferred_element_type=*/absl::nullopt);
ASSERT_FALSE(inferred_status_mismatch.ok());
}
@ -1334,11 +1477,13 @@ TEST_F(ShapeInferenceTest, VectorDotMatrix) {
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(0);
auto inferred_status =
ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums);
ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums,
/*preferred_element_type=*/absl::nullopt);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_));
auto inferred_status_mismatch =
ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums);
ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums,
/*preferred_element_type=*/absl::nullopt);
ASSERT_FALSE(inferred_status_mismatch.ok());
}
@ -1348,7 +1493,8 @@ TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
auto inferred_status_match =
ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums);
ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums,
/*preferred_element_type=*/absl::nullopt);
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(
ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_))
@ -1356,7 +1502,8 @@ TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
<< ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
<< " expected: " << ShapeUtil::HumanString(matrix_64_48_);
auto inferred_status_mismatch =
ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums);
ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums,
/*preferred_element_type=*/absl::nullopt);
ASSERT_FALSE(inferred_status_mismatch.ok());
}
@ -1376,7 +1523,8 @@ TEST_F(ShapeInferenceTest, DotGeneral) {
dot_dnums.add_rhs_batch_dimensions(1);
auto inferred_status_match =
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
/*preferred_element_type=*/absl::nullopt);
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(
ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape))
@ -1399,7 +1547,8 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) {
dot_dnums.add_rhs_batch_dimensions(0);
auto inferred_status =
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
/*preferred_element_type=*/absl::nullopt);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
HasSubstr("Must specify the same number of contracting "
@ -1421,7 +1570,8 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) {
dot_dnums.add_rhs_batch_dimensions(0);
auto inferred_status =
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
/*preferred_element_type=*/absl::nullopt);
EXPECT_TRUE(inferred_status.ok());
EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape));
}
@ -1461,7 +1611,8 @@ TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) {
dot_dnums.add_rhs_batch_dimensions(0);
auto inferred_status =
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
/*preferred_element_type=*/absl::nullopt);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
HasSubstr("Batch dimension sizes must match"));
@ -1480,7 +1631,8 @@ TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimNumbersPasses) {
dot_dnums.add_rhs_batch_dimensions(1);
auto inferred_status =
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
/*preferred_element_type=*/absl::nullopt);
ASSERT_TRUE(inferred_status.ok());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
ShapeUtil::MakeShape(F32, {2, 11, 14})));
@ -1499,7 +1651,8 @@ TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) {
dot_dnums.add_rhs_batch_dimensions(1);
auto inferred_status =
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
/*preferred_element_type=*/absl::nullopt);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
HasSubstr("A dimension number is out of range"));
@ -1518,12 +1671,108 @@ TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) {
dot_dnums.add_rhs_batch_dimensions(1);
auto inferred_status =
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
/*preferred_element_type=*/absl::nullopt);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
HasSubstr("A dimension number is not unique"));
}
TEST_F(ShapeInferenceTest, DotWithIntegralPreferredElementType) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
ShapeInference::InferDotOpShape(
ShapeUtil::MakeShape(S8, {32, 32}),
ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
/*preferred_element_type=*/S32));
EXPECT_TRUE(
ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S32, {32, 32})));
}
TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeSameAsInferredType) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
ShapeInference::InferDotOpShape(
ShapeUtil::MakeShape(BF16, {32, 32}),
ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
/*preferred_element_type=*/F32));
EXPECT_TRUE(
ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
}
TEST_F(ShapeInferenceTest, FloatingPointDotWithNarrowerPreferredElementType) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
ShapeInference::InferDotOpShape(
ShapeUtil::MakeShape(BF16, {32, 32}),
ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
/*preferred_element_type=*/BF16));
EXPECT_TRUE(
ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
}
TEST_F(ShapeInferenceTest, FloatingPointDotWithInvalidPreferredElementType) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
auto inferred_status = ShapeInference::InferDotOpShape(
ShapeUtil::MakeShape(BF16, {32, 32}),
ShapeUtil::MakeShape(BF16, {32, 32}), dot_dnums,
/*preferred_element_type=*/S32)
.status();
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.error_message(),
HasSubstr("must both be integral or both be floating point"));
}
TEST_F(ShapeInferenceTest, IntegralDotWithFloatingPointPreferredElementType) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
auto inferred_status = ShapeInference::InferDotOpShape(
ShapeUtil::MakeShape(S8, {32, 32}),
ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
/*preferred_element_type=*/F32)
.status();
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.error_message(),
HasSubstr("must both be integral or both be floating point"));
}
TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeWithDifferentSignedness) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
auto inferred_status = ShapeInference::InferDotOpShape(
ShapeUtil::MakeShape(S8, {32, 32}),
ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
/*preferred_element_type=*/U32)
.status();
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.error_message(),
HasSubstr("must have the same signedness as the original type"));
}
TEST_F(ShapeInferenceTest, DotWithNarrowerPreferredElementType) {
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
auto inferred_status = ShapeInference::InferDotOpShape(
ShapeUtil::MakeShape(S8, {32, 32}),
ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
/*preferred_element_type=*/S8)
.status();
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.error_message(),
HasSubstr("must not be narrower than the original type"));
}
TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
// Test variations of broadcasting a vector for a binary add with a
// matrix.

View File

@ -1520,10 +1520,11 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
->set_padding_low(0);
TF_ASSIGN_OR_RETURN(
HloInstruction * new_conv,
MakeConvolveHlo(activations_new, /*rhs=*/convolution->mutable_operand(1),
convolution->feature_group_count(),
convolution->batch_group_count(), new_window,
new_dim_numbers, convolution->precision_config()));
MakeConvolveHlo(
activations_new, /*rhs=*/convolution->mutable_operand(1),
convolution->feature_group_count(), convolution->batch_group_count(),
new_window, new_dim_numbers, convolution->precision_config(),
/*preferred_element_type=*/convolution->shape().element_type()));
convolution->SetupDerivedInstruction(new_conv);
old_to_new_instrs_[convolution] = new_conv;
@ -1800,10 +1801,11 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
TF_ASSIGN_OR_RETURN(
HloInstruction * new_conv,
MakeConvolveHlo(activations_new, kernel_new,
convolution->feature_group_count(),
convolution->batch_group_count(), new_window,
new_dim_numbers, convolution->precision_config()));
MakeConvolveHlo(
activations_new, kernel_new, convolution->feature_group_count(),
convolution->batch_group_count(), new_window, new_dim_numbers,
convolution->precision_config(),
/*preferred_element_type=*/convolution->shape().element_type()));
convolution->SetupDerivedInstruction(new_conv);
std::vector<int64> output_sizes(new_conv->shape().dimensions().begin(),
@ -2043,10 +2045,11 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
->set_padding_low(0);
TF_ASSIGN_OR_RETURN(
HloInstruction * new_conv,
MakeConvolveHlo(activations, /*rhs=*/convolution->mutable_operand(1),
convolution->feature_group_count(),
convolution->batch_group_count(), new_window,
new_dim_numbers, convolution->precision_config()));
MakeConvolveHlo(
activations, /*rhs=*/convolution->mutable_operand(1),
convolution->feature_group_count(), convolution->batch_group_count(),
new_window, new_dim_numbers, convolution->precision_config(),
/*preferred_element_type=*/convolution->shape().element_type()));
convolution->SetupDerivedInstruction(new_conv);
VLOG(1) << "Space-to-batched convolution " << new_conv->ToString();

View File

@ -950,7 +950,8 @@ StatusOr<std::unique_ptr<HloInstruction>> CreateShardedConvConvolution(
Shape sharded_conv_shape,
ShapeInference::InferConvolveShape(
sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
feature_group_count, batch_group_count, window, conv_dnums));
feature_group_count, batch_group_count, window, conv_dnums,
/*preferred_element_type=*/conv.shape().element_type()));
*sharded_conv_shape.mutable_layout() = conv.shape().layout();
return HloInstruction::CreateConvolve(
sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, feature_group_count,

View File

@ -80,8 +80,9 @@ Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) {
const Window& conv_window) -> StatusOr<HloInstruction*> {
TF_ASSIGN_OR_RETURN(
auto sharded_dot_shape,
ShapeInference::InferDotOpShape(l->shape(), r->shape(),
hlo->dot_dimension_numbers()));
ShapeInference::InferDotOpShape(
l->shape(), r->shape(), hlo->dot_dimension_numbers(),
/*preferred_element_type=*/hlo->shape().element_type()));
return b->AddInstruction(HloInstruction::CreateDot(
sharded_dot_shape, l, r, hlo->dot_dimension_numbers(),
hlo->precision_config()));

View File

@ -242,7 +242,8 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
x->shape(), transpose_y->shape(), /*feature_group_count=*/1,
/*batch_group_count=*/1, window, dnums);
/*batch_group_count=*/1, window, dnums,
/*preferred_element_type=*/absl::nullopt);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
conv_shape.ValueOrDie(), x, transpose_y,
@ -298,7 +299,8 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
x->shape(), transpose_y->shape(), /*feature_group_count=*/1,
/*batch_group_count=*/1, window, dnums);
/*batch_group_count=*/1, window, dnums,
/*preferred_element_type=*/absl::nullopt);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
conv_shape.ValueOrDie(), x, transpose_y,
@ -359,7 +361,8 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
transpose_x->shape(), y->shape(), /*feature_group_count=*/1,
/*batch_group_count=*/1, window, dnums);
/*batch_group_count=*/1, window, dnums,
/*preferred_element_type=*/absl::nullopt);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
conv_shape.ValueOrDie(), transpose_x, y,
@ -426,7 +429,8 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
}
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
transpose_x->shape(), y->shape(), /*feature_group_count=*/1,
/*batch_group_count=*/1, window, dnums);
/*batch_group_count=*/1, window, dnums,
/*preferred_element_type=*/absl::nullopt);
EXPECT_IS_OK(conv_shape);
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
conv_shape.ValueOrDie(), transpose_x, y,