Merge pull request #31489 from ROCmSoftwarePlatform:google_upstream_xla_amdgpu_frontend_impl
PiperOrigin-RevId: 264256263
This commit is contained in:
commit
d52eda989e
@ -1133,8 +1133,7 @@ cc_library(
|
||||
cc_library(
|
||||
name = "amdgpu_compiler",
|
||||
srcs = [
|
||||
# TODO(whchung@gmail.com): Enable in the subsequent PR.
|
||||
# "amdgpu_compiler_registration.cc",
|
||||
"amdgpu_compiler_registration.cc",
|
||||
],
|
||||
deps = [
|
||||
":amdgpu_compiler_impl",
|
||||
@ -1145,18 +1144,32 @@ cc_library(
|
||||
cc_library(
|
||||
name = "amdgpu_compiler_impl",
|
||||
srcs = [
|
||||
# TODO(whchung@gmail.com) : enable in the subsequent PR.
|
||||
#"amdgpu_compiler.cc",
|
||||
"amdgpu_compiler.cc",
|
||||
],
|
||||
hdrs = [
|
||||
# TODO(whchung@gmail.com): enable in the subsequent PR.
|
||||
#"amdgpu_compiler.h"
|
||||
"amdgpu_compiler.h",
|
||||
],
|
||||
deps = [
|
||||
# TODO(whchung@gmail.com): Enable these after pending PRs get merged.
|
||||
#":gpu_compiler_impl",
|
||||
#":miopen_conv_algorithm_picker",
|
||||
#"//tensorflow/core:rocm_rocdl_path",
|
||||
":cudnn_conv_padding_legalization",
|
||||
":cudnn_conv_rewriter",
|
||||
":gpu_compiler",
|
||||
":gpu_layout_assignment",
|
||||
":target_constants",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service:algebraic_simplifier",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_constant_folding",
|
||||
"//tensorflow/compiler/xla/service:hlo_cse",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/service:llvm_compiler",
|
||||
"//tensorflow/compiler/xla/service:tuple_simplifier",
|
||||
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/platform:rocm_rocdl_path",
|
||||
],
|
||||
)
|
||||
|
||||
|
156
tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc
Normal file
156
tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc
Normal file
@ -0,0 +1,156 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h"
|
||||
// TODO(whchung@gmail.com): Add gpu_conv_algorithm_picker after its PR merged.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_cse.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
|
||||
#include "tensorflow/core/platform/rocm_rocdl_path.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns the directory containing ROCm-Device-Libs files. This function is
|
||||
// called in AMDGPUCompiler's constructor, so can't return an error. But
|
||||
// AMDGPUCompiler::Compile will return an error when the wanted rocdl file
|
||||
// doesn't exist in the folder this function returns.
|
||||
string GetROCDLDir(const HloModuleConfig& config) {
|
||||
std::vector<string> potential_rocdl_dirs;
|
||||
const string datadir = config.debug_options().xla_gpu_cuda_data_dir();
|
||||
if (!datadir.empty()) {
|
||||
potential_rocdl_dirs.push_back(datadir);
|
||||
}
|
||||
potential_rocdl_dirs.push_back(tensorflow::RocdlRoot());
|
||||
|
||||
// Tries all potential ROCDL directories in the order they are inserted.
|
||||
// Returns the first directory that exists in the file system.
|
||||
for (const string& potential_rocdl_dir : potential_rocdl_dirs) {
|
||||
if (tensorflow::Env::Default()->IsDirectory(potential_rocdl_dir).ok()) {
|
||||
VLOG(2) << "Found ROCm-Device-Libs dir " << potential_rocdl_dir;
|
||||
return potential_rocdl_dir;
|
||||
}
|
||||
VLOG(2) << "Unable to find potential ROCm-Device-Libs dir "
|
||||
<< potential_rocdl_dir;
|
||||
}
|
||||
|
||||
// Last resort: maybe in the current folder.
|
||||
return ".";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
|
||||
HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
// Convert convolutions into CustomCalls to MIOpen, then canonicalize them
|
||||
// (PadInsertion).
|
||||
HloPassPipeline pipeline("conv_canonicalization");
|
||||
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/false);
|
||||
pipeline.AddPass<CudnnConvRewriter>();
|
||||
pipeline.AddPass<CudnnConvPaddingLegalization>();
|
||||
|
||||
pipeline.AddPass<HloConstantFolding>();
|
||||
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AMDGPUCompiler::OptimizeHloPostLayoutAssignment(
|
||||
HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) {
|
||||
HloPassPipeline pipeline("post-layout_assignment");
|
||||
pipeline.AddInvariantChecker<HloVerifier>(
|
||||
/*layout_sensitive=*/true,
|
||||
/*allow_mixed_precision=*/false,
|
||||
LayoutAssignment::InstructionCanChangeLayout);
|
||||
|
||||
// The LayoutAssignment pass may leave behind kCopy instructions which are
|
||||
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
|
||||
AlgebraicSimplifierOptions options;
|
||||
options.set_is_layout_sensitive(true);
|
||||
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
|
||||
|
||||
// TODO(whchung@gmail.com): Add gpu_conv_algorithm_picker after its PR merged.
|
||||
|
||||
// Clean up new_tuple described above.
|
||||
pipeline.AddPass<TupleSimplifier>();
|
||||
|
||||
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
|
||||
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AMDGPUCompiler::AMDGPUCompiler()
|
||||
: GpuCompiler(stream_executor::rocm::kROCmPlatformId, amdgpu::kTargetTriple,
|
||||
amdgpu::kDataLayout) {}
|
||||
|
||||
GpuVersion AMDGPUCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) {
|
||||
int isa_version = 0;
|
||||
if (!stream_exec->GetDeviceDescription().rocm_amdgpu_isa_version(
|
||||
&isa_version)) {
|
||||
LOG(WARNING)
|
||||
<< "Couldn't get AMDGPU ISA version for device; assuming gfx803.";
|
||||
isa_version = 803;
|
||||
}
|
||||
|
||||
return isa_version;
|
||||
}
|
||||
|
||||
StatusOr<std::pair<std::string, std::vector<uint8>>>
|
||||
AMDGPUCompiler::CompileTargetBinary(const HloModule* module,
|
||||
llvm::Module* llvm_module,
|
||||
GpuVersion gpu_version,
|
||||
se::StreamExecutor* stream_exec) {
|
||||
if (rocdl_dir_.empty()) {
|
||||
// Compute rocdl_dir_ just once and cache it in this member.
|
||||
rocdl_dir_ = GetROCDLDir(module->config());
|
||||
}
|
||||
|
||||
std::vector<uint8> hsaco;
|
||||
{
|
||||
XLA_SCOPED_LOGGING_TIMER(
|
||||
"AMDGPUCompiler::CompileTargetBinary - CompileToHsaco");
|
||||
TF_ASSIGN_OR_RETURN(hsaco,
|
||||
amdgpu::CompileToHsaco(llvm_module, gpu_version,
|
||||
module->config(), rocdl_dir_));
|
||||
}
|
||||
|
||||
llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false);
|
||||
|
||||
if (user_post_optimization_hook_) {
|
||||
user_post_optimization_hook_(*llvm_module);
|
||||
}
|
||||
|
||||
return std::pair<std::string, std::vector<uint8>>("", std::move(hsaco));
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
60
tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h
Normal file
60
tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h
Normal file
@ -0,0 +1,60 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AMDGPU_COMPILER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AMDGPU_COMPILER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
// AMDGPUCompiler generates efficient GPU executables for AMDGPU target.
|
||||
class AMDGPUCompiler : public GpuCompiler {
|
||||
public:
|
||||
AMDGPUCompiler();
|
||||
~AMDGPUCompiler() override {}
|
||||
|
||||
Status OptimizeHloConvolutionCanonicalization(
|
||||
HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
|
||||
Status OptimizeHloPostLayoutAssignment(
|
||||
HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
se::DeviceMemoryAllocator* device_allocator) override;
|
||||
|
||||
GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override;
|
||||
|
||||
StatusOr<std::pair<std::string, std::vector<uint8>>> CompileTargetBinary(
|
||||
const HloModule* hlo_module, llvm::Module* llvm_module,
|
||||
GpuVersion gpu_version, se::StreamExecutor* stream_exec) override;
|
||||
|
||||
private:
|
||||
// The parent directory of ROCm-Device-Libs IR libraries.
|
||||
string rocdl_dir_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(AMDGPUCompiler);
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AMDGPU_COMPILER_H_
|
@ -0,0 +1,24 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h"
|
||||
|
||||
static bool InitModule() {
|
||||
xla::Compiler::RegisterCompilerFactory(
|
||||
stream_executor::rocm::kROCmPlatformId,
|
||||
[]() { return absl::make_unique<xla::gpu::AMDGPUCompiler>(); });
|
||||
return true;
|
||||
}
|
||||
static bool module_initialized = InitModule();
|
@ -24,9 +24,9 @@ tensorflow/third_party/clang_toolchain/BUILD
|
||||
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
|
||||
tensorflow/third_party/clang_toolchain/download_clang.bzl
|
||||
tensorflow/third_party/codegen.BUILD
|
||||
tensorflow/third_party/com_google_absl.BUILD
|
||||
tensorflow/third_party/common.bzl
|
||||
tensorflow/third_party/cub.BUILD
|
||||
tensorflow/third_party/com_google_absl.BUILD
|
||||
tensorflow/third_party/curl.BUILD
|
||||
tensorflow/third_party/cython.BUILD
|
||||
tensorflow/third_party/eigen3/Eigen/Cholesky
|
||||
@ -43,49 +43,49 @@ tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
|
||||
tensorflow/third_party/double_conversion.BUILD
|
||||
tensorflow/third_party/eigen.BUILD
|
||||
tensorflow/third_party/enum34.BUILD
|
||||
tensorflow/third_party/farmhash.BUILD
|
||||
tensorflow/third_party/fft2d/BUILD
|
||||
tensorflow/third_party/fft2d/LICENSE
|
||||
tensorflow/third_party/fft2d/fft.h
|
||||
tensorflow/third_party/fft2d/fft2d.BUILD
|
||||
tensorflow/third_party/fft2d/fft2d.h
|
||||
tensorflow/third_party/farmhash.BUILD
|
||||
tensorflow/third_party/functools32.BUILD
|
||||
tensorflow/third_party/gast.BUILD
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/git/BUILD
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/git/git_configure.bzl
|
||||
tensorflow/third_party/gif.BUILD
|
||||
tensorflow/third_party/gpus/crosstool/BUILD.tpl
|
||||
tensorflow/third_party/gpus/crosstool/BUILD
|
||||
tensorflow/third_party/gpus/crosstool/LICENSE
|
||||
tensorflow/third_party/gpus/crosstool/BUILD.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
|
||||
tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
|
||||
tensorflow/third_party/gpus/BUILD
|
||||
tensorflow/third_party/gpus/cuda/BUILD.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD
|
||||
tensorflow/third_party/gpus/cuda/BUILD.tpl
|
||||
tensorflow/third_party/gpus/cuda/LICENSE
|
||||
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
|
||||
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
|
||||
tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/rocm/BUILD
|
||||
tensorflow/third_party/gpus/rocm/BUILD.tpl
|
||||
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
|
||||
tensorflow/third_party/gpus/find_cuda_config.py
|
||||
tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/rocm_configure.bzl
|
||||
tensorflow/third_party/googleapis.BUILD
|
||||
tensorflow/third_party/grpc/BUILD
|
||||
@ -97,14 +97,14 @@ tensorflow/third_party/llvm/BUILD
|
||||
tensorflow/third_party/llvm/expand_cmake_vars.py
|
||||
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
|
||||
tensorflow/third_party/llvm/llvm.bzl
|
||||
tensorflow/third_party/linenoise.BUILD
|
||||
tensorflow/third_party/libxsmm.BUILD
|
||||
tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/linenoise.BUILD
|
||||
tensorflow/third_party/mkl/BUILD
|
||||
tensorflow/third_party/mkl/LICENSE
|
||||
tensorflow/third_party/mkl/MKL_LICENSE
|
||||
tensorflow/third_party/mkl/build_defs.bzl
|
||||
tensorflow/third_party/mkl/mkl.BUILD
|
||||
tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/mkl_dnn/LICENSE
|
||||
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
|
||||
tensorflow/third_party/mpi/.gitignore
|
||||
@ -113,28 +113,28 @@ tensorflow/third_party/mpi_collectives/BUILD
|
||||
tensorflow/third_party/nanopb.BUILD
|
||||
tensorflow/third_party/nccl/BUILD
|
||||
tensorflow/third_party/nccl/archive.BUILD
|
||||
tensorflow/third_party/nccl/archive.patch
|
||||
tensorflow/third_party/nccl/LICENSE
|
||||
tensorflow/third_party/nccl/build_defs.bzl.tpl
|
||||
tensorflow/third_party/nccl/archive.patch
|
||||
tensorflow/third_party/nccl/system.BUILD.tpl
|
||||
tensorflow/third_party/nccl/nccl_configure.bzl
|
||||
tensorflow/third_party/nccl/system.BUILD.tpl
|
||||
tensorflow/third_party/ngraph/BUILD
|
||||
tensorflow/third_party/ngraph/LICENSE
|
||||
tensorflow/third_party/ngraph/NGRAPH_LICENSE
|
||||
tensorflow/third_party/ngraph/ngraph.BUILD
|
||||
tensorflow/third_party/ngraph/build_defs.bzl
|
||||
tensorflow/third_party/ngraph/ngraph.BUILD
|
||||
tensorflow/third_party/ngraph/ngraph_tf.BUILD
|
||||
tensorflow/third_party/ngraph/nlohmann_json.BUILD
|
||||
tensorflow/third_party/ngraph/tbb.BUILD
|
||||
tensorflow/third_party/opt_einsum.BUILD
|
||||
tensorflow/third_party/pcre.BUILD
|
||||
tensorflow/third_party/png_fix_rpi.patch
|
||||
tensorflow/third_party/png.BUILD
|
||||
tensorflow/third_party/protobuf/BUILD
|
||||
tensorflow/third_party/png_fix_rpi.patch
|
||||
tensorflow/third_party/pprof.BUILD
|
||||
tensorflow/third_party/py/numpy/BUILD
|
||||
tensorflow/third_party/py/BUILD
|
||||
tensorflow/third_party/py/BUILD.tpl
|
||||
tensorflow/third_party/py/BUILD
|
||||
tensorflow/third_party/py/python_configure.bzl
|
||||
tensorflow/third_party/python_runtime/BUILD
|
||||
tensorflow/third_party/pybind11.BUILD
|
||||
@ -160,14 +160,14 @@ tensorflow/third_party/systemlibs/gif.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
|
||||
tensorflow/third_party/systemlibs/googleapis.BUILD
|
||||
tensorflow/third_party/systemlibs/grpc.BUILD
|
||||
tensorflow/third_party/systemlibs/jsoncpp.BUILD
|
||||
tensorflow/third_party/systemlibs/lmdb.BUILD
|
||||
tensorflow/third_party/systemlibs/grpc.BUILD
|
||||
tensorflow/third_party/systemlibs/nsync.BUILD
|
||||
tensorflow/third_party/systemlibs/opt_einsum.BUILD
|
||||
tensorflow/third_party/systemlibs/lmdb.BUILD
|
||||
tensorflow/third_party/systemlibs/pcre.BUILD
|
||||
tensorflow/third_party/systemlibs/png.BUILD
|
||||
tensorflow/third_party/systemlibs/opt_einsum.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
tensorflow/third_party/systemlibs/png.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.bzl
|
||||
tensorflow/third_party/systemlibs/re2.BUILD
|
||||
tensorflow/third_party/systemlibs/six.BUILD
|
||||
@ -179,8 +179,8 @@ tensorflow/third_party/systemlibs/termcolor.BUILD
|
||||
tensorflow/third_party/systemlibs/zlib.BUILD
|
||||
tensorflow/third_party/tensorrt/BUILD
|
||||
tensorflow/third_party/tensorrt/BUILD.tpl
|
||||
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
|
||||
tensorflow/third_party/tensorrt/LICENSE
|
||||
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
|
||||
tensorflow/third_party/tensorrt/tensorrt/include/tensorrt_config.h.tpl
|
||||
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
|
||||
tensorflow/third_party/termcolor.BUILD
|
||||
@ -188,13 +188,13 @@ tensorflow/third_party/tflite_mobilenet.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_float.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_quant.BUILD
|
||||
tensorflow/third_party/toolchains/clang6/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/README.md
|
||||
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
|
||||
tensorflow/third_party/toolchains/clang6/README.md
|
||||
tensorflow/third_party/toolchains/clang6/clang.BUILD
|
||||
tensorflow/third_party/toolchains/clang6/repo.bzl
|
||||
tensorflow/third_party/toolchains/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
|
||||
tensorflow/third_party/toolchains/cpus/arm/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
|
||||
tensorflow/third_party/toolchains/cpus/arm/cc_config.bzl.tpl
|
||||
tensorflow/third_party/toolchains/cpus/py/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/py3/BUILD
|
||||
@ -225,8 +225,8 @@ tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_too
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
@ -240,22 +240,22 @@ tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3_opt/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
|
||||
tensorflow/third_party/toolchains/remote/BUILD
|
||||
tensorflow/third_party/toolchains/remote/BUILD.tpl
|
||||
tensorflow/third_party/toolchains/remote/configure.bzl
|
||||
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
|
||||
tensorflow/third_party/toolchains/remote/BUILD.tpl
|
||||
tensorflow/third_party/tflite_smartreply.BUILD
|
||||
tensorflow/third_party/tflite_ovic_testdata.BUILD
|
||||
tensorflow/third_party/zlib.BUILD
|
||||
tensorflow/third_party/tflite_smartreply.BUILD
|
||||
tensorflow/third_party/wrapt.BUILD
|
||||
tensorflow/third_party/zlib.BUILD
|
||||
tensorflow/tools/ci_build/remote/BUILD
|
||||
tensorflow/tools/def_file_filter/BUILD
|
||||
tensorflow/tools/def_file_filter/BUILD.tpl
|
||||
@ -265,8 +265,8 @@ tensorflow/tools/lib_package/BUILD
|
||||
tensorflow/tools/lib_package/LibTensorFlowTest.java
|
||||
tensorflow/tools/lib_package/README.md
|
||||
tensorflow/tools/lib_package/concat_licenses.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_test.c
|
||||
tensorflow/tools/lib_package/libtensorflow_java_test.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_test.c
|
||||
tensorflow/tools/lib_package/libtensorflow_test.sh
|
||||
tensorflow/tools/pip_package/BUILD
|
||||
tensorflow/tools/pip_package/MANIFEST.in
|
||||
@ -275,8 +275,8 @@ tensorflow/tools/pip_package/build_pip_package.sh
|
||||
tensorflow/tools/pip_package/check_load_py_test.py
|
||||
tensorflow/tools/pip_package/pip_smoke_test.py
|
||||
tensorflow/tools/pip_package/setup.py
|
||||
tensorflow/tools/pip_package/simple_console_for_windows.py
|
||||
tensorflow/tools/pip_package/simple_console.py
|
||||
tensorflow/tools/pip_package/simple_console_for_windows.py
|
||||
tensorflow/virtual_root_template_v1.__init__.py
|
||||
tensorflow/virtual_root_template_v2.__init__.py
|
||||
llvm/llvm/projects/google_mlir/WORKSPACE
|
Loading…
Reference in New Issue
Block a user