[XLA] Propagate the TF Stream to the XLA Compiler via the allocator

GPU BFC allocator XLA gets from TF should only be used on a main TF stream,
otherwise data races are possible.
XLA needs access to this stream during the compilation, otherwise data races are
possible.

PiperOrigin-RevId: 264260382
This commit is contained in:
George Karpenkov 2019-08-19 15:47:05 -07:00 committed by TensorFlower Gardener
parent a0ee95db22
commit 8e2ca26f57
8 changed files with 154 additions and 105 deletions

View File

@ -63,8 +63,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
DeviceType device_type = ctx->device_type();
se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr;
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator;
se::DeviceMemoryAllocator* device_allocator = nullptr;
se::DeviceMemoryAllocator* custom_allocator = nullptr;
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
platform_id = se::host::kHostPlatformId;
@ -84,23 +83,13 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
// (which xla_allocator above uses) as on an XlaDevice, this is a dummy
// allocator that returns XlaTensor objects. The XlaCompiler needs a real
// allocator to allocate real buffers.
platform_id = xla_device_metadata->platform()->id();
device_allocator =
custom_allocator =
xla_device_metadata->client()->backend().memory_allocator();
}
if (!device_allocator) {
xla::StatusOr<se::Platform*> maybe_platform =
se::MultiPlatformManager::PlatformWithId(platform_id);
OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status());
xla_allocator = absl::make_unique<se::TfAllocatorAdapter>(
maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({}));
}
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
std::move(xla_allocator), device_allocator);
custom_allocator);
}
// A closure describing how to run a compiled version of a TensorFlow function.
@ -185,6 +174,33 @@ class XlaExecutableClosureStore {
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
};
// Return allocator from platform info if non-null, or populate and return a
// pointer to the allocator adapter with allocator from context.
//
// This is necessary because for XLA devices the underlying TF allocator returns
// dummy tensors.
se::DeviceMemoryAllocator* GetAllocator(
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
OpKernelContext* ctx, const XlaPlatformInfo& platform_info) {
if (platform_info.custom_allocator()) {
return platform_info.custom_allocator();
}
if (!ctx->op_device_context()) {
// Stream is not set for the host platform.
se::Platform* platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
.ValueOrDie();
tf_allocator_adapter->emplace(platform, ctx->device()->GetAllocator({}),
/*stream=*/nullptr);
return &tf_allocator_adapter->value();
}
// platform_info.
tf_allocator_adapter->emplace(
ctx->op_device_context()->stream()->parent()->platform(),
ctx->device()->GetAllocator({}), ctx->op_device_context()->stream());
return &tf_allocator_adapter->value();
}
} // namespace
XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
@ -281,6 +297,7 @@ static Status CompileToLocalExecutable(
TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, resources, variables));
*client = static_cast<xla::LocalClient*>(cache->client());
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options;
options.client = *client;
if (ctx->op_device_context() != nullptr) {
@ -292,7 +309,8 @@ static Status CompileToLocalExecutable(
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls =
(platform_info.platform_id() == se::host::kHostPlatformId);
options.device_allocator = platform_info.allocator();
options.device_allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info);
if (platform_info.xla_device_metadata()) {
options.shape_representation_fn =
platform_info.xla_device_metadata()->shape_representation_fn();
@ -350,8 +368,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
VLOG(1) << "Executing XLA Computation...";
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
XlaComputationLaunchContext launch_context(
client, platform_info_.allocator(),
client, allocator,
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
platform_info_.UseMultipleStreams());
launch_context.PopulateInputs(ctx, kernel, variables,
@ -361,7 +382,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
VLOG(2) << "Executing computation.";
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_allocator(platform_info_.allocator());
run_options.set_allocator(allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed());
Env* env = Env::Default();
@ -528,8 +549,11 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
XlaExecutableClosure closure =
XlaExecutableClosureStore::Global()->Consume(key);
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
XlaComputationLaunchContext launch_context(
closure.client(), platform_info_.allocator(),
closure.client(), allocator,
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
/*use_multiple_streams=*/platform_info_.UseMultipleStreams());
@ -554,7 +578,7 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_allocator(platform_info_.allocator());
run_options.set_allocator(allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed());
Env* env = Env::Default();

View File

@ -37,18 +37,14 @@ class XlaPlatformInfo {
public:
XlaPlatformInfo() : device_type_("") {}
XlaPlatformInfo(XlaPlatformInfo&&) = default;
explicit XlaPlatformInfo(
const DeviceType device_type, se::Platform::Id platform_id,
const XlaDevice::Metadata* xla_device_metadata,
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator,
se::DeviceMemoryAllocator* device_allocator)
explicit XlaPlatformInfo(const DeviceType device_type,
se::Platform::Id platform_id,
const XlaDevice::Metadata* xla_device_metadata,
se::DeviceMemoryAllocator* device_allocator)
: device_type_(device_type),
platform_id_(platform_id),
xla_device_metadata_(xla_device_metadata),
xla_allocator_(std::move(xla_allocator)),
device_allocator_(device_allocator) {
CHECK((device_allocator_ != nullptr) ^ (xla_allocator_.get() != nullptr));
}
device_allocator_(device_allocator) {}
XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
@ -56,9 +52,11 @@ class XlaPlatformInfo {
return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
}
se::DeviceMemoryAllocator* allocator() const {
return device_allocator_ ? device_allocator_ : xla_allocator_.get();
// Non-null only when run on an XLA device.
se::DeviceMemoryAllocator* custom_allocator() const {
return device_allocator_;
}
DeviceType device_type() const { return device_type_; }
// This is equal to xla_device_metadata()->platform()->id() if
@ -82,11 +80,8 @@ class XlaPlatformInfo {
const XlaDevice::Metadata* xla_device_metadata_;
// If the op associated with this XlaPlatformInfo is placed on an XLA device
// then device_allocator_ is the xla::Backend's memory allocator and
// xla_allocator_ is null. If the op is placed on a regular CPU or GPU device
// then device_allocator_ is null and xla_allocator_ points to an appropriate
// se::TfAllocatorAdapter instance.
std::unique_ptr<se::TfAllocatorAdapter> xla_allocator_;
// then device_allocator_ is the xla::Backend's memory allocator. If the op
// is placed on a regular CPU or GPU device then device_allocator_ is null.
se::DeviceMemoryAllocator* device_allocator_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);

View File

@ -248,9 +248,6 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
return InternalError("Failed to synchronize GPU for autotuning.");
}
// Create a stream for us to do our work on.
se::Stream stream{stream_exec_};
stream.Init();
// allocator either points to this->allocator_ or, if that's null, to a
// se::StreamExecutorMemoryAllocator for stream_exec_.
se::DeviceMemoryAllocator* allocator;
@ -262,11 +259,21 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
allocator = &*se_allocator;
}
absl::optional<se::Stream> stream_opt;
se::Stream* stream = [&] {
if (allocator->GetStream()) {
return allocator->GetStream();
}
stream_opt.emplace(stream_exec_);
stream_opt->Init();
return &stream_opt.value();
}();
int64 rng_state = 0;
const auto initialize_buffer = [&stream, &result_shape,
const auto initialize_buffer = [stream, &result_shape,
&rng_state](DeviceMemoryBase buffer) {
InitializeFloatBuffer(&stream, result_shape.element_type(), &rng_state,
InitializeFloatBuffer(stream, result_shape.element_type(), &rng_state,
buffer);
};
@ -274,7 +281,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
// Allocate space for the input, filter, and output of the convolution.
se::cuda::RedzoneAllocator input_output_allocator(
&stream, allocator, PtxOptsFromConfig(hlo_module_config));
stream, allocator, PtxOptsFromConfig(hlo_module_config));
std::vector<se::DeviceMemoryBase> operand_buffers;
for (const auto* operand : instr->operands()) {
TF_ASSIGN_OR_RETURN(auto buffer,
@ -328,7 +335,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
}
se::cuda::RedzoneAllocator scratch_allocator(
&stream, allocator, PtxOptsFromConfig(hlo_module_config));
stream, allocator, PtxOptsFromConfig(hlo_module_config));
se::dnn::ProfileResult profile_result;
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
<< instr->ToString();
@ -339,7 +346,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
options.algo_override = alg;
Status launch_status =
RunCudnnConv(instr, absl::MakeSpan(operand_buffers), result_buffer,
&scratch_allocator, &stream, options);
&scratch_allocator, stream, options);
if (!launch_status.ok()) {
continue;
@ -362,12 +369,12 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
// Check for writes to redzones.
TF_ASSIGN_OR_RETURN(bool input_output_allocator_redzone_clear,
CheckRedzones(input_output_allocator, &stream,
CheckRedzones(input_output_allocator, stream,
"input/output", instr, &result));
TF_ASSIGN_OR_RETURN(
bool scratch_allocator_redzone_clear,
CheckRedzones(scratch_allocator, &stream, "scratch", instr, &result));
CheckRedzones(scratch_allocator, stream, "scratch", instr, &result));
if (!input_output_allocator_redzone_clear ||
!scratch_allocator_redzone_clear) {
@ -393,7 +400,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
if (comparator.has_value()) {
XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::CompareEqual", 2);
StatusOr<bool> compare_result = comparator->CompareEqual(
&stream, reference_result_buffer, result_buffer);
stream, reference_result_buffer, result_buffer);
if (!compare_result.ok()) {
LOG(ERROR) << "Unable to compare " << AlgorithmToString(first_algorithm)
<< " against " << AlgorithmToString(alg) << " for "
@ -411,7 +418,7 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
<< instr->ToString() << " for "
<< AlgorithmToString(first_algorithm) << " vs "
<< AlgorithmToString(alg);
PrintPlatformInfo(&stream);
PrintPlatformInfo(stream);
VLOG(1) << "Full module on failure: \n"
<< instr->GetModule()->ToString();
auto* fail = result.mutable_failure();
@ -429,8 +436,8 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
TF_ASSIGN_OR_RETURN(
reference_result_buffer,
input_output_allocator.AllocateBytes(result_buffer.size()));
stream.ThenMemcpy(&reference_result_buffer, result_buffer,
result_buffer.size());
stream->ThenMemcpy(&reference_result_buffer, result_buffer,
result_buffer.size());
first_algorithm = alg;
}
}

View File

@ -239,16 +239,22 @@ static StatusOr<absl::optional<se::blas::AlgorithmType>> DoGemmAutotune(
static StatusOr<bool> RunOnInstruction(HloInstruction* instr,
se::StreamExecutor* executor,
se::DeviceMemoryAllocator* allocator) {
se::Stream stream{executor};
stream.Init();
if (allocator == nullptr) {
allocator = executor->GetAllocator();
}
absl::optional<se::Stream> stream_opt;
se::Stream* stream = [&]() {
if (allocator->GetStream()) {
return allocator->GetStream();
}
stream_opt.emplace(executor);
stream_opt->Init();
return &stream_opt.value();
}();
const HloModuleConfig& hlo_module_config = instr->GetModule()->config();
se::cuda::RedzoneAllocator input_output_allocator(
&stream, allocator, PtxOptsFromConfig(hlo_module_config));
stream, allocator, PtxOptsFromConfig(hlo_module_config));
BufferComparator comparator(instr->shape(), hlo_module_config);
@ -258,7 +264,7 @@ static StatusOr<bool> RunOnInstruction(HloInstruction* instr,
TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer,
input_output_allocator.AllocateBytes(
ShapeUtil::ByteSizeOf(op->shape())));
InitializeFloatBuffer(&stream, op->shape().element_type(), &rng_state,
InitializeFloatBuffer(stream, op->shape().element_type(), &rng_state,
buffer);
return buffer;
};
@ -283,11 +289,11 @@ static StatusOr<bool> RunOnInstruction(HloInstruction* instr,
const bool crash_on_checking_failure =
debug_options.xla_gpu_crash_on_verification_failures();
TF_ASSIGN_OR_RETURN(absl::optional<se::blas::AlgorithmType> gemm_algorithm,
DoGemmAutotune(instr, lhs, rhs, lhs_buffer, rhs_buffer,
output_buffer, reference_result_buffer,
&stream, crash_on_checking_failure,
input_output_allocator, comparator));
TF_ASSIGN_OR_RETURN(
absl::optional<se::blas::AlgorithmType> gemm_algorithm,
DoGemmAutotune(instr, lhs, rhs, lhs_buffer, rhs_buffer, output_buffer,
reference_result_buffer, stream, crash_on_checking_failure,
input_output_allocator, comparator));
// We update instruction->backend_config(); if no algorithms are supported,
// a different API is used, which does not require specifying an algorithm.

View File

@ -7,8 +7,8 @@ tensorflow/contrib/mpi/BUILD
tensorflow/python/autograph/core/config.py
tensorflow/python/tpu/profiler/pip_package/BUILD
tensorflow/python/tpu/profiler/pip_package/README
tensorflow/python/tpu/profiler/pip_package/build_pip_package.sh
tensorflow/python/tpu/profiler/pip_package/setup.py
tensorflow/python/tpu/profiler/pip_package/build_pip_package.sh
tensorflow/stream_executor/build_defs.bzl
tensorflow/third_party/BUILD
tensorflow/third_party/android/BUILD
@ -28,31 +28,31 @@ tensorflow/third_party/com_google_absl.BUILD
tensorflow/third_party/common.bzl
tensorflow/third_party/cub.BUILD
tensorflow/third_party/curl.BUILD
tensorflow/third_party/cython.BUILD
tensorflow/third_party/double_conversion.BUILD
tensorflow/third_party/eigen3/Eigen/Cholesky
tensorflow/third_party/eigen3/Eigen/Core
tensorflow/third_party/eigen3/Eigen/Eigenvalues
tensorflow/third_party/eigen3/Eigen/Core
tensorflow/third_party/eigen3/Eigen/LU
tensorflow/third_party/eigen3/Eigen/QR
tensorflow/third_party/eigen3/Eigen/SVD
tensorflow/third_party/eigen3/Eigen/QR
tensorflow/third_party/eigen3/BUILD
tensorflow/third_party/eigen3/LICENSE
tensorflow/third_party/eigen3/gpu_packet_math.patch
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
tensorflow/third_party/cython.BUILD
tensorflow/third_party/eigen.BUILD
tensorflow/third_party/fft2d/BUILD
tensorflow/third_party/fft2d/LICENSE
@ -61,11 +61,11 @@ tensorflow/third_party/fft2d/fft2d.BUILD
tensorflow/third_party/fft2d/fft2d.h
tensorflow/third_party/enum34.BUILD
tensorflow/third_party/farmhash.BUILD
tensorflow/third_party/git/BUILD
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/git/git_configure.bzl
tensorflow/third_party/functools32.BUILD
tensorflow/third_party/gast.BUILD
tensorflow/third_party/git/BUILD
tensorflow/third_party/git/git_configure.bzl
tensorflow/third_party/git/BUILD.tpl
tensorflow/third_party/gif.BUILD
tensorflow/third_party/gpus/crosstool/BUILD
tensorflow/third_party/gpus/crosstool/BUILD.tpl
@ -74,18 +74,18 @@ tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
tensorflow/third_party/gpus/BUILD
tensorflow/third_party/gpus/cuda/BUILD.tpl
tensorflow/third_party/gpus/cuda/BUILD
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
tensorflow/third_party/gpus/cuda/BUILD.tpl
tensorflow/third_party/gpus/cuda/LICENSE
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
tensorflow/third_party/gpus/rocm/BUILD
tensorflow/third_party/gpus/rocm/BUILD.tpl
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
tensorflow/third_party/gpus/find_cuda_config.py
tensorflow/third_party/gpus/cuda_configure.bzl
tensorflow/third_party/gpus/rocm_configure.bzl
tensorflow/third_party/googleapis.BUILD
tensorflow/third_party/grpc/BUILD
@ -93,18 +93,18 @@ tensorflow/third_party/icu/udata.patch
tensorflow/third_party/kafka/BUILD
tensorflow/third_party/kafka/config.patch
tensorflow/third_party/jsoncpp.BUILD
tensorflow/third_party/libxsmm.BUILD
tensorflow/third_party/llvm/BUILD
tensorflow/third_party/llvm/expand_cmake_vars.py
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
tensorflow/third_party/llvm/llvm.bzl
tensorflow/third_party/libxsmm.BUILD
tensorflow/third_party/linenoise.BUILD
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/mkl/BUILD
tensorflow/third_party/mkl/LICENSE
tensorflow/third_party/mkl/MKL_LICENSE
tensorflow/third_party/mkl/build_defs.bzl
tensorflow/third_party/mkl/mkl.BUILD
tensorflow/third_party/mkl/build_defs.bzl
tensorflow/third_party/lmdb.BUILD
tensorflow/third_party/mkl_dnn/LICENSE
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
tensorflow/third_party/mpi/.gitignore
@ -112,25 +112,25 @@ tensorflow/third_party/mpi/BUILD
tensorflow/third_party/mpi_collectives/BUILD
tensorflow/third_party/nanopb.BUILD
tensorflow/third_party/nccl/BUILD
tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/archive.BUILD
tensorflow/third_party/nccl/archive.patch
tensorflow/third_party/nccl/LICENSE
tensorflow/third_party/nccl/build_defs.bzl.tpl
tensorflow/third_party/nccl/archive.patch
tensorflow/third_party/nccl/nccl_configure.bzl
tensorflow/third_party/nccl/system.BUILD.tpl
tensorflow/third_party/ngraph/BUILD
tensorflow/third_party/ngraph/LICENSE
tensorflow/third_party/ngraph/NGRAPH_LICENSE
tensorflow/third_party/ngraph/ngraph_tf.BUILD
tensorflow/third_party/ngraph/build_defs.bzl
tensorflow/third_party/ngraph/ngraph.BUILD
tensorflow/third_party/ngraph/nlohmann_json.BUILD
tensorflow/third_party/ngraph/ngraph_tf.BUILD
tensorflow/third_party/ngraph/tbb.BUILD
tensorflow/third_party/opt_einsum.BUILD
tensorflow/third_party/pcre.BUILD
tensorflow/third_party/png.BUILD
tensorflow/third_party/png_fix_rpi.patch
tensorflow/third_party/protobuf/BUILD
tensorflow/third_party/png_fix_rpi.patch
tensorflow/third_party/pprof.BUILD
tensorflow/third_party/py/numpy/BUILD
tensorflow/third_party/py/BUILD
@ -142,55 +142,54 @@ tensorflow/third_party/repo.bzl
tensorflow/third_party/six.BUILD
tensorflow/third_party/snappy.BUILD
tensorflow/third_party/sqlite.BUILD
tensorflow/third_party/swig.BUILD
tensorflow/third_party/sycl/crosstool/BUILD
tensorflow/third_party/swig.BUILD
tensorflow/third_party/systemlibs/BUILD.tpl
tensorflow/third_party/systemlibs/BUILD
tensorflow/third_party/systemlibs/absl_py.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
tensorflow/third_party/systemlibs/astor.BUILD
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
tensorflow/third_party/systemlibs/boringssl.BUILD
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
tensorflow/third_party/systemlibs/curl.BUILD
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
tensorflow/third_party/systemlibs/cython.BUILD
tensorflow/third_party/systemlibs/double_conversion.BUILD
tensorflow/third_party/systemlibs/gast.BUILD
tensorflow/third_party/systemlibs/gif.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
tensorflow/third_party/systemlibs/googleapis.BUILD
tensorflow/third_party/systemlibs/jsoncpp.BUILD
tensorflow/third_party/systemlibs/grpc.BUILD
tensorflow/third_party/systemlibs/lmdb.BUILD
tensorflow/third_party/systemlibs/opt_einsum.BUILD
tensorflow/third_party/systemlibs/jsoncpp.BUILD
tensorflow/third_party/systemlibs/nsync.BUILD
tensorflow/third_party/systemlibs/opt_einsum.BUILD
tensorflow/third_party/systemlibs/pcre.BUILD
tensorflow/third_party/systemlibs/protobuf.BUILD
tensorflow/third_party/systemlibs/six.BUILD
tensorflow/third_party/systemlibs/protobuf.bzl
tensorflow/third_party/systemlibs/png.BUILD
tensorflow/third_party/systemlibs/protobuf.bzl
tensorflow/third_party/systemlibs/re2.BUILD
tensorflow/third_party/systemlibs/snappy.BUILD
tensorflow/third_party/systemlibs/six.BUILD
tensorflow/third_party/systemlibs/sqlite.BUILD
tensorflow/third_party/systemlibs/swig.BUILD
tensorflow/third_party/systemlibs/snappy.BUILD
tensorflow/third_party/systemlibs/syslibs_configure.bzl
tensorflow/third_party/systemlibs/termcolor.BUILD
tensorflow/third_party/systemlibs/zlib.BUILD
tensorflow/third_party/tensorrt/BUILD
tensorflow/third_party/tensorrt/BUILD.tpl
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
tensorflow/third_party/tensorrt/LICENSE
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
tensorflow/third_party/tensorrt/tensorrt/include/tensorrt_config.h.tpl
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
tensorflow/third_party/termcolor.BUILD
tensorflow/third_party/tflite_mobilenet.BUILD
tensorflow/third_party/tflite_mobilenet_float.BUILD
tensorflow/third_party/tflite_mobilenet_quant.BUILD
tensorflow/third_party/toolchains/clang6/BUILD
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
tensorflow/third_party/toolchains/clang6/README.md
tensorflow/third_party/toolchains/clang6/clang.BUILD
tensorflow/third_party/toolchains/clang6/README.md
tensorflow/third_party/toolchains/clang6/repo.bzl
tensorflow/third_party/toolchains/BUILD
tensorflow/third_party/toolchains/cpus/arm/BUILD
@ -215,8 +214,8 @@ tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/generate/BUILD
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
@ -225,9 +224,9 @@ tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_too
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc5-rocm/BUILD
@ -235,24 +234,25 @@ tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc5-rocm/cc_toolchain_c
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010/dummy_toolchain.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/cc_toolchain_config.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3_opt/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/rocm/rocm/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/BUILD
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5.1/build_defs.bzl
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
tensorflow/third_party/toolchains/remote/BUILD
tensorflow/third_party/toolchains/remote/configure.bzl
tensorflow/third_party/toolchains/remote/BUILD.tpl
tensorflow/third_party/toolchains/remote/configure.bzl
tensorflow/third_party/toolchains/remote/BUILD
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
tensorflow/third_party/tflite_ovic_testdata.BUILD
tensorflow/third_party/tflite_mobilenet_quant.BUILD
tensorflow/third_party/tflite_smartreply.BUILD
tensorflow/third_party/wrapt.BUILD
tensorflow/third_party/zlib.BUILD
@ -264,19 +264,19 @@ tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
tensorflow/tools/lib_package/BUILD
tensorflow/tools/lib_package/LibTensorFlowTest.java
tensorflow/tools/lib_package/README.md
tensorflow/tools/lib_package/libtensorflow_java_test.sh
tensorflow/tools/lib_package/concat_licenses.sh
tensorflow/tools/lib_package/libtensorflow_test.c
tensorflow/tools/lib_package/libtensorflow_java_test.sh
tensorflow/tools/lib_package/libtensorflow_test.sh
tensorflow/tools/pip_package/BUILD
tensorflow/tools/pip_package/MANIFEST.in
tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/check_load_py_test.py
tensorflow/tools/pip_package/BUILD
tensorflow/tools/pip_package/README
tensorflow/tools/pip_package/build_pip_package.sh
tensorflow/tools/pip_package/pip_smoke_test.py
tensorflow/tools/pip_package/setup.py
tensorflow/tools/pip_package/simple_console.py
tensorflow/tools/pip_package/simple_console_for_windows.py
tensorflow/virtual_root_template_v1.__init__.py
tensorflow/tools/pip_package/check_load_py_test.py
tensorflow/virtual_root_template_v2.__init__.py
tensorflow/virtual_root_template_v1.__init__.py
llvm/llvm/projects/google_mlir/WORKSPACE

View File

@ -199,6 +199,16 @@ class DeviceMemoryAllocator {
// a stream, or do we have to wait for the computation to complete first?
virtual bool AllowsAsynchronousDeallocation() const { return false; }
// Returns nullable stream pointer.
//
// If the pointer is non-null, then it is always safe to access the memory
// allocated by the allocator on the returned stream. This condition is not
// required though, as streams could be synchronized by other means.
//
// TODO(cheshire): clean up the interface, it might be cleaner to explicitly
// pass the stream to Compiler.
virtual Stream *GetStream() const { return nullptr; }
protected:
const Platform* platform_;
};

View File

@ -21,8 +21,9 @@ limitations under the License.
namespace stream_executor {
TfAllocatorAdapter::TfAllocatorAdapter(const Platform *platform,
tensorflow::Allocator *wrapped)
: DeviceMemoryAllocator(platform), wrapped_(wrapped) {}
tensorflow::Allocator *wrapped,
Stream *stream)
: DeviceMemoryAllocator(platform), wrapped_(wrapped), stream_(stream) {}
TfAllocatorAdapter::~TfAllocatorAdapter() {}

View File

@ -30,7 +30,10 @@ namespace stream_executor {
// see comment on `AllowsAsynchronousDeallocation()`.
class TfAllocatorAdapter : public DeviceMemoryAllocator {
public:
TfAllocatorAdapter(const Platform *platform, tensorflow::Allocator *wrapped);
// stream: a Stream on which the allocator can only be used. If non-null, the
// allocator can not be used on any other stream.
TfAllocatorAdapter(const Platform *platform, tensorflow::Allocator *wrapped,
Stream *stream = nullptr);
~TfAllocatorAdapter() override;
port::StatusOr<OwningDeviceMemory> Allocate(int device_ordinal, uint64 size,
@ -47,8 +50,11 @@ class TfAllocatorAdapter : public DeviceMemoryAllocator {
// (This attribute has no effect on CPU.)
bool AllowsAsynchronousDeallocation() const override { return true; }
Stream *GetStream() const override { return stream_; }
private:
tensorflow::Allocator *wrapped_;
Stream *stream_;
};
// Adapter class that wraps per-device TF allocators as an XLA allocator.