[TF:XLA] Register the standard Stack kernels on XLA_... devices.
Refactor the core stack ops to split the op classes from their registrations on CPU/GPU/SYSCL devices. Refactor the stack push op to be templated on an allow_swapping bool, rather than a specific device. The device was only ever used in a type equality test to determine whether to swap or not. On XLA_... devices, previously stack operators only worked when the entire computation was grouped into a single cluster (e.g., via xla.compile()). This change also allows stack-using operators to work in "ondemand" or eager modes, when running ops one-at-a-time. However, since the compiled and interpreted representations of stacks are still different, there is not yet any support for passing stacks into or out of compiled blocks. Stack usage must remain entirely inside or entirely outside a compiled block until we rectify this in a future change. PiperOrigin-RevId: 220132340
This commit is contained in:
parent
6dc0f05f84
commit
b9b0354ebc
tensorflow
compiler
contrib/makefile
core/kernels
@ -190,6 +190,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:resource_variable_ops",
|
||||
"//tensorflow/core/kernels:sendrecv_ops",
|
||||
"//tensorflow/core/kernels:shape_ops",
|
||||
"//tensorflow/core/kernels:stack",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"//tensorflow/core/kernels/data:generator_dataset_op",
|
||||
"//tensorflow/core/kernels/data:iterator_ops",
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/resource_variable_ops.h"
|
||||
#include "tensorflow/core/kernels/sendrecv_ops.h"
|
||||
#include "tensorflow/core/kernels/shape_ops.h"
|
||||
#include "tensorflow/core/kernels/stack.h"
|
||||
#include "tensorflow/core/kernels/variable_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -257,9 +258,27 @@ class XlaAssignVariableOp : public OpKernel {
|
||||
.Device(DEVICE) \
|
||||
.TypeConstraint<string>("T") \
|
||||
.HostMemory("input"), \
|
||||
RetvalOp);
|
||||
RetvalOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER(Name("StackV2") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("max_size") \
|
||||
.HostMemory("handle"), \
|
||||
StackOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPushV2") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("handle") \
|
||||
.TypeConstraint("T", TYPES), \
|
||||
TemplatedStackPushOp</*allow_swapping=*/false>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPopV2") \
|
||||
.Device(DEVICE) \
|
||||
.HostMemory("handle") \
|
||||
.TypeConstraint("elem_type", TYPES), \
|
||||
StackPopOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("StackCloseV2").Device(DEVICE).HostMemory("handle"), StackCloseOp);
|
||||
|
||||
// TODO(phawkins): currently we do not register the QueueEnqueueMany,
|
||||
// TODO(b/118881356): currently we do not register the QueueEnqueueMany,
|
||||
// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
|
||||
// and write the tensors they access in order to concatenate them into a batch.
|
||||
// We would need either to call out to an XLA computation to perform the
|
||||
|
@ -837,8 +837,6 @@ tf_xla_py_test(
|
||||
name = "stack_ops_test",
|
||||
size = "small",
|
||||
srcs = ["stack_ops_test.py"],
|
||||
# Stack ops are not implemented in the on-demand compilation model yet.
|
||||
disabled_backends = ["cpu_ondemand"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
|
@ -126,7 +126,9 @@ class StackOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StackV2").CompileTimeConstantInput("max_size"), StackOp);
|
||||
REGISTER_XLA_OP(
|
||||
Name("StackV2").CompileTimeConstantInput("max_size").CompilationOnly(),
|
||||
StackOp);
|
||||
|
||||
class StackPushOp : public XlaOpKernel {
|
||||
public:
|
||||
@ -173,7 +175,7 @@ class StackPushOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StackPushV2"), StackPushOp);
|
||||
REGISTER_XLA_OP(Name("StackPushV2").CompilationOnly(), StackPushOp);
|
||||
|
||||
class StackPopOp : public XlaOpKernel {
|
||||
public:
|
||||
@ -227,7 +229,7 @@ class StackPopOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StackPopV2"), StackPopOp);
|
||||
REGISTER_XLA_OP(Name("StackPopV2").CompilationOnly(), StackPopOp);
|
||||
|
||||
class StackCloseOp : public XlaOpKernel {
|
||||
public:
|
||||
@ -241,7 +243,7 @@ class StackCloseOp : public XlaOpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("StackCloseV2"), StackCloseOp);
|
||||
REGISTER_XLA_OP(Name("StackCloseV2").CompilationOnly(), StackCloseOp);
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -248,6 +248,7 @@ tensorflow/core/kernels/spectrogram_op.cc
|
||||
tensorflow/core/kernels/split_lib_cpu.cc
|
||||
tensorflow/core/kernels/split_op.cc
|
||||
tensorflow/core/kernels/split_v_op.cc
|
||||
tensorflow/core/kernels/stack.cc
|
||||
tensorflow/core/kernels/stack_ops.cc
|
||||
tensorflow/core/kernels/strided_slice_op.cc
|
||||
tensorflow/core/kernels/strided_slice_op_inst_0.cc
|
||||
|
@ -1887,10 +1887,22 @@ tf_kernel_library(
|
||||
deps = DATA_FLOW_DEPS,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "stack",
|
||||
srcs = ["stack.cc"],
|
||||
hdrs = ["stack.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "stack_ops",
|
||||
prefix = "stack_ops",
|
||||
deps = DATA_FLOW_DEPS,
|
||||
deps = DATA_FLOW_DEPS + [":stack"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
@ -5485,6 +5497,8 @@ filegroup(
|
||||
"sparse_to_dense_op.cc",
|
||||
"spectrogram.cc",
|
||||
"spectrogram_op.cc",
|
||||
"stack.cc",
|
||||
"stack.h",
|
||||
"stack_ops.cc",
|
||||
"string_join_op.cc",
|
||||
"string_util.cc",
|
||||
|
339
tensorflow/core/kernels/stack.cc
Normal file
339
tensorflow/core/kernels/stack.cc
Normal file
@ -0,0 +1,339 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/stack.h"
|
||||
|
||||
#include <limits.h>
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Stack : public ResourceBase {
|
||||
public:
|
||||
static std::atomic<int64> stack_counter;
|
||||
|
||||
struct TensorAndAllocation {
|
||||
Tensor tensor;
|
||||
AllocatorAttributes alloc_attrs;
|
||||
bool swapped_to_cpu;
|
||||
};
|
||||
|
||||
Stack(const DataType& elem_type, const string& stack_name, int max_size)
|
||||
: elem_type_(elem_type),
|
||||
stack_name_(stack_name),
|
||||
max_size_(max_size),
|
||||
closed_(false) {}
|
||||
|
||||
Status Push(const TensorAndAllocation& value) {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||
if (max_size_ >= 0 && stack_.size() >= max_size_) {
|
||||
return errors::InvalidArgument("Stack[", stack_name_, "] overflowed ",
|
||||
"its max_size (", max_size_, ")");
|
||||
}
|
||||
stack_.push_back(value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Pop(TensorAndAllocation* value) {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||
if (stack_.empty()) {
|
||||
return errors::InvalidArgument("Stack[", stack_name_,
|
||||
"] is empty when calling Pop().");
|
||||
}
|
||||
*value = stack_.back();
|
||||
stack_.pop_back();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// We don't swap the first tensor on the stack and any subsequent tensors
|
||||
// that share the buffer with the first tensor.
|
||||
bool IsUsefulToSwap(const Tensor& tensor) const {
|
||||
mutex_lock l(mu_);
|
||||
if (stack_.empty()) {
|
||||
return false;
|
||||
}
|
||||
const Tensor& first = stack_.front().tensor;
|
||||
return !tensor.SharesBufferWith(first);
|
||||
}
|
||||
|
||||
void Close() {
|
||||
mutex_lock l(mu_);
|
||||
stack_.clear();
|
||||
closed_ = true;
|
||||
}
|
||||
|
||||
DataType ElemType() { return elem_type_; }
|
||||
|
||||
string DebugString() override {
|
||||
mutex_lock l(mu_);
|
||||
return strings::StrCat("Stack[", stack_name_, "]");
|
||||
}
|
||||
|
||||
const string& stack_name() { return stack_name_; }
|
||||
|
||||
private:
|
||||
friend class StackOp;
|
||||
mutex* mu() { return &mu_; }
|
||||
|
||||
mutable mutex mu_;
|
||||
DataType elem_type_;
|
||||
const string stack_name_;
|
||||
Tensor handle_;
|
||||
int max_size_;
|
||||
bool closed_ GUARDED_BY(mu_);
|
||||
std::vector<TensorAndAllocation> stack_ GUARDED_BY(mu_);
|
||||
|
||||
Status CheckNotClosed() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (closed_) {
|
||||
return errors::InvalidArgument("Stack[", stack_name_,
|
||||
"] has already been closed.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
Status GetStack(OpKernelContext* ctx, Stack** stack) {
|
||||
if (ctx->input_dtype(0) == DT_RESOURCE) {
|
||||
return LookupResource(ctx, HandleFromInput(ctx, 0), stack);
|
||||
} else {
|
||||
Tensor Tstack_handle = ctx->mutable_input(0, false);
|
||||
if (Tstack_handle.NumElements() != 2) {
|
||||
return errors::InvalidArgument(
|
||||
"Stack handle must have two elements, but had shape: ",
|
||||
Tstack_handle.shape().DebugString());
|
||||
}
|
||||
const string& container = Tstack_handle.flat<string>()(0);
|
||||
const string& stack_name = Tstack_handle.flat<string>()(1);
|
||||
string key = strings::StrCat(container, stack_name);
|
||||
ResourceMgr* rm = ctx->resource_manager();
|
||||
if (rm == nullptr) {
|
||||
return errors::Internal("No resource manager.");
|
||||
}
|
||||
auto* step_container = ctx->step_container();
|
||||
if (step_container == nullptr) {
|
||||
return errors::Internal("No step container.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack));
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
std::atomic<int64> Stack::stack_counter{0};
|
||||
|
||||
// StackOp
|
||||
|
||||
StackOp::StackOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("elem_type", &elem_type_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("stack_name", &stack_name_));
|
||||
if (stack_name_.empty()) stack_name_ = name();
|
||||
}
|
||||
|
||||
void StackOp::Compute(OpKernelContext* ctx) {
|
||||
int32 size = std::numeric_limits<int32>::max();
|
||||
if (ctx->num_inputs() > 0) {
|
||||
const Tensor* tensor_size;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("max_size", &tensor_size));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsScalar(tensor_size->shape()),
|
||||
errors::InvalidArgument("Stack size must be a scalar, but had shape: ",
|
||||
tensor_size->shape().DebugString()));
|
||||
|
||||
int32 size_value = tensor_size->scalar<int32>()();
|
||||
if (size_value >= 0) {
|
||||
size = size_value;
|
||||
}
|
||||
}
|
||||
|
||||
static const char kContainer[] = "_stacks";
|
||||
auto stack_id = Stack::stack_counter.fetch_add(1);
|
||||
string stack_name = strings::StrCat(stack_name_, "_", stack_id);
|
||||
// Store the handle in a per-step container.
|
||||
ResourceMgr* rm = ctx->resource_manager();
|
||||
OP_REQUIRES(ctx, rm != nullptr, errors::Internal("No resource manager."));
|
||||
string key = strings::StrCat(kContainer, stack_name);
|
||||
Stack* stack = new Stack(elem_type_, stack_name, size);
|
||||
auto* step_container = ctx->step_container();
|
||||
OP_REQUIRES(ctx, step_container != nullptr,
|
||||
errors::Internal("No step container."));
|
||||
OP_REQUIRES_OK(ctx, rm->Create(step_container->name(), key, stack));
|
||||
if (IsRefType(ctx->expected_output_dtype(0))) {
|
||||
// Create the stack handle.
|
||||
AllocatorAttributes alloc_attr;
|
||||
alloc_attr.set_on_host(true);
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_STRING,
|
||||
tensorflow::TensorShape({2}),
|
||||
&stack->handle_, alloc_attr));
|
||||
auto handle = stack->handle_.flat<string>();
|
||||
handle(0) = kContainer;
|
||||
handle(1) = std::move(stack_name);
|
||||
ctx->set_output_ref(0, stack->mu(), &stack->handle_);
|
||||
} else {
|
||||
Tensor* handle;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
|
||||
handle->flat<ResourceHandle>()(0) =
|
||||
MakePerStepResourceHandle<Stack>(ctx, key);
|
||||
}
|
||||
}
|
||||
|
||||
// StackPushOp
|
||||
|
||||
StackPushOp::StackPushOp(OpKernelConstruction* context, bool allow_swapping)
|
||||
: AsyncOpKernel(context) {
|
||||
if (allow_swapping) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("swap_memory", &swap_memory_));
|
||||
}
|
||||
}
|
||||
|
||||
void StackPushOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
// Get the stack from the handle.
|
||||
Stack* stack = nullptr;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, GetStack(ctx, &stack), done);
|
||||
core::ScopedUnref unref(stack);
|
||||
|
||||
if (ctx->input_dtype(1) != stack->ElemType()) {
|
||||
ctx->CtxFailure(errors::InvalidArgument("Must have type ",
|
||||
stack->ElemType(), " but got ",
|
||||
ctx->input_dtype(1)));
|
||||
done();
|
||||
return;
|
||||
}
|
||||
|
||||
// Push the tensor onto the stack. Swap the tensor to CPU if instructed.
|
||||
const Tensor& tensor = ctx->input(1);
|
||||
AllocatorAttributes alloc_attrs = ctx->input_alloc_attr(1);
|
||||
// For now, we use a simple heuristic for swapping: A GPU tensor is moved
|
||||
// to CPU if the tensor has more than kCopyThreshold bytes and the GPU
|
||||
// allocator says more than kOccupancy of the memory is in use.
|
||||
static constexpr int kCopyThreshold = 2048;
|
||||
static constexpr double kOccupancy = 0.7;
|
||||
if (swap_memory_ && !alloc_attrs.on_host() &&
|
||||
tensor.TotalBytes() > kCopyThreshold && stack->IsUsefulToSwap(tensor)) {
|
||||
DeviceContext* device_ctxt = ctx->op_device_context();
|
||||
auto device = static_cast<tensorflow::Device*>(ctx->device());
|
||||
Allocator* allocator = device->GetAllocator(alloc_attrs);
|
||||
AllocatorStats stats;
|
||||
allocator->GetStats(&stats);
|
||||
if (stats.bytes_in_use > (stats.bytes_limit * kOccupancy)) {
|
||||
// Asynchronously copy the tensor from GPU to CPU memory.
|
||||
// TODO(yuanbyu): Swap the oldest tensor first.
|
||||
AllocatorAttributes host_alloc_attrs;
|
||||
host_alloc_attrs.set_gpu_compatible(true);
|
||||
host_alloc_attrs.set_on_host(true);
|
||||
Allocator* cpu_allocator = device->GetAllocator(host_alloc_attrs);
|
||||
Tensor* cpu_tensor =
|
||||
new Tensor(cpu_allocator, tensor.dtype(), tensor.shape());
|
||||
device_ctxt->CopyDeviceTensorToCPU(
|
||||
&tensor, "StackPush", device, cpu_tensor,
|
||||
[cpu_tensor, stack, ctx, done](const Status& s) {
|
||||
ctx->SetStatus(s);
|
||||
if (s.ok()) {
|
||||
AllocatorAttributes alloc_attrs = ctx->input_alloc_attr(1);
|
||||
ctx->SetStatus(stack->Push({*cpu_tensor, alloc_attrs, true}));
|
||||
}
|
||||
if (ctx->status().ok()) {
|
||||
ctx->set_output(0, *cpu_tensor);
|
||||
}
|
||||
done();
|
||||
delete cpu_tensor;
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Execute synchronously if not swapped.
|
||||
OP_REQUIRES_OK_ASYNC(ctx, stack->Push({tensor, alloc_attrs, false}), done);
|
||||
ctx->set_output(0, tensor);
|
||||
done();
|
||||
}
|
||||
|
||||
bool StackPushOp::IsExpensive() { return false; }
|
||||
|
||||
// StackPopOp
|
||||
|
||||
StackPopOp::StackPopOp(OpKernelConstruction* context)
|
||||
: AsyncOpKernel(context) {}
|
||||
|
||||
void StackPopOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
// Get the stack from the handle.
|
||||
Stack* stack = nullptr;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, GetStack(ctx, &stack), done);
|
||||
core::ScopedUnref unref(stack);
|
||||
|
||||
// Pop the tensor. Transfer the tensor back to device if it was
|
||||
// swapped out to CPU.
|
||||
Stack::TensorAndAllocation value;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, stack->Pop(&value), done);
|
||||
if (value.swapped_to_cpu) {
|
||||
// Asynchronously copy the tensor back from CPU to GPU memory.
|
||||
DeviceContext* device_ctxt = ctx->op_device_context();
|
||||
Device* device = static_cast<Device*>(ctx->device());
|
||||
Tensor* cpu_tensor = &value.tensor;
|
||||
Allocator* gpu_allocator = device->GetAllocator(value.alloc_attrs);
|
||||
Tensor* device_tensor =
|
||||
new Tensor(gpu_allocator, cpu_tensor->dtype(), cpu_tensor->shape());
|
||||
device_ctxt->CopyCPUTensorToDevice(
|
||||
cpu_tensor, device, device_tensor,
|
||||
[device_tensor, ctx, done](const Status& s) {
|
||||
ctx->SetStatus(s);
|
||||
if (s.ok()) {
|
||||
ctx->set_output(0, *device_tensor);
|
||||
}
|
||||
done();
|
||||
delete device_tensor;
|
||||
});
|
||||
} else {
|
||||
// Execute synchronously if not swapped.
|
||||
ctx->set_output(0, value.tensor);
|
||||
done();
|
||||
}
|
||||
}
|
||||
|
||||
bool StackPopOp::IsExpensive() { return false; }
|
||||
|
||||
// StackCloseOp
|
||||
|
||||
StackCloseOp::StackCloseOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void StackCloseOp::Compute(OpKernelContext* ctx) {
|
||||
Stack* stack = nullptr;
|
||||
OP_REQUIRES_OK(ctx, GetStack(ctx, &stack));
|
||||
core::ScopedUnref unref(stack);
|
||||
stack->Close();
|
||||
}
|
||||
|
||||
bool StackCloseOp::IsExpensive() { return false; }
|
||||
|
||||
} // namespace tensorflow
|
76
tensorflow/core/kernels/stack.h
Normal file
76
tensorflow/core/kernels/stack.h
Normal file
@ -0,0 +1,76 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_STACK_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_STACK_H_
|
||||
|
||||
// See docs in ../ops/data_flow_ops.cc.
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A per-run local stack. The stack uses a "per-step" resource manager which
|
||||
// ensures that correct garbage collection on error or successful completion.
|
||||
class StackOp : public OpKernel {
|
||||
public:
|
||||
explicit StackOp(OpKernelConstruction* context);
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
DataType elem_type_;
|
||||
string stack_name_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackOp);
|
||||
};
|
||||
|
||||
class StackPushOp : public AsyncOpKernel {
|
||||
public:
|
||||
StackPushOp(OpKernelConstruction* context, bool allow_swapping);
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
|
||||
bool IsExpensive() override;
|
||||
|
||||
private:
|
||||
bool swap_memory_ = false;
|
||||
};
|
||||
|
||||
// Templated helper to make it easier to register kernels with or without
|
||||
// swapping.
|
||||
template <bool allow_swapping>
|
||||
class TemplatedStackPushOp : public StackPushOp {
|
||||
public:
|
||||
TemplatedStackPushOp(OpKernelConstruction* context)
|
||||
: StackPushOp(context, allow_swapping) {}
|
||||
};
|
||||
|
||||
class StackPopOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit StackPopOp(OpKernelConstruction* context);
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
|
||||
bool IsExpensive() override;
|
||||
};
|
||||
|
||||
class StackCloseOp : public OpKernel {
|
||||
public:
|
||||
explicit StackCloseOp(OpKernelConstruction* context);
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
bool IsExpensive() override;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_STACK_H_
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
// See docs in ../ops/data_flow_ops.cc.
|
||||
|
||||
#include "tensorflow/core/kernels/stack.h"
|
||||
|
||||
#include <limits.h>
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
@ -38,191 +40,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
typedef Eigen::SyclDevice SYCLDevice;
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
class Stack : public ResourceBase {
|
||||
public:
|
||||
static std::atomic<int64> stack_counter;
|
||||
|
||||
struct TensorAndAllocation {
|
||||
Tensor tensor;
|
||||
AllocatorAttributes alloc_attrs;
|
||||
bool swapped_to_cpu;
|
||||
};
|
||||
|
||||
Stack(const DataType& elem_type, const string& stack_name, int max_size)
|
||||
: elem_type_(elem_type),
|
||||
stack_name_(stack_name),
|
||||
max_size_(max_size),
|
||||
closed_(false) {}
|
||||
|
||||
Status Push(const TensorAndAllocation& value) {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||
if (max_size_ >= 0 && stack_.size() >= max_size_) {
|
||||
return errors::InvalidArgument("Stack[", stack_name_, "] overflowed ",
|
||||
"its max_size (", max_size_, ")");
|
||||
}
|
||||
stack_.push_back(value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Pop(TensorAndAllocation* value) {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(CheckNotClosed());
|
||||
if (stack_.empty()) {
|
||||
return errors::InvalidArgument("Stack[", stack_name_,
|
||||
"] is empty when calling Pop().");
|
||||
}
|
||||
*value = stack_.back();
|
||||
stack_.pop_back();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// We don't swap the first tensor on the stack and any subsequent tensors
|
||||
// that share the buffer with the first tensor.
|
||||
bool IsUsefulToSwap(const Tensor& tensor) const {
|
||||
mutex_lock l(mu_);
|
||||
if (stack_.empty()) {
|
||||
return false;
|
||||
}
|
||||
const Tensor& first = stack_.front().tensor;
|
||||
return !tensor.SharesBufferWith(first);
|
||||
}
|
||||
|
||||
void Close() {
|
||||
mutex_lock l(mu_);
|
||||
stack_.clear();
|
||||
closed_ = true;
|
||||
}
|
||||
|
||||
DataType ElemType() { return elem_type_; }
|
||||
|
||||
string DebugString() override {
|
||||
mutex_lock l(mu_);
|
||||
return strings::StrCat("Stack[", stack_name_, "]");
|
||||
}
|
||||
|
||||
const string& stack_name() { return stack_name_; }
|
||||
|
||||
private:
|
||||
friend class StackOp;
|
||||
mutex* mu() { return &mu_; }
|
||||
|
||||
mutable mutex mu_;
|
||||
DataType elem_type_;
|
||||
const string stack_name_;
|
||||
Tensor handle_;
|
||||
int max_size_;
|
||||
bool closed_ GUARDED_BY(mu_);
|
||||
std::vector<TensorAndAllocation> stack_ GUARDED_BY(mu_);
|
||||
|
||||
Status CheckNotClosed() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (closed_) {
|
||||
return errors::InvalidArgument("Stack[", stack_name_,
|
||||
"] has already been closed.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
Status GetStack(OpKernelContext* ctx, Stack** stack) {
|
||||
if (ctx->input_dtype(0) == DT_RESOURCE) {
|
||||
return LookupResource(ctx, HandleFromInput(ctx, 0), stack);
|
||||
} else {
|
||||
Tensor Tstack_handle = ctx->mutable_input(0, false);
|
||||
if (Tstack_handle.NumElements() != 2) {
|
||||
return errors::InvalidArgument(
|
||||
"Stack handle must have two elements, but had shape: ",
|
||||
Tstack_handle.shape().DebugString());
|
||||
}
|
||||
const string& container = Tstack_handle.flat<string>()(0);
|
||||
const string& stack_name = Tstack_handle.flat<string>()(1);
|
||||
string key = strings::StrCat(container, stack_name);
|
||||
ResourceMgr* rm = ctx->resource_manager();
|
||||
if (rm == nullptr) {
|
||||
return errors::Internal("No resource manager.");
|
||||
}
|
||||
auto* step_container = ctx->step_container();
|
||||
if (step_container == nullptr) {
|
||||
return errors::Internal("No step container.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack));
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
std::atomic<int64> Stack::stack_counter{0};
|
||||
|
||||
// A per-run local stack. The stack uses a "per-step" resource manager which
|
||||
// ensures that correct garbage collection on error or successful completion.
|
||||
class StackOp : public OpKernel {
|
||||
public:
|
||||
explicit StackOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("elem_type", &elem_type_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("stack_name", &stack_name_));
|
||||
if (stack_name_.empty()) stack_name_ = name();
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
int32 size = std::numeric_limits<int32>::max();
|
||||
if (ctx->num_inputs() > 0) {
|
||||
const Tensor* tensor_size;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("max_size", &tensor_size));
|
||||
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_size->shape()),
|
||||
errors::InvalidArgument(
|
||||
"Stack size must be a scalar, but had shape: ",
|
||||
tensor_size->shape().DebugString()));
|
||||
|
||||
int32 size_value = tensor_size->scalar<int32>()();
|
||||
if (size_value >= 0) {
|
||||
size = size_value;
|
||||
}
|
||||
}
|
||||
|
||||
static const char kContainer[] = "_stacks";
|
||||
auto stack_id = Stack::stack_counter.fetch_add(1);
|
||||
string stack_name = strings::StrCat(stack_name_, "_", stack_id);
|
||||
// Store the handle in a per-step container.
|
||||
ResourceMgr* rm = ctx->resource_manager();
|
||||
OP_REQUIRES(ctx, rm != nullptr, errors::Internal("No resource manager."));
|
||||
string key = strings::StrCat(kContainer, stack_name);
|
||||
Stack* stack = new Stack(elem_type_, stack_name, size);
|
||||
auto* step_container = ctx->step_container();
|
||||
OP_REQUIRES(ctx, step_container != nullptr,
|
||||
errors::Internal("No step container."));
|
||||
OP_REQUIRES_OK(ctx, rm->Create(step_container->name(), key, stack));
|
||||
if (IsRefType(ctx->expected_output_dtype(0))) {
|
||||
// Create the stack handle.
|
||||
AllocatorAttributes alloc_attr;
|
||||
alloc_attr.set_on_host(true);
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_STRING,
|
||||
tensorflow::TensorShape({2}),
|
||||
&stack->handle_, alloc_attr));
|
||||
auto handle = stack->handle_.flat<string>();
|
||||
handle(0) = kContainer;
|
||||
handle(1) = std::move(stack_name);
|
||||
ctx->set_output_ref(0, stack->mu(), &stack->handle_);
|
||||
} else {
|
||||
Tensor* handle;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
|
||||
handle->flat<ResourceHandle>()(0) =
|
||||
MakePerStepResourceHandle<Stack>(ctx, key);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
DataType elem_type_;
|
||||
string stack_name_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Stack").Device(DEVICE_CPU), StackOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("Stack").Device(DEVICE_GPU).HostMemory("handle"),
|
||||
StackOp);
|
||||
@ -242,102 +59,22 @@ REGISTER_KERNEL_BUILDER(Name("StackV2")
|
||||
StackOp);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
template <typename Device>
|
||||
class StackPushOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit StackPushOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("swap_memory", &swap_memory_));
|
||||
}
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
// Get the stack from the handle.
|
||||
Stack* stack = nullptr;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, GetStack(ctx, &stack), done);
|
||||
core::ScopedUnref unref(stack);
|
||||
|
||||
if (ctx->input_dtype(1) != stack->ElemType()) {
|
||||
ctx->CtxFailure(errors::InvalidArgument("Must have type ",
|
||||
stack->ElemType(), " but got ",
|
||||
ctx->input_dtype(1)));
|
||||
done();
|
||||
return;
|
||||
}
|
||||
|
||||
// Push the tensor onto the stack. Swap the tensor to CPU if instructed.
|
||||
const Tensor& tensor = ctx->input(1);
|
||||
AllocatorAttributes alloc_attrs = ctx->input_alloc_attr(1);
|
||||
// For now, we use a simple heuristic for swapping: A GPU tensor is moved
|
||||
// to CPU if the tensor has more than kCopyThreshold bytes and the GPU
|
||||
// allocator says more than kOccupancy of the memory is in use.
|
||||
static constexpr int kCopyThreshold = 2048;
|
||||
static constexpr double kOccupancy = 0.7;
|
||||
if (swap_memory_ && !alloc_attrs.on_host() &&
|
||||
(std::is_same<Device, GPUDevice>::value
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
|| std::is_same<Device, SYCLDevice>::value
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
) &&
|
||||
tensor.TotalBytes() > kCopyThreshold && stack->IsUsefulToSwap(tensor)) {
|
||||
DeviceContext* device_ctxt = ctx->op_device_context();
|
||||
auto device = static_cast<tensorflow::Device*>(ctx->device());
|
||||
Allocator* allocator = device->GetAllocator(alloc_attrs);
|
||||
AllocatorStats stats;
|
||||
allocator->GetStats(&stats);
|
||||
if (stats.bytes_in_use > (stats.bytes_limit * kOccupancy)) {
|
||||
// Asynchronously copy the tensor from GPU to CPU memory.
|
||||
// TODO(yuanbyu): Swap the oldest tensor first.
|
||||
AllocatorAttributes host_alloc_attrs;
|
||||
host_alloc_attrs.set_gpu_compatible(true);
|
||||
host_alloc_attrs.set_on_host(true);
|
||||
Allocator* cpu_allocator = device->GetAllocator(host_alloc_attrs);
|
||||
Tensor* cpu_tensor =
|
||||
new Tensor(cpu_allocator, tensor.dtype(), tensor.shape());
|
||||
device_ctxt->CopyDeviceTensorToCPU(
|
||||
&tensor, "StackPush", device, cpu_tensor,
|
||||
[cpu_tensor, stack, ctx, done](const Status& s) {
|
||||
ctx->SetStatus(s);
|
||||
if (s.ok()) {
|
||||
AllocatorAttributes alloc_attrs = ctx->input_alloc_attr(1);
|
||||
ctx->SetStatus(stack->Push({*cpu_tensor, alloc_attrs, true}));
|
||||
}
|
||||
if (ctx->status().ok()) {
|
||||
ctx->set_output(0, *cpu_tensor);
|
||||
}
|
||||
done();
|
||||
delete cpu_tensor;
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Execute synchronously if not swapped.
|
||||
OP_REQUIRES_OK_ASYNC(ctx, stack->Push({tensor, alloc_attrs, false}), done);
|
||||
ctx->set_output(0, tensor);
|
||||
done();
|
||||
}
|
||||
|
||||
bool IsExpensive() override { return false; }
|
||||
|
||||
private:
|
||||
bool swap_memory_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPush").Device(DEVICE_CPU),
|
||||
StackPushOp<CPUDevice>);
|
||||
TemplatedStackPushOp</*allow_swapping=*/false>);
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPushV2").Device(DEVICE_CPU),
|
||||
StackPushOp<CPUDevice>);
|
||||
TemplatedStackPushOp</*allow_swapping=*/false>);
|
||||
|
||||
#define REGISTER_GPU_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPush") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("handle") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
StackPushOp<GPUDevice>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPushV2") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("handle") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
StackPushOp<GPUDevice>);
|
||||
#define REGISTER_GPU_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPush") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("handle") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
TemplatedStackPushOp</*allow_swapping=*/true>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPushV2") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("handle") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
TemplatedStackPushOp</*allow_swapping=*/true>);
|
||||
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
@ -345,21 +82,21 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
|
||||
// Special GPU kernels for int32 and bool.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
// registration requires all int32 inputs and outputs to be in host memory.
|
||||
#define REGISTER_GPU_HOST_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPush") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("handle") \
|
||||
.HostMemory("elem") \
|
||||
.HostMemory("output") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
StackPushOp<GPUDevice>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPushV2") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("handle") \
|
||||
.HostMemory("elem") \
|
||||
.HostMemory("output") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
StackPushOp<GPUDevice>);
|
||||
#define REGISTER_GPU_HOST_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPush") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("handle") \
|
||||
.HostMemory("elem") \
|
||||
.HostMemory("output") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
TemplatedStackPushOp</*allow_swapping=*/true>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPushV2") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("handle") \
|
||||
.HostMemory("elem") \
|
||||
.HostMemory("output") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
TemplatedStackPushOp</*allow_swapping=*/true>);
|
||||
|
||||
REGISTER_GPU_HOST_KERNEL(int32);
|
||||
REGISTER_GPU_HOST_KERNEL(bool);
|
||||
@ -372,7 +109,7 @@ REGISTER_GPU_HOST_KERNEL(bool);
|
||||
.Device(DEVICE_SYCL) \
|
||||
.HostMemory("handle") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
StackPushOp<SYCLDevice>);
|
||||
TemplatedStackPushOp</*allow_swapping=*/true>);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
|
||||
|
||||
@ -383,7 +120,7 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL_KERNEL);
|
||||
.HostMemory("elem") \
|
||||
.HostMemory("output") \
|
||||
.TypeConstraint<type>("T"), \
|
||||
StackPushOp<SYCLDevice>)
|
||||
TemplatedStackPushOp</*allow_swapping=*/true>)
|
||||
|
||||
REGISTER_SYCL_HOST_KERNEL(int32);
|
||||
REGISTER_SYCL_HOST_KERNEL(bool);
|
||||
@ -391,48 +128,6 @@ REGISTER_SYCL_HOST_KERNEL(bool);
|
||||
#undef REGISTER_SYCL_HOST_KERNEL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
class StackPopOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit StackPopOp(OpKernelConstruction* context) : AsyncOpKernel(context) {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
// Get the stack from the handle.
|
||||
Stack* stack = nullptr;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, GetStack(ctx, &stack), done);
|
||||
core::ScopedUnref unref(stack);
|
||||
|
||||
// Pop the tensor. Transfer the tensor back to device if it was
|
||||
// swapped out to CPU.
|
||||
Stack::TensorAndAllocation value;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, stack->Pop(&value), done);
|
||||
if (value.swapped_to_cpu) {
|
||||
// Asynchronously copy the tensor back from CPU to GPU memory.
|
||||
DeviceContext* device_ctxt = ctx->op_device_context();
|
||||
Device* device = static_cast<Device*>(ctx->device());
|
||||
Tensor* cpu_tensor = &value.tensor;
|
||||
Allocator* gpu_allocator = device->GetAllocator(value.alloc_attrs);
|
||||
Tensor* device_tensor =
|
||||
new Tensor(gpu_allocator, cpu_tensor->dtype(), cpu_tensor->shape());
|
||||
device_ctxt->CopyCPUTensorToDevice(
|
||||
cpu_tensor, device, device_tensor,
|
||||
[device_tensor, ctx, done](const Status& s) {
|
||||
ctx->SetStatus(s);
|
||||
if (s.ok()) {
|
||||
ctx->set_output(0, *device_tensor);
|
||||
}
|
||||
done();
|
||||
delete device_tensor;
|
||||
});
|
||||
} else {
|
||||
// Execute synchronously if not swapped.
|
||||
ctx->set_output(0, value.tensor);
|
||||
done();
|
||||
}
|
||||
}
|
||||
|
||||
bool IsExpensive() override { return false; }
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPop").Device(DEVICE_CPU), StackPopOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPopV2").Device(DEVICE_CPU), StackPopOp);
|
||||
|
||||
@ -498,20 +193,6 @@ REGISTER_SYCL_HOST_KERNEL(bool);
|
||||
#undef REGISTER_SYCL_HOST_KERNEL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
class StackCloseOp : public OpKernel {
|
||||
public:
|
||||
explicit StackCloseOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Stack* stack = nullptr;
|
||||
OP_REQUIRES_OK(ctx, GetStack(ctx, &stack));
|
||||
core::ScopedUnref unref(stack);
|
||||
stack->Close();
|
||||
}
|
||||
|
||||
bool IsExpensive() override { return false; }
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StackClose").Device(DEVICE_CPU), StackCloseOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("StackClose").Device(DEVICE_GPU).HostMemory("handle"), StackCloseOp);
|
||||
|
Loading…
Reference in New Issue
Block a user