Fix confusing way of specifying filter_expansion parameter.
Currently, the constructor of ConvolutionGroupConverter allows to specify canonicalize_depthwise_filter without specifying what it does (assigning it to a variable called filter_expansion). Then this is used to initialize the filter_expansion variable of visitor with !filter_expansion. So essentially one needed to pass 'false' if the filter expansion should be done. This CL makes it clearer by using always the variable name filter_expansion and avoids the intermediate negation step. No functional change. PiperOrigin-RevId: 288647349 Change-Id: I0f2984f403d6b4e0cc88405bd05234cb6a7b8a92
This commit is contained in:
parent
ef7f47792d
commit
583b7418a6
@ -56,8 +56,7 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault {
|
|||||||
// Runs the visitor on a computation.
|
// Runs the visitor on a computation.
|
||||||
static bool Run(HloComputation* computation,
|
static bool Run(HloComputation* computation,
|
||||||
std::function<bool(HloInstruction*)> is_cost_viable,
|
std::function<bool(HloInstruction*)> is_cost_viable,
|
||||||
bool convert_batch_groups_only,
|
bool convert_batch_groups_only, bool filter_expansion);
|
||||||
bool canonicalize_depthwise_filter);
|
|
||||||
|
|
||||||
// Returns whether any convolution ops were rewritten.
|
// Returns whether any convolution ops were rewritten.
|
||||||
const bool changed() const { return changed_; }
|
const bool changed() const { return changed_; }
|
||||||
@ -68,10 +67,9 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault {
|
|||||||
explicit ConvolutionVisitor(
|
explicit ConvolutionVisitor(
|
||||||
HloComputation* computation,
|
HloComputation* computation,
|
||||||
std::function<bool(HloInstruction*)> is_cost_viable,
|
std::function<bool(HloInstruction*)> is_cost_viable,
|
||||||
bool convert_batch_groups_only,
|
bool convert_batch_groups_only, bool filter_expansion)
|
||||||
bool canonicalize_depthwise_filter = false)
|
|
||||||
: computation_(computation),
|
: computation_(computation),
|
||||||
filter_expansion_(!canonicalize_depthwise_filter),
|
filter_expansion_(filter_expansion),
|
||||||
convert_batch_groups_only_(convert_batch_groups_only),
|
convert_batch_groups_only_(convert_batch_groups_only),
|
||||||
is_cost_viable_(is_cost_viable) {}
|
is_cost_viable_(is_cost_viable) {}
|
||||||
|
|
||||||
@ -94,10 +92,9 @@ class ConvolutionVisitor : public DfsHloVisitorWithDefault {
|
|||||||
bool ConvolutionVisitor::Run(
|
bool ConvolutionVisitor::Run(
|
||||||
HloComputation* computation,
|
HloComputation* computation,
|
||||||
std::function<bool(HloInstruction*)> is_cost_viable,
|
std::function<bool(HloInstruction*)> is_cost_viable,
|
||||||
bool convert_batch_groups_only, bool canonicalize_depthwise_filter) {
|
bool convert_batch_groups_only, bool filter_expansion) {
|
||||||
ConvolutionVisitor visitor(computation, is_cost_viable,
|
ConvolutionVisitor visitor(computation, is_cost_viable,
|
||||||
convert_batch_groups_only,
|
convert_batch_groups_only, filter_expansion);
|
||||||
canonicalize_depthwise_filter);
|
|
||||||
TF_CHECK_OK(computation->Accept(&visitor));
|
TF_CHECK_OK(computation->Accept(&visitor));
|
||||||
return visitor.changed_;
|
return visitor.changed_;
|
||||||
}
|
}
|
||||||
|
@ -29,10 +29,10 @@ class ConvolutionGroupConverter : public HloModulePass {
|
|||||||
public:
|
public:
|
||||||
ConvolutionGroupConverter(std::function<bool(HloInstruction*)> is_cost_viable,
|
ConvolutionGroupConverter(std::function<bool(HloInstruction*)> is_cost_viable,
|
||||||
bool convert_batch_groups_only,
|
bool convert_batch_groups_only,
|
||||||
bool canonicalize_depthwise_filter = false)
|
bool filter_expansion = true)
|
||||||
: is_cost_viable_(is_cost_viable),
|
: is_cost_viable_(is_cost_viable),
|
||||||
convert_batch_groups_only_(convert_batch_groups_only),
|
convert_batch_groups_only_(convert_batch_groups_only),
|
||||||
filter_expansion_(canonicalize_depthwise_filter) {}
|
filter_expansion_(filter_expansion) {}
|
||||||
|
|
||||||
absl::string_view name() const override {
|
absl::string_view name() const override {
|
||||||
return "convolution-group-converter";
|
return "convolution-group-converter";
|
||||||
|
@ -152,7 +152,7 @@ Status GpuCompiler::OptimizeHloModule(
|
|||||||
pipeline.AddPass<ConvolutionGroupConverter>(
|
pipeline.AddPass<ConvolutionGroupConverter>(
|
||||||
batch_group_cost_model,
|
batch_group_cost_model,
|
||||||
/*convert_batch_groups_only=*/true,
|
/*convert_batch_groups_only=*/true,
|
||||||
/*canonicalize_depthwise_filter=*/false);
|
/*filter_expansion=*/true);
|
||||||
|
|
||||||
auto cost_model = [](HloInstruction* conv) {
|
auto cost_model = [](HloInstruction* conv) {
|
||||||
// We need a cost model for GPUs. Currently, do nothing.
|
// We need a cost model for GPUs. Currently, do nothing.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user