Make the constant in hlo_algorithm_blacklist a more primitive type.
msvc does not like string_view for constant types. PiperOrigin-RevId: 293209567 Change-Id: Ie4f171a76816f26d7ecfc2189106e48df59d94ff
This commit is contained in:
parent
9bdcb357b5
commit
be4b1b795e
@ -1636,7 +1636,6 @@ cc_library(
|
|||||||
"//tensorflow/core:autotuning_proto_cc",
|
"//tensorflow/core:autotuning_proto_cc",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -15,17 +15,16 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h"
|
#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/string_view.h"
|
|
||||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
// MSVC requires the extra const. Without, it reports an
|
constexpr char kDefaultBlacklist[] = R"pb(
|
||||||
// "error C2131: expression did not evaluate to a constant".
|
|
||||||
constexpr const absl::string_view kDefaultBlacklist = R"pb(
|
|
||||||
entries {
|
entries {
|
||||||
hlo: "(f32[4,32,32,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[4,32,32,32]{2,1,3,0}, f32[5,5,32,32]{1,0,2,3}), window={size=5x5 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, custom_call_target=\"__cudnn$convForward\", backend_config=\"{conv_result_scale:1}\""
|
hlo: "(f32[4,32,32,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[4,32,32,32]{2,1,3,0}, f32[5,5,32,32]{1,0,2,3}), window={size=5x5 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, custom_call_target=\"__cudnn$convForward\", backend_config=\"{conv_result_scale:1}\""
|
||||||
cc { major: 7 }
|
cc { major: 7 }
|
||||||
@ -45,8 +44,8 @@ constexpr const absl::string_view kDefaultBlacklist = R"pb(
|
|||||||
absl::Span<const stream_executor::dnn::AlgorithmDesc>
|
absl::Span<const stream_executor::dnn::AlgorithmDesc>
|
||||||
GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc,
|
GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc,
|
||||||
tensorflow::CudnnVersion cudnn_version,
|
tensorflow::CudnnVersion cudnn_version,
|
||||||
absl::string_view blas_version,
|
const std::string& blas_version,
|
||||||
absl::string_view hlo) {
|
const std::string& hlo) {
|
||||||
// Key is the tuple of canonicalized hlo, compute capability major/minor,
|
// Key is the tuple of canonicalized hlo, compute capability major/minor,
|
||||||
// cudnn version major/minor/patch, blas version.
|
// cudnn version major/minor/patch, blas version.
|
||||||
using MapType = absl::flat_hash_map<
|
using MapType = absl::flat_hash_map<
|
||||||
@ -79,8 +78,8 @@ GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc,
|
|||||||
}();
|
}();
|
||||||
|
|
||||||
auto iter = blacklist->find(std::make_tuple(
|
auto iter = blacklist->find(std::make_tuple(
|
||||||
std::string(hlo), cc.major(), cc.minor(), cudnn_version.major(),
|
hlo, cc.major(), cc.minor(), cudnn_version.major(), cudnn_version.minor(),
|
||||||
cudnn_version.minor(), cudnn_version.patch(), std::string(blas_version)));
|
cudnn_version.patch(), std::string(blas_version)));
|
||||||
if (iter != blacklist->end()) {
|
if (iter != blacklist->end()) {
|
||||||
return iter->second;
|
return iter->second;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,7 +18,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
|
||||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||||
|
|
||||||
@ -28,8 +27,8 @@ namespace gpu {
|
|||||||
absl::Span<const stream_executor::dnn::AlgorithmDesc>
|
absl::Span<const stream_executor::dnn::AlgorithmDesc>
|
||||||
GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc,
|
GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc,
|
||||||
tensorflow::CudnnVersion cudnn_version,
|
tensorflow::CudnnVersion cudnn_version,
|
||||||
absl::string_view blas_version,
|
const std::string& blas_version,
|
||||||
absl::string_view hlo);
|
const std::string& hlo);
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user