[XLA:GPU] Rename cudnn convolution passes.
Make them shorter and more consistent. - CudnnConvolutionFoo -> CudnnConvFoo - PadInsertion -> CudnnConvPaddingLegalization - PadForTensorCores -> CudnnConvPadForSpeed (padding channel dimensions from 3 -> 4 is not a tensor-cores-related optimization and ideally should be run on P100s as well). PiperOrigin-RevId: 216618934
This commit is contained in:
parent
128903381b
commit
0be7b32fa4
@ -154,7 +154,7 @@ cc_library(
|
||||
deps = [
|
||||
":backend_configs",
|
||||
":buffer_allocations",
|
||||
":cudnn_convolution_runner",
|
||||
":cudnn_conv_runner",
|
||||
":elemental_ir_emitter",
|
||||
":gpu_constants",
|
||||
":gpu_executable",
|
||||
@ -323,7 +323,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":buffer_allocations",
|
||||
":cudnn_convolution_runner",
|
||||
":cudnn_conv_runner",
|
||||
":hlo_execution_profiler",
|
||||
":infeed_manager",
|
||||
":ir_emission_utils",
|
||||
@ -385,13 +385,13 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cudnn_convolution_algorithm_picker",
|
||||
srcs = ["cudnn_convolution_algorithm_picker.cc"],
|
||||
hdrs = ["cudnn_convolution_algorithm_picker.h"],
|
||||
name = "cudnn_conv_algorithm_picker",
|
||||
srcs = ["cudnn_conv_algorithm_picker.cc"],
|
||||
hdrs = ["cudnn_conv_algorithm_picker.h"],
|
||||
deps = [
|
||||
":backend_configs",
|
||||
":buffer_comparator",
|
||||
":cudnn_convolution_runner",
|
||||
":cudnn_conv_runner",
|
||||
":gpu_executable",
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
@ -410,9 +410,9 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cudnn_convolution_runner",
|
||||
srcs = ["cudnn_convolution_runner.cc"],
|
||||
hdrs = ["cudnn_convolution_runner.h"],
|
||||
name = "cudnn_conv_runner",
|
||||
srcs = ["cudnn_conv_runner.cc"],
|
||||
hdrs = ["cudnn_conv_runner.h"],
|
||||
deps = [
|
||||
":backend_configs",
|
||||
":ir_emission_utils",
|
||||
@ -432,9 +432,9 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cudnn_convolution_rewriter",
|
||||
srcs = ["cudnn_convolution_rewriter.cc"],
|
||||
hdrs = ["cudnn_convolution_rewriter.h"],
|
||||
name = "cudnn_conv_rewriter",
|
||||
srcs = ["cudnn_conv_rewriter.cc"],
|
||||
hdrs = ["cudnn_conv_rewriter.h"],
|
||||
deps = [
|
||||
":backend_configs",
|
||||
":ir_emission_utils",
|
||||
@ -449,10 +449,10 @@ cc_library(
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "cudnn_convolution_rewriter_test",
|
||||
srcs = ["cudnn_convolution_rewriter_test.cc"],
|
||||
name = "cudnn_conv_rewriter_test",
|
||||
srcs = ["cudnn_conv_rewriter_test.cc"],
|
||||
deps = [
|
||||
":cudnn_convolution_rewriter",
|
||||
":cudnn_conv_rewriter",
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
@ -581,9 +581,9 @@ tf_cc_test(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pad_insertion",
|
||||
srcs = ["pad_insertion.cc"],
|
||||
hdrs = ["pad_insertion.h"],
|
||||
name = "cudnn_conv_padding_legalization",
|
||||
srcs = ["cudnn_conv_padding_legalization.cc"],
|
||||
hdrs = ["cudnn_conv_padding_legalization.h"],
|
||||
deps = [
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
@ -600,9 +600,9 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pad_for_tensor_cores",
|
||||
srcs = ["pad_for_tensor_cores.cc"],
|
||||
hdrs = ["pad_for_tensor_cores.h"],
|
||||
name = "cudnn_conv_pad_for_speed",
|
||||
srcs = ["cudnn_conv_pad_for_speed.cc"],
|
||||
hdrs = ["cudnn_conv_pad_for_speed.h"],
|
||||
deps = [
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
@ -614,11 +614,11 @@ cc_library(
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "pad_for_tensor_cores_test",
|
||||
srcs = ["pad_for_tensor_cores_test.cc"],
|
||||
name = "cudnn_conv_pad_for_speed_test",
|
||||
srcs = ["cudnn_conv_pad_for_speed_test.cc"],
|
||||
deps = [
|
||||
":cudnn_conv_pad_for_speed",
|
||||
":ir_emission_utils",
|
||||
":pad_for_tensor_cores",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
@ -660,9 +660,11 @@ cc_library(
|
||||
srcs = ["nvptx_compiler.cc"],
|
||||
hdrs = ["nvptx_compiler.h"],
|
||||
deps = [
|
||||
":cudnn_convolution_algorithm_picker",
|
||||
":cudnn_convolution_rewriter",
|
||||
":cudnn_fused_convolution_rewriter",
|
||||
":cudnn_conv_algorithm_picker",
|
||||
":cudnn_conv_pad_for_speed",
|
||||
":cudnn_conv_padding_legalization",
|
||||
":cudnn_conv_rewriter",
|
||||
":cudnn_fused_conv_rewriter",
|
||||
":fusion_merger",
|
||||
":gpu_constants",
|
||||
":gpu_copy_insertion",
|
||||
@ -674,8 +676,6 @@ cc_library(
|
||||
":ir_emission_utils",
|
||||
":ir_emitter",
|
||||
":multi_output_fusion",
|
||||
":pad_for_tensor_cores",
|
||||
":pad_insertion",
|
||||
":partition_assignment",
|
||||
":stream_assignment",
|
||||
":stream_executor_util",
|
||||
@ -966,9 +966,9 @@ tf_cc_test(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cudnn_fused_convolution_rewriter",
|
||||
srcs = ["cudnn_fused_convolution_rewriter.cc"],
|
||||
hdrs = ["cudnn_fused_convolution_rewriter.h"],
|
||||
name = "cudnn_fused_conv_rewriter",
|
||||
srcs = ["cudnn_fused_conv_rewriter.cc"],
|
||||
hdrs = ["cudnn_fused_conv_rewriter.h"],
|
||||
deps = [
|
||||
":backend_configs",
|
||||
":ir_emission_utils",
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -56,9 +56,9 @@ Status ConvolutionThunk::ExecuteOnStream(
|
||||
buffer_allocations.GetDeviceAddress(scratch_buffer_);
|
||||
|
||||
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
|
||||
TF_RETURN_IF_ERROR(RunCudnnConvolution(cudnn_call_,
|
||||
absl::MakeSpan(operand_se_buffers),
|
||||
result_buffer, scratch, stream));
|
||||
TF_RETURN_IF_ERROR(RunCudnnConv(cudnn_call_,
|
||||
absl::MakeSpan(operand_se_buffers),
|
||||
result_buffer, scratch, stream));
|
||||
|
||||
void* ptrs[] = {result_buffer.opaque(), scratch.opaque()};
|
||||
se::DeviceMemory<void*> tuple_addr(
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/types/optional.h"
|
||||
@ -145,9 +145,8 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
|
||||
// cache misses and doing extra work. Overall, caching doesn't seem worth the
|
||||
// trouble, but we may want to revisit this if we ever find a model where
|
||||
// caching would speed up compilation a lot.
|
||||
StatusOr<CudnnConvolutionAlgorithmPicker::AutotuneResult>
|
||||
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
|
||||
HloCustomCallInstruction* instr) {
|
||||
StatusOr<CudnnConvAlgorithmPicker::AutotuneResult>
|
||||
CudnnConvAlgorithmPicker::PickBestAlgorithm(HloCustomCallInstruction* instr) {
|
||||
// TODO(timshen): for now only check fp16. It can be expanded to other types,
|
||||
// with some work on the HLO routines.
|
||||
const bool cross_check_enabled =
|
||||
@ -253,10 +252,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
|
||||
backend_config.set_algorithm(alg.algo_id());
|
||||
backend_config.set_tensor_ops_enabled(alg.tensor_ops_enabled());
|
||||
TF_RETURN_IF_ERROR(instr->set_backend_config(backend_config));
|
||||
bool launch_ok = RunCudnnConvolution(instr, absl::MakeSpan(operand_buffers),
|
||||
result_buffer, &scratch_allocator,
|
||||
&stream, &profile_result)
|
||||
.ok();
|
||||
bool launch_ok =
|
||||
RunCudnnConv(instr, absl::MakeSpan(operand_buffers), result_buffer,
|
||||
&scratch_allocator, &stream, &profile_result)
|
||||
.ok();
|
||||
|
||||
if (launch_ok && profile_result.is_valid()) {
|
||||
const bool crash_on_checking_failure =
|
||||
@ -328,7 +327,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
|
||||
instr->ToString());
|
||||
}
|
||||
|
||||
StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
|
||||
StatusOr<bool> CudnnConvAlgorithmPicker::RunOnInstruction(
|
||||
HloInstruction* instr) {
|
||||
CHECK(IsCustomCallToDnnConvolution(*instr));
|
||||
|
||||
@ -378,7 +377,7 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
|
||||
return true;
|
||||
}
|
||||
|
||||
StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnComputation(
|
||||
StatusOr<bool> CudnnConvAlgorithmPicker::RunOnComputation(
|
||||
HloComputation* computation) {
|
||||
std::vector<HloInstruction*> convs;
|
||||
for (auto* instr : computation->instructions()) {
|
||||
@ -395,7 +394,7 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnComputation(
|
||||
return changed;
|
||||
}
|
||||
|
||||
StatusOr<bool> CudnnConvolutionAlgorithmPicker::Run(HloModule* module) {
|
||||
StatusOr<bool> CudnnConvAlgorithmPicker::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
|
@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_
|
||||
|
||||
#include "absl/time/time.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
@ -31,18 +31,17 @@ namespace gpu {
|
||||
|
||||
// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for
|
||||
// each and adding explicit scratch space to the CustomCalls.
|
||||
class CudnnConvolutionAlgorithmPicker : public HloModulePass {
|
||||
class CudnnConvAlgorithmPicker : public HloModulePass {
|
||||
public:
|
||||
// If the `allocator` parameter is not null, we will use it to allocate temp
|
||||
// memory while timing the various convolution algorithms. If it's null,
|
||||
// we'll use the default allocator on the StreamExecutor.
|
||||
CudnnConvolutionAlgorithmPicker(se::StreamExecutor* stream_exec,
|
||||
DeviceMemoryAllocator* allocator,
|
||||
Compiler* compiler)
|
||||
CudnnConvAlgorithmPicker(se::StreamExecutor* stream_exec,
|
||||
DeviceMemoryAllocator* allocator, Compiler* compiler)
|
||||
: stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {}
|
||||
|
||||
absl::string_view name() const override {
|
||||
return "cudnn-convolution-algorithm-picker";
|
||||
return "cudnn-conv-algorithm-picker";
|
||||
}
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
@ -67,4 +66,4 @@ class CudnnConvolutionAlgorithmPicker : public HloModulePass {
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_speed.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
@ -108,7 +108,7 @@ static HloInstruction* PadInstruction(HloInstruction* instr,
|
||||
static StatusOr<bool> PadFeaturesDims(HloCustomCallInstruction* conv) {
|
||||
CHECK_EQ(0, conv->shape().tuple_shapes(1).dimensions(0))
|
||||
<< "conv must use 0 scratch bytes, i.e. this pass must be run "
|
||||
"before CudnnConvolutionAlgorithmPicker.";
|
||||
"before CudnnConvAlgorithmPicker.";
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(conv));
|
||||
const auto& dnums = conv->convolution_dimension_numbers();
|
||||
@ -252,7 +252,7 @@ static std::vector<HloCustomCallInstruction*> GetRelevantConvs(
|
||||
return convs;
|
||||
}
|
||||
|
||||
StatusOr<bool> PadForTensorCores::Run(HloModule* module) {
|
||||
StatusOr<bool> CudnnConvPadForSpeed::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
for (HloComputation* comp : module->MakeNonfusionComputations()) {
|
||||
for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PAD_FOR_SPEED_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PAD_FOR_SPEED_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
@ -29,10 +29,13 @@ namespace gpu {
|
||||
// opposite of useful on other GPUs, so you should check what GPU you're
|
||||
// targeting before running this pass.
|
||||
//
|
||||
// TODO(jlebar): Rework this. For one thing, it should not be Volta-only.
|
||||
// Padding input channels 3 to 4 is (we think) applicable to Pascal as well.
|
||||
//
|
||||
// TODO(jlebar): Also pad dots.
|
||||
class PadForTensorCores : public HloModulePass {
|
||||
class CudnnConvPadForSpeed : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override { return "pad for tensor cores"; }
|
||||
absl::string_view name() const override { return "cudnn-conv-pad-for-speed"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
@ -40,4 +43,4 @@ class PadForTensorCores : public HloModulePass {
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_FOR_TENSOR_CORES_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PAD_FOR_SPEED_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_speed.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
@ -29,9 +29,9 @@ namespace {
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
using ::testing::_;
|
||||
|
||||
class PadForTensorCoresTest : public HloVerifiedTestBase {};
|
||||
class CudnnConvPadForSpeedTest : public HloVerifiedTestBase {};
|
||||
|
||||
TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) {
|
||||
TEST_F(CudnnConvPadForSpeedTest, PadF16ForwardConvInputChannels) {
|
||||
ParseAndVerifyModule(R"(
|
||||
HloModule TestModule
|
||||
|
||||
@ -42,7 +42,7 @@ TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) {
|
||||
window={size=2x2}, dim_labels=b01f_01io->b01f,
|
||||
custom_call_target="__cudnn$convForward"
|
||||
})");
|
||||
EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
|
||||
EXPECT_TRUE(CudnnConvPadForSpeed().Run(&module()).ValueOrDie());
|
||||
auto* root = module().entry_computation()->root_instruction();
|
||||
|
||||
SCOPED_TRACE(module().ToString());
|
||||
@ -55,7 +55,7 @@ TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) {
|
||||
ShapeUtil::MakeShape(F16, {2, 2, 48, 40})));
|
||||
}
|
||||
|
||||
TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) {
|
||||
TEST_F(CudnnConvPadForSpeedTest, PadF16BackwardInputConvOutputChannels) {
|
||||
ParseAndVerifyModule(R"(
|
||||
HloModule TestModule
|
||||
|
||||
@ -66,7 +66,7 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) {
|
||||
window={size=2x2}, dim_labels=b01f_01io->b01f,
|
||||
custom_call_target="__cudnn$convBackwardInput"
|
||||
})");
|
||||
EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
|
||||
EXPECT_TRUE(CudnnConvPadForSpeed().Run(&module()).ValueOrDie());
|
||||
auto* root = module().entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::CustomCall(kCudnnConvBackwardInputCallTarget,
|
||||
op::Pad(op::Parameter(0), _),
|
||||
@ -77,7 +77,7 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvOutputChannels) {
|
||||
ShapeUtil::MakeShape(F16, {2, 2, 40, 48})));
|
||||
}
|
||||
|
||||
TEST_F(PadForTensorCoresTest, PadF16ForwardConvOutputChannels) {
|
||||
TEST_F(CudnnConvPadForSpeedTest, PadF16ForwardConvOutputChannels) {
|
||||
ParseAndVerifyModule(R"(
|
||||
HloModule TestModule
|
||||
|
||||
@ -88,7 +88,7 @@ TEST_F(PadForTensorCoresTest, PadF16ForwardConvOutputChannels) {
|
||||
window={size=2x2}, dim_labels=b01f_01io->b01f,
|
||||
custom_call_target="__cudnn$convForward"
|
||||
})");
|
||||
EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
|
||||
EXPECT_TRUE(CudnnConvPadForSpeed().Run(&module()).ValueOrDie());
|
||||
auto* root = module().entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Tuple(op::Slice(op::GetTupleElement(op::CustomCall(
|
||||
kCudnnConvForwardCallTarget, op::Parameter(0),
|
||||
@ -96,7 +96,7 @@ TEST_F(PadForTensorCoresTest, PadF16ForwardConvOutputChannels) {
|
||||
_));
|
||||
}
|
||||
|
||||
TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvInputChannels) {
|
||||
TEST_F(CudnnConvPadForSpeedTest, PadF16BackwardInputConvInputChannels) {
|
||||
ParseAndVerifyModule(R"(
|
||||
HloModule TestModule
|
||||
|
||||
@ -108,7 +108,7 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvInputChannels) {
|
||||
custom_call_target="__cudnn$convBackwardInput"
|
||||
ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0
|
||||
})");
|
||||
EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
|
||||
EXPECT_TRUE(CudnnConvPadForSpeed().Run(&module()).ValueOrDie());
|
||||
auto* root = module().entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::GetTupleElement(op::Tuple(
|
||||
op::Slice(op::GetTupleElement(op::CustomCall(
|
||||
@ -117,7 +117,7 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardInputConvInputChannels) {
|
||||
_)));
|
||||
}
|
||||
|
||||
TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) {
|
||||
TEST_F(CudnnConvPadForSpeedTest, PadF16BackwardFilterConvInputChannels) {
|
||||
ParseAndVerifyModule(R"(
|
||||
HloModule TestModule
|
||||
|
||||
@ -129,7 +129,7 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) {
|
||||
custom_call_target="__cudnn$convBackwardFilter"
|
||||
ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0
|
||||
})");
|
||||
EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
|
||||
EXPECT_TRUE(CudnnConvPadForSpeed().Run(&module()).ValueOrDie());
|
||||
auto* root = module().entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::GetTupleElement(op::Tuple(
|
||||
op::Slice(op::GetTupleElement(op::CustomCall(
|
||||
@ -138,7 +138,7 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvInputChannels) {
|
||||
_)));
|
||||
}
|
||||
|
||||
TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) {
|
||||
TEST_F(CudnnConvPadForSpeedTest, PadF16BackwardFilterConvOutputChannels) {
|
||||
ParseAndVerifyModule(R"(
|
||||
HloModule TestModule
|
||||
|
||||
@ -150,7 +150,7 @@ TEST_F(PadForTensorCoresTest, PadF16BackwardFilterConvOutputChannels) {
|
||||
custom_call_target="__cudnn$convBackwardFilter"
|
||||
ROOT gte = f16[2,2,40,41] get-tuple-element(result), index=0
|
||||
})");
|
||||
EXPECT_TRUE(PadForTensorCores().Run(&module()).ValueOrDie());
|
||||
EXPECT_TRUE(CudnnConvPadForSpeed().Run(&module()).ValueOrDie());
|
||||
auto* root = module().entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::GetTupleElement(op::Tuple(
|
||||
op::Slice(op::GetTupleElement(op::CustomCall(
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
@ -132,7 +132,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
|
||||
bool CudnnConvPaddingLegalization::CanonicalizeForwardConvolution(
|
||||
HloInstruction* conv) {
|
||||
if (IsForwardConvolutionCanonical(*conv)) {
|
||||
return false;
|
||||
}
|
||||
@ -187,7 +188,7 @@ void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool PadInsertion::CanonicalizeBackwardFilterConvolution(
|
||||
bool CudnnConvPaddingLegalization::CanonicalizeBackwardFilterConvolution(
|
||||
HloInstruction* backward_conv) {
|
||||
CHECK_EQ(backward_conv->custom_call_target(),
|
||||
kCudnnConvBackwardFilterCallTarget);
|
||||
@ -260,7 +261,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PadInsertion::CanonicalizeBackwardInputConvolution(
|
||||
bool CudnnConvPaddingLegalization::CanonicalizeBackwardInputConvolution(
|
||||
HloInstruction* backward_conv) {
|
||||
if (window_util::HasSymmetricPadding(backward_conv->window())) {
|
||||
return false;
|
||||
@ -377,7 +378,8 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
|
||||
return true;
|
||||
}
|
||||
|
||||
StatusOr<bool> PadInsertion::RunOnComputation(HloComputation* computation) {
|
||||
StatusOr<bool> CudnnConvPaddingLegalization::RunOnComputation(
|
||||
HloComputation* computation) {
|
||||
bool changed = false;
|
||||
std::vector<HloCustomCallInstruction*> convs;
|
||||
for (auto* instr : computation->instructions()) {
|
||||
@ -402,7 +404,7 @@ StatusOr<bool> PadInsertion::RunOnComputation(HloComputation* computation) {
|
||||
return changed;
|
||||
}
|
||||
|
||||
StatusOr<bool> PadInsertion::Run(HloModule* module) {
|
||||
StatusOr<bool> CudnnConvPaddingLegalization::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PADDING_LEGALIZATION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PADDING_LEGALIZATION_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
@ -24,9 +24,11 @@ namespace gpu {
|
||||
// An HLO pass that canonicalizes convolution instructions for GPU codegen. It
|
||||
// inserts Pad instructions before Convolution instructions with uncanonicalized
|
||||
// padding, so that they can be lowered to cuDNN convolution.
|
||||
class PadInsertion : public HloModulePass {
|
||||
class CudnnConvPaddingLegalization : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override { return "pad insertion"; }
|
||||
absl::string_view name() const override {
|
||||
return "cudnn-conv-padding-legalization";
|
||||
}
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
@ -41,4 +43,4 @@ class PadInsertion : public HloModulePass {
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_PADDING_LEGALIZATION_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <numeric>
|
||||
@ -188,9 +188,9 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
|
||||
// the amount of high padding the same as the amount of low padding as long
|
||||
// as it is between min_padding_high and max_padding_high. If it is not in
|
||||
// that range, we pick the one that's closest to dim->padding_low() and let
|
||||
// PadInsertion canonicalize the resultant backward convolution later.
|
||||
// Picking the closest one minimizes the cost of the kPad instruction to be
|
||||
// inserted by PadInsertion.
|
||||
// CudnnConvPaddingLegalization canonicalize the resultant backward
|
||||
// convolution later. Picking the closest one minimizes the cost of the kPad
|
||||
// instruction to be inserted by CudnnConvPaddingLegalization.
|
||||
if (dim->padding_low() >= min_padding_high &&
|
||||
dim->padding_low() <= max_padding_high) {
|
||||
dim->set_padding_high(dim->padding_low());
|
||||
@ -207,7 +207,8 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardFilter(
|
||||
"negative padding ("
|
||||
<< dim->padding_high()
|
||||
<< ") on right/bottom of the weight gradients, which is not "
|
||||
"supported by PadInsertion (b/32744257). Falling back to "
|
||||
"supported by CudnnConvPaddingLegalization (b/32744257). "
|
||||
"Falling back to "
|
||||
"unfused convolution for instruction: "
|
||||
<< conv->ToString();
|
||||
return no_match_result;
|
||||
@ -342,7 +343,8 @@ MatchBackwardInput(HloInstruction* conv) {
|
||||
LOG(ERROR)
|
||||
<< "The low padding of the backward convolution would be negative ("
|
||||
<< backward_padding_low
|
||||
<< "), which isn't supported by PadInsertion for now (b/32744257).";
|
||||
<< "), which isn't supported by CudnnConvPaddingLegalization "
|
||||
"for now (b/32744257).";
|
||||
return no_match_result;
|
||||
}
|
||||
dim->set_padding_low(backward_padding_low);
|
||||
@ -371,8 +373,8 @@ MatchBackwardInput(HloInstruction* conv) {
|
||||
dim->set_padding_high(backward_padding_low);
|
||||
} else {
|
||||
// Otherwise, we choose the amount that's closest to backward_padding_low,
|
||||
// and PadInsertion will later insert kSlice instructions to enforce even
|
||||
// padding.
|
||||
// and CudnnConvPaddingLegalization will later insert kSlice
|
||||
// instructions to enforce even padding.
|
||||
//
|
||||
// For example, consider the backward convolution pattern
|
||||
//
|
||||
@ -398,9 +400,9 @@ MatchBackwardInput(HloInstruction* conv) {
|
||||
dim->set_padding_high(max_padding_high);
|
||||
}
|
||||
}
|
||||
// PadInsertion doesn't handle backward input convolution with negative
|
||||
// padding for now. So fall back to unfused convolution in case of negative
|
||||
// padding. For example,
|
||||
// CudnnConvPaddingLegalization doesn't handle backward input
|
||||
// convolution with negative padding for now. So fall back to unfused
|
||||
// convolution in case of negative padding. For example,
|
||||
// ABCD = Conv(abc, reverse(xy), padding_high=2)
|
||||
// could be fused to
|
||||
// ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1)
|
||||
@ -410,8 +412,8 @@ MatchBackwardInput(HloInstruction* conv) {
|
||||
"negative padding ("
|
||||
<< dim->padding_high()
|
||||
<< ") on right/bottom of the activations, which is not "
|
||||
"supported by PadInsertion (b/32744257). Falling back to "
|
||||
"unfused convolution for instruction: "
|
||||
"supported by CudnnConvPaddingLegalization (b/32744257). "
|
||||
"Falling back to unfused convolution for instruction: "
|
||||
<< conv->ToString();
|
||||
return no_match_result;
|
||||
}
|
||||
@ -555,7 +557,7 @@ StatusOr<bool> RunOnComputation(HloComputation* computation) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> CudnnConvolutionRewriter::Run(HloModule* module) {
|
||||
StatusOr<bool> CudnnConvRewriter::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_REWRITER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_REWRITER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
@ -24,11 +24,9 @@ namespace gpu {
|
||||
|
||||
// Rewrites plain convolutions, backwards-filter convolutions, and
|
||||
// backwards-input convolutions into CustomCall HLOs that call into cuDNN.
|
||||
class CudnnConvolutionRewriter : public HloModulePass {
|
||||
class CudnnConvRewriter : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override {
|
||||
return "cudnn-convolution-rewriter";
|
||||
}
|
||||
absl::string_view name() const override { return "cudnn-conv-rewriter"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
@ -36,4 +34,4 @@ class CudnnConvolutionRewriter : public HloModulePass {
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_REWRITER_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
@ -34,9 +34,9 @@ namespace {
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
using ::testing::_;
|
||||
|
||||
class CudnnConvolutionRewriterTest : public HloVerifiedTestBase {
|
||||
class CudnnConvRewriterTest : public HloVerifiedTestBase {
|
||||
public:
|
||||
CudnnConvolutionRewriterTest()
|
||||
CudnnConvRewriterTest()
|
||||
: HloVerifiedTestBase(/*layout_sensitive=*/true,
|
||||
/*allow_mixed_precision=*/false) {
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
@ -85,7 +85,7 @@ class CudnnConvolutionRewriterTest : public HloVerifiedTestBase {
|
||||
|
||||
protected:
|
||||
bool RunPass(HloModule* module) {
|
||||
return CudnnConvolutionRewriter().Run(module).ValueOrDie();
|
||||
return CudnnConvRewriter().Run(module).ValueOrDie();
|
||||
}
|
||||
|
||||
// A convolution window with stride 1 and zero padding. The size fields are
|
||||
@ -95,7 +95,7 @@ class CudnnConvolutionRewriterTest : public HloVerifiedTestBase {
|
||||
ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_;
|
||||
};
|
||||
|
||||
TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) {
|
||||
TEST_F(CudnnConvRewriterTest, BackwardFilterConvolve) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* activations =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
@ -123,7 +123,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) {
|
||||
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
|
||||
}
|
||||
|
||||
TEST_F(CudnnConvolutionRewriterTest,
|
||||
TEST_F(CudnnConvRewriterTest,
|
||||
BackwardFilterConvolveEquivalentToForwardConvolution) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
HloInstruction* activations =
|
||||
@ -152,8 +152,7 @@ TEST_F(CudnnConvolutionRewriterTest,
|
||||
}
|
||||
|
||||
// Extracted from block35 training.
|
||||
TEST_F(CudnnConvolutionRewriterTest,
|
||||
BackwardFilterConvolveWithPaddedActivations) {
|
||||
TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* activations =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
@ -183,8 +182,7 @@ TEST_F(CudnnConvolutionRewriterTest,
|
||||
}
|
||||
|
||||
// Extracted from inception v3 training.
|
||||
TEST_F(CudnnConvolutionRewriterTest,
|
||||
BackwardFilterConvolveWithPaddedGradients) {
|
||||
TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* activations =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
@ -213,7 +211,7 @@ TEST_F(CudnnConvolutionRewriterTest,
|
||||
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
|
||||
}
|
||||
|
||||
TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
|
||||
TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* activations =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
@ -242,7 +240,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) {
|
||||
op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0));
|
||||
}
|
||||
|
||||
TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) {
|
||||
TEST_F(CudnnConvRewriterTest, BackwardInputConvolveEvenPadding) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* output =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
@ -307,7 +305,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) {
|
||||
// Convolve([abc], [x], base_dilation=2)
|
||||
// = Convolve([abc], Reverse([x]), base_dilation=2)
|
||||
// = BackwardInputConvolve([abc], [x], stride=2)
|
||||
TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) {
|
||||
TEST_F(CudnnConvRewriterTest, BackwardInputConvolve1x1Filter) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
// NHWC dimension order.
|
||||
HloInstruction* output =
|
||||
@ -341,7 +339,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) {
|
||||
// BackwardInputConvolve([abc], [x], stride=1) is equivalent to
|
||||
// ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input
|
||||
// convolution.
|
||||
TEST_F(CudnnConvolutionRewriterTest,
|
||||
TEST_F(CudnnConvRewriterTest,
|
||||
BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
// NHWC dimension order.
|
||||
@ -385,8 +383,7 @@ TEST_F(CudnnConvolutionRewriterTest,
|
||||
// 20x10x10x192
|
||||
//
|
||||
// Gradients are padded unevenly.
|
||||
TEST_F(CudnnConvolutionRewriterTest,
|
||||
BackwardInputConvolveUnevenPaddingOnGradients) {
|
||||
TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* output =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
@ -436,7 +433,7 @@ TEST_F(CudnnConvolutionRewriterTest,
|
||||
|
||||
// Similar to BackwardInputConvolveUnevenPadding, but the low padding of the
|
||||
// gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused.
|
||||
TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
|
||||
TEST_F(CudnnConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* output =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
@ -488,9 +485,8 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) {
|
||||
// padding_low=2, padding_high=1, base_dilation=2)
|
||||
//
|
||||
// We should fuse BC even though padding on activations is uneven, because
|
||||
// PadInsertion will canonicalize the fusion HLO.
|
||||
TEST_F(CudnnConvolutionRewriterTest,
|
||||
BackwardInputConvolveUnevenPaddingOnActivations) {
|
||||
// CudnnConvPaddingLegalization will canonicalize the fusion HLO.
|
||||
TEST_F(CudnnConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
// The gradients are in NCHW layout.
|
||||
HloInstruction* output =
|
||||
@ -543,9 +539,10 @@ TEST_F(CudnnConvolutionRewriterTest,
|
||||
// BC = BackwardInput(FC) does:
|
||||
// [4] = conv([3], reverse([2]), padding_high=2)
|
||||
//
|
||||
// We currently don't fuse BC because PadInsertion doesn't support negative
|
||||
// padding on the gradients of backward convolution (b/32744257).
|
||||
TEST_F(CudnnConvolutionRewriterTest,
|
||||
// We currently don't fuse BC because CudnnConvPaddingLegalization
|
||||
// doesn't support negative padding on the gradients of backward convolution
|
||||
// (b/32744257).
|
||||
TEST_F(CudnnConvRewriterTest,
|
||||
BackwardInputConvolveNegativePaddingHighOnActivations) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
// The gradients are in NCHW layout.
|
||||
@ -586,7 +583,7 @@ TEST_F(CudnnConvolutionRewriterTest,
|
||||
|
||||
// Check that we will materialize a reversed version of a constant in order to
|
||||
// pattern-match a backwards input convolution.
|
||||
TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) {
|
||||
TEST_F(CudnnConvRewriterTest, BackwardInputConvolveConstantFilter) {
|
||||
Array4D<float> constant_arr(4, 4, 2, 2);
|
||||
constant_arr.FillIota(0);
|
||||
string constant_str =
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
|
||||
@ -110,10 +110,10 @@ class ScratchBufAllocator : public se::ScratchAllocator {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status RunCudnnConvolutionImpl(CudnnConvParams params,
|
||||
se::ScratchAllocator* scratch_allocator,
|
||||
se::Stream* stream,
|
||||
se::dnn::ProfileResult* profile_result) {
|
||||
Status RunCudnnConvImpl(CudnnConvParams params,
|
||||
se::ScratchAllocator* scratch_allocator,
|
||||
se::Stream* stream,
|
||||
se::dnn::ProfileResult* profile_result) {
|
||||
CudnnConvKind kind = params.kind;
|
||||
const Shape& input_shape = *params.input_shape;
|
||||
const Shape& filter_shape = *params.filter_shape;
|
||||
@ -380,22 +380,21 @@ StatusOr<CudnnConvParams> GetCudnnConvParams(
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
|
||||
absl::Span<se::DeviceMemoryBase> operand_buffers,
|
||||
se::DeviceMemoryBase result_buffer,
|
||||
se::DeviceMemoryBase scratch_buf, se::Stream* stream,
|
||||
se::dnn::ProfileResult* profile_result) {
|
||||
Status RunCudnnConv(const HloCustomCallInstruction* conv,
|
||||
absl::Span<se::DeviceMemoryBase> operand_buffers,
|
||||
se::DeviceMemoryBase result_buffer,
|
||||
se::DeviceMemoryBase scratch_buf, se::Stream* stream,
|
||||
se::dnn::ProfileResult* profile_result) {
|
||||
ScratchBufAllocator scratch_allocator(scratch_buf);
|
||||
return RunCudnnConvolution(conv, operand_buffers, result_buffer,
|
||||
&scratch_allocator, stream, profile_result);
|
||||
return RunCudnnConv(conv, operand_buffers, result_buffer, &scratch_allocator,
|
||||
stream, profile_result);
|
||||
}
|
||||
|
||||
Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
|
||||
absl::Span<se::DeviceMemoryBase> operand_buffers,
|
||||
se::DeviceMemoryBase result_buffer,
|
||||
se::ScratchAllocator* scratch_allocator,
|
||||
se::Stream* stream,
|
||||
se::dnn::ProfileResult* profile_result) {
|
||||
Status RunCudnnConv(const HloCustomCallInstruction* conv,
|
||||
absl::Span<se::DeviceMemoryBase> operand_buffers,
|
||||
se::DeviceMemoryBase result_buffer,
|
||||
se::ScratchAllocator* scratch_allocator, se::Stream* stream,
|
||||
se::dnn::ProfileResult* profile_result) {
|
||||
TF_ASSIGN_OR_RETURN(CudnnConvParams params,
|
||||
GetCudnnConvParams(conv, operand_buffers, result_buffer));
|
||||
|
||||
@ -403,14 +402,14 @@ Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
|
||||
conv->shape().tuple_shapes(0).element_type();
|
||||
switch (output_primitive_type) {
|
||||
case F16:
|
||||
return RunCudnnConvolutionImpl<Eigen::half>(params, scratch_allocator,
|
||||
stream, profile_result);
|
||||
return RunCudnnConvImpl<Eigen::half>(params, scratch_allocator, stream,
|
||||
profile_result);
|
||||
case F32:
|
||||
return RunCudnnConvolutionImpl<float>(params, scratch_allocator, stream,
|
||||
profile_result);
|
||||
return RunCudnnConvImpl<float>(params, scratch_allocator, stream,
|
||||
profile_result);
|
||||
case F64:
|
||||
return RunCudnnConvolutionImpl<double>(params, scratch_allocator, stream,
|
||||
profile_result);
|
||||
return RunCudnnConvImpl<double>(params, scratch_allocator, stream,
|
||||
profile_result);
|
||||
default:
|
||||
LOG(FATAL) << ShapeUtil::HumanString(*params.output_shape);
|
||||
}
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_RUNNER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_RUNNER_H_
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -42,20 +42,19 @@ namespace gpu {
|
||||
// allocator and take note of how much memory is used. The next time you call
|
||||
// the same conv, you can provide an explicitly preallocated scratch buffer of
|
||||
// that size, if you like.
|
||||
Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
|
||||
absl::Span<se::DeviceMemoryBase> operand_buffers,
|
||||
se::DeviceMemoryBase result_buffer,
|
||||
se::DeviceMemoryBase scratch_buf, se::Stream* stream,
|
||||
se::dnn::ProfileResult* profile_result = nullptr);
|
||||
Status RunCudnnConv(const HloCustomCallInstruction* conv,
|
||||
absl::Span<se::DeviceMemoryBase> operand_buffers,
|
||||
se::DeviceMemoryBase result_buffer,
|
||||
se::DeviceMemoryBase scratch_buf, se::Stream* stream,
|
||||
se::dnn::ProfileResult* profile_result = nullptr);
|
||||
|
||||
Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
|
||||
absl::Span<se::DeviceMemoryBase> operand_buffers,
|
||||
se::DeviceMemoryBase result_buffer,
|
||||
se::ScratchAllocator* scratch_allocator,
|
||||
se::Stream* stream,
|
||||
se::dnn::ProfileResult* profile_result = nullptr);
|
||||
Status RunCudnnConv(const HloCustomCallInstruction* conv,
|
||||
absl::Span<se::DeviceMemoryBase> operand_buffers,
|
||||
se::DeviceMemoryBase result_buffer,
|
||||
se::ScratchAllocator* scratch_allocator, se::Stream* stream,
|
||||
se::dnn::ProfileResult* profile_result = nullptr);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_RUNNER_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
|
||||
@ -242,7 +242,7 @@ StatusOr<std::unique_ptr<HloInstruction>> TryRewriteToCudnnForwardRelu(
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> CudnnFusedConvolutionRewriter::Run(HloModule* module) {
|
||||
StatusOr<bool> CudnnFusedConvRewriter::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
std::vector<ConvWithRelu> matches;
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
@ -22,7 +22,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
class CudnnFusedConvolutionRewriter : public HloModulePass {
|
||||
class CudnnFusedConvRewriter : public HloModulePass {
|
||||
public:
|
||||
absl::string_view name() const override {
|
||||
return "cudnn-fused-convolution-rewriter";
|
||||
@ -34,4 +34,4 @@ class CudnnFusedConvolutionRewriter : public HloModulePass {
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_
|
@ -108,9 +108,9 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo);
|
||||
// memory used by cudnn. Callers shouldn't inspect scratch_memory, as its value
|
||||
// is not well-defined.
|
||||
//
|
||||
// CudnnConvolutionRewriter lowers kConvolution HLOs to these custom calls.
|
||||
// CudnnConvRewriter lowers kConvolution HLOs to these custom calls.
|
||||
// When it does so, it chooses algorithm -1 and 0 bytes of scratch space. Later
|
||||
// on in the pipeline, CudnnConvolutionAlgorithmChooser chooses an explicit
|
||||
// on in the pipeline, CudnnConvAlgorithmChooser chooses an explicit
|
||||
// algorithm for each conv and sets the amount of scratch space needed.
|
||||
//
|
||||
// (Representing the scratch memory as an output may seem strange at first, but
|
||||
|
@ -43,7 +43,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
|
||||
|
@ -38,9 +38,11 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_speed.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
|
||||
@ -54,8 +56,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
|
||||
@ -201,21 +201,22 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
|
||||
{
|
||||
// Convert convolutions into CustomCalls to cudnn, then canonicalize them
|
||||
// (PadInsertion).
|
||||
// (CudnnConvPaddingLegalization).
|
||||
HloPassPipeline pipeline("conv_canonicalization");
|
||||
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/false);
|
||||
pipeline.AddPass<CudnnConvolutionRewriter>();
|
||||
pipeline.AddPass<CudnnFusedConvolutionRewriter>();
|
||||
pipeline.AddPass<PadInsertion>();
|
||||
pipeline.AddPass<CudnnConvRewriter>();
|
||||
pipeline.AddPass<CudnnFusedConvRewriter>();
|
||||
pipeline.AddPass<CudnnConvPaddingLegalization>();
|
||||
if (IsVoltaOrLater(*stream_exec)) {
|
||||
pipeline.AddPass<PadForTensorCores>();
|
||||
// PadForTensorCores leaves behind unnecessary tuple/get-tuple-element
|
||||
pipeline.AddPass<CudnnConvPadForSpeed>();
|
||||
// CudnnConvPadForSpeed leaves behind unnecessary tuple/get-tuple-element
|
||||
// pairs that TupleSimplifier fixes.
|
||||
pipeline.AddPass<TupleSimplifier>();
|
||||
}
|
||||
// CudnnConvolutionRewriter, PadInsertion and PadForTensorCores may add
|
||||
// instructions which can be simplified by constant folding.
|
||||
// CudnnConvRewriter, CudnnConvPaddingLegalization and
|
||||
// CudnnConvPadForSpeed may add instructions which can be simplified by
|
||||
// constant folding.
|
||||
pipeline.AddPass<HloConstantFolding>();
|
||||
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
|
||||
}
|
||||
@ -252,7 +253,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
// Choose the fastest algorithm for each conv.
|
||||
//
|
||||
// We pick the algorithm before fusion so we can generate better HLO. After
|
||||
// CudnnConvolutionRewriter, our convolutions are CustomCalls which return a
|
||||
// CudnnConvRewriter, our convolutions are CustomCalls which return a
|
||||
// tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
|
||||
// scratch:
|
||||
//
|
||||
@ -270,12 +271,12 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
// The new tuple and gte instructions then be simplified away, because
|
||||
// nobody is expected to use the scratch value.
|
||||
//
|
||||
// However, if we were to run CudnnConvolutionAlgorithmPicker after fusion
|
||||
// However, if we were to run CudnnConvAlgorithmPicker after fusion
|
||||
// the gte(customcall, 0) would probably already be into a fusion node. We
|
||||
// can't simplify across HloComputation boundaries, so in this case we
|
||||
// wouldn't be able to simplify away the new_tuple bits.
|
||||
pipeline.AddPass<CudnnConvolutionAlgorithmPicker>(
|
||||
stream_exec, device_allocator, compiler);
|
||||
pipeline.AddPass<CudnnConvAlgorithmPicker>(stream_exec, device_allocator,
|
||||
compiler);
|
||||
// Clean up new_tuple described above.
|
||||
pipeline.AddPass<TupleSimplifier>();
|
||||
|
||||
|
@ -211,8 +211,8 @@ tf_cc_test(
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "cudnn_fused_convolution_rewriter_test",
|
||||
srcs = ["cudnn_fused_convolution_rewriter_test.cc"],
|
||||
name = "cudnn_fused_conv_rewriter_test",
|
||||
srcs = ["cudnn_fused_conv_rewriter_test.cc"],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":gpu_codegen_test",
|
||||
|
@ -22,7 +22,7 @@ namespace xla {
|
||||
namespace gpu {
|
||||
namespace {
|
||||
|
||||
class CudnnFusedConvolutionRewriterTest : public HloTestBase {
|
||||
class CudnnFusedConvRewriterTest : public HloTestBase {
|
||||
protected:
|
||||
string GetOptimizedHlo(absl::string_view hlo_string) {
|
||||
return backend()
|
||||
@ -66,7 +66,7 @@ class CudnnFusedConvolutionRewriterTest : public HloTestBase {
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(CudnnFusedConvolutionRewriterTest, TestConvOnly) {
|
||||
TEST_F(CudnnFusedConvRewriterTest, TestConvOnly) {
|
||||
// max(0, conv(x, w));
|
||||
TestMatchWithAllTypes(R"(
|
||||
HloModule Test
|
||||
@ -83,7 +83,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest, TestConvOnly) {
|
||||
})");
|
||||
}
|
||||
|
||||
TEST_F(CudnnFusedConvolutionRewriterTest, TestBias) {
|
||||
TEST_F(CudnnFusedConvRewriterTest, TestBias) {
|
||||
// max(0, conv(x, w) + bias);
|
||||
TestMatchWithAllTypes(R"(
|
||||
HloModule Test
|
||||
@ -103,7 +103,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest, TestBias) {
|
||||
})");
|
||||
}
|
||||
|
||||
TEST_F(CudnnFusedConvolutionRewriterTest, TestSideInputOnly) {
|
||||
TEST_F(CudnnFusedConvRewriterTest, TestSideInputOnly) {
|
||||
// max(0, conv(x, w) + side_input);
|
||||
TestMatchWithAllTypes(R"(
|
||||
HloModule Test
|
||||
@ -122,7 +122,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest, TestSideInputOnly) {
|
||||
})");
|
||||
}
|
||||
|
||||
TEST_F(CudnnFusedConvolutionRewriterTest, TestBiasAndSideInput) {
|
||||
TEST_F(CudnnFusedConvRewriterTest, TestBiasAndSideInput) {
|
||||
// max(0, conv(x, w) + side_input + bias);
|
||||
TestMatchWithAllTypes(R"(
|
||||
HloModule Test
|
||||
@ -144,7 +144,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest, TestBiasAndSideInput) {
|
||||
})");
|
||||
}
|
||||
|
||||
TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConv) {
|
||||
TEST_F(CudnnFusedConvRewriterTest, TestScaledConv) {
|
||||
// max(0, 0.999994934 * conv(x, w));
|
||||
TestMatchWithAllTypes(R"(
|
||||
HloModule Test
|
||||
@ -164,7 +164,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConv) {
|
||||
})");
|
||||
}
|
||||
|
||||
TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndSideInput) {
|
||||
TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndSideInput) {
|
||||
// max(0, conv(x, w) + 0.899994934 * side_input);
|
||||
TestMatchWithAllTypes(R"(
|
||||
HloModule Test
|
||||
@ -186,7 +186,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndSideInput) {
|
||||
})");
|
||||
}
|
||||
|
||||
TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndScaledSideInput) {
|
||||
TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInput) {
|
||||
// max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input);
|
||||
TestMatchWithAllTypes(R"(
|
||||
HloModule Test
|
||||
@ -211,8 +211,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndScaledSideInput) {
|
||||
})");
|
||||
}
|
||||
|
||||
TEST_F(CudnnFusedConvolutionRewriterTest,
|
||||
TestScaledConvAndScaledSideInputWithBias) {
|
||||
TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndScaledSideInputWithBias) {
|
||||
// max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias);
|
||||
TestMatchWithAllTypes(R"(
|
||||
HloModule Test
|
||||
@ -240,7 +239,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest,
|
||||
})");
|
||||
}
|
||||
|
||||
TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchMaxZeroOnly) {
|
||||
TEST_F(CudnnFusedConvRewriterTest, TestMatchMaxZeroOnly) {
|
||||
// max(0.1, conv(x, w)) shouldn't match.
|
||||
TestNotMatchWithAllTypes(R"(
|
||||
HloModule Test
|
||||
@ -257,7 +256,7 @@ TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchMaxZeroOnly) {
|
||||
})");
|
||||
}
|
||||
|
||||
TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchBroadcastedBiasOnly) {
|
||||
TEST_F(CudnnFusedConvRewriterTest, TestMatchBroadcastedBiasOnly) {
|
||||
// max(0, conv(x, w) + side_input1 + side_input2) shouldn't match.
|
||||
TestNotMatchWithAllTypes(R"(
|
||||
HloModule Test
|
Loading…
x
Reference in New Issue
Block a user