[XLA] Add support for convolutions with no spatial dimensions

PiperOrigin-RevId: 173126950
This commit is contained in:
David Majnemer 2017-10-23 09:36:32 -07:00 committed by TensorFlower Gardener
parent fc56349b7f
commit 46ab25e4de
6 changed files with 32 additions and 24 deletions

View File

@ -663,7 +663,7 @@ bool ComputationBuilder::VerifyConvolution(
return false;
}
int num_dims = ShapeUtil::Rank(lhs_shape);
if (num_dims < 3) {
if (num_dims < 2) {
NoteError(InvalidArgument(
"Convolution expects argument arrays with >= 3 dimensions. "
"Got: %s and %s",

View File

@ -201,17 +201,18 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
static bool Run(
HloComputation* computation, bool is_layout_sensitive,
AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback,
bool enable_dot_simplification);
bool enable_dot_simplification, bool enable_conv_simplification);
private:
explicit AlgebraicSimplifierVisitor(
HloComputation* computation, bool is_layout_sensitive,
AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback,
bool enable_dot_simplification)
bool enable_dot_simplification, bool enable_conv_simplification)
: computation_(computation),
is_layout_sensitive_(is_layout_sensitive),
valid_bitcast_callback_(std::move(valid_bitcast_callback)),
enable_dot_simplification_(enable_dot_simplification) {}
enable_dot_simplification_(enable_dot_simplification),
enable_conv_simplification_(enable_conv_simplification) {}
// Convenience method for replacing an instruction with a bitcast.
void ReplaceWithBitcast(HloInstruction* instruction);
@ -287,15 +288,18 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Disable dot simplication on platforms where it causes a slowdown.
bool enable_dot_simplification_;
// Disable convolution simplication on platforms where it causes a slowdown.
bool enable_conv_simplification_;
};
bool AlgebraicSimplifierVisitor::Run(
HloComputation* computation, bool is_layout_sensitive,
AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback,
bool enable_dot_simplification) {
AlgebraicSimplifierVisitor visitor(computation, is_layout_sensitive,
std::move(valid_bitcast_callback),
enable_dot_simplification);
bool enable_dot_simplification, bool enable_conv_simplification) {
AlgebraicSimplifierVisitor visitor(
computation, is_layout_sensitive, std::move(valid_bitcast_callback),
enable_dot_simplification, enable_conv_simplification);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
}
@ -1459,6 +1463,9 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
Status AlgebraicSimplifierVisitor::HandleConvolution(
HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs,
const Window& window) {
if (!enable_conv_simplification_) {
return Status::OK();
}
// HandleConvolution tries to replace a convolution with a DOT instruction.
//
// Only add when bitcasts can be used:
@ -1962,9 +1969,9 @@ StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
"AlgebraicSimplifier::Run(), before:\n" + module->ToString());
bool changed = false;
for (auto* comp : module->MakeNonfusionComputations()) {
if (AlgebraicSimplifierVisitor::Run(comp, is_layout_sensitive_,
valid_bitcast_callback_,
enable_dot_simplification_)) {
if (AlgebraicSimplifierVisitor::Run(
comp, is_layout_sensitive_, valid_bitcast_callback_,
enable_dot_simplification_, enable_conv_simplification_)) {
changed = true;
}
}

View File

@ -40,11 +40,13 @@ class AlgebraicSimplifier : public HloPassInterface {
// bitcasts.
AlgebraicSimplifier(bool is_layout_sensitive,
ValidBitcastCallback valid_bitcast_callback,
bool enable_dot_simplification = true)
bool enable_dot_simplification = true,
bool enable_conv_simplification = true)
: is_layout_sensitive_(is_layout_sensitive),
valid_bitcast_callback_(std::move(valid_bitcast_callback)),
enable_dot_simplification_(enable_dot_simplification) {}
~AlgebraicSimplifier() override {}
enable_dot_simplification_(enable_dot_simplification),
enable_conv_simplification_(enable_conv_simplification) {}
~AlgebraicSimplifier() override = default;
tensorflow::StringPiece name() const override { return "algsimp"; }
// Run algebraic simplification on the given computation. Returns whether the
@ -57,6 +59,9 @@ class AlgebraicSimplifier : public HloPassInterface {
// Enable dot simplication on platforms where it is profitable.
bool enable_dot_simplification_;
// Enable convolution simplication on platforms where it is profitable.
bool enable_conv_simplification_;
};
} // namespace xla

View File

@ -398,7 +398,9 @@ Status HloCostAnalysis::HandleConvolution(HloInstruction* convolution,
// For each output element, we do one fma per element in the kernel at some
// given output feature index.
const int64 fmas_per_output_element =
ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features;
output_features > 0
? ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features
: 0;
const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape());
current_properties_[kFlopsKey] =
output_elements * fmas_per_output_element * kFmaFlops;

View File

@ -547,7 +547,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
const auto& dnums = conv->convolution_dimension_numbers();
const int64 num_spatial_dims = dnums.spatial_dimensions_size();
CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size());
CHECK_GE(num_spatial_dims, 1);
CHECK_GE(num_spatial_dims, 0);
CHECK_EQ(window.dimensions_size(), num_spatial_dims);
const auto lhs_rank = ShapeUtil::Rank(lhs_shape);

View File

@ -1385,14 +1385,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
"Window: %s",
window.DebugString().c_str());
}
int num_spatial_dims = dnums.spatial_dimensions_size();
if (num_spatial_dims < 1) {
return InvalidArgument(
"Convolution requires at least one spatial dimension.\n"
"Window: %s",
window.DebugString().c_str());
}
const int num_spatial_dims = dnums.spatial_dimensions_size();
if (window.dimensions_size() != num_spatial_dims) {
return InvalidArgument(
"Window must have same number of dimensions as dimension numbers.\n"
@ -1400,7 +1394,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
window.DebugString().c_str(), dnums.DebugString().c_str());
}
int num_dims = num_spatial_dims + 2;
const int num_dims = num_spatial_dims + 2;
if (ShapeUtil::Rank(lhs) != num_dims) {
return InvalidArgument(
"The LHS argument to a convolution should have rank %d.\n"