Merge pull request #31489 from ROCmSoftwarePlatform:google_upstream_xla_amdgpu_frontend_impl

PiperOrigin-RevId: 264256263
This commit is contained in:
TensorFlower Gardener 2019-08-19 15:57:26 -07:00
commit d52eda989e
5 changed files with 295 additions and 42 deletions

View File

@ -1133,8 +1133,7 @@ cc_library(
cc_library(
name = "amdgpu_compiler",
srcs = [
# TODO(whchung@gmail.com): Enable in the subsequent PR.
# "amdgpu_compiler_registration.cc",
"amdgpu_compiler_registration.cc",
],
deps = [
":amdgpu_compiler_impl",
@ -1145,18 +1144,32 @@ cc_library(
cc_library(
name = "amdgpu_compiler_impl",
srcs = [
# TODO(whchung@gmail.com) : enable in the subsequent PR.
#"amdgpu_compiler.cc",
"amdgpu_compiler.cc",
],
hdrs = [
# TODO(whchung@gmail.com): enable in the subsequent PR.
#"amdgpu_compiler.h"
"amdgpu_compiler.h",
],
deps = [
# TODO(whchung@gmail.com): Enable these after pending PRs get merged.
#":gpu_compiler_impl",
#":miopen_conv_algorithm_picker",
#"//tensorflow/core:rocm_rocdl_path",
":cudnn_conv_padding_legalization",
":cudnn_conv_rewriter",
":gpu_compiler",
":gpu_layout_assignment",
":target_constants",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service:algebraic_simplifier",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_constant_folding",
"//tensorflow/compiler/xla/service:hlo_cse",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service:tuple_simplifier",
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/platform:rocm_rocdl_path",
],
)

View File

@ -0,0 +1,156 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h"
// TODO(whchung@gmail.com): Add gpu_conv_algorithm_picker after its PR merged.
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
#include "tensorflow/compiler/xla/service/hlo_cse.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/core/platform/rocm_rocdl_path.h"
namespace xla {
namespace gpu {
namespace {
// Returns the directory containing ROCm-Device-Libs files. This function is
// called in AMDGPUCompiler's constructor, so can't return an error. But
// AMDGPUCompiler::Compile will return an error when the wanted rocdl file
// doesn't exist in the folder this function returns.
string GetROCDLDir(const HloModuleConfig& config) {
std::vector<string> potential_rocdl_dirs;
const string datadir = config.debug_options().xla_gpu_cuda_data_dir();
if (!datadir.empty()) {
potential_rocdl_dirs.push_back(datadir);
}
potential_rocdl_dirs.push_back(tensorflow::RocdlRoot());
// Tries all potential ROCDL directories in the order they are inserted.
// Returns the first directory that exists in the file system.
for (const string& potential_rocdl_dir : potential_rocdl_dirs) {
if (tensorflow::Env::Default()->IsDirectory(potential_rocdl_dir).ok()) {
VLOG(2) << "Found ROCm-Device-Libs dir " << potential_rocdl_dir;
return potential_rocdl_dir;
}
VLOG(2) << "Unable to find potential ROCm-Device-Libs dir "
<< potential_rocdl_dir;
}
// Last resort: maybe in the current folder.
return ".";
}
} // namespace
Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
HloModule* hlo_module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) {
// Convert convolutions into CustomCalls to MIOpen, then canonicalize them
// (PadInsertion).
HloPassPipeline pipeline("conv_canonicalization");
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pipeline.AddPass<CudnnConvRewriter>();
pipeline.AddPass<CudnnConvPaddingLegalization>();
pipeline.AddPass<HloConstantFolding>();
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
return Status::OK();
}
Status AMDGPUCompiler::OptimizeHloPostLayoutAssignment(
HloModule* hlo_module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) {
HloPassPipeline pipeline("post-layout_assignment");
pipeline.AddInvariantChecker<HloVerifier>(
/*layout_sensitive=*/true,
/*allow_mixed_precision=*/false,
LayoutAssignment::InstructionCanChangeLayout);
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
AlgebraicSimplifierOptions options;
options.set_is_layout_sensitive(true);
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
// TODO(whchung@gmail.com): Add gpu_conv_algorithm_picker after its PR merged.
// Clean up new_tuple described above.
pipeline.AddPass<TupleSimplifier>();
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
return Status::OK();
}
AMDGPUCompiler::AMDGPUCompiler()
: GpuCompiler(stream_executor::rocm::kROCmPlatformId, amdgpu::kTargetTriple,
amdgpu::kDataLayout) {}
GpuVersion AMDGPUCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) {
int isa_version = 0;
if (!stream_exec->GetDeviceDescription().rocm_amdgpu_isa_version(
&isa_version)) {
LOG(WARNING)
<< "Couldn't get AMDGPU ISA version for device; assuming gfx803.";
isa_version = 803;
}
return isa_version;
}
StatusOr<std::pair<std::string, std::vector<uint8>>>
AMDGPUCompiler::CompileTargetBinary(const HloModule* module,
llvm::Module* llvm_module,
GpuVersion gpu_version,
se::StreamExecutor* stream_exec) {
if (rocdl_dir_.empty()) {
// Compute rocdl_dir_ just once and cache it in this member.
rocdl_dir_ = GetROCDLDir(module->config());
}
std::vector<uint8> hsaco;
{
XLA_SCOPED_LOGGING_TIMER(
"AMDGPUCompiler::CompileTargetBinary - CompileToHsaco");
TF_ASSIGN_OR_RETURN(hsaco,
amdgpu::CompileToHsaco(llvm_module, gpu_version,
module->config(), rocdl_dir_));
}
llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false);
if (user_post_optimization_hook_) {
user_post_optimization_hook_(*llvm_module);
}
return std::pair<std::string, std::vector<uint8>>("", std::move(hsaco));
}
} // namespace gpu
} // namespace xla

View File

@ -0,0 +1,60 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AMDGPU_COMPILER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AMDGPU_COMPILER_H_
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
namespace gpu {
// AMDGPUCompiler generates efficient GPU executables for AMDGPU target.
class AMDGPUCompiler : public GpuCompiler {
public:
AMDGPUCompiler();
~AMDGPUCompiler() override {}
Status OptimizeHloConvolutionCanonicalization(
HloModule* hlo_module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override;
Status OptimizeHloPostLayoutAssignment(
HloModule* hlo_module, se::StreamExecutor* stream_exec,
se::DeviceMemoryAllocator* device_allocator) override;
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override;
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
const HloModule* hlo_module, llvm::Module* llvm_module,
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
private:
// The parent directory of ROCm-Device-Libs IR libraries.
string rocdl_dir_;
TF_DISALLOW_COPY_AND_ASSIGN(AMDGPUCompiler);
};
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AMDGPU_COMPILER_H_

View File

@ -0,0 +1,24 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h"
static bool InitModule() {
xla::Compiler::RegisterCompilerFactory(
stream_executor::rocm::kROCmPlatformId,
[]() { return absl::make_unique<xla::gpu::AMDGPUCompiler>(); });
return true;
}
static bool module_initialized = InitModule();

View File

@ -24,9 +24,9 @@ tensorflow/third_party/clang_toolchain/BUILD
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
tensorflow/third_party/clang_toolchain/download_clang.bzl
tensorflow/third_party/codegen.BUILD
tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/common.bzl
tensorflow/third_party/cub.BUILD
tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/curl.BUILD
tensorflow/third_party/cython.BUILD
tensorflow/third_party/eigen3/Eigen/Cholesky
@ -43,49 +43,49 @@ tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
tensorflow/third_party/double_conversion.BUILD
tensorflow/third_party/eigen.BUILD
tensorflow/third_party/enum34.BUILD
tensorflow/third_party/farmhash.BUILD
tensorflow/third_party/fft2d/BUILD
tensorflow/third_party/fft2d/LICENSE
tensorflow/third_party/fft2d/fft.h
tensorflow/third_party/fft2d/fft2d.BUILD
tensorflow/third_party/fft2d/fft2d.h
tensorflow/third_party/farmhash.BUILD
tensorflow/third_party/functools32.BUILD
tensorflow/third_party/gast.BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/git_configure.bzl
tensorflow/third_party/gif.BUILD
tensorflow/third_party/gpus/crosstool/BUILD.tpl
tensorflow/third_party/gpus/crosstool/BUILD
tensorflow/third_party/gpus/crosstool/LICENSE
tensorflow/third_party/gpus/crosstool/BUILD.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
tensorflow/third_party/gpus/BUILD
tensorflow/third_party/gpus/cuda/BUILD.tpl
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
tensorflow/third_party/gpus/cuda/BUILD
tensorflow/third_party/gpus/cuda/BUILD.tpl
tensorflow/third_party/gpus/cuda/LICENSE
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/rocm/BUILD
tensorflow/third_party/gpus/rocm/BUILD.tpl
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
tensorflow/third_party/gpus/find_cuda_config.py
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/rocm_configure.bzl
tensorflow/third_party/googleapis.BUILD
tensorflow/third_party/grpc/BUILD
@ -97,14 +97,14 @@ tensorflow/third_party/llvm/BUILD
tensorflow/third_party/llvm/expand_cmake_vars.py
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
tensorflow/third_party/llvm/llvm.bzl
tensorflow/third_party/linenoise.BUILD
tensorflow/third_party/libxsmm.BUILD
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/linenoise.BUILD
tensorflow/third_party/mkl/BUILD
tensorflow/third_party/mkl/LICENSE
tensorflow/third_party/mkl/MKL_LICENSE
tensorflow/third_party/mkl/build_defs.bzl
tensorflow/third_party/mkl/mkl.BUILD
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/mkl_dnn/LICENSE
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
tensorflow/third_party/mpi/.gitignore
@ -113,28 +113,28 @@ tensorflow/third_party/mpi_collectives/BUILD
tensorflow/third_party/nanopb.BUILD
tensorflow/third_party/nccl/BUILD
tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/archive.patch
tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/build_defs.bzl.tpl
tensorflow/third_party/nccl/archive.patch
tensorflow/third_party/nccl/system.BUILD.tpl
tensorflow/third_party/nccl/nccl_configure.bzl
tensorflow/third_party/nccl/system.BUILD.tpl
tensorflow/third_party/ngraph/BUILD
tensorflow/third_party/ngraph/LICENSE
tensorflow/third_party/ngraph/NGRAPH_LICENSE
tensorflow/third_party/ngraph/ngraph.BUILD
tensorflow/third_party/ngraph/build_defs.bzl
tensorflow/third_party/ngraph/ngraph.BUILD
tensorflow/third_party/ngraph/ngraph_tf.BUILD
tensorflow/third_party/ngraph/nlohmann_json.BUILD
tensorflow/third_party/ngraph/tbb.BUILD
tensorflow/third_party/opt_einsum.BUILD
tensorflow/third_party/pcre.BUILD
tensorflow/third_party/png_fix_rpi.patch
tensorflow/third_party/png.BUILD
tensorflow/third_party/protobuf/BUILD
tensorflow/third_party/png_fix_rpi.patch
tensorflow/third_party/pprof.BUILD
tensorflow/third_party/py/numpy/BUILD
tensorflow/third_party/py/BUILD
tensorflow/third_party/py/BUILD.tpl
tensorflow/third_party/py/BUILD
tensorflow/third_party/py/python_configure.bzl
tensorflow/third_party/python_runtime/BUILD
tensorflow/third_party/pybind11.BUILD
@ -160,14 +160,14 @@ tensorflow/third_party/systemlibs/gif.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
tensorflow/third_party/systemlibs/googleapis.BUILD
tensorflow/third_party/systemlibs/grpc.BUILD
tensorflow/third_party/systemlibs/jsoncpp.BUILD
tensorflow/third_party/systemlibs/lmdb.BUILD
tensorflow/third_party/systemlibs/grpc.BUILD
tensorflow/third_party/systemlibs/nsync.BUILD
tensorflow/third_party/systemlibs/opt_einsum.BUILD
tensorflow/third_party/systemlibs/lmdb.BUILD
tensorflow/third_party/systemlibs/pcre.BUILD
tensorflow/third_party/systemlibs/png.BUILD
tensorflow/third_party/systemlibs/opt_einsum.BUILD
tensorflow/third_party/systemlibs/protobuf.BUILD
tensorflow/third_party/systemlibs/png.BUILD
tensorflow/third_party/systemlibs/protobuf.bzl
tensorflow/third_party/systemlibs/re2.BUILD
tensorflow/third_party/systemlibs/six.BUILD
@ -179,8 +179,8 @@ tensorflow/third_party/systemlibs/termcolor.BUILD
tensorflow/third_party/systemlibs/zlib.BUILD
tensorflow/third_party/tensorrt/BUILD
tensorflow/third_party/tensorrt/BUILD.tpl
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
tensorflow/third_party/tensorrt/LICENSE
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
tensorflow/third_party/tensorrt/tensorrt/include/tensorrt_config.h.tpl
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
tensorflow/third_party/termcolor.BUILD
@ -188,13 +188,13 @@ tensorflow/third_party/tflite_mobilenet.BUILD
tensorflow/third_party/tflite_mobilenet_float.BUILD
tensorflow/third_party/tflite_mobilenet_quant.BUILD
tensorflow/third_party/toolchains/clang6/BUILD
tensorflow/third_party/toolchains/clang6/README.md
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
tensorflow/third_party/toolchains/clang6/README.md
tensorflow/third_party/toolchains/clang6/clang.BUILD
tensorflow/third_party/toolchains/clang6/repo.bzl
tensorflow/third_party/toolchains/BUILD
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
tensorflow/third_party/toolchains/cpus/arm/BUILD
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
tensorflow/third_party/toolchains/cpus/arm/cc_config.bzl.tpl
tensorflow/third_party/toolchains/cpus/py/BUILD
tensorflow/third_party/toolchains/cpus/py3/BUILD
@ -225,8 +225,8 @@ tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_too
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl
@ -240,22 +240,22 @@ tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3_opt/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
tensorflow/third_party/toolchains/remote/BUILD
tensorflow/third_party/toolchains/remote/BUILD.tpl
tensorflow/third_party/toolchains/remote/configure.bzl
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
tensorflow/third_party/toolchains/remote/BUILD.tpl
tensorflow/third_party/tflite_smartreply.BUILD
tensorflow/third_party/tflite_ovic_testdata.BUILD
tensorflow/third_party/zlib.BUILD
tensorflow/third_party/tflite_smartreply.BUILD
tensorflow/third_party/wrapt.BUILD
tensorflow/third_party/zlib.BUILD
tensorflow/tools/ci_build/remote/BUILD
tensorflow/tools/def_file_filter/BUILD
tensorflow/tools/def_file_filter/BUILD.tpl
@ -265,8 +265,8 @@ tensorflow/tools/lib_package/BUILD
tensorflow/tools/lib_package/LibTensorFlowTest.java
tensorflow/tools/lib_package/README.md
tensorflow/tools/lib_package/concat_licenses.sh
tensorflow/tools/lib_package/libtensorflow_test.c
tensorflow/tools/lib_package/libtensorflow_java_test.sh
tensorflow/tools/lib_package/libtensorflow_test.c
tensorflow/tools/lib_package/libtensorflow_test.sh
tensorflow/tools/pip_package/BUILD
tensorflow/tools/pip_package/MANIFEST.in
@ -275,8 +275,8 @@ tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/check_load_py_test.py
tensorflow/tools/pip_package/pip_smoke_test.py
tensorflow/tools/pip_package/setup.py
tensorflow/tools/pip_package/simple_console_for_windows.py
tensorflow/tools/pip_package/simple_console.py
tensorflow/tools/pip_package/simple_console_for_windows.py
tensorflow/virtual_root_template_v1.__init__.py
tensorflow/virtual_root_template_v2.__init__.py
llvm/llvm/projects/google_mlir/WORKSPACE