Make the constant in hlo_algorithm_blacklist a more primitive type.
msvc does not like string_view for constant types. PiperOrigin-RevId: 293209567 Change-Id: Ie4f171a76816f26d7ecfc2189106e48df59d94ff
This commit is contained in:
parent
9bdcb357b5
commit
be4b1b795e
@ -1636,7 +1636,6 @@ cc_library(
|
||||
"//tensorflow/core:autotuning_proto_cc",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -15,17 +15,16 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
// MSVC requires the extra const. Without, it reports an
|
||||
// "error C2131: expression did not evaluate to a constant".
|
||||
constexpr const absl::string_view kDefaultBlacklist = R"pb(
|
||||
constexpr char kDefaultBlacklist[] = R"pb(
|
||||
entries {
|
||||
hlo: "(f32[4,32,32,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[4,32,32,32]{2,1,3,0}, f32[5,5,32,32]{1,0,2,3}), window={size=5x5 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, custom_call_target=\"__cudnn$convForward\", backend_config=\"{conv_result_scale:1}\""
|
||||
cc { major: 7 }
|
||||
@ -45,8 +44,8 @@ constexpr const absl::string_view kDefaultBlacklist = R"pb(
|
||||
absl::Span<const stream_executor::dnn::AlgorithmDesc>
|
||||
GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc,
|
||||
tensorflow::CudnnVersion cudnn_version,
|
||||
absl::string_view blas_version,
|
||||
absl::string_view hlo) {
|
||||
const std::string& blas_version,
|
||||
const std::string& hlo) {
|
||||
// Key is the tuple of canonicalized hlo, compute capability major/minor,
|
||||
// cudnn version major/minor/patch, blas version.
|
||||
using MapType = absl::flat_hash_map<
|
||||
@ -79,8 +78,8 @@ GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc,
|
||||
}();
|
||||
|
||||
auto iter = blacklist->find(std::make_tuple(
|
||||
std::string(hlo), cc.major(), cc.minor(), cudnn_version.major(),
|
||||
cudnn_version.minor(), cudnn_version.patch(), std::string(blas_version)));
|
||||
hlo, cc.major(), cc.minor(), cudnn_version.major(), cudnn_version.minor(),
|
||||
cudnn_version.patch(), std::string(blas_version)));
|
||||
if (iter != blacklist->end()) {
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
@ -18,7 +18,6 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
|
||||
@ -28,8 +27,8 @@ namespace gpu {
|
||||
absl::Span<const stream_executor::dnn::AlgorithmDesc>
|
||||
GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc,
|
||||
tensorflow::CudnnVersion cudnn_version,
|
||||
absl::string_view blas_version,
|
||||
absl::string_view hlo);
|
||||
const std::string& blas_version,
|
||||
const std::string& hlo);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user