[XLA] Add support for convolutions with no spatial dimensions
PiperOrigin-RevId: 173126950
This commit is contained in:
parent
fc56349b7f
commit
46ab25e4de
@ -663,7 +663,7 @@ bool ComputationBuilder::VerifyConvolution(
|
||||
return false;
|
||||
}
|
||||
int num_dims = ShapeUtil::Rank(lhs_shape);
|
||||
if (num_dims < 3) {
|
||||
if (num_dims < 2) {
|
||||
NoteError(InvalidArgument(
|
||||
"Convolution expects argument arrays with >= 3 dimensions. "
|
||||
"Got: %s and %s",
|
||||
|
@ -201,17 +201,18 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
static bool Run(
|
||||
HloComputation* computation, bool is_layout_sensitive,
|
||||
AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback,
|
||||
bool enable_dot_simplification);
|
||||
bool enable_dot_simplification, bool enable_conv_simplification);
|
||||
|
||||
private:
|
||||
explicit AlgebraicSimplifierVisitor(
|
||||
HloComputation* computation, bool is_layout_sensitive,
|
||||
AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback,
|
||||
bool enable_dot_simplification)
|
||||
bool enable_dot_simplification, bool enable_conv_simplification)
|
||||
: computation_(computation),
|
||||
is_layout_sensitive_(is_layout_sensitive),
|
||||
valid_bitcast_callback_(std::move(valid_bitcast_callback)),
|
||||
enable_dot_simplification_(enable_dot_simplification) {}
|
||||
enable_dot_simplification_(enable_dot_simplification),
|
||||
enable_conv_simplification_(enable_conv_simplification) {}
|
||||
|
||||
// Convenience method for replacing an instruction with a bitcast.
|
||||
void ReplaceWithBitcast(HloInstruction* instruction);
|
||||
@ -287,15 +288,18 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
// Disable dot simplication on platforms where it causes a slowdown.
|
||||
bool enable_dot_simplification_;
|
||||
|
||||
// Disable convolution simplication on platforms where it causes a slowdown.
|
||||
bool enable_conv_simplification_;
|
||||
};
|
||||
|
||||
bool AlgebraicSimplifierVisitor::Run(
|
||||
HloComputation* computation, bool is_layout_sensitive,
|
||||
AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback,
|
||||
bool enable_dot_simplification) {
|
||||
AlgebraicSimplifierVisitor visitor(computation, is_layout_sensitive,
|
||||
std::move(valid_bitcast_callback),
|
||||
enable_dot_simplification);
|
||||
bool enable_dot_simplification, bool enable_conv_simplification) {
|
||||
AlgebraicSimplifierVisitor visitor(
|
||||
computation, is_layout_sensitive, std::move(valid_bitcast_callback),
|
||||
enable_dot_simplification, enable_conv_simplification);
|
||||
TF_CHECK_OK(computation->Accept(&visitor));
|
||||
return visitor.changed_;
|
||||
}
|
||||
@ -1459,6 +1463,9 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
|
||||
Status AlgebraicSimplifierVisitor::HandleConvolution(
|
||||
HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs,
|
||||
const Window& window) {
|
||||
if (!enable_conv_simplification_) {
|
||||
return Status::OK();
|
||||
}
|
||||
// HandleConvolution tries to replace a convolution with a DOT instruction.
|
||||
//
|
||||
// Only add when bitcasts can be used:
|
||||
@ -1962,9 +1969,9 @@ StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
|
||||
"AlgebraicSimplifier::Run(), before:\n" + module->ToString());
|
||||
bool changed = false;
|
||||
for (auto* comp : module->MakeNonfusionComputations()) {
|
||||
if (AlgebraicSimplifierVisitor::Run(comp, is_layout_sensitive_,
|
||||
valid_bitcast_callback_,
|
||||
enable_dot_simplification_)) {
|
||||
if (AlgebraicSimplifierVisitor::Run(
|
||||
comp, is_layout_sensitive_, valid_bitcast_callback_,
|
||||
enable_dot_simplification_, enable_conv_simplification_)) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
@ -40,11 +40,13 @@ class AlgebraicSimplifier : public HloPassInterface {
|
||||
// bitcasts.
|
||||
AlgebraicSimplifier(bool is_layout_sensitive,
|
||||
ValidBitcastCallback valid_bitcast_callback,
|
||||
bool enable_dot_simplification = true)
|
||||
bool enable_dot_simplification = true,
|
||||
bool enable_conv_simplification = true)
|
||||
: is_layout_sensitive_(is_layout_sensitive),
|
||||
valid_bitcast_callback_(std::move(valid_bitcast_callback)),
|
||||
enable_dot_simplification_(enable_dot_simplification) {}
|
||||
~AlgebraicSimplifier() override {}
|
||||
enable_dot_simplification_(enable_dot_simplification),
|
||||
enable_conv_simplification_(enable_conv_simplification) {}
|
||||
~AlgebraicSimplifier() override = default;
|
||||
tensorflow::StringPiece name() const override { return "algsimp"; }
|
||||
|
||||
// Run algebraic simplification on the given computation. Returns whether the
|
||||
@ -57,6 +59,9 @@ class AlgebraicSimplifier : public HloPassInterface {
|
||||
|
||||
// Enable dot simplication on platforms where it is profitable.
|
||||
bool enable_dot_simplification_;
|
||||
|
||||
// Enable convolution simplication on platforms where it is profitable.
|
||||
bool enable_conv_simplification_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -398,7 +398,9 @@ Status HloCostAnalysis::HandleConvolution(HloInstruction* convolution,
|
||||
// For each output element, we do one fma per element in the kernel at some
|
||||
// given output feature index.
|
||||
const int64 fmas_per_output_element =
|
||||
ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features;
|
||||
output_features > 0
|
||||
? ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features
|
||||
: 0;
|
||||
const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape());
|
||||
current_properties_[kFlopsKey] =
|
||||
output_elements * fmas_per_output_element * kFmaFlops;
|
||||
|
@ -547,7 +547,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
|
||||
const auto& dnums = conv->convolution_dimension_numbers();
|
||||
const int64 num_spatial_dims = dnums.spatial_dimensions_size();
|
||||
CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size());
|
||||
CHECK_GE(num_spatial_dims, 1);
|
||||
CHECK_GE(num_spatial_dims, 0);
|
||||
CHECK_EQ(window.dimensions_size(), num_spatial_dims);
|
||||
|
||||
const auto lhs_rank = ShapeUtil::Rank(lhs_shape);
|
||||
|
@ -1385,14 +1385,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
"Window: %s",
|
||||
window.DebugString().c_str());
|
||||
}
|
||||
int num_spatial_dims = dnums.spatial_dimensions_size();
|
||||
if (num_spatial_dims < 1) {
|
||||
return InvalidArgument(
|
||||
"Convolution requires at least one spatial dimension.\n"
|
||||
"Window: %s",
|
||||
window.DebugString().c_str());
|
||||
}
|
||||
|
||||
const int num_spatial_dims = dnums.spatial_dimensions_size();
|
||||
if (window.dimensions_size() != num_spatial_dims) {
|
||||
return InvalidArgument(
|
||||
"Window must have same number of dimensions as dimension numbers.\n"
|
||||
@ -1400,7 +1394,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
window.DebugString().c_str(), dnums.DebugString().c_str());
|
||||
}
|
||||
|
||||
int num_dims = num_spatial_dims + 2;
|
||||
const int num_dims = num_spatial_dims + 2;
|
||||
if (ShapeUtil::Rank(lhs) != num_dims) {
|
||||
return InvalidArgument(
|
||||
"The LHS argument to a convolution should have rank %d.\n"
|
||||
|
Loading…
Reference in New Issue
Block a user