[XLA] Propagate the TF Stream to the XLA Compiler via the allocator
GPU BFC allocator XLA gets from TF should only be used on a main TF stream, otherwise data races are possible. XLA needs access to this stream during the compilation, otherwise data races are possible. PiperOrigin-RevId: 264260382
This commit is contained in:
parent
a0ee95db22
commit
8e2ca26f57
@ -63,8 +63,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
|
||||
DeviceType device_type = ctx->device_type();
|
||||
se::Platform::Id platform_id = nullptr;
|
||||
const XlaDevice::Metadata* xla_device_metadata = nullptr;
|
||||
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator;
|
||||
se::DeviceMemoryAllocator* device_allocator = nullptr;
|
||||
se::DeviceMemoryAllocator* custom_allocator = nullptr;
|
||||
|
||||
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
|
||||
platform_id = se::host::kHostPlatformId;
|
||||
@ -84,23 +83,13 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
|
||||
// (which xla_allocator above uses) as on an XlaDevice, this is a dummy
|
||||
// allocator that returns XlaTensor objects. The XlaCompiler needs a real
|
||||
// allocator to allocate real buffers.
|
||||
|
||||
platform_id = xla_device_metadata->platform()->id();
|
||||
device_allocator =
|
||||
custom_allocator =
|
||||
xla_device_metadata->client()->backend().memory_allocator();
|
||||
}
|
||||
|
||||
if (!device_allocator) {
|
||||
xla::StatusOr<se::Platform*> maybe_platform =
|
||||
se::MultiPlatformManager::PlatformWithId(platform_id);
|
||||
OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status());
|
||||
|
||||
xla_allocator = absl::make_unique<se::TfAllocatorAdapter>(
|
||||
maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({}));
|
||||
}
|
||||
|
||||
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
|
||||
std::move(xla_allocator), device_allocator);
|
||||
custom_allocator);
|
||||
}
|
||||
|
||||
// A closure describing how to run a compiled version of a TensorFlow function.
|
||||
@ -185,6 +174,33 @@ class XlaExecutableClosureStore {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
|
||||
};
|
||||
|
||||
// Return allocator from platform info if non-null, or populate and return a
|
||||
// pointer to the allocator adapter with allocator from context.
|
||||
//
|
||||
// This is necessary because for XLA devices the underlying TF allocator returns
|
||||
// dummy tensors.
|
||||
se::DeviceMemoryAllocator* GetAllocator(
|
||||
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
|
||||
OpKernelContext* ctx, const XlaPlatformInfo& platform_info) {
|
||||
if (platform_info.custom_allocator()) {
|
||||
return platform_info.custom_allocator();
|
||||
}
|
||||
if (!ctx->op_device_context()) {
|
||||
// Stream is not set for the host platform.
|
||||
se::Platform* platform =
|
||||
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
|
||||
.ValueOrDie();
|
||||
tf_allocator_adapter->emplace(platform, ctx->device()->GetAllocator({}),
|
||||
/*stream=*/nullptr);
|
||||
return &tf_allocator_adapter->value();
|
||||
}
|
||||
// platform_info.
|
||||
tf_allocator_adapter->emplace(
|
||||
ctx->op_device_context()->stream()->parent()->platform(),
|
||||
ctx->device()->GetAllocator({}), ctx->op_device_context()->stream());
|
||||
return &tf_allocator_adapter->value();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
|
||||
@ -281,6 +297,7 @@ static Status CompileToLocalExecutable(
|
||||
TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, resources, variables));
|
||||
*client = static_cast<xla::LocalClient*>(cache->client());
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
XlaCompiler::Options options;
|
||||
options.client = *client;
|
||||
if (ctx->op_device_context() != nullptr) {
|
||||
@ -292,7 +309,8 @@ static Status CompileToLocalExecutable(
|
||||
options.graph_def_version = ctx->function_library()->graph_def_version();
|
||||
options.allow_cpu_custom_calls =
|
||||
(platform_info.platform_id() == se::host::kHostPlatformId);
|
||||
options.device_allocator = platform_info.allocator();
|
||||
options.device_allocator =
|
||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info);
|
||||
if (platform_info.xla_device_metadata()) {
|
||||
options.shape_representation_fn =
|
||||
platform_info.xla_device_metadata()->shape_representation_fn();
|
||||
@ -350,8 +368,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
|
||||
VLOG(1) << "Executing XLA Computation...";
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
|
||||
XlaComputationLaunchContext launch_context(
|
||||
client, platform_info_.allocator(),
|
||||
client, allocator,
|
||||
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
|
||||
platform_info_.UseMultipleStreams());
|
||||
launch_context.PopulateInputs(ctx, kernel, variables,
|
||||
@ -361,7 +382,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
VLOG(2) << "Executing computation.";
|
||||
xla::ExecutableRunOptions run_options;
|
||||
run_options.set_stream(stream);
|
||||
run_options.set_allocator(platform_info_.allocator());
|
||||
run_options.set_allocator(allocator);
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
run_options.set_rng_seed(GetXLARandomSeed());
|
||||
Env* env = Env::Default();
|
||||
@ -528,8 +549,11 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
XlaExecutableClosure closure =
|
||||
XlaExecutableClosureStore::Global()->Consume(key);
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
|
||||
XlaComputationLaunchContext launch_context(
|
||||
closure.client(), platform_info_.allocator(),
|
||||
closure.client(), allocator,
|
||||
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
|
||||
/*use_multiple_streams=*/platform_info_.UseMultipleStreams());
|
||||
|
||||
@ -554,7 +578,7 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
xla::ExecutableRunOptions run_options;
|
||||
run_options.set_stream(stream);
|
||||
run_options.set_allocator(platform_info_.allocator());
|
||||
run_options.set_allocator(allocator);
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
run_options.set_rng_seed(GetXLARandomSeed());
|
||||
Env* env = Env::Default();
|
||||
|
@ -37,18 +37,14 @@ class XlaPlatformInfo {
|
||||
public:
|
||||
XlaPlatformInfo() : device_type_("") {}
|
||||
XlaPlatformInfo(XlaPlatformInfo&&) = default;
|
||||
explicit XlaPlatformInfo(
|
||||
const DeviceType device_type, se::Platform::Id platform_id,
|
||||
const XlaDevice::Metadata* xla_device_metadata,
|
||||
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator,
|
||||
se::DeviceMemoryAllocator* device_allocator)
|
||||
explicit XlaPlatformInfo(const DeviceType device_type,
|
||||
se::Platform::Id platform_id,
|
||||
const XlaDevice::Metadata* xla_device_metadata,
|
||||
se::DeviceMemoryAllocator* device_allocator)
|
||||
: device_type_(device_type),
|
||||
platform_id_(platform_id),
|
||||
xla_device_metadata_(xla_device_metadata),
|
||||
xla_allocator_(std::move(xla_allocator)),
|
||||
device_allocator_(device_allocator) {
|
||||
CHECK((device_allocator_ != nullptr) ^ (xla_allocator_.get() != nullptr));
|
||||
}
|
||||
device_allocator_(device_allocator) {}
|
||||
|
||||
XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
|
||||
|
||||
@ -56,9 +52,11 @@ class XlaPlatformInfo {
|
||||
return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
|
||||
}
|
||||
|
||||
se::DeviceMemoryAllocator* allocator() const {
|
||||
return device_allocator_ ? device_allocator_ : xla_allocator_.get();
|
||||
// Non-null only when run on an XLA device.
|
||||
se::DeviceMemoryAllocator* custom_allocator() const {
|
||||
return device_allocator_;
|
||||
}
|
||||
|
||||
DeviceType device_type() const { return device_type_; }
|
||||
|
||||
// This is equal to xla_device_metadata()->platform()->id() if
|
||||
@ -82,11 +80,8 @@ class XlaPlatformInfo {
|
||||
const XlaDevice::Metadata* xla_device_metadata_;
|
||||
|
||||
// If the op associated with this XlaPlatformInfo is placed on an XLA device
|
||||
// then device_allocator_ is the xla::Backend's memory allocator and
|
||||
// xla_allocator_ is null. If the op is placed on a regular CPU or GPU device
|
||||
// then device_allocator_ is null and xla_allocator_ points to an appropriate
|
||||
// se::TfAllocatorAdapter instance.
|
||||
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator_;
|
||||
// then device_allocator_ is the xla::Backend's memory allocator. If the op
|
||||
// is placed on a regular CPU or GPU device then device_allocator_ is null.
|
||||
se::DeviceMemoryAllocator* device_allocator_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
|
||||
|
@ -248,9 +248,6 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
|
||||
return InternalError("Failed to synchronize GPU for autotuning.");
|
||||
}
|
||||
|
||||
// Create a stream for us to do our work on.
|
||||
se::Stream stream{stream_exec_};
|
||||
stream.Init();
|
||||
// allocator either points to this->allocator_ or, if that's null, to a
|
||||
// se::StreamExecutorMemoryAllocator for stream_exec_.
|
||||
se::DeviceMemoryAllocator* allocator;
|
||||
@ -262,11 +259,21 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
|
||||
allocator = &*se_allocator;
|
||||
}
|
||||
|
||||
absl::optional<se::Stream> stream_opt;
|
||||
se::Stream* stream = [&] {
|
||||
if (allocator->GetStream()) {
|
||||
return allocator->GetStream();
|
||||
}
|
||||
stream_opt.emplace(stream_exec_);
|
||||
stream_opt->Init();
|
||||
return &stream_opt.value();
|
||||
}();
|
||||
|
||||
int64 rng_state = 0;
|
||||
|
||||
const auto initialize_buffer = [&stream, &result_shape,
|
||||
const auto initialize_buffer = [stream, &result_shape,
|
||||
&rng_state](DeviceMemoryBase buffer) {
|
||||
InitializeFloatBuffer(&stream, result_shape.element_type(), &rng_state,
|
||||
InitializeFloatBuffer(stream, result_shape.element_type(), &rng_state,
|
||||
buffer);
|
||||
};
|
||||
|
||||
@ -274,7 +281,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
|
||||
|
||||
// Allocate space for the input, filter, and output of the convolution.
|
||||
se::cuda::RedzoneAllocator input_output_allocator(
|
||||
&stream, allocator, PtxOptsFromConfig(hlo_module_config));
|
||||
stream, allocator, PtxOptsFromConfig(hlo_module_config));
|
||||
std::vector<se::DeviceMemoryBase> operand_buffers;
|
||||
for (const auto* operand : instr->operands()) {
|
||||
TF_ASSIGN_OR_RETURN(auto buffer,
|
||||
@ -328,7 +335,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
|
||||
}
|
||||
|
||||
se::cuda::RedzoneAllocator scratch_allocator(
|
||||
&stream, allocator, PtxOptsFromConfig(hlo_module_config));
|
||||
stream, allocator, PtxOptsFromConfig(hlo_module_config));
|
||||
se::dnn::ProfileResult profile_result;
|
||||
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
|
||||
<< instr->ToString();
|
||||
@ -339,7 +346,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
|
||||
options.algo_override = alg;
|
||||
Status launch_status =
|
||||
RunCudnnConv(instr, absl::MakeSpan(operand_buffers), result_buffer,
|
||||
&scratch_allocator, &stream, options);
|
||||
&scratch_allocator, stream, options);
|
||||
|
||||
if (!launch_status.ok()) {
|
||||
continue;
|
||||
@ -362,12 +369,12 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
|
||||
|
||||
// Check for writes to redzones.
|
||||
TF_ASSIGN_OR_RETURN(bool input_output_allocator_redzone_clear,
|
||||
CheckRedzones(input_output_allocator, &stream,
|
||||
CheckRedzones(input_output_allocator, stream,
|
||||
"input/output", instr, &result));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool scratch_allocator_redzone_clear,
|
||||
CheckRedzones(scratch_allocator, &stream, "scratch", instr, &result));
|
||||
CheckRedzones(scratch_allocator, stream, "scratch", instr, &result));
|
||||
|
||||
if (!input_output_allocator_redzone_clear ||
|
||||
!scratch_allocator_redzone_clear) {
|
||||
@ -393,7 +400,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
|
||||
if (comparator.has_value()) {
|
||||
XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::CompareEqual", 2);
|
||||
StatusOr<bool> compare_result = comparator->CompareEqual(
|
||||
&stream, reference_result_buffer, result_buffer);
|
||||
stream, reference_result_buffer, result_buffer);
|
||||
if (!compare_result.ok()) {
|
||||
LOG(ERROR) << "Unable to compare " << AlgorithmToString(first_algorithm)
|
||||
<< " against " << AlgorithmToString(alg) << " for "
|
||||
@ -411,7 +418,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
|
||||
<< instr->ToString() << " for "
|
||||
<< AlgorithmToString(first_algorithm) << " vs "
|
||||
<< AlgorithmToString(alg);
|
||||
PrintPlatformInfo(&stream);
|
||||
PrintPlatformInfo(stream);
|
||||
VLOG(1) << "Full module on failure: \n"
|
||||
<< instr->GetModule()->ToString();
|
||||
auto* fail = result.mutable_failure();
|
||||
@ -429,8 +436,8 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
reference_result_buffer,
|
||||
input_output_allocator.AllocateBytes(result_buffer.size()));
|
||||
stream.ThenMemcpy(&reference_result_buffer, result_buffer,
|
||||
result_buffer.size());
|
||||
stream->ThenMemcpy(&reference_result_buffer, result_buffer,
|
||||
result_buffer.size());
|
||||
first_algorithm = alg;
|
||||
}
|
||||
}
|
||||
|
@ -239,16 +239,22 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoGemmAutotune(
|
||||
static StatusOr<bool> RunOnInstruction(HloInstruction* instr,
|
||||
se::StreamExecutor* executor,
|
||||
se::DeviceMemoryAllocator* allocator) {
|
||||
se::Stream stream{executor};
|
||||
stream.Init();
|
||||
|
||||
if (allocator == nullptr) {
|
||||
allocator = executor->GetAllocator();
|
||||
}
|
||||
absl::optional<se::Stream> stream_opt;
|
||||
se::Stream* stream = [&]() {
|
||||
if (allocator->GetStream()) {
|
||||
return allocator->GetStream();
|
||||
}
|
||||
stream_opt.emplace(executor);
|
||||
stream_opt->Init();
|
||||
return &stream_opt.value();
|
||||
}();
|
||||
|
||||
const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
|
||||
se::cuda::RedzoneAllocator input_output_allocator(
|
||||
&stream, allocator, PtxOptsFromConfig(hlo_module_config));
|
||||
stream, allocator, PtxOptsFromConfig(hlo_module_config));
|
||||
|
||||
BufferComparator comparator(instr->shape(), hlo_module_config);
|
||||
|
||||
@ -258,7 +264,7 @@ static StatusOr<bool> RunOnInstruction(HloInstruction* instr,
|
||||
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer,
|
||||
input_output_allocator.AllocateBytes(
|
||||
ShapeUtil::ByteSizeOf(op->shape())));
|
||||
InitializeFloatBuffer(&stream, op->shape().element_type(), &rng_state,
|
||||
InitializeFloatBuffer(stream, op->shape().element_type(), &rng_state,
|
||||
buffer);
|
||||
return buffer;
|
||||
};
|
||||
@ -283,11 +289,11 @@ static StatusOr<bool> RunOnInstruction(HloInstruction* instr,
|
||||
const bool crash_on_checking_failure =
|
||||
debug_options.xla_gpu_crash_on_verification_failures();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(absl::optional<se::blas::AlgorithmType> gemm_algorithm,
|
||||
DoGemmAutotune(instr, lhs, rhs, lhs_buffer, rhs_buffer,
|
||||
output_buffer, reference_result_buffer,
|
||||
&stream, crash_on_checking_failure,
|
||||
input_output_allocator, comparator));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
absl::optional<se::blas::AlgorithmType> gemm_algorithm,
|
||||
DoGemmAutotune(instr, lhs, rhs, lhs_buffer, rhs_buffer, output_buffer,
|
||||
reference_result_buffer, stream, crash_on_checking_failure,
|
||||
input_output_allocator, comparator));
|
||||
|
||||
// We update instruction->backend_config(); if no algorithms are supported,
|
||||
// a different API is used, which does not require specifying an algorithm.
|
||||
|
@ -7,8 +7,8 @@ tensorflow/contrib/mpi/BUILD
|
||||
tensorflow/python/autograph/core/config.py
|
||||
tensorflow/python/tpu/profiler/pip_package/BUILD
|
||||
tensorflow/python/tpu/profiler/pip_package/README
|
||||
tensorflow/python/tpu/profiler/pip_package/build_pip_package.sh
|
||||
tensorflow/python/tpu/profiler/pip_package/setup.py
|
||||
tensorflow/python/tpu/profiler/pip_package/build_pip_package.sh
|
||||
tensorflow/stream_executor/build_defs.bzl
|
||||
tensorflow/third_party/BUILD
|
||||
tensorflow/third_party/android/BUILD
|
||||
@ -28,31 +28,31 @@ tensorflow/third_party/com_google_absl.BUILD
|
||||
tensorflow/third_party/common.bzl
|
||||
tensorflow/third_party/cub.BUILD
|
||||
tensorflow/third_party/curl.BUILD
|
||||
tensorflow/third_party/cython.BUILD
|
||||
tensorflow/third_party/double_conversion.BUILD
|
||||
tensorflow/third_party/eigen3/Eigen/Cholesky
|
||||
tensorflow/third_party/eigen3/Eigen/Core
|
||||
tensorflow/third_party/eigen3/Eigen/Eigenvalues
|
||||
tensorflow/third_party/eigen3/Eigen/Core
|
||||
tensorflow/third_party/eigen3/Eigen/LU
|
||||
tensorflow/third_party/eigen3/Eigen/QR
|
||||
tensorflow/third_party/eigen3/Eigen/SVD
|
||||
tensorflow/third_party/eigen3/Eigen/QR
|
||||
tensorflow/third_party/eigen3/BUILD
|
||||
tensorflow/third_party/eigen3/LICENSE
|
||||
tensorflow/third_party/eigen3/gpu_packet_math.patch
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
|
||||
tensorflow/third_party/cython.BUILD
|
||||
tensorflow/third_party/eigen.BUILD
|
||||
tensorflow/third_party/fft2d/BUILD
|
||||
tensorflow/third_party/fft2d/LICENSE
|
||||
@ -61,11 +61,11 @@ tensorflow/third_party/fft2d/fft2d.BUILD
|
||||
tensorflow/third_party/fft2d/fft2d.h
|
||||
tensorflow/third_party/enum34.BUILD
|
||||
tensorflow/third_party/farmhash.BUILD
|
||||
tensorflow/third_party/git/BUILD
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/git/git_configure.bzl
|
||||
tensorflow/third_party/functools32.BUILD
|
||||
tensorflow/third_party/gast.BUILD
|
||||
tensorflow/third_party/git/BUILD
|
||||
tensorflow/third_party/git/git_configure.bzl
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/gif.BUILD
|
||||
tensorflow/third_party/gpus/crosstool/BUILD
|
||||
tensorflow/third_party/gpus/crosstool/BUILD.tpl
|
||||
@ -74,18 +74,18 @@ tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
|
||||
tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
|
||||
tensorflow/third_party/gpus/BUILD
|
||||
tensorflow/third_party/gpus/cuda/BUILD.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD
|
||||
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD.tpl
|
||||
tensorflow/third_party/gpus/cuda/LICENSE
|
||||
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
|
||||
tensorflow/third_party/gpus/rocm/BUILD
|
||||
tensorflow/third_party/gpus/rocm/BUILD.tpl
|
||||
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
|
||||
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
|
||||
tensorflow/third_party/gpus/find_cuda_config.py
|
||||
tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/rocm_configure.bzl
|
||||
tensorflow/third_party/googleapis.BUILD
|
||||
tensorflow/third_party/grpc/BUILD
|
||||
@ -93,18 +93,18 @@ tensorflow/third_party/icu/udata.patch
|
||||
tensorflow/third_party/kafka/BUILD
|
||||
tensorflow/third_party/kafka/config.patch
|
||||
tensorflow/third_party/jsoncpp.BUILD
|
||||
tensorflow/third_party/libxsmm.BUILD
|
||||
tensorflow/third_party/llvm/BUILD
|
||||
tensorflow/third_party/llvm/expand_cmake_vars.py
|
||||
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
|
||||
tensorflow/third_party/llvm/llvm.bzl
|
||||
tensorflow/third_party/libxsmm.BUILD
|
||||
tensorflow/third_party/linenoise.BUILD
|
||||
tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/mkl/BUILD
|
||||
tensorflow/third_party/mkl/LICENSE
|
||||
tensorflow/third_party/mkl/MKL_LICENSE
|
||||
tensorflow/third_party/mkl/build_defs.bzl
|
||||
tensorflow/third_party/mkl/mkl.BUILD
|
||||
tensorflow/third_party/mkl/build_defs.bzl
|
||||
tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/mkl_dnn/LICENSE
|
||||
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
|
||||
tensorflow/third_party/mpi/.gitignore
|
||||
@ -112,25 +112,25 @@ tensorflow/third_party/mpi/BUILD
|
||||
tensorflow/third_party/mpi_collectives/BUILD
|
||||
tensorflow/third_party/nanopb.BUILD
|
||||
tensorflow/third_party/nccl/BUILD
|
||||
tensorflow/third_party/nccl/LICENSE
|
||||
tensorflow/third_party/nccl/archive.BUILD
|
||||
tensorflow/third_party/nccl/archive.patch
|
||||
tensorflow/third_party/nccl/LICENSE
|
||||
tensorflow/third_party/nccl/build_defs.bzl.tpl
|
||||
tensorflow/third_party/nccl/archive.patch
|
||||
tensorflow/third_party/nccl/nccl_configure.bzl
|
||||
tensorflow/third_party/nccl/system.BUILD.tpl
|
||||
tensorflow/third_party/ngraph/BUILD
|
||||
tensorflow/third_party/ngraph/LICENSE
|
||||
tensorflow/third_party/ngraph/NGRAPH_LICENSE
|
||||
tensorflow/third_party/ngraph/ngraph_tf.BUILD
|
||||
tensorflow/third_party/ngraph/build_defs.bzl
|
||||
tensorflow/third_party/ngraph/ngraph.BUILD
|
||||
tensorflow/third_party/ngraph/nlohmann_json.BUILD
|
||||
tensorflow/third_party/ngraph/ngraph_tf.BUILD
|
||||
tensorflow/third_party/ngraph/tbb.BUILD
|
||||
tensorflow/third_party/opt_einsum.BUILD
|
||||
tensorflow/third_party/pcre.BUILD
|
||||
tensorflow/third_party/png.BUILD
|
||||
tensorflow/third_party/png_fix_rpi.patch
|
||||
tensorflow/third_party/protobuf/BUILD
|
||||
tensorflow/third_party/png_fix_rpi.patch
|
||||
tensorflow/third_party/pprof.BUILD
|
||||
tensorflow/third_party/py/numpy/BUILD
|
||||
tensorflow/third_party/py/BUILD
|
||||
@ -142,55 +142,54 @@ tensorflow/third_party/repo.bzl
|
||||
tensorflow/third_party/six.BUILD
|
||||
tensorflow/third_party/snappy.BUILD
|
||||
tensorflow/third_party/sqlite.BUILD
|
||||
tensorflow/third_party/swig.BUILD
|
||||
tensorflow/third_party/sycl/crosstool/BUILD
|
||||
tensorflow/third_party/swig.BUILD
|
||||
tensorflow/third_party/systemlibs/BUILD.tpl
|
||||
tensorflow/third_party/systemlibs/BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
|
||||
tensorflow/third_party/systemlibs/astor.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
|
||||
tensorflow/third_party/systemlibs/boringssl.BUILD
|
||||
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
|
||||
tensorflow/third_party/systemlibs/curl.BUILD
|
||||
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
|
||||
tensorflow/third_party/systemlibs/cython.BUILD
|
||||
tensorflow/third_party/systemlibs/double_conversion.BUILD
|
||||
tensorflow/third_party/systemlibs/gast.BUILD
|
||||
tensorflow/third_party/systemlibs/gif.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
|
||||
tensorflow/third_party/systemlibs/googleapis.BUILD
|
||||
tensorflow/third_party/systemlibs/jsoncpp.BUILD
|
||||
tensorflow/third_party/systemlibs/grpc.BUILD
|
||||
tensorflow/third_party/systemlibs/lmdb.BUILD
|
||||
tensorflow/third_party/systemlibs/opt_einsum.BUILD
|
||||
tensorflow/third_party/systemlibs/jsoncpp.BUILD
|
||||
tensorflow/third_party/systemlibs/nsync.BUILD
|
||||
tensorflow/third_party/systemlibs/opt_einsum.BUILD
|
||||
tensorflow/third_party/systemlibs/pcre.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
tensorflow/third_party/systemlibs/six.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.bzl
|
||||
tensorflow/third_party/systemlibs/png.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.bzl
|
||||
tensorflow/third_party/systemlibs/re2.BUILD
|
||||
tensorflow/third_party/systemlibs/snappy.BUILD
|
||||
tensorflow/third_party/systemlibs/six.BUILD
|
||||
tensorflow/third_party/systemlibs/sqlite.BUILD
|
||||
tensorflow/third_party/systemlibs/swig.BUILD
|
||||
tensorflow/third_party/systemlibs/snappy.BUILD
|
||||
tensorflow/third_party/systemlibs/syslibs_configure.bzl
|
||||
tensorflow/third_party/systemlibs/termcolor.BUILD
|
||||
tensorflow/third_party/systemlibs/zlib.BUILD
|
||||
tensorflow/third_party/tensorrt/BUILD
|
||||
tensorflow/third_party/tensorrt/BUILD.tpl
|
||||
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
|
||||
tensorflow/third_party/tensorrt/LICENSE
|
||||
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
|
||||
tensorflow/third_party/tensorrt/tensorrt/include/tensorrt_config.h.tpl
|
||||
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
|
||||
tensorflow/third_party/termcolor.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_float.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_quant.BUILD
|
||||
tensorflow/third_party/toolchains/clang6/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
|
||||
tensorflow/third_party/toolchains/clang6/README.md
|
||||
tensorflow/third_party/toolchains/clang6/clang.BUILD
|
||||
tensorflow/third_party/toolchains/clang6/README.md
|
||||
tensorflow/third_party/toolchains/clang6/repo.bzl
|
||||
tensorflow/third_party/toolchains/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/arm/BUILD
|
||||
@ -215,8 +214,8 @@ tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
@ -225,9 +224,9 @@ tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_too
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc5-rocm/BUILD
|
||||
@ -235,24 +234,25 @@ tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc5-rocm/cc_toolchain_c
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3_opt/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
|
||||
tensorflow/third_party/toolchains/remote/BUILD
|
||||
tensorflow/third_party/toolchains/remote/configure.bzl
|
||||
tensorflow/third_party/toolchains/remote/BUILD.tpl
|
||||
tensorflow/third_party/toolchains/remote/configure.bzl
|
||||
tensorflow/third_party/toolchains/remote/BUILD
|
||||
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
|
||||
tensorflow/third_party/tflite_ovic_testdata.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_quant.BUILD
|
||||
tensorflow/third_party/tflite_smartreply.BUILD
|
||||
tensorflow/third_party/wrapt.BUILD
|
||||
tensorflow/third_party/zlib.BUILD
|
||||
@ -264,19 +264,19 @@ tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
|
||||
tensorflow/tools/lib_package/BUILD
|
||||
tensorflow/tools/lib_package/LibTensorFlowTest.java
|
||||
tensorflow/tools/lib_package/README.md
|
||||
tensorflow/tools/lib_package/libtensorflow_java_test.sh
|
||||
tensorflow/tools/lib_package/concat_licenses.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_test.c
|
||||
tensorflow/tools/lib_package/libtensorflow_java_test.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_test.sh
|
||||
tensorflow/tools/pip_package/BUILD
|
||||
tensorflow/tools/pip_package/MANIFEST.in
|
||||
tensorflow/tools/pip_package/build_pip_package.sh
|
||||
tensorflow/tools/pip_package/check_load_py_test.py
|
||||
tensorflow/tools/pip_package/BUILD
|
||||
tensorflow/tools/pip_package/README
|
||||
tensorflow/tools/pip_package/build_pip_package.sh
|
||||
tensorflow/tools/pip_package/pip_smoke_test.py
|
||||
tensorflow/tools/pip_package/setup.py
|
||||
tensorflow/tools/pip_package/simple_console.py
|
||||
tensorflow/tools/pip_package/simple_console_for_windows.py
|
||||
tensorflow/virtual_root_template_v1.__init__.py
|
||||
tensorflow/tools/pip_package/check_load_py_test.py
|
||||
tensorflow/virtual_root_template_v2.__init__.py
|
||||
tensorflow/virtual_root_template_v1.__init__.py
|
||||
llvm/llvm/projects/google_mlir/WORKSPACE
|
@ -199,6 +199,16 @@ class DeviceMemoryAllocator {
|
||||
// a stream, or do we have to wait for the computation to complete first?
|
||||
virtual bool AllowsAsynchronousDeallocation() const { return false; }
|
||||
|
||||
// Returns nullable stream pointer.
|
||||
//
|
||||
// If the pointer is non-null, then it is always safe to access the memory
|
||||
// allocated by the allocator on the returned stream. This condition is not
|
||||
// required though, as streams could be synchronized by other means.
|
||||
//
|
||||
// TODO(cheshire): clean up the interface, it might be cleaner to explicitly
|
||||
// pass the stream to Compiler.
|
||||
virtual Stream *GetStream() const { return nullptr; }
|
||||
|
||||
protected:
|
||||
const Platform* platform_;
|
||||
};
|
||||
|
@ -21,8 +21,9 @@ limitations under the License.
|
||||
namespace stream_executor {
|
||||
|
||||
TfAllocatorAdapter::TfAllocatorAdapter(const Platform *platform,
|
||||
tensorflow::Allocator *wrapped)
|
||||
: DeviceMemoryAllocator(platform), wrapped_(wrapped) {}
|
||||
tensorflow::Allocator *wrapped,
|
||||
Stream *stream)
|
||||
: DeviceMemoryAllocator(platform), wrapped_(wrapped), stream_(stream) {}
|
||||
|
||||
TfAllocatorAdapter::~TfAllocatorAdapter() {}
|
||||
|
||||
|
@ -30,7 +30,10 @@ namespace stream_executor {
|
||||
// see comment on `AllowsAsynchronousDeallocation()`.
|
||||
class TfAllocatorAdapter : public DeviceMemoryAllocator {
|
||||
public:
|
||||
TfAllocatorAdapter(const Platform *platform, tensorflow::Allocator *wrapped);
|
||||
// stream: a Stream on which the allocator can only be used. If non-null, the
|
||||
// allocator can not be used on any other stream.
|
||||
TfAllocatorAdapter(const Platform *platform, tensorflow::Allocator *wrapped,
|
||||
Stream *stream = nullptr);
|
||||
~TfAllocatorAdapter() override;
|
||||
|
||||
port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
|
||||
@ -47,8 +50,11 @@ class TfAllocatorAdapter : public DeviceMemoryAllocator {
|
||||
// (This attribute has no effect on CPU.)
|
||||
bool AllowsAsynchronousDeallocation() const override { return true; }
|
||||
|
||||
Stream *GetStream() const override { return stream_; }
|
||||
|
||||
private:
|
||||
tensorflow::Allocator *wrapped_;
|
||||
Stream *stream_;
|
||||
};
|
||||
|
||||
// Adapter class that wraps per-device TF allocators as an XLA allocator.
|
||||
|
Loading…
Reference in New Issue
Block a user