Remove unused tensor-reference-recording feature from the executor.

This change also removes the `Device::RequiresRecordingAccessedTensors()` and `Device::ConsumeListOfAccessedTensors()` methods.

Some device objects (historically, GPUs with experimental multi-stream support) required the ability to record which tensors were used during kernel execution. This support has bit-rotted since it was introduced, and causes runtime overhead for most devices that do not use the feature.

PiperOrigin-RevId: 300443774
Change-Id: Ia44ff65dee57f4d9f971f0079f79edd2fde2a1dc
This commit is contained in:
Derek Murray 2020-03-11 17:36:43 -07:00 committed by TensorFlower Gardener
parent aedb53e371
commit 06e20a2fe2
13 changed files with 13 additions and 240 deletions

View File

@ -27,14 +27,10 @@ namespace {
class DummyDevice : public DeviceBase { class DummyDevice : public DeviceBase {
public: public:
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} explicit DummyDevice(Env* env) : DeviceBase(env) {}
bool RequiresRecordingAccessedTensors() const override { return save_; }
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
return cpu_allocator(); return cpu_allocator();
} }
private:
bool save_;
}; };
void TestBitcastOp(Tensor* input_tensor, DataType out_type, void TestBitcastOp(Tensor* input_tensor, DataType out_type,
@ -61,7 +57,7 @@ void TestBitcastOp(Tensor* input_tensor, DataType out_type,
ASSERT_TRUE(status.ok()) << status.ToString(); ASSERT_TRUE(status.ok()) << status.ToString();
OpKernelContext::Params params; OpKernelContext::Params params;
DummyDevice dummy_device(nullptr, false); DummyDevice dummy_device(nullptr);
params.device = &dummy_device; params.device = &dummy_device;
params.op_kernel = kernel.get(); params.op_kernel = kernel.get();
gtl::InlinedVector<TensorValue, 4> inputs; gtl::InlinedVector<TensorValue, 4> inputs;

View File

@ -155,14 +155,10 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
class DummyDevice : public DeviceBase { class DummyDevice : public DeviceBase {
public: public:
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} explicit DummyDevice(Env* env) : DeviceBase(env) {}
bool RequiresRecordingAccessedTensors() const override { return save_; }
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
return cpu_allocator(); return cpu_allocator();
} }
private:
bool save_;
}; };
TEST(TestKernel, TestInputAndOutputCount) { TEST(TestKernel, TestInputAndOutputCount) {
@ -223,7 +219,7 @@ TEST(TestKernel, TestInputAndOutputCount) {
{ {
OpKernelContext::Params p; OpKernelContext::Params p;
DummyDevice dummy_device(nullptr, false); DummyDevice dummy_device(nullptr);
p.device = &dummy_device; p.device = &dummy_device;
p.step_id = 43; p.step_id = 43;

View File

@ -92,16 +92,6 @@ class Device : public DeviceBase {
op_kernel->ComputeAsync(context, std::move(done)); op_kernel->ComputeAsync(context, std::move(done));
} }
// Takes ownership of the references in tensors. If necessary, a
// device may override this method to keep a reference to the
// accessed tensors until the async computation has completed.
virtual void ConsumeListOfAccessedTensors(
DeviceContext* context, const TensorReferenceVector& tensors) {
for (const auto& ref : tensors) {
ref.Unref();
}
}
// Blocks until all operations queued on the device at the time of // Blocks until all operations queued on the device at the time of
// the call have completed. Returns any error pending on the device // the call have completed. Returns any error pending on the device
// at completion. // at completion.

View File

@ -121,12 +121,6 @@ void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
stats->SetMemory(ctx); stats->SetMemory(ctx);
} }
void SetReferencedTensors(NodeExecStatsInterface* stats,
const TensorReferenceVector& tensors) {
if (!stats) return;
stats->SetReferencedTensors(tensors);
}
} // namespace nodestats } // namespace nodestats
class ExecutorImpl; class ExecutorImpl;
@ -403,9 +397,6 @@ class ExecutorImpl : public Executor {
LocalExecutorParams params_; LocalExecutorParams params_;
GraphView gview_; GraphView gview_;
// A cached value of params_
bool device_record_tensor_accesses_ = false;
// Root nodes (with no in edges) that should form the initial ready queue // Root nodes (with no in edges) that should form the initial ready queue
std::vector<const NodeItem*> root_nodes_; std::vector<const NodeItem*> root_nodes_;
@ -639,11 +630,6 @@ Status ExecutorImpl::Initialize(const Graph& graph) {
ControlFlowInfo cf_info; ControlFlowInfo cf_info;
TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph, &cf_info)); TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph, &cf_info));
// Cache this value so we make this virtual function call once, rather
// that O(# steps * # nodes per step) times.
device_record_tensor_accesses_ =
params_.device->RequiresRecordingAccessedTensors();
for (auto& it : cf_info.unique_frame_names) { for (auto& it : cf_info.unique_frame_names) {
EnsureFrameInfo(it)->nodes = new std::vector<const NodeItem*>; EnsureFrameInfo(it)->nodes = new std::vector<const NodeItem*>;
} }
@ -1444,8 +1430,6 @@ class ExecutorState {
Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params, Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params,
EntryVector* outputs, EntryVector* outputs,
TensorReferenceVector* accessed_tensors,
DeviceContext** device_context,
NodeExecStatsInterface* stats); NodeExecStatsInterface* stats);
void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params, void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params,
const TaggedNode& tagged_node, Entry* first_input, const TaggedNode& tagged_node, Entry* first_input,
@ -1751,8 +1735,6 @@ bool MightTrace(const NodeItem& item,
Status ExecutorState::ProcessSync(const NodeItem& item, Status ExecutorState::ProcessSync(const NodeItem& item,
OpKernelContext::Params* params, OpKernelContext::Params* params,
EntryVector* outputs, EntryVector* outputs,
TensorReferenceVector* accessed_tensors,
DeviceContext** device_context,
NodeExecStatsInterface* stats) { NodeExecStatsInterface* stats) {
Status s; Status s;
OpKernelContext ctx(params, item.num_outputs); OpKernelContext ctx(params, item.num_outputs);
@ -1785,11 +1767,6 @@ Status ExecutorState::ProcessSync(const NodeItem& item,
nodestats::SetOpEnd(stats); nodestats::SetOpEnd(stats);
s = ProcessOutputs(item, &ctx, outputs, stats); s = ProcessOutputs(item, &ctx, outputs, stats);
} }
if (TF_PREDICT_FALSE(impl_->device_record_tensor_accesses_) && s.ok()) {
// Get the list of all tensors accessed during the execution
ctx.retrieve_accessed_tensors(accessed_tensors);
*device_context = ctx.op_device_context();
}
nodestats::SetMemory(stats, &ctx); nodestats::SetMemory(stats, &ctx);
return s; return s;
} }
@ -1833,15 +1810,6 @@ void ExecutorState::ProcessAsync(const NodeItem& item,
PropagateOutputs(state->tagged_node, state->item, &outputs, &ready); PropagateOutputs(state->tagged_node, state->item, &outputs, &ready);
} }
outputs.clear(); outputs.clear();
if (TF_PREDICT_FALSE(impl_->device_record_tensor_accesses_) && s.ok()) {
// Get the list of all tensors accessed during the execution
TensorReferenceVector accessed;
state->ctx.retrieve_accessed_tensors(&accessed);
nodestats::SetReferencedTensors(stats, accessed);
// callee takes ownership of the vector
device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
accessed);
}
const bool completed = NodeDone(s, &ready, stats, nullptr); const bool completed = NodeDone(s, &ready, stats, nullptr);
delete state; delete state;
if (completed) ScheduleFinish(); if (completed) ScheduleFinish();
@ -1905,7 +1873,6 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
params.device = device; params.device = device;
} }
params.log_memory = log_memory_; params.log_memory = log_memory_;
params.record_tensor_accesses = impl_->device_record_tensor_accesses_;
params.rendezvous = rendezvous_; params.rendezvous = rendezvous_;
params.collective_executor = collective_executor_; params.collective_executor = collective_executor_;
params.session_state = session_state_; params.session_state = session_state_;
@ -1988,7 +1955,6 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
Entry* first_input = input_tensors + item.input_start; Entry* first_input = input_tensors + item.input_start;
outputs.clear(); outputs.clear();
TensorReferenceVector accessed_tensors;
// Only execute this node if it is not dead or it is a send/recv // Only execute this node if it is not dead or it is a send/recv
// transfer node. For transfer nodes, we need to propagate the "dead" // transfer node. For transfer nodes, we need to propagate the "dead"
// bit even when the node is dead. // bit even when the node is dead.
@ -2028,15 +1994,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
ProcessAsync(item, params, tagged_node, first_input, stats); ProcessAsync(item, params, tagged_node, first_input, stats);
launched_asynchronously = true; launched_asynchronously = true;
} else { } else {
DeviceContext* device_context = nullptr; s = ProcessSync(item, &params, &outputs, stats);
s = ProcessSync(item, &params, &outputs, &accessed_tensors,
&device_context, stats);
if (!accessed_tensors.empty()) {
nodestats::SetReferencedTensors(stats, accessed_tensors);
// device_context is set above in `ProcessSync()`.
device->ConsumeListOfAccessedTensors(device_context,
accessed_tensors);
}
} }
} }

View File

@ -462,13 +462,6 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
return Status::OK(); return Status::OK();
} }
bool BaseGPUDevice::RequiresRecordingAccessedTensors() const {
// Since there is only one stream, we release the tensor reference
// at the end of the kernel launch, instead of at the end of the kernel
// execution.
return false;
}
string BaseGPUDevice::ComputeOpKernelDebugString(const OpKernel& op_kernel, string BaseGPUDevice::ComputeOpKernelDebugString(const OpKernel& op_kernel,
const int& stream_id) { const int& stream_id) {
return strings::StrCat(op_kernel.name(), " op ", op_kernel.type_string(), return strings::StrCat(op_kernel.name(), " op ", op_kernel.type_string(),
@ -541,16 +534,6 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
} }
} }
void BaseGPUDevice::ConsumeListOfAccessedTensors(
DeviceContext* device_context, const TensorReferenceVector& tensor_refs) {
GPUDeviceContext* gpu_device_context = device_context_;
if (device_context != nullptr) {
gpu_device_context = static_cast<GPUDeviceContext*>(device_context);
}
se::Stream* stream = gpu_device_context->stream();
em_->ThenDeleteTensors(stream, tensor_refs);
}
// Based on the semantics of Device::Sync this call should wait for // Based on the semantics of Device::Sync this call should wait for
// all streams not just the current one. // all streams not just the current one.
Status BaseGPUDevice::Sync() { return GPUUtil::SyncAll(this); } Status BaseGPUDevice::Sync() { return GPUUtil::SyncAll(this); }

View File

@ -62,15 +62,6 @@ class BaseGPUDevice : public LocalDevice {
// Initialize the device and return the status of initialization. // Initialize the device and return the status of initialization.
Status Init(const SessionOptions& options); Status Init(const SessionOptions& options);
// GPU devices require the Op Compute method to save a reference to
// any temporary tensors that are allocated until the Op execution
// completes.
bool RequiresRecordingAccessedTensors() const override;
void ConsumeListOfAccessedTensors(
DeviceContext* device_context,
const TensorReferenceVector& tensor_refs) override;
void Compute(OpKernel* op_kernel, OpKernelContext* context) override; void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
Status Sync() override; Status Sync() override;

View File

@ -426,7 +426,6 @@ class EMBenchmarkHelper {
params->step_id = 1; params->step_id = 1;
params->device = gpu_helper_->gpu(); params->device = gpu_helper_->gpu();
params->log_memory = false; params->log_memory = false;
params->record_tensor_accesses = false;
params->rendezvous = nullptr; params->rendezvous = nullptr;
params->collective_executor = nullptr; params->collective_executor = nullptr;
params->session_state = nullptr; // ??? params->session_state = nullptr; // ???

View File

@ -36,11 +36,6 @@ class RenamedDevice : public Device {
~RenamedDevice() override; ~RenamedDevice() override;
// Below are virtual methods defined on DeviceBase
bool RequiresRecordingAccessedTensors() const override {
return underlying_device_->RequiresRecordingAccessedTensors();
}
const DeviceBase* UnderlyingDevice() const override { const DeviceBase* UnderlyingDevice() const override {
return underlying_device_->UnderlyingDevice(); return underlying_device_->UnderlyingDevice();
} }
@ -138,11 +133,6 @@ class RenamedDevice : public Device {
underlying_device_->ComputeAsync(op_kernel, context, std::move(done)); underlying_device_->ComputeAsync(op_kernel, context, std::move(done));
} }
void ConsumeListOfAccessedTensors(
DeviceContext* context, const TensorReferenceVector& tensors) override {
underlying_device_->ConsumeListOfAccessedTensors(context, tensors);
}
Status Sync() override { return underlying_device_->Sync(); } Status Sync() override { return underlying_device_->Sync(); }
Status MaybeRewriteGraph(std::unique_ptr<Graph>* graph) override { Status MaybeRewriteGraph(std::unique_ptr<Graph>* graph) override {

View File

@ -120,11 +120,6 @@ class DeviceBase {
Env* env() const { return env_; } Env* env() const { return env_; }
// Override this to return true for devices that require an Op's
// compute method to save references to the temporary tensors it
// allocates until the Op execution completes
virtual bool RequiresRecordingAccessedTensors() const { return false; }
struct CpuWorkerThreads { struct CpuWorkerThreads {
int num_threads = 0; int num_threads = 0;
thread::ThreadPool* workers = nullptr; thread::ThreadPool* workers = nullptr;

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/node_properties.h" #include "tensorflow/core/framework/node_properties.h"
#include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/notification.h"
@ -222,7 +223,6 @@ Tensor* PersistentTensor::AccessTensor(OpKernelConstruction* context) {
} }
Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) { Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) {
context->NotifyUseOfPersistentTensor(tensor_);
return &tensor_; return &tensor_;
} }
@ -331,7 +331,7 @@ OpKernelContext::OpKernelContext(Params* params)
OpKernelContext::OpKernelContext(Params* params, int num_outputs) OpKernelContext::OpKernelContext(Params* params, int num_outputs)
: params_(params), outputs_(num_outputs) { : params_(params), outputs_(num_outputs) {
if (params_->record_tensor_accesses || params_->track_allocations) { if (params_->track_allocations) {
tracking_state_ = absl::make_unique<TrackingState>(); tracking_state_ = absl::make_unique<TrackingState>();
} }
@ -393,13 +393,6 @@ void OpKernelContext::SetStatus(const Status& status) {
status_.Update(status); status_.Update(status);
} }
void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) {
DCHECK(tracking_state_);
mutex_lock l(tracking_state_->mu);
// Keep a reference to the underlying memory around.
tracking_state_->referenced_tensors.Add(tensor);
}
Status OpKernelContext::input(StringPiece name, const Tensor** tensor) { Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
int start, stop; int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
@ -414,7 +407,6 @@ Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
"' when non-ref input was expected"); "' when non-ref input was expected");
} }
*tensor = (*params_->inputs)[start].tensor; *tensor = (*params_->inputs)[start].tensor;
record_tensor_reference(**tensor);
return Status::OK(); return Status::OK();
} }
@ -449,7 +441,6 @@ const Tensor& OpKernelContext::input(int index) {
CHECK_LT(index, num_inputs()) << " name: " << op_kernel().name(); CHECK_LT(index, num_inputs()) << " name: " << op_kernel().name();
CHECK(!input_is_ref(index)); CHECK(!input_is_ref(index));
const Tensor& tensor = *((*params_->inputs)[index].tensor); const Tensor& tensor = *((*params_->inputs)[index].tensor);
record_tensor_reference(tensor);
return tensor; return tensor;
} }
@ -460,12 +451,10 @@ Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
// return a copy of the Ref acquired while holding the mutex // return a copy of the Ref acquired while holding the mutex
if (lock_held) { if (lock_held) {
Tensor& tensor = *((*params_->inputs)[index].tensor); Tensor& tensor = *((*params_->inputs)[index].tensor);
record_tensor_reference(tensor);
return tensor; return tensor;
} else { } else {
tf_shared_lock l(*input_ref_mutex(index)); tf_shared_lock l(*input_ref_mutex(index));
Tensor& tensor = *((*params_->inputs)[index].tensor); Tensor& tensor = *((*params_->inputs)[index].tensor);
record_tensor_reference(tensor);
return tensor; return tensor;
} }
} }
@ -482,7 +471,6 @@ void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
mutex_lock l(*input_ref_mutex(index)); mutex_lock l(*input_ref_mutex(index));
*(*params_->inputs)[index].tensor = tensor; *(*params_->inputs)[index].tensor = tensor;
} }
record_tensor_reference(tensor);
} }
void OpKernelContext::forward_ref_input_to_ref_output(int input_index, void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
@ -658,7 +646,6 @@ Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
tf_shared_lock l(*input_ref_mutex(start)); tf_shared_lock l(*input_ref_mutex(start));
*tensor = *(*params_->inputs)[start].tensor; *tensor = *(*params_->inputs)[start].tensor;
} }
record_tensor_reference(*tensor);
return Status::OK(); return Status::OK();
} }
@ -781,7 +768,6 @@ Status OpKernelContext::allocate_tensor(
LogMemory::RecordTensorAllocation(params_->op_kernel->name(), LogMemory::RecordTensorAllocation(params_->op_kernel->name(),
params_->step_id, new_tensor); params_->step_id, new_tensor);
} }
record_tensor_reference(new_tensor);
*out_tensor = std::move(new_tensor); *out_tensor = std::move(new_tensor);
return Status::OK(); return Status::OK();
} }
@ -969,7 +955,6 @@ void OpKernelContext::set_output(int index, const Tensor& tensor) {
} else { } else {
// Input can be forwarded to output; incref on `tensor` and set output at // Input can be forwarded to output; incref on `tensor` and set output at
// `index` to this tensor. // `index` to this tensor.
record_tensor_reference(tensor);
outputs_[index] = TensorValue(new Tensor(tensor)); outputs_[index] = TensorValue(new Tensor(tensor));
if (track_allocations() && tensor.TotalBytes() > 0) { if (track_allocations() && tensor.TotalBytes() > 0) {
DCHECK(tracking_state_); DCHECK(tracking_state_);
@ -994,7 +979,6 @@ void OpKernelContext::set_output_ref(int index, mutex* mu,
CHECK_GE(index, 0); CHECK_GE(index, 0);
CHECK_LT(index, outputs_.size()); CHECK_LT(index, outputs_.size());
CHECK(IsRefType(params_->op_kernel->output_type(index))); CHECK(IsRefType(params_->op_kernel->output_type(index)));
record_tensor_reference(*tensor_for_ref);
outputs_[index] = TensorValue(mu, tensor_for_ref); outputs_[index] = TensorValue(mu, tensor_for_ref);
} }

View File

@ -673,7 +673,6 @@ class OpKernelContext {
bool track_allocations = false; bool track_allocations = false;
bool log_memory = false; bool log_memory = false;
bool record_tensor_accesses = false;
// Array indexed by output number for this node // Array indexed by output number for this node
const AllocatorAttributes* output_attr_array = nullptr; const AllocatorAttributes* output_attr_array = nullptr;
@ -1221,11 +1220,6 @@ class OpKernelContext {
// TODO(tucker): Add example usage. // TODO(tucker): Add example usage.
DeviceBase* device() const { return params_->device; } DeviceBase* device() const { return params_->device; }
// Retrieve list of referenced tensors in out_vector. Once this is
// called, it is not legal to reference any more tensors. Should
// not be called from Op kernels.
void retrieve_accessed_tensors(TensorReferenceVector* out_vector);
// Per-step container for use by white-listed internal ops. // Per-step container for use by white-listed internal ops.
ScopedStepContainer* step_container() const { ScopedStepContainer* step_container() const {
return params_->step_container; return params_->step_container;
@ -1308,13 +1302,6 @@ class OpKernelContext {
private: private:
bool record_memory_consumption_ = false; bool record_memory_consumption_ = false;
// Internal method to add a tensor's buffer to the list of buffers
// referenced during the execution of the Op, so that GPUs may
// accurately track the memory that may not be reused until the Op
// execution completes.
void record_tensor_reference(const Tensor& tensor);
void really_record_tensor_reference(const Tensor& tensor);
// Internal common method used when allocating tensor memory // Internal common method used when allocating tensor memory
Status allocate_tensor(DataType type, const TensorShape& shape, Status allocate_tensor(DataType type, const TensorShape& shape,
Tensor* out_tensor, Tensor* out_tensor,
@ -1331,14 +1318,6 @@ class OpKernelContext {
// called. // called.
void maybe_initialize_scope_id_set(); void maybe_initialize_scope_id_set();
// This is called by PersistentTensor::AccessTensor whenever the
// wrapped tensor is retrieved, to ensure the runtime knows that the
// Tensor is being accessed within an Op. This is necessary for
// memory safety of devices like GPUs that queue Ops for
// asynchronous execution after the Compute() method completes.
friend class PersistentTensor;
void NotifyUseOfPersistentTensor(const Tensor& tensor);
Status status_; Status status_;
friend class CollectiveExecutor; // for access to params_ friend class CollectiveExecutor; // for access to params_
Params* params_; // not owned Params* params_; // not owned
@ -1356,8 +1335,6 @@ class OpKernelContext {
gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators
TF_GUARDED_BY(mu); TF_GUARDED_BY(mu);
UniqueTensorReferences referenced_tensors TF_GUARDED_BY(mu);
mutable mutex stats_mu; mutable mutex stats_mu;
int64 temp_memory_allocated TF_GUARDED_BY(stats_mu) = 0; int64 temp_memory_allocated TF_GUARDED_BY(stats_mu) = 0;
@ -1658,23 +1635,6 @@ inline bool OpKernelContext::input_is_ref(int index) const {
return value.is_ref(); return value.is_ref();
} }
inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) {
DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(),
params_->record_tensor_accesses);
if (params_->record_tensor_accesses) {
really_record_tensor_reference(tensor);
}
}
inline void OpKernelContext::retrieve_accessed_tensors(
TensorReferenceVector* out_vector) {
if (params_->record_tensor_accesses) {
DCHECK(tracking_state_);
mutex_lock l(tracking_state_->mu);
tracking_state_->referenced_tensors.FreezeAndReturnReferences(out_vector);
}
}
// no input if tensor == nullptr. // no input if tensor == nullptr.
inline bool OpKernelContext::has_input(int index) const { inline bool OpKernelContext::has_input(int index) const {
DCHECK_GE(index, 0); DCHECK_GE(index, 0);
@ -1689,17 +1649,9 @@ inline mutex* OpKernelContext::input_ref_mutex(int index) {
return (*params_->inputs)[index].mutex_if_ref; return (*params_->inputs)[index].mutex_if_ref;
} }
inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) {
if (t.IsInitialized()) {
record_tensor_reference(t);
}
}
inline Tensor* OpKernelContext::mutable_output(int index) { inline Tensor* OpKernelContext::mutable_output(int index) {
DCHECK_GE(index, 0); DCHECK_GE(index, 0);
DCHECK_LT(index, num_outputs()); DCHECK_LT(index, num_outputs());
// No need to record_tensor_reference since the output must already
// have been set by a call that did so.
return outputs_[index].tensor; return outputs_[index].tensor;
} }

View File

@ -348,74 +348,17 @@ TEST_F(OpKernelTest, MatchSignatureFailes) {
class DummyDevice : public DeviceBase { class DummyDevice : public DeviceBase {
public: public:
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} explicit DummyDevice(Env* env) : DeviceBase(env) {}
bool RequiresRecordingAccessedTensors() const override { return save_; }
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
return cpu_allocator(); return cpu_allocator();
} }
private:
bool save_;
}; };
TEST_F(OpKernelTest, SaveTempFalse) {
Env* env = Env::Default();
OpKernelContext::Params params;
params.record_tensor_accesses = false;
auto device =
absl::make_unique<DummyDevice>(env, params.record_tensor_accesses);
params.device = device.get();
Status status;
std::unique_ptr<OpKernel> op(
CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
TF_GRAPH_DEF_VERSION, &status));
EXPECT_TRUE(status.ok());
params.op_kernel = op.get();
auto ctx = absl::make_unique<OpKernelContext>(&params);
Tensor t;
TF_EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
TensorReferenceVector referenced_tensors;
ctx->retrieve_accessed_tensors(&referenced_tensors);
EXPECT_EQ(0, referenced_tensors.size());
}
TEST_F(OpKernelTest, SaveTempTrue) {
Env* env = Env::Default();
OpKernelContext::Params params;
params.record_tensor_accesses = true;
auto device =
absl::make_unique<DummyDevice>(env, params.record_tensor_accesses);
params.device = device.get();
Status status;
std::unique_ptr<OpKernel> op(
CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
TF_GRAPH_DEF_VERSION, &status));
EXPECT_TRUE(status.ok());
params.op_kernel = op.get();
auto ctx = absl::make_unique<OpKernelContext>(&params);
Tensor t;
TF_EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
TensorReferenceVector referenced_tensors;
ctx->retrieve_accessed_tensors(&referenced_tensors);
EXPECT_EQ(1, referenced_tensors.size());
for (auto& ref : referenced_tensors) {
ref.Unref();
}
}
TEST_F(OpKernelTest, InputDtype) { TEST_F(OpKernelTest, InputDtype) {
Env* env = Env::Default(); Env* env = Env::Default();
OpKernelContext::Params params; OpKernelContext::Params params;
params.record_tensor_accesses = false; DummyDevice device(env);
auto device = params.device = &device;
absl::make_unique<DummyDevice>(env, params.record_tensor_accesses);
params.device = device.get();
Status status; Status status;
std::unique_ptr<OpKernel> op( std::unique_ptr<OpKernel> op(
CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
@ -499,7 +442,6 @@ class ScopedAllocatorDevice : public DeviceBase {
TEST_F(OpKernelTest, ScopedAllocationTest) { TEST_F(OpKernelTest, ScopedAllocationTest) {
Env* env = Env::Default(); Env* env = Env::Default();
OpKernelContext::Params params; OpKernelContext::Params params;
params.record_tensor_accesses = false;
auto sa_device = absl::make_unique<ScopedAllocatorDevice>(env); auto sa_device = absl::make_unique<ScopedAllocatorDevice>(env);
params.device = sa_device.get(); params.device = sa_device.get();
Status status; Status status;
@ -788,10 +730,8 @@ REGISTER_KERNEL_BUILDER(Name("ListOut").Device(tensorflow::DEVICE_CPU),
TEST_F(OpKernelBuilderTest, OpOutputList) { TEST_F(OpKernelBuilderTest, OpOutputList) {
Env* env = Env::Default(); Env* env = Env::Default();
OpKernelContext::Params params; OpKernelContext::Params params;
params.record_tensor_accesses = false; DummyDevice device(env);
auto device = params.device = &device;
absl::make_unique<DummyDevice>(env, params.record_tensor_accesses);
params.device = device.get();
Status status; Status status;
std::unique_ptr<OpKernel> op(CreateOpKernel( std::unique_ptr<OpKernel> op(CreateOpKernel(
DEVICE_CPU, params.device, cpu_allocator(), DEVICE_CPU, params.device, cpu_allocator(),
@ -1066,7 +1006,7 @@ void BM_InputRangeHelper(int iters, const NodeDef& node_def,
const char* input_name, int expected_start, const char* input_name, int expected_start,
int expected_stop) { int expected_stop) {
Status status; Status status;
auto device = absl::make_unique<DummyDevice>(Env::Default(), false); auto device = absl::make_unique<DummyDevice>(Env::Default());
std::unique_ptr<OpKernel> op(CreateOpKernel(DEVICE_CPU, device.get(), std::unique_ptr<OpKernel> op(CreateOpKernel(DEVICE_CPU, device.get(),
cpu_allocator(), node_def, cpu_allocator(), node_def,

View File

@ -243,7 +243,6 @@ class SingleThreadedExecutorImpl : public Executor {
Device* device = params_.device; Device* device = params_.device;
params.device = device; params.device = device;
params.log_memory = false; // TODO(mrry): Too severe? params.log_memory = false; // TODO(mrry): Too severe?
params.record_tensor_accesses = false; // TODO(mrry): Too severe?
params.rendezvous = args.rendezvous; params.rendezvous = args.rendezvous;
params.session_state = args.session_state; params.session_state = args.session_state;
params.tensor_store = args.tensor_store; params.tensor_store = args.tensor_store;