diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 47be6aa7435..ec8689e983f 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -145,6 +145,7 @@ load( "if_dynamic_kernels", "if_static", "tf_cuda_tests_tags", + "tf_gpu_tests_tags", ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") @@ -4524,6 +4525,20 @@ tf_cuda_cc_test( ], ) +tf_cc_test_gpu( + name = "rocm_rocdl_path_test", + size = "small", + srcs = ["//tensorflow/core/platform:rocm_rocdl_path_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_gpu_tests_tags(), + deps = [ + ":lib", + ":test", + ":test_main", + "//tensorflow/core/platform:rocm_rocdl_path", + ], +) + tf_cuda_only_cc_test( name = "util_gpu_kernel_helper_test", srcs = [ diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index 4ab8f755f44..c50fc306097 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -17,6 +17,8 @@ load( "tf_additional_minimal_lib_srcs", "tf_additional_monitoring_srcs", "tf_additional_proto_hdrs", + "tf_additional_rocdl_deps", + "tf_additional_rocdl_srcs", "tf_additional_test_srcs", "tf_env_time_srcs", "tf_logging_absl_deps", @@ -134,6 +136,16 @@ cc_library( hdrs = ["macros.h"], ) +cc_library( + name = "rocm_rocdl_path", + srcs = ["rocm_rocdl_path.cc"] + tf_additional_rocdl_srcs(), + hdrs = ["rocm_rocdl_path.h"], + deps = [ + ":types", + "//tensorflow/core:lib", + ] + tf_additional_rocdl_deps(), +) + cc_library( name = "platform", hdrs = ["platform.h"], @@ -230,6 +242,8 @@ filegroup( "**/logger.cc", "**/logging.cc", "**/human_readable_json.cc", + "**/rocm.h", + "**/rocm_rocdl_path.cc", "abi.cc", "cpu_info.cc", "platform_strings.cc", @@ -268,6 +282,7 @@ filegroup( # :platform_base, a common dependency for downstream targets. "**/env_time.cc", "**/logging.cc", + "**/rocm_rocdl_path.*", "default/test_benchmark.*", "cuda.h", "rocm.h", @@ -320,6 +335,7 @@ filegroup( "**/logger.cc", "**/logging.cc", "**/human_readable_json.cc", + "**/rocm_rocdl_path.cc", "abi.cc", "cpu_info.cc", "platform_strings.cc", diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 6404fde5504..417f37f3694 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -624,6 +624,12 @@ def tf_additional_libdevice_deps(): def tf_additional_libdevice_srcs(): return ["default/cuda_libdevice_path.cc"] +def tf_additional_rocdl_deps(): + return ["@local_config_rocm//rocm:rocm_headers"] + +def tf_additional_rocdl_srcs(): + return ["default/rocm_rocdl_path.cc"] + def tf_additional_test_deps(): return [] diff --git a/tensorflow/core/platform/default/rocm_rocdl_path.cc b/tensorflow/core/platform/default/rocm_rocdl_path.cc new file mode 100644 index 00000000000..14196044656 --- /dev/null +++ b/tensorflow/core/platform/default/rocm_rocdl_path.cc @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/rocm_rocdl_path.h" + +#include <stdlib.h> + +#if !defined(PLATFORM_GOOGLE) && TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" +#endif +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +string RocmRoot() { +#if TENSORFLOW_USE_ROCM + VLOG(3) << "ROCM root = " << TF_ROCM_TOOLKIT_PATH; + return TF_ROCM_TOOLKIT_PATH; +#else + return ""; +#endif +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/rocm_rocdl_path.cc b/tensorflow/core/platform/rocm_rocdl_path.cc new file mode 100644 index 00000000000..bf5b2bf722c --- /dev/null +++ b/tensorflow/core/platform/rocm_rocdl_path.cc @@ -0,0 +1,26 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/rocm_rocdl_path.h" + +#include "tensorflow/core/lib/io/path.h" + +namespace tensorflow { + +string RocdlRoot() { + return tensorflow::io::JoinPath(tensorflow::RocmRoot(), "hcc/lib"); +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/rocm_rocdl_path.h b/tensorflow/core/platform/rocm_rocdl_path.h new file mode 100644 index 00000000000..e83ef5b8235 --- /dev/null +++ b/tensorflow/core/platform/rocm_rocdl_path.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_ROCM_ROCDL_PATH_H_ +#define TENSORFLOW_CORE_PLATFORM_ROCM_ROCDL_PATH_H_ + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Returns the root directory of the ROCM SDK, which contains sub-folders such +// as bin, lib, and rocdl. +string RocmRoot(); + +// Returns the directory that contains ROCm-Device-Libs files in the ROCm SDK. +string RocdlRoot(); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_ROCM_ROCDL_PATH_H_ diff --git a/tensorflow/core/platform/rocm_rocdl_path_test.cc b/tensorflow/core/platform/rocm_rocdl_path_test.cc new file mode 100644 index 00000000000..4a4d9b89c59 --- /dev/null +++ b/tensorflow/core/platform/rocm_rocdl_path_test.cc @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/rocm_rocdl_path.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +#if TENSORFLOW_USE_ROCM +TEST(RocmRocdlPathTest, ROCDLPath) { + VLOG(2) << "ROCm-Deivce-Libs root = " << RocdlRoot(); + std::vector<string> rocdl_files; + TF_EXPECT_OK(Env::Default()->GetMatchingPaths( + io::JoinPath(RocdlRoot(), "*.amdgcn.bc"), &rocdl_files)); + EXPECT_LT(0, rocdl_files.size()); +} +#endif + +} // namespace tensorflow