Make the constant in hlo_algorithm_blacklist a more primitive type.

msvc does not like string_view for constant types.

PiperOrigin-RevId: 293209567
Change-Id: Ie4f171a76816f26d7ecfc2189106e48df59d94ff
This commit is contained in:
Gunhan Gulsoy 2020-02-04 13:00:01 -08:00 committed by TensorFlower Gardener
parent 9bdcb357b5
commit be4b1b795e
3 changed files with 9 additions and 12 deletions

View File

@ -1636,7 +1636,6 @@ cc_library(
"//tensorflow/core:autotuning_proto_cc", "//tensorflow/core:autotuning_proto_cc",
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
], ],
) )

View File

@ -15,17 +15,16 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h" #include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h"
#include <string>
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h" #include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h"
namespace xla { namespace xla {
namespace gpu { namespace gpu {
// MSVC requires the extra const. Without, it reports an constexpr char kDefaultBlacklist[] = R"pb(
// "error C2131: expression did not evaluate to a constant".
constexpr const absl::string_view kDefaultBlacklist = R"pb(
entries { entries {
hlo: "(f32[4,32,32,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[4,32,32,32]{2,1,3,0}, f32[5,5,32,32]{1,0,2,3}), window={size=5x5 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, custom_call_target=\"__cudnn$convForward\", backend_config=\"{conv_result_scale:1}\"" hlo: "(f32[4,32,32,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[4,32,32,32]{2,1,3,0}, f32[5,5,32,32]{1,0,2,3}), window={size=5x5 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, custom_call_target=\"__cudnn$convForward\", backend_config=\"{conv_result_scale:1}\""
cc { major: 7 } cc { major: 7 }
@ -45,8 +44,8 @@ constexpr const absl::string_view kDefaultBlacklist = R"pb(
absl::Span<const stream_executor::dnn::AlgorithmDesc> absl::Span<const stream_executor::dnn::AlgorithmDesc>
GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc, GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc,
tensorflow::CudnnVersion cudnn_version, tensorflow::CudnnVersion cudnn_version,
absl::string_view blas_version, const std::string& blas_version,
absl::string_view hlo) { const std::string& hlo) {
// Key is the tuple of canonicalized hlo, compute capability major/minor, // Key is the tuple of canonicalized hlo, compute capability major/minor,
// cudnn version major/minor/patch, blas version. // cudnn version major/minor/patch, blas version.
using MapType = absl::flat_hash_map< using MapType = absl::flat_hash_map<
@ -79,8 +78,8 @@ GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc,
}(); }();
auto iter = blacklist->find(std::make_tuple( auto iter = blacklist->find(std::make_tuple(
std::string(hlo), cc.major(), cc.minor(), cudnn_version.major(), hlo, cc.major(), cc.minor(), cudnn_version.major(), cudnn_version.minor(),
cudnn_version.minor(), cudnn_version.patch(), std::string(blas_version))); cudnn_version.patch(), std::string(blas_version)));
if (iter != blacklist->end()) { if (iter != blacklist->end()) {
return iter->second; return iter->second;
} }

View File

@ -18,7 +18,6 @@ limitations under the License.
#include <vector> #include <vector>
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/protobuf/autotuning.pb.h" #include "tensorflow/core/protobuf/autotuning.pb.h"
@ -28,8 +27,8 @@ namespace gpu {
absl::Span<const stream_executor::dnn::AlgorithmDesc> absl::Span<const stream_executor::dnn::AlgorithmDesc>
GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc, GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc,
tensorflow::CudnnVersion cudnn_version, tensorflow::CudnnVersion cudnn_version,
absl::string_view blas_version, const std::string& blas_version,
absl::string_view hlo); const std::string& hlo);
} // namespace gpu } // namespace gpu
} // namespace xla } // namespace xla