From b6171a4eb3d308b1794b75566c5e19cd275c1a2c Mon Sep 17 00:00:00 2001 From: jerryyin Date: Tue, 1 Oct 2019 18:38:30 +0000 Subject: [PATCH] Addressing review feedbacks --- .../compiler/xla/service/gpu/gpu_conv_algorithm_picker.h | 6 +++--- tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h index 7b6ca6a8e2c..dddbe2ddfdc 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_ #include "absl/time/time.h" #include "absl/types/optional.h" @@ -67,4 +67,4 @@ class GpuConvAlgorithmPicker : public HloModulePass { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc index 261d43d5938..07b6c9108ae 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc @@ -229,7 +229,8 @@ Status RunGpuConvImpl(const GpuConvParams& params, // first call we need to ensure that the AlgorithmConfig::algorithm is // empty. For all subsequent calls, we should use the value retrieved from // the backend_config - if ((options.algo_override.has_value()) && + if ((stream->parent()->platform_kind() == se::PlatformKind::kROCm) && + (options.algo_override.has_value()) && (*options.algo_override == se::dnn::AlgorithmDesc())) { algorithm = AlgorithmConfig(); } else if (options.algo_override.has_value()) {