[XLA] Add an optional preferred_element_type argument to Dot/Conv builder methods to enable generating HLOs that has wider accumulation type than default shape inference result.
PiperOrigin-RevId: 343546767 Change-Id: Iba19e8c9a9748d0eb0e738432f714f43b4575a38
This commit is contained in:
parent
7fa0d80d06
commit
06f7a3cb81
@ -1266,7 +1266,8 @@ StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape,
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
|
||||
const PrecisionConfig* precision_config) {
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
||||
|
||||
@ -1278,15 +1279,17 @@ XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::DotGeneral(XlaOp lhs, XlaOp rhs,
|
||||
const DotDimensionNumbers& dimension_numbers,
|
||||
const PrecisionConfig* precision_config) {
|
||||
XlaOp XlaBuilder::DotGeneral(
|
||||
XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
||||
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferDotOpShape(*lhs_shape, *rhs_shape,
|
||||
dimension_numbers));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape shape,
|
||||
ShapeInference::InferDotOpShape(
|
||||
*lhs_shape, *rhs_shape, dimension_numbers, preferred_element_type));
|
||||
return DotGeneralInternal(shape, lhs, rhs, dimension_numbers,
|
||||
precision_config);
|
||||
});
|
||||
@ -1353,28 +1356,33 @@ Status XlaBuilder::VerifyConvolution(
|
||||
XlaOp XlaBuilder::Conv(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides, Padding padding,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config) {
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return ConvWithGeneralDimensions(
|
||||
lhs, rhs, window_strides, padding,
|
||||
CreateDefaultConvDimensionNumbers(window_strides.size()),
|
||||
feature_group_count, batch_group_count, precision_config);
|
||||
feature_group_count, batch_group_count, precision_config,
|
||||
preferred_element_type);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::ConvWithGeneralPadding(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config) {
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return ConvGeneral(lhs, rhs, window_strides, padding,
|
||||
CreateDefaultConvDimensionNumbers(window_strides.size()),
|
||||
feature_group_count, batch_group_count, precision_config);
|
||||
feature_group_count, batch_group_count, precision_config,
|
||||
preferred_element_type);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::ConvWithGeneralDimensions(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config) {
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
||||
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
|
||||
@ -1402,7 +1410,8 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions(
|
||||
MakePadding(base_area_dimensions, window_dimensions,
|
||||
window_strides, padding),
|
||||
dimension_numbers, feature_group_count,
|
||||
batch_group_count, precision_config);
|
||||
batch_group_count, precision_config,
|
||||
preferred_element_type);
|
||||
});
|
||||
}
|
||||
|
||||
@ -1411,10 +1420,12 @@ XlaOp XlaBuilder::ConvGeneral(
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config) {
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
|
||||
dimension_numbers, feature_group_count,
|
||||
batch_group_count, precision_config);
|
||||
batch_group_count, precision_config,
|
||||
preferred_element_type);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::ConvGeneralDilated(
|
||||
@ -1423,7 +1434,8 @@ XlaOp XlaBuilder::ConvGeneralDilated(
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config) {
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
||||
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
|
||||
@ -1442,10 +1454,11 @@ XlaOp XlaBuilder::ConvGeneralDilated(
|
||||
ShapeInference::InferWindowFromDimensions(
|
||||
window_dimensions, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
*lhs_shape, *rhs_shape, feature_group_count,
|
||||
batch_group_count, window, dimension_numbers));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
*lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
|
||||
window, dimension_numbers, preferred_element_type));
|
||||
return ConvGeneralDilatedInternal(shape, lhs, rhs, window, window_strides,
|
||||
padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count,
|
||||
@ -1459,7 +1472,8 @@ StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type) {
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs));
|
||||
TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs));
|
||||
std::vector<int64> window_dimensions(
|
||||
@ -1472,10 +1486,11 @@ StatusOr<HloInstructionProto> XlaBuilder::DynamicConvInstruction(
|
||||
TF_ASSIGN_OR_RETURN(Window window, ShapeInference::InferWindowFromDimensions(
|
||||
window_dimensions, window_strides,
|
||||
padding, lhs_dilation, rhs_dilation));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
*lhs_shape, *rhs_shape, feature_group_count,
|
||||
batch_group_count, window, dimension_numbers));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
*lhs_shape, *rhs_shape, feature_group_count, batch_group_count,
|
||||
window, dimension_numbers, preferred_element_type));
|
||||
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
@ -1499,14 +1514,15 @@ XlaOp XlaBuilder::DynamicConvInputGrad(
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type) {
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstructionProto instr,
|
||||
DynamicConvInstruction(lhs, rhs, window_strides, padding, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers,
|
||||
feature_group_count, batch_group_count,
|
||||
precision_config, padding_type));
|
||||
DynamicConvInstruction(
|
||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count, batch_group_count,
|
||||
precision_config, padding_type, preferred_element_type));
|
||||
|
||||
instr.set_custom_call_target("DynamicConvolutionInputGrad");
|
||||
|
||||
@ -1521,14 +1537,16 @@ XlaOp XlaBuilder::DynamicConvKernelGrad(
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type) {
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstructionProto instr,
|
||||
DynamicConvInstruction(activations, gradients, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers,
|
||||
feature_group_count, batch_group_count,
|
||||
precision_config, padding_type));
|
||||
precision_config, padding_type,
|
||||
preferred_element_type));
|
||||
|
||||
instr.set_custom_call_target("DynamicConvolutionKernelGrad");
|
||||
// The gradient of kernel has kernel shape and shouldn't have any dynamic
|
||||
@ -1545,14 +1563,15 @@ XlaOp XlaBuilder::DynamicConvForward(
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type) {
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstructionProto instr,
|
||||
DynamicConvInstruction(lhs, rhs, window_strides, padding, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers,
|
||||
feature_group_count, batch_group_count,
|
||||
precision_config, padding_type));
|
||||
DynamicConvInstruction(
|
||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count, batch_group_count,
|
||||
precision_config, padding_type, preferred_element_type));
|
||||
instr.set_custom_call_target("DynamicConvolutionForward");
|
||||
|
||||
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, {lhs, rhs});
|
||||
@ -4074,44 +4093,49 @@ XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) {
|
||||
}
|
||||
|
||||
XlaOp Dot(const XlaOp lhs, const XlaOp rhs,
|
||||
const PrecisionConfig* precision_config) {
|
||||
return lhs.builder()->Dot(lhs, rhs, precision_config);
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return lhs.builder()->Dot(lhs, rhs, precision_config, preferred_element_type);
|
||||
}
|
||||
|
||||
XlaOp DotGeneral(const XlaOp lhs, const XlaOp rhs,
|
||||
const DotDimensionNumbers& dimension_numbers,
|
||||
const PrecisionConfig* precision_config) {
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers,
|
||||
precision_config);
|
||||
precision_config, preferred_element_type);
|
||||
}
|
||||
|
||||
XlaOp Conv(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> window_strides, Padding padding,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config) {
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return lhs.builder()->Conv(lhs, rhs, window_strides, padding,
|
||||
feature_group_count, batch_group_count,
|
||||
precision_config);
|
||||
precision_config, preferred_element_type);
|
||||
}
|
||||
|
||||
XlaOp ConvWithGeneralPadding(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config) {
|
||||
XlaOp ConvWithGeneralPadding(
|
||||
const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return lhs.builder()->ConvWithGeneralPadding(
|
||||
lhs, rhs, window_strides, padding, feature_group_count, batch_group_count,
|
||||
precision_config);
|
||||
precision_config, preferred_element_type);
|
||||
}
|
||||
|
||||
XlaOp ConvWithGeneralDimensions(
|
||||
const XlaOp lhs, const XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config) {
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return lhs.builder()->ConvWithGeneralDimensions(
|
||||
lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
|
||||
batch_group_count, precision_config);
|
||||
batch_group_count, precision_config, preferred_element_type);
|
||||
}
|
||||
|
||||
XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
|
||||
@ -4119,10 +4143,11 @@ XlaOp ConvGeneral(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config) {
|
||||
return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding,
|
||||
dimension_numbers, feature_group_count,
|
||||
batch_group_count, precision_config);
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return lhs.builder()->ConvGeneral(
|
||||
lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count,
|
||||
batch_group_count, precision_config, preferred_element_type);
|
||||
}
|
||||
|
||||
XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
|
||||
@ -4132,26 +4157,27 @@ XlaOp ConvGeneralDilated(const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config) {
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return lhs.builder()->ConvGeneralDilated(
|
||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count, batch_group_count,
|
||||
precision_config);
|
||||
precision_config, preferred_element_type);
|
||||
}
|
||||
|
||||
XlaOp DynamicConvInputGrad(XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config,
|
||||
PaddingType padding_type) {
|
||||
XlaOp DynamicConvInputGrad(
|
||||
XlaOp input_sizes, const XlaOp lhs, const XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return lhs.builder()->DynamicConvInputGrad(
|
||||
input_sizes, lhs, rhs, window_strides, padding, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
|
||||
precision_config, padding_type);
|
||||
precision_config, padding_type, preferred_element_type);
|
||||
}
|
||||
|
||||
XlaOp DynamicConvKernelGrad(
|
||||
@ -4160,11 +4186,12 @@ XlaOp DynamicConvKernelGrad(
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type) {
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return activations.builder()->DynamicConvKernelGrad(
|
||||
activations, gradients, window_strides, padding, lhs_dilation,
|
||||
rhs_dilation, dimension_numbers, feature_group_count, batch_group_count,
|
||||
precision_config, padding_type);
|
||||
precision_config, padding_type, preferred_element_type);
|
||||
}
|
||||
|
||||
XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
|
||||
@ -4175,11 +4202,12 @@ XlaOp DynamicConvForward(const XlaOp lhs, const XlaOp rhs,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config,
|
||||
PaddingType padding_type) {
|
||||
PaddingType padding_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
return lhs.builder()->DynamicConvForward(
|
||||
lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
|
||||
dimension_numbers, feature_group_count, batch_group_count,
|
||||
precision_config, padding_type);
|
||||
precision_config, padding_type, preferred_element_type);
|
||||
}
|
||||
|
||||
XlaOp Fft(const XlaOp operand, FftType fft_type,
|
||||
|
@ -521,56 +521,63 @@ class XlaBuilder {
|
||||
XlaOp tuple_data,
|
||||
int64 index);
|
||||
|
||||
XlaOp Dot(XlaOp lhs, XlaOp rhs,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
XlaOp Dot(
|
||||
XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config = nullptr,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
|
||||
const DotDimensionNumbers& dimension_numbers,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
XlaOp DotGeneral(
|
||||
XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
|
||||
const PrecisionConfig* precision_config = nullptr,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
Padding padding, int64 feature_group_count = 1,
|
||||
int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
XlaOp Conv(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
Padding padding, int64 feature_group_count = 1,
|
||||
int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
XlaOp ConvWithGeneralPadding(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
const PrecisionConfig* precision_config = nullptr,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
XlaOp ConvWithGeneralDimensions(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
const PrecisionConfig* precision_config = nullptr,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
XlaOp ConvGeneral(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count = 1,
|
||||
int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
XlaOp ConvGeneralDilated(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
XlaOp DynamicConvForward(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config,
|
||||
PaddingType padding_type);
|
||||
XlaOp DynamicConvForward(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
XlaOp DynamicConvInputGrad(
|
||||
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
|
||||
@ -580,7 +587,8 @@ class XlaBuilder {
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
XlaOp DynamicConvKernelGrad(
|
||||
XlaOp activations, XlaOp gradients,
|
||||
@ -590,7 +598,8 @@ class XlaBuilder {
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
StatusOr<HloInstructionProto> DynamicConvInstruction(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
@ -599,7 +608,8 @@ class XlaBuilder {
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
virtual StatusOr<XlaOp> ConvGeneralDilatedInternal(
|
||||
const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
|
||||
@ -1098,10 +1108,12 @@ class XlaBuilder {
|
||||
ComparisonDirection direction,
|
||||
Comparison::Type compare_type);
|
||||
friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
|
||||
const PrecisionConfig* precision_config);
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type);
|
||||
friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
|
||||
const DotDimensionNumbers& dimension_number,
|
||||
const PrecisionConfig* precision_config);
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type);
|
||||
virtual StatusOr<XlaOp> DotGeneralInternal(
|
||||
const Shape& shape, XlaOp lhs, XlaOp rhs,
|
||||
const DotDimensionNumbers& dimension_number,
|
||||
@ -1109,23 +1121,27 @@ class XlaBuilder {
|
||||
friend XlaOp Conv(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides, Padding padding,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config);
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type);
|
||||
friend XlaOp ConvWithGeneralPadding(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config);
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type);
|
||||
friend XlaOp ConvWithGeneralDimensions(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_confige);
|
||||
friend XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config);
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type);
|
||||
friend XlaOp ConvGeneral(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type);
|
||||
friend XlaOp DynamicConvForward(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
@ -1133,7 +1149,8 @@ class XlaBuilder {
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type);
|
||||
friend XlaOp DynamicConvKernelGrad(
|
||||
XlaOp activations, XlaOp gradients,
|
||||
absl::Span<const int64> window_strides,
|
||||
@ -1142,7 +1159,8 @@ class XlaBuilder {
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type);
|
||||
friend XlaOp DynamicConvInputGrad(
|
||||
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
@ -1151,7 +1169,8 @@ class XlaBuilder {
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type);
|
||||
|
||||
friend XlaOp ConvKernelGrad(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
@ -1160,7 +1179,8 @@ class XlaBuilder {
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config);
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type);
|
||||
|
||||
friend XlaOp ConvGeneralDilated(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
@ -1169,7 +1189,8 @@ class XlaBuilder {
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config);
|
||||
const PrecisionConfig* precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type);
|
||||
friend XlaOp Fft(XlaOp operand, FftType fft_type,
|
||||
absl::Span<const int64> fft_length);
|
||||
friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
|
||||
@ -1813,28 +1834,31 @@ XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
|
||||
|
||||
// Enqueues a dot instruction onto the computation.
|
||||
XlaOp Dot(XlaOp lhs, XlaOp rhs,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
const PrecisionConfig* precision_config = nullptr,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
// Enqueues a general dot instruction onto the computation.
|
||||
XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
|
||||
const DotDimensionNumbers& dimension_numbers,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
XlaOp DotGeneral(
|
||||
XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
|
||||
const PrecisionConfig* precision_config = nullptr,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
// Enqueues a convolution instruction onto the computation, which uses the
|
||||
// default convolution dimension numbers.
|
||||
XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
Padding padding, int64 feature_group_count = 1,
|
||||
int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
XlaOp Conv(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
Padding padding, int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
// Enqueues a convolution instruction onto the computation, with the caller
|
||||
// provided padding configuration in the format returned by MakePadding().
|
||||
XlaOp ConvWithGeneralPadding(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
int64 feature_group_count = 1,
|
||||
int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
XlaOp ConvWithGeneralPadding(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
// Enqueues a convolution instruction onto the computation, with the caller
|
||||
// provided dimension numbers configuration.
|
||||
@ -1842,47 +1866,48 @@ XlaOp ConvWithGeneralDimensions(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
const PrecisionConfig* precision_config = nullptr,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
// Enqueues a convolution instruction onto the computation, with the caller
|
||||
// provided padding configuration as well as the dimension numbers.
|
||||
XlaOp ConvGeneral(XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
XlaOp ConvGeneral(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
// Enqueues a convolution instruction onto the computation, with the caller
|
||||
// provided padding configuration, dilation factors and dimension numbers.
|
||||
XlaOp ConvGeneralDilated(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count = 1,
|
||||
int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
XlaOp ConvGeneralDilated(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count = 1, int64 batch_group_count = 1,
|
||||
const PrecisionConfig* precision_config = nullptr,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
XlaOp DynamicConvForward(XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config,
|
||||
PaddingType padding_type);
|
||||
XlaOp DynamicConvForward(
|
||||
XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
XlaOp DynamicConvInputGrad(XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation,
|
||||
absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config,
|
||||
PaddingType padding_type);
|
||||
XlaOp DynamicConvInputGrad(
|
||||
XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const std::pair<int64, int64>> padding,
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
XlaOp DynamicConvKernelGrad(
|
||||
XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
|
||||
@ -1890,7 +1915,8 @@ XlaOp DynamicConvKernelGrad(
|
||||
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type);
|
||||
const PrecisionConfig* precision_config, PaddingType padding_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
|
||||
|
||||
// Enqueues an FFT instruction onto the computation, of the given type and
|
||||
// with the given FFT length.
|
||||
|
@ -1066,6 +1066,56 @@ TEST_F(XlaBuilderTest, DynamicTranspose) {
|
||||
<< result_shape;
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, DotWithPreferredElementType) {
|
||||
XlaBuilder b(TestName());
|
||||
Shape p0_shape = ShapeUtil::MakeShape(U8, {2, 3});
|
||||
Shape p1_shape = ShapeUtil::MakeShape(U16, {3, 2});
|
||||
auto p0 = Parameter(&b, 0, p0_shape, "p0");
|
||||
auto p1 = Parameter(&b, 1, p1_shape, "p1");
|
||||
|
||||
DotDimensionNumbers dnums;
|
||||
dnums.add_lhs_contracting_dimensions(1);
|
||||
dnums.add_rhs_contracting_dimensions(0);
|
||||
DotGeneral(p0, p1, dnums, /*precision_config=*/nullptr,
|
||||
/*preferred_element_type=*/U32);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||
const Shape& result_shape =
|
||||
module->entry_computation()->root_instruction()->shape();
|
||||
ASSERT_TRUE(
|
||||
ShapeUtil::Equal(ShapeUtil::MakeShape(U32, {2, 2}), result_shape));
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, ConvolutionWithPreferredElementType) {
|
||||
XlaBuilder b(TestName());
|
||||
Shape p0_shape = ShapeUtil::MakeShape(S16, {1, 2, 2, 128});
|
||||
Shape p1_shape = ShapeUtil::MakeShape(S8, {2, 2, 128, 8});
|
||||
auto p0 = Parameter(&b, 0, p0_shape, "p0");
|
||||
auto p1 = Parameter(&b, 1, p1_shape, "p1");
|
||||
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
dnums.set_input_batch_dimension(0);
|
||||
dnums.set_output_batch_dimension(0);
|
||||
dnums.add_input_spatial_dimensions(1);
|
||||
dnums.add_output_spatial_dimensions(1);
|
||||
dnums.add_input_spatial_dimensions(2);
|
||||
dnums.add_output_spatial_dimensions(2);
|
||||
dnums.set_input_feature_dimension(3);
|
||||
dnums.set_output_feature_dimension(3);
|
||||
dnums.add_kernel_spatial_dimensions(0);
|
||||
dnums.add_kernel_spatial_dimensions(1);
|
||||
dnums.set_kernel_input_feature_dimension(2);
|
||||
dnums.set_kernel_output_feature_dimension(3);
|
||||
ConvWithGeneralDimensions(p0, p1, {1, 1}, Padding::kValid, dnums,
|
||||
/*feature_group_count=*/1, /*batch_group_count=*/1,
|
||||
/*precision_config=*/nullptr,
|
||||
/*preferred_element_type=*/S32);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||
const Shape& result_shape =
|
||||
module->entry_computation()->root_instruction()->shape();
|
||||
ASSERT_TRUE(
|
||||
ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {1, 1, 1, 8}), result_shape));
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) {
|
||||
XlaBuilder b(TestName());
|
||||
AfterAll(&b, {CreateToken(&b), ConstantR0<float>(&b, 1.0)});
|
||||
|
@ -108,7 +108,8 @@ void BuildOpsSubmodule(py::module* m) {
|
||||
py::arg("lhs_dilation"), py::arg("rhs_dilation"),
|
||||
py::arg("dimension_numbers"), py::arg("feature_group_count") = 1,
|
||||
py::arg("batch_group_count") = 1,
|
||||
py::arg("precision_config") = nullptr);
|
||||
py::arg("precision_config") = nullptr,
|
||||
py::arg("preferred_element_type") = absl::nullopt);
|
||||
ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"),
|
||||
py::arg("new_element_type"));
|
||||
ops.def(
|
||||
@ -136,9 +137,11 @@ void BuildOpsSubmodule(py::module* m) {
|
||||
py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"),
|
||||
py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false);
|
||||
ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("precision_config") = nullptr);
|
||||
py::arg("precision_config") = nullptr,
|
||||
py::arg("preferred_element_type") = absl::nullopt);
|
||||
ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"),
|
||||
py::arg("dimension_numbers"), py::arg("precision_config") = nullptr);
|
||||
py::arg("dimension_numbers"), py::arg("precision_config") = nullptr,
|
||||
py::arg("preferred_element_type") = absl::nullopt);
|
||||
ops.def("DynamicSlice",
|
||||
static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>,
|
||||
absl::Span<const int64>)>(&DynamicSlice),
|
||||
|
@ -465,7 +465,8 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
|
||||
const Shape& shape =
|
||||
ShapeInference::InferConvolveShape(
|
||||
lhs_literal.shape(), rhs_literal.shape(),
|
||||
/*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums)
|
||||
/*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
|
||||
/*preferred_element_type=*/absl::nullopt)
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
HloInstruction* lhs_instruction =
|
||||
|
@ -1798,10 +1798,10 @@ StatusOr<bool> AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot(
|
||||
ShapeUtil::DropDegenerateDimensions(rhs_shape),
|
||||
dot->mutable_operand(1)))
|
||||
: dot->mutable_operand(1);
|
||||
TF_ASSIGN_OR_RETURN(auto new_dot, MakeDotHlo(new_lhs, new_rhs, new_dnums,
|
||||
dot->precision_config()));
|
||||
// TODO(b/165824019): Add an optional preferred element type to MakeDotHlo.
|
||||
new_dot->mutable_shape()->set_element_type(dot->shape().element_type());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto new_dot,
|
||||
MakeDotHlo(new_lhs, new_rhs, new_dnums, dot->precision_config(),
|
||||
/*preferred_element_type=*/dot->shape().element_type()));
|
||||
if (ShapeUtil::Compatible(dot->shape(), new_dot->shape())) {
|
||||
TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot));
|
||||
} else {
|
||||
@ -4678,10 +4678,10 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
|
||||
}
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto new_dot, MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config()));
|
||||
auto new_dot,
|
||||
MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config(),
|
||||
/*preferred_element_type=*/dot->shape().element_type()));
|
||||
dot->SetupDerivedInstruction(new_dot);
|
||||
// TODO(b/165824019): Add an optional preferred element type to MakeDotHlo.
|
||||
new_dot->mutable_shape()->set_element_type(dot->shape().element_type());
|
||||
if (reduce_dims.empty()) {
|
||||
return ReplaceInstruction(hlo, new_dot);
|
||||
}
|
||||
@ -5312,10 +5312,13 @@ StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
|
||||
if (!reverse_dimensions.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * new_convolution,
|
||||
MakeConvolveHlo(kernel, input, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, swapped_window,
|
||||
swapped_dnums, precision_config));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * new_convolution,
|
||||
MakeConvolveHlo(
|
||||
kernel, input, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, swapped_window, swapped_dnums,
|
||||
precision_config,
|
||||
/*preferred_element_type=*/convolution->shape().element_type()));
|
||||
|
||||
convolution->SetupDerivedInstruction(new_convolution);
|
||||
TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution));
|
||||
|
@ -3785,9 +3785,11 @@ TEST_P(ConvInputPaddingTest, DoTest) {
|
||||
ParseWindow(absl::StrCat("size=3x3 ", testcase.orig_conv_window))
|
||||
.ValueOrDie();
|
||||
builder.AddInstruction(HloInstruction::CreateConvolve(
|
||||
ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(),
|
||||
/*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, window, dnums)
|
||||
ShapeInference::InferConvolveShape(
|
||||
lhs_pad->shape(), filter->shape(),
|
||||
/*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, window, dnums,
|
||||
/*preferred_element_type=*/absl::nullopt)
|
||||
.ValueOrDie(),
|
||||
lhs_pad, filter, /*feature_group_count=*/1, /*batch_group_count=*/1,
|
||||
window, dnums, DefaultPrecisionConfig(2)));
|
||||
@ -3902,9 +3904,11 @@ TEST_P(ConvFilterPaddingTest, DoIt) {
|
||||
precision_config.add_operand_precision(PrecisionConfig::HIGHEST);
|
||||
|
||||
builder.AddInstruction(HloInstruction::CreateConvolve(
|
||||
ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
|
||||
/*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, window, dnums)
|
||||
ShapeInference::InferConvolveShape(
|
||||
input->shape(), rhs_pad->shape(),
|
||||
/*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, window, dnums,
|
||||
/*preferred_element_type=*/absl::nullopt)
|
||||
.ValueOrDie(),
|
||||
input, rhs_pad, /*feature_group_count=*/1, /*batch_group_count=*/1,
|
||||
window, dnums, precision_config));
|
||||
@ -4050,7 +4054,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
|
||||
b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
|
||||
Shape out_shape = ShapeInference::InferConvolveShape(
|
||||
in_shape, f_shape, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, window, dnums)
|
||||
/*batch_group_count=*/1, window, dnums,
|
||||
/*preferred_element_type=*/absl::nullopt)
|
||||
.ValueOrDie();
|
||||
if (options.output_minor_to_major_layout) {
|
||||
out_shape = ShapeUtil::MakeShapeWithLayout(F32, out_shape.dimensions(),
|
||||
|
@ -87,9 +87,11 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot(
|
||||
0,
|
||||
new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
|
||||
MakeDotHlo(new_lhs, new_rhs, new_dim_numbers,
|
||||
batch_dot->precision_config()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * new_dot,
|
||||
MakeDotHlo(new_lhs, new_rhs, new_dim_numbers,
|
||||
batch_dot->precision_config(),
|
||||
/*preferred_element_type=*/batch_dot->shape().element_type()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped,
|
||||
MakeReshapeHlo(batch_dot->shape(), new_dot));
|
||||
|
@ -299,9 +299,11 @@ Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
|
||||
window_dim->set_window_reversal(false);
|
||||
window_dim->set_window_dilation(1);
|
||||
HloInstruction* new_convolution =
|
||||
MakeConvolveHlo(activation, filter, convolution->feature_group_count(),
|
||||
/*batch_group_count=*/1, window, dim_numbers,
|
||||
convolution->precision_config())
|
||||
MakeConvolveHlo(
|
||||
activation, filter, convolution->feature_group_count(),
|
||||
/*batch_group_count=*/1, window, dim_numbers,
|
||||
convolution->precision_config(),
|
||||
/*preferred_element_type=*/convolution->shape().element_type())
|
||||
.ValueOrDie();
|
||||
convolution->SetupDerivedInstruction(new_convolution);
|
||||
TF_CHECK_OK(computation_->ReplaceInstruction(
|
||||
@ -650,9 +652,11 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
|
||||
window_dim->set_window_reversal(false);
|
||||
window_dim->set_window_dilation(1);
|
||||
HloInstruction* new_convolution =
|
||||
MakeConvolveHlo(activation, filter, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, window, dim_numbers,
|
||||
convolution->precision_config())
|
||||
MakeConvolveHlo(
|
||||
activation, filter, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, window, dim_numbers,
|
||||
convolution->precision_config(),
|
||||
/*preferred_element_type=*/convolution->shape().element_type())
|
||||
.ValueOrDie();
|
||||
convolution->SetupDerivedInstruction(new_convolution);
|
||||
changed_ = true;
|
||||
|
@ -142,7 +142,8 @@ CreateShardedConvForDotGeneralConvolution(
|
||||
ShapeInference::InferConvolveShape(
|
||||
sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
|
||||
/*feature_group_count=*/conv.feature_group_count(),
|
||||
/*batch_group_count=*/conv.batch_group_count(), window, conv_dnums));
|
||||
/*batch_group_count=*/conv.batch_group_count(), window, conv_dnums,
|
||||
/*preferred_element_type=*/conv.shape().element_type()));
|
||||
*sharded_conv_shape.mutable_layout() = conv.shape().layout();
|
||||
return HloInstruction::CreateConvolve(
|
||||
sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo,
|
||||
|
@ -110,7 +110,8 @@ TEST_F(GpuConvRewriterTest, BackwardFilterConvolve) {
|
||||
ShapeInference::InferConvolveShape(
|
||||
activations->shape(), gradients->shape(), /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, conv_window,
|
||||
tf_default_dnums_for_backward_filter_)
|
||||
tf_default_dnums_for_backward_filter_,
|
||||
/*preferred_element_type=*/absl::nullopt)
|
||||
.ConsumeValueOrDie(),
|
||||
activations, gradients, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, conv_window,
|
||||
@ -150,7 +151,8 @@ TEST_F(GpuConvRewriterTest,
|
||||
ShapeInference::InferConvolveShape(
|
||||
activations->shape(), gradients->shape(), /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, conv_window,
|
||||
tf_default_dnums_for_backward_filter_)
|
||||
tf_default_dnums_for_backward_filter_,
|
||||
/*preferred_element_type=*/absl::nullopt)
|
||||
.ConsumeValueOrDie(),
|
||||
activations, gradients, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, conv_window,
|
||||
@ -292,11 +294,12 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveEvenPadding) {
|
||||
DefaultPrecisionConfig(2)));
|
||||
// Verify the convolution's shape is consistent with ShapeInference.
|
||||
CHECK(ShapeUtil::Compatible(
|
||||
conv->shape(), ShapeInference::InferConvolveShape(
|
||||
output->shape(), reverse_kernel->shape(),
|
||||
/*feature_group_count=*/1, /*batch_group_count=*/1,
|
||||
conv_window, conv_dnums)
|
||||
.ValueOrDie()));
|
||||
conv->shape(),
|
||||
ShapeInference::InferConvolveShape(
|
||||
output->shape(), reverse_kernel->shape(),
|
||||
/*feature_group_count=*/1, /*batch_group_count=*/1, conv_window,
|
||||
conv_dnums, /*preferred_element_type=*/absl::nullopt)
|
||||
.ValueOrDie()));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
HloComputation* entry_computation =
|
||||
@ -337,10 +340,12 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolve1x1Filter) {
|
||||
conv_window.mutable_dimensions(1)->set_base_dilation(2);
|
||||
|
||||
builder.AddInstruction(HloInstruction::CreateConvolve(
|
||||
ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
|
||||
/*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, conv_window,
|
||||
tf_default_dnums_for_backward_input_)
|
||||
ShapeInference::InferConvolveShape(
|
||||
output->shape(), kernel->shape(),
|
||||
/*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, conv_window,
|
||||
tf_default_dnums_for_backward_input_,
|
||||
/*preferred_element_type=*/absl::nullopt)
|
||||
.ConsumeValueOrDie(),
|
||||
/*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, conv_window,
|
||||
@ -374,7 +379,8 @@ TEST_F(GpuConvRewriterTest,
|
||||
ShapeInference::InferConvolveShape(
|
||||
output->shape(), kernel->shape(), /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, default_conv_window_,
|
||||
tf_default_dnums_for_backward_input_)
|
||||
tf_default_dnums_for_backward_input_,
|
||||
/*preferred_element_type=*/absl::nullopt)
|
||||
.ConsumeValueOrDie(),
|
||||
/*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, default_conv_window_,
|
||||
@ -431,7 +437,8 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) {
|
||||
conv->shape(), ShapeInference::InferConvolveShape(
|
||||
output->shape(), reverse_kernel->shape(),
|
||||
/*feature_group_count=*/1, /*batch_group_count=*/1,
|
||||
conv_window, tf_default_dnums_for_backward_input_)
|
||||
conv_window, tf_default_dnums_for_backward_input_,
|
||||
/*preferred_element_type=*/absl::nullopt)
|
||||
.ValueOrDie()));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
@ -481,7 +488,8 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
|
||||
conv->shape(), ShapeInference::InferConvolveShape(
|
||||
output->shape(), reverse_kernel->shape(),
|
||||
/*feature_group_count=*/1, /*batch_group_count=*/1,
|
||||
conv_window, tf_default_dnums_for_backward_input_)
|
||||
conv_window, tf_default_dnums_for_backward_input_,
|
||||
/*preferred_element_type=*/absl::nullopt)
|
||||
.ValueOrDie()));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
@ -535,7 +543,8 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) {
|
||||
conv->shape(), ShapeInference::InferConvolveShape(
|
||||
output->shape(), reverse_kernel->shape(),
|
||||
/*feature_group_count=*/1, /*batch_group_count=*/1,
|
||||
conv_window, tf_default_dnums_for_backward_input_)
|
||||
conv_window, tf_default_dnums_for_backward_input_,
|
||||
/*preferred_element_type=*/absl::nullopt)
|
||||
.ValueOrDie()));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
@ -590,7 +599,8 @@ TEST_F(GpuConvRewriterTest,
|
||||
conv->shape(), ShapeInference::InferConvolveShape(
|
||||
output->shape(), reverse_kernel->shape(),
|
||||
/*feature_group_count=*/1, /*batch_group_count=*/1,
|
||||
conv_window, tf_default_dnums_for_backward_input_)
|
||||
conv_window, tf_default_dnums_for_backward_input_,
|
||||
/*preferred_element_type=*/absl::nullopt)
|
||||
.ValueOrDie()));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
|
@ -94,13 +94,15 @@ StatusOr<HloInstruction*> MakeConvolveHlo(
|
||||
HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
|
||||
int64 batch_group_count, const Window& window,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
const PrecisionConfig& precision_config) {
|
||||
const PrecisionConfig& precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
HloComputation* computation = lhs->parent();
|
||||
CHECK_EQ(computation, rhs->parent());
|
||||
TF_ASSIGN_OR_RETURN(Shape convolve_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
lhs->shape(), rhs->shape(), feature_group_count,
|
||||
batch_group_count, window, dimension_numbers));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape convolve_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
lhs->shape(), rhs->shape(), feature_group_count, batch_group_count,
|
||||
window, dimension_numbers, preferred_element_type));
|
||||
return computation->AddInstruction(HloInstruction::CreateConvolve(
|
||||
convolve_shape, lhs, rhs, feature_group_count, batch_group_count, window,
|
||||
dimension_numbers, precision_config));
|
||||
@ -281,14 +283,17 @@ HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
|
||||
HloInstruction::CreateIota(shape, iota_dimension));
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
|
||||
const DotDimensionNumbers& dim_numbers,
|
||||
const PrecisionConfig& precision_config) {
|
||||
StatusOr<HloInstruction*> MakeDotHlo(
|
||||
HloInstruction* lhs, HloInstruction* rhs,
|
||||
const DotDimensionNumbers& dim_numbers,
|
||||
const PrecisionConfig& precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
HloComputation* computation = lhs->parent();
|
||||
CHECK_EQ(computation, rhs->parent());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape dot_shape,
|
||||
ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers));
|
||||
ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers,
|
||||
preferred_element_type));
|
||||
return computation->AddInstruction(HloInstruction::CreateDot(
|
||||
dot_shape, lhs, rhs, dim_numbers, precision_config));
|
||||
}
|
||||
|
@ -59,11 +59,14 @@ StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
|
||||
|
||||
// Creates a convolution HLO instruction and adds it to the computation
|
||||
// containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
|
||||
// If the result shape has integral element type, an optional
|
||||
// preferred_element_type can be specified to override the element type.
|
||||
StatusOr<HloInstruction*> MakeConvolveHlo(
|
||||
HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
|
||||
int64 batch_group_count, const Window& window,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
const PrecisionConfig& precision_config);
|
||||
const PrecisionConfig& precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type);
|
||||
|
||||
// Creates a transpose HLO instruction and adds it to the computation containing
|
||||
// `operand`.
|
||||
@ -128,10 +131,14 @@ HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
|
||||
int64 iota_dimension);
|
||||
|
||||
// Creates a Dot HLO instruction and adds it to the computation containing `lhs`
|
||||
// and `rhs` (both must be in the same computation).
|
||||
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
|
||||
const DotDimensionNumbers& dim_numbers,
|
||||
const PrecisionConfig& precision_config);
|
||||
// and `rhs` (both must be in the same computation). If the result shape has
|
||||
// integral element type, an optional preferred_element_type can be specified to
|
||||
// override the element type.
|
||||
StatusOr<HloInstruction*> MakeDotHlo(
|
||||
HloInstruction* lhs, HloInstruction* rhs,
|
||||
const DotDimensionNumbers& dim_numbers,
|
||||
const PrecisionConfig& precision_config,
|
||||
absl::optional<PrimitiveType> preferred_element_type);
|
||||
|
||||
// Creates a Map HLO instruction and adds it to the computation containing the
|
||||
// operands. All operands must be in the same computation.
|
||||
|
@ -388,9 +388,10 @@ StatusOr<Literal> HloEvaluator::EvaluateDotOp(
|
||||
std::unique_ptr<HloInstruction> rhs_instr =
|
||||
HloInstruction::CreateConstant(rhs.Clone());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape dot_shape,
|
||||
ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers));
|
||||
TF_ASSIGN_OR_RETURN(Shape dot_shape,
|
||||
ShapeInference::InferDotOpShape(
|
||||
lhs.shape(), rhs.shape(), dim_numbers,
|
||||
/*preferred_element_type=*/absl::nullopt));
|
||||
|
||||
std::unique_ptr<HloInstruction> cloned_instruction =
|
||||
HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
|
||||
|
@ -1290,7 +1290,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
TF_ASSIGN_OR_RETURN(auto inferred_return_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
lhs_shape, rhs_shape, conv->feature_group_count(),
|
||||
conv->batch_group_count(), window, dnums));
|
||||
conv->batch_group_count(), window, dnums,
|
||||
/*preferred_element_type=*/absl::nullopt));
|
||||
CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
|
||||
<< "return shape set to: " << ShapeUtil::HumanString(result_shape)
|
||||
<< " but is inferred to be: "
|
||||
|
@ -136,13 +136,12 @@ Status ShapeVerifier::HandleCopy(HloInstruction* copy) {
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleDot(HloInstruction* dot) {
|
||||
TF_ASSIGN_OR_RETURN(Shape expected,
|
||||
ShapeInference::InferDotOpShape(
|
||||
dot->operand(0)->shape(), dot->operand(1)->shape(),
|
||||
dot->dot_dimension_numbers()));
|
||||
if (ShapeUtil::CanUpcastIntegral(expected, dot->shape())) {
|
||||
expected.set_element_type(dot->shape().element_type());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const Shape expected,
|
||||
ShapeInference::InferDotOpShape(
|
||||
dot->operand(0)->shape(), dot->operand(1)->shape(),
|
||||
dot->dot_dimension_numbers(),
|
||||
/*preferred_element_type=*/dot->shape().element_type()));
|
||||
return CheckShape(dot, expected);
|
||||
}
|
||||
|
||||
@ -152,10 +151,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
|
||||
ShapeInference::InferConvolveShape(
|
||||
convolution->operand(0)->shape(), convolution->operand(1)->shape(),
|
||||
convolution->feature_group_count(), convolution->batch_group_count(),
|
||||
convolution->window(), convolution->convolution_dimension_numbers()));
|
||||
if (ShapeUtil::CanUpcastIntegral(expected, convolution->shape())) {
|
||||
expected.set_element_type(convolution->shape().element_type());
|
||||
}
|
||||
convolution->window(), convolution->convolution_dimension_numbers(),
|
||||
/*preferred_element_type=*/convolution->shape().element_type()));
|
||||
return CheckShape(convolution, expected);
|
||||
}
|
||||
|
||||
|
@ -26,12 +26,14 @@ StatusOr<absl::optional<Shape>> MaybeInferShape(
|
||||
case HloOpcode::kDot:
|
||||
return ShapeInference::InferDotOpShape(
|
||||
instruction->operand(0)->shape(), instruction->operand(1)->shape(),
|
||||
instruction->dot_dimension_numbers());
|
||||
instruction->dot_dimension_numbers(),
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
case HloOpcode::kConvolution:
|
||||
return ShapeInference::InferConvolveShape(
|
||||
instruction->operand(0)->shape(), instruction->operand(1)->shape(),
|
||||
instruction->feature_group_count(), instruction->batch_group_count(),
|
||||
instruction->window(), instruction->convolution_dimension_numbers());
|
||||
instruction->window(), instruction->convolution_dimension_numbers(),
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
default:
|
||||
return absl::make_optional<Shape>();
|
||||
}
|
||||
|
@ -214,6 +214,37 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
output_is_dynamic);
|
||||
}
|
||||
|
||||
StatusOr<PrimitiveType> MaybeUpcast(
|
||||
PrimitiveType from_type,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
if (!preferred_element_type.has_value() ||
|
||||
*preferred_element_type == from_type) {
|
||||
return from_type;
|
||||
}
|
||||
if (primitive_util::IsIntegralType(from_type) !=
|
||||
primitive_util::IsIntegralType(*preferred_element_type)) {
|
||||
return InvalidArgument(
|
||||
"`preferred_element_type` and the original type must both be integral "
|
||||
"or both be floating point.");
|
||||
}
|
||||
if (!primitive_util::IsSignedIntegralType(from_type) !=
|
||||
!primitive_util::IsSignedIntegralType(*preferred_element_type)) {
|
||||
return InvalidArgument(
|
||||
"`preferred_element_type` must have the same signedness as the "
|
||||
"original type.");
|
||||
}
|
||||
if (primitive_util::BitWidth(*preferred_element_type) <
|
||||
primitive_util::BitWidth(from_type)) {
|
||||
if (primitive_util::IsFloatingPointType(from_type)) {
|
||||
return from_type;
|
||||
}
|
||||
return InvalidArgument(
|
||||
"`preferred_element_type` must not be narrower than the original "
|
||||
"type.");
|
||||
}
|
||||
return *preferred_element_type;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
|
||||
@ -622,7 +653,8 @@ Status ValidateDotDimensionNumbers(
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
|
||||
const Shape& lhs, const Shape& rhs,
|
||||
const DotDimensionNumbers& dimension_numbers) {
|
||||
const DotDimensionNumbers& dimension_numbers,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot"));
|
||||
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
|
||||
|
||||
@ -700,8 +732,11 @@ Status ValidateDotDimensionNumbers(
|
||||
is_dynamic.push_back(rhs.is_dynamic_dimension(i));
|
||||
}
|
||||
}
|
||||
Shape result = ShapeUtil::MakeShape(
|
||||
ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions, is_dynamic);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
PrimitiveType type,
|
||||
MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
|
||||
preferred_element_type));
|
||||
Shape result = ShapeUtil::MakeShape(type, dimensions, is_dynamic);
|
||||
|
||||
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
|
||||
VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
|
||||
@ -1586,7 +1621,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
|
||||
const Shape& lhs, const Shape& rhs, int64 feature_group_count,
|
||||
int64 batch_group_count, const Window& window,
|
||||
const ConvolutionDimensionNumbers& dnums) {
|
||||
const ConvolutionDimensionNumbers& dnums,
|
||||
absl::optional<PrimitiveType> preferred_element_type) {
|
||||
TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
|
||||
TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
|
||||
|
||||
@ -1833,8 +1869,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
}
|
||||
}
|
||||
}
|
||||
return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
|
||||
dimensions, is_dynamic);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
PrimitiveType type,
|
||||
MaybeUpcast(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
|
||||
preferred_element_type));
|
||||
return ShapeUtil::MakeShape(type, dimensions, is_dynamic);
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferFftShape(
|
||||
|
@ -105,12 +105,14 @@ class ShapeInference {
|
||||
const Shape& output_grad_shape,
|
||||
int64 feature_index);
|
||||
|
||||
// Infers the shape produced by applying the given convolutional
|
||||
// filter (rhs) to lhs in the way specified by the fields on window.
|
||||
// Infers the shape produced by applying the given convolutional filter (rhs)
|
||||
// to lhs in the way specified by the fields on window. An optional
|
||||
// preferred_element_type can be specified to upcast the element type.
|
||||
static StatusOr<Shape> InferConvolveShape(
|
||||
const Shape& lhs, const Shape& rhs, int64 feature_group_count,
|
||||
int64 batch_group_count, const Window& window,
|
||||
const ConvolutionDimensionNumbers& dimension_numbers);
|
||||
const ConvolutionDimensionNumbers& dimension_numbers,
|
||||
absl::optional<PrimitiveType> preferred_element_type);
|
||||
|
||||
// Infers the shape produced by the given FFT type on the given operand.
|
||||
static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type,
|
||||
@ -298,10 +300,12 @@ class ShapeInference {
|
||||
absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply);
|
||||
|
||||
// Helper that infers the shape produced by performing a dot operation with
|
||||
// the given LHS and RHS shapes.
|
||||
// the given LHS and RHS shapes. An optional preferred_element_type can be
|
||||
// specified to upcast the element type.
|
||||
static StatusOr<Shape> InferDotOpShape(
|
||||
const Shape& lhs, const Shape& rhs,
|
||||
const DotDimensionNumbers& dimension_numbers);
|
||||
const DotDimensionNumbers& dimension_numbers,
|
||||
absl::optional<PrimitiveType> preferred_element_type);
|
||||
|
||||
// Helper that infers the shape of the tensor produced by a gather operation
|
||||
// with the given input shape, gather indices shape and gather dimension
|
||||
|
@ -437,7 +437,7 @@ TEST_F(ShapeInferenceTest, Convolve) {
|
||||
dim1->set_base_dilation(1);
|
||||
auto inferred_status = ShapeInference::InferConvolveShape(
|
||||
lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
|
||||
window, dnums);
|
||||
window, dnums, /*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
Shape inferred_shape = inferred_status.ValueOrDie();
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
|
||||
@ -483,7 +483,7 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
|
||||
dim1->set_base_dilation(1);
|
||||
auto inferred_status = ShapeInference::InferConvolveShape(
|
||||
lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
|
||||
window, dnums);
|
||||
window, dnums, /*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
Shape inferred_shape = inferred_status.ValueOrDie();
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
|
||||
@ -529,7 +529,7 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
|
||||
dim1->set_base_dilation(2);
|
||||
auto inferred_status = ShapeInference::InferConvolveShape(
|
||||
lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
|
||||
window, dnums);
|
||||
window, dnums, /*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
Shape inferred_shape = inferred_status.ValueOrDie();
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
|
||||
@ -568,7 +568,7 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
|
||||
dim1->set_padding_high(1);
|
||||
auto inferred_status = ShapeInference::InferConvolveShape(
|
||||
lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
|
||||
window, dnums);
|
||||
window, dnums, /*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_FALSE(inferred_status.ok());
|
||||
ASSERT_THAT(inferred_status.status().error_message(),
|
||||
HasSubstr("each dimension exactly once"));
|
||||
@ -605,12 +605,150 @@ TEST_F(ShapeInferenceTest, ConvolveBatchGroupCountUnequalOutputFeature) {
|
||||
dim1->set_window_dilation(2);
|
||||
auto inferred_status = ShapeInference::InferConvolveShape(
|
||||
lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6,
|
||||
window, dnums);
|
||||
window, dnums, /*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_FALSE(inferred_status.ok());
|
||||
ASSERT_THAT(inferred_status.status().error_message(),
|
||||
HasSubstr("to be a multiple of batch group count"));
|
||||
}
|
||||
|
||||
struct ConvolveArgs {
|
||||
Shape lhs_shape;
|
||||
Shape rhs_shape;
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
Window window;
|
||||
};
|
||||
|
||||
ConvolveArgs MakeConvolveArgs(PrimitiveType lhs_type, PrimitiveType rhs_type) {
|
||||
ConvolveArgs args;
|
||||
ConvolutionDimensionNumbers& dnums = args.dnums;
|
||||
|
||||
// Dimension order: batch, feature, x0, x1
|
||||
args.lhs_shape = ShapeUtil::MakeShape(lhs_type, {10, 11, 3, 4});
|
||||
dnums.set_input_batch_dimension(0);
|
||||
dnums.set_output_batch_dimension(0);
|
||||
dnums.set_input_feature_dimension(1);
|
||||
dnums.set_output_feature_dimension(1);
|
||||
dnums.add_input_spatial_dimensions(2);
|
||||
dnums.add_output_spatial_dimensions(2);
|
||||
dnums.add_input_spatial_dimensions(3);
|
||||
dnums.add_output_spatial_dimensions(3);
|
||||
|
||||
// Dimension order: x1, batch, feature, x0
|
||||
args.rhs_shape = ShapeUtil::MakeShape(rhs_type, {2, 12, 11, 3});
|
||||
dnums.set_kernel_input_feature_dimension(2);
|
||||
dnums.set_kernel_output_feature_dimension(1);
|
||||
dnums.add_kernel_spatial_dimensions(3);
|
||||
dnums.add_kernel_spatial_dimensions(0);
|
||||
|
||||
auto dim0 = args.window.add_dimensions();
|
||||
auto dim1 = args.window.add_dimensions();
|
||||
dim0->set_size(3);
|
||||
dim0->set_stride(2);
|
||||
dim0->set_padding_low(1);
|
||||
dim0->set_padding_high(1);
|
||||
dim0->set_window_dilation(1);
|
||||
dim0->set_base_dilation(1);
|
||||
dim1->set_size(2);
|
||||
dim1->set_stride(1);
|
||||
dim1->set_padding_low(0);
|
||||
dim1->set_padding_high(0);
|
||||
dim1->set_window_dilation(1);
|
||||
dim1->set_base_dilation(1);
|
||||
return args;
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementType) {
|
||||
ConvolveArgs args = MakeConvolveArgs(S8, S16);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Shape inferred_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, args.window, args.dnums,
|
||||
/*preferred_element_type=*/S16))
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S16, {10, 12, 2, 3}),
|
||||
inferred_shape));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementTypeSameAsInferredType) {
|
||||
ConvolveArgs args = MakeConvolveArgs(S8, S16);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Shape inferred_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, args.window, args.dnums,
|
||||
/*preferred_element_type=*/S32))
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
|
||||
inferred_shape));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest,
|
||||
FloatingPointConvolveWithNarrowerPreferredElementType) {
|
||||
ConvolveArgs args = MakeConvolveArgs(F32, F32);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Shape inferred_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, args.window, args.dnums,
|
||||
/*preferred_element_type=*/BF16))
|
||||
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
|
||||
inferred_shape));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest,
|
||||
FloatingPointConvolveWithInvalidPreferredElementType) {
|
||||
ConvolveArgs args = MakeConvolveArgs(BF16, BF16);
|
||||
auto inferred_status =
|
||||
ShapeInference::InferConvolveShape(
|
||||
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, args.window, args.dnums,
|
||||
/*preferred_element_type=*/S32)
|
||||
.status();
|
||||
ASSERT_FALSE(inferred_status.ok());
|
||||
ASSERT_THAT(inferred_status.error_message(),
|
||||
HasSubstr("must both be integral or both be floating point"));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest,
|
||||
IntegralConvolveWithFloatingPointPreferredElementType) {
|
||||
ConvolveArgs args = MakeConvolveArgs(S8, S16);
|
||||
auto inferred_status =
|
||||
ShapeInference::InferConvolveShape(
|
||||
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, args.window, args.dnums,
|
||||
/*preferred_element_type=*/F32)
|
||||
.status();
|
||||
ASSERT_FALSE(inferred_status.ok());
|
||||
ASSERT_THAT(inferred_status.error_message(),
|
||||
HasSubstr("must both be integral or both be floating point"));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest,
|
||||
ConvolveWithPreferredElementTypeWithDifferentSignedness) {
|
||||
ConvolveArgs args = MakeConvolveArgs(S8, S16);
|
||||
auto inferred_status =
|
||||
ShapeInference::InferConvolveShape(
|
||||
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, args.window, args.dnums,
|
||||
/*preferred_element_type=*/U32)
|
||||
.status();
|
||||
ASSERT_FALSE(inferred_status.ok());
|
||||
ASSERT_THAT(inferred_status.error_message(),
|
||||
HasSubstr("must have the same signedness as the original type"));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, ConvolveWithNarrowerPreferredElementType) {
|
||||
ConvolveArgs args = MakeConvolveArgs(S8, S16);
|
||||
auto inferred_status =
|
||||
ShapeInference::InferConvolveShape(
|
||||
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, args.window, args.dnums,
|
||||
/*preferred_element_type=*/S8)
|
||||
.status();
|
||||
ASSERT_FALSE(inferred_status.ok());
|
||||
ASSERT_THAT(inferred_status.error_message(),
|
||||
HasSubstr("must not be narrower than the original type"));
|
||||
}
|
||||
|
||||
namespace fft {
|
||||
|
||||
static const char* unsupported_rank = "only supports ranks 1-3";
|
||||
@ -1282,8 +1420,8 @@ TEST_F(ShapeInferenceTest, BroadcastScalar) {
|
||||
// scalar <dot> vector: ok
|
||||
TEST_F(ShapeInferenceTest, ScalarDotVector) {
|
||||
DotDimensionNumbers dot_dnums;
|
||||
auto inferred_status =
|
||||
ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums);
|
||||
auto inferred_status = ShapeInference::InferDotOpShape(
|
||||
f32_, vector_32_, dot_dnums, /*preferred_element_type=*/absl::nullopt);
|
||||
EXPECT_TRUE(inferred_status.ok());
|
||||
EXPECT_EQ(inferred_status.ValueOrDie(), vector_32_);
|
||||
}
|
||||
@ -1294,7 +1432,8 @@ TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
|
||||
dot_dnums.add_lhs_contracting_dimensions(1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
auto inferred_status = ShapeInference::InferDotOpShape(
|
||||
ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums);
|
||||
ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
EXPECT_TRUE(inferred_status.ok());
|
||||
EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
|
||||
ShapeUtil::MakeShape(F32, {32, 32, 64})));
|
||||
@ -1306,11 +1445,13 @@ TEST_F(ShapeInferenceTest, VectorDotVector) {
|
||||
dot_dnums.add_lhs_contracting_dimensions(0);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
auto inferred_status =
|
||||
ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums);
|
||||
ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
|
||||
auto inferred_status_mismatch =
|
||||
ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums);
|
||||
ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_FALSE(inferred_status_mismatch.ok());
|
||||
}
|
||||
|
||||
@ -1320,11 +1461,13 @@ TEST_F(ShapeInferenceTest, MatrixDotVector) {
|
||||
dot_dnums.add_lhs_contracting_dimensions(1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
auto inferred_status =
|
||||
ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums);
|
||||
ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_));
|
||||
auto inferred_status_mismatch =
|
||||
ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums);
|
||||
ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_FALSE(inferred_status_mismatch.ok());
|
||||
}
|
||||
|
||||
@ -1334,11 +1477,13 @@ TEST_F(ShapeInferenceTest, VectorDotMatrix) {
|
||||
dot_dnums.add_lhs_contracting_dimensions(0);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
auto inferred_status =
|
||||
ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums);
|
||||
ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_IS_OK(inferred_status.status());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_));
|
||||
auto inferred_status_mismatch =
|
||||
ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums);
|
||||
ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_FALSE(inferred_status_mismatch.ok());
|
||||
}
|
||||
|
||||
@ -1348,7 +1493,8 @@ TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
|
||||
dot_dnums.add_lhs_contracting_dimensions(1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
auto inferred_status_match =
|
||||
ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums);
|
||||
ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_IS_OK(inferred_status_match.status());
|
||||
ASSERT_TRUE(
|
||||
ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_))
|
||||
@ -1356,7 +1502,8 @@ TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
|
||||
<< ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
|
||||
<< " expected: " << ShapeUtil::HumanString(matrix_64_48_);
|
||||
auto inferred_status_mismatch =
|
||||
ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums);
|
||||
ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_FALSE(inferred_status_mismatch.ok());
|
||||
}
|
||||
|
||||
@ -1376,7 +1523,8 @@ TEST_F(ShapeInferenceTest, DotGeneral) {
|
||||
dot_dnums.add_rhs_batch_dimensions(1);
|
||||
|
||||
auto inferred_status_match =
|
||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
|
||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_IS_OK(inferred_status_match.status());
|
||||
ASSERT_TRUE(
|
||||
ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape))
|
||||
@ -1399,7 +1547,8 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) {
|
||||
dot_dnums.add_rhs_batch_dimensions(0);
|
||||
|
||||
auto inferred_status =
|
||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
|
||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_FALSE(inferred_status.ok());
|
||||
ASSERT_THAT(inferred_status.status().error_message(),
|
||||
HasSubstr("Must specify the same number of contracting "
|
||||
@ -1421,7 +1570,8 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) {
|
||||
dot_dnums.add_rhs_batch_dimensions(0);
|
||||
|
||||
auto inferred_status =
|
||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
|
||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
EXPECT_TRUE(inferred_status.ok());
|
||||
EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape));
|
||||
}
|
||||
@ -1461,7 +1611,8 @@ TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) {
|
||||
dot_dnums.add_rhs_batch_dimensions(0);
|
||||
|
||||
auto inferred_status =
|
||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
|
||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_FALSE(inferred_status.ok());
|
||||
ASSERT_THAT(inferred_status.status().error_message(),
|
||||
HasSubstr("Batch dimension sizes must match"));
|
||||
@ -1480,7 +1631,8 @@ TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimNumbersPasses) {
|
||||
dot_dnums.add_rhs_batch_dimensions(1);
|
||||
|
||||
auto inferred_status =
|
||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
|
||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_TRUE(inferred_status.ok());
|
||||
ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
|
||||
ShapeUtil::MakeShape(F32, {2, 11, 14})));
|
||||
@ -1499,7 +1651,8 @@ TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) {
|
||||
dot_dnums.add_rhs_batch_dimensions(1);
|
||||
|
||||
auto inferred_status =
|
||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
|
||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_FALSE(inferred_status.ok());
|
||||
ASSERT_THAT(inferred_status.status().error_message(),
|
||||
HasSubstr("A dimension number is out of range"));
|
||||
@ -1518,12 +1671,108 @@ TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) {
|
||||
dot_dnums.add_rhs_batch_dimensions(1);
|
||||
|
||||
auto inferred_status =
|
||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
|
||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
ASSERT_FALSE(inferred_status.ok());
|
||||
ASSERT_THAT(inferred_status.status().error_message(),
|
||||
HasSubstr("A dimension number is not unique"));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, DotWithIntegralPreferredElementType) {
|
||||
DotDimensionNumbers dot_dnums;
|
||||
dot_dnums.add_lhs_contracting_dimensions(1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
|
||||
ShapeInference::InferDotOpShape(
|
||||
ShapeUtil::MakeShape(S8, {32, 32}),
|
||||
ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
|
||||
/*preferred_element_type=*/S32));
|
||||
EXPECT_TRUE(
|
||||
ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S32, {32, 32})));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeSameAsInferredType) {
|
||||
DotDimensionNumbers dot_dnums;
|
||||
dot_dnums.add_lhs_contracting_dimensions(1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
|
||||
ShapeInference::InferDotOpShape(
|
||||
ShapeUtil::MakeShape(BF16, {32, 32}),
|
||||
ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
|
||||
/*preferred_element_type=*/F32));
|
||||
EXPECT_TRUE(
|
||||
ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, FloatingPointDotWithNarrowerPreferredElementType) {
|
||||
DotDimensionNumbers dot_dnums;
|
||||
dot_dnums.add_lhs_contracting_dimensions(1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
|
||||
ShapeInference::InferDotOpShape(
|
||||
ShapeUtil::MakeShape(BF16, {32, 32}),
|
||||
ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
|
||||
/*preferred_element_type=*/BF16));
|
||||
EXPECT_TRUE(
|
||||
ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, FloatingPointDotWithInvalidPreferredElementType) {
|
||||
DotDimensionNumbers dot_dnums;
|
||||
dot_dnums.add_lhs_contracting_dimensions(1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
auto inferred_status = ShapeInference::InferDotOpShape(
|
||||
ShapeUtil::MakeShape(BF16, {32, 32}),
|
||||
ShapeUtil::MakeShape(BF16, {32, 32}), dot_dnums,
|
||||
/*preferred_element_type=*/S32)
|
||||
.status();
|
||||
ASSERT_FALSE(inferred_status.ok());
|
||||
ASSERT_THAT(inferred_status.error_message(),
|
||||
HasSubstr("must both be integral or both be floating point"));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, IntegralDotWithFloatingPointPreferredElementType) {
|
||||
DotDimensionNumbers dot_dnums;
|
||||
dot_dnums.add_lhs_contracting_dimensions(1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
auto inferred_status = ShapeInference::InferDotOpShape(
|
||||
ShapeUtil::MakeShape(S8, {32, 32}),
|
||||
ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
|
||||
/*preferred_element_type=*/F32)
|
||||
.status();
|
||||
ASSERT_FALSE(inferred_status.ok());
|
||||
ASSERT_THAT(inferred_status.error_message(),
|
||||
HasSubstr("must both be integral or both be floating point"));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeWithDifferentSignedness) {
|
||||
DotDimensionNumbers dot_dnums;
|
||||
dot_dnums.add_lhs_contracting_dimensions(1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
auto inferred_status = ShapeInference::InferDotOpShape(
|
||||
ShapeUtil::MakeShape(S8, {32, 32}),
|
||||
ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
|
||||
/*preferred_element_type=*/U32)
|
||||
.status();
|
||||
ASSERT_FALSE(inferred_status.ok());
|
||||
ASSERT_THAT(inferred_status.error_message(),
|
||||
HasSubstr("must have the same signedness as the original type"));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, DotWithNarrowerPreferredElementType) {
|
||||
DotDimensionNumbers dot_dnums;
|
||||
dot_dnums.add_lhs_contracting_dimensions(1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
auto inferred_status = ShapeInference::InferDotOpShape(
|
||||
ShapeUtil::MakeShape(S8, {32, 32}),
|
||||
ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
|
||||
/*preferred_element_type=*/S8)
|
||||
.status();
|
||||
ASSERT_FALSE(inferred_status.ok());
|
||||
ASSERT_THAT(inferred_status.error_message(),
|
||||
HasSubstr("must not be narrower than the original type"));
|
||||
}
|
||||
|
||||
TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
|
||||
// Test variations of broadcasting a vector for a binary add with a
|
||||
// matrix.
|
||||
|
@ -1520,10 +1520,11 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
||||
->set_padding_low(0);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * new_conv,
|
||||
MakeConvolveHlo(activations_new, /*rhs=*/convolution->mutable_operand(1),
|
||||
convolution->feature_group_count(),
|
||||
convolution->batch_group_count(), new_window,
|
||||
new_dim_numbers, convolution->precision_config()));
|
||||
MakeConvolveHlo(
|
||||
activations_new, /*rhs=*/convolution->mutable_operand(1),
|
||||
convolution->feature_group_count(), convolution->batch_group_count(),
|
||||
new_window, new_dim_numbers, convolution->precision_config(),
|
||||
/*preferred_element_type=*/convolution->shape().element_type()));
|
||||
convolution->SetupDerivedInstruction(new_conv);
|
||||
|
||||
old_to_new_instrs_[convolution] = new_conv;
|
||||
@ -1800,10 +1801,11 @@ Status ConvolutionVisitor::PropagateOnBackpropFilterConv(
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * new_conv,
|
||||
MakeConvolveHlo(activations_new, kernel_new,
|
||||
convolution->feature_group_count(),
|
||||
convolution->batch_group_count(), new_window,
|
||||
new_dim_numbers, convolution->precision_config()));
|
||||
MakeConvolveHlo(
|
||||
activations_new, kernel_new, convolution->feature_group_count(),
|
||||
convolution->batch_group_count(), new_window, new_dim_numbers,
|
||||
convolution->precision_config(),
|
||||
/*preferred_element_type=*/convolution->shape().element_type()));
|
||||
convolution->SetupDerivedInstruction(new_conv);
|
||||
|
||||
std::vector<int64> output_sizes(new_conv->shape().dimensions().begin(),
|
||||
@ -2043,10 +2045,11 @@ Status ConvolutionVisitor::PerformSpaceToBatchOnConvolution(
|
||||
->set_padding_low(0);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * new_conv,
|
||||
MakeConvolveHlo(activations, /*rhs=*/convolution->mutable_operand(1),
|
||||
convolution->feature_group_count(),
|
||||
convolution->batch_group_count(), new_window,
|
||||
new_dim_numbers, convolution->precision_config()));
|
||||
MakeConvolveHlo(
|
||||
activations, /*rhs=*/convolution->mutable_operand(1),
|
||||
convolution->feature_group_count(), convolution->batch_group_count(),
|
||||
new_window, new_dim_numbers, convolution->precision_config(),
|
||||
/*preferred_element_type=*/convolution->shape().element_type()));
|
||||
convolution->SetupDerivedInstruction(new_conv);
|
||||
|
||||
VLOG(1) << "Space-to-batched convolution " << new_conv->ToString();
|
||||
|
@ -950,7 +950,8 @@ StatusOr<std::unique_ptr<HloInstruction>> CreateShardedConvConvolution(
|
||||
Shape sharded_conv_shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
sharded_lhs_hlo->shape(), sharded_rhs_hlo->shape(),
|
||||
feature_group_count, batch_group_count, window, conv_dnums));
|
||||
feature_group_count, batch_group_count, window, conv_dnums,
|
||||
/*preferred_element_type=*/conv.shape().element_type()));
|
||||
*sharded_conv_shape.mutable_layout() = conv.shape().layout();
|
||||
return HloInstruction::CreateConvolve(
|
||||
sharded_conv_shape, sharded_lhs_hlo, sharded_rhs_hlo, feature_group_count,
|
||||
|
@ -80,8 +80,9 @@ Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) {
|
||||
const Window& conv_window) -> StatusOr<HloInstruction*> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto sharded_dot_shape,
|
||||
ShapeInference::InferDotOpShape(l->shape(), r->shape(),
|
||||
hlo->dot_dimension_numbers()));
|
||||
ShapeInference::InferDotOpShape(
|
||||
l->shape(), r->shape(), hlo->dot_dimension_numbers(),
|
||||
/*preferred_element_type=*/hlo->shape().element_type()));
|
||||
return b->AddInstruction(HloInstruction::CreateDot(
|
||||
sharded_dot_shape, l, r, hlo->dot_dimension_numbers(),
|
||||
hlo->precision_config()));
|
||||
|
@ -242,7 +242,8 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) {
|
||||
}
|
||||
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
|
||||
x->shape(), transpose_y->shape(), /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, window, dnums);
|
||||
/*batch_group_count=*/1, window, dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
EXPECT_IS_OK(conv_shape);
|
||||
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
|
||||
conv_shape.ValueOrDie(), x, transpose_y,
|
||||
@ -298,7 +299,8 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
|
||||
}
|
||||
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
|
||||
x->shape(), transpose_y->shape(), /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, window, dnums);
|
||||
/*batch_group_count=*/1, window, dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
EXPECT_IS_OK(conv_shape);
|
||||
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
|
||||
conv_shape.ValueOrDie(), x, transpose_y,
|
||||
@ -359,7 +361,8 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
|
||||
}
|
||||
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
|
||||
transpose_x->shape(), y->shape(), /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, window, dnums);
|
||||
/*batch_group_count=*/1, window, dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
EXPECT_IS_OK(conv_shape);
|
||||
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
|
||||
conv_shape.ValueOrDie(), transpose_x, y,
|
||||
@ -426,7 +429,8 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
|
||||
}
|
||||
StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
|
||||
transpose_x->shape(), y->shape(), /*feature_group_count=*/1,
|
||||
/*batch_group_count=*/1, window, dnums);
|
||||
/*batch_group_count=*/1, window, dnums,
|
||||
/*preferred_element_type=*/absl::nullopt);
|
||||
EXPECT_IS_OK(conv_shape);
|
||||
HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
|
||||
conv_shape.ValueOrDie(), transpose_x, y,
|
||||
|
Loading…
Reference in New Issue
Block a user