Addressing review feedbacks

This commit is contained in:
jerryyin 2019-10-01 18:38:30 +00:00
parent 98e4579b39
commit b6171a4eb3
2 changed files with 5 additions and 4 deletions

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_
#include "absl/time/time.h"
#include "absl/types/optional.h"
@ -67,4 +67,4 @@ class GpuConvAlgorithmPicker : public HloModulePass {
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_

View File

@ -229,7 +229,8 @@ Status RunGpuConvImpl(const GpuConvParams& params,
// first call we need to ensure that the AlgorithmConfig::algorithm is
// empty. For all subsequent calls, we should use the value retrieved from
// the backend_config
if ((options.algo_override.has_value()) &&
if ((stream->parent()->platform_kind() == se::PlatformKind::kROCm) &&
(options.algo_override.has_value()) &&
(*options.algo_override == se::dnn::AlgorithmDesc())) {
algorithm = AlgorithmConfig();
} else if (options.algo_override.has_value()) {