From 06e20a2fe22e28e00a0868f35bbbf10dfa345659 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 11 Mar 2020 17:36:43 -0700 Subject: [PATCH] Remove unused tensor-reference-recording feature from the executor. This change also removes the `Device::RequiresRecordingAccessedTensors()` and `Device::ConsumeListOfAccessedTensors()` methods. Some device objects (historically, GPUs with experimental multi-stream support) required the ability to record which tensors were used during kernel execution. This support has bit-rotted since it was introduced, and causes runtime overhead for most devices that do not use the feature. PiperOrigin-RevId: 300443774 Change-Id: Ia44ff65dee57f4d9f971f0079f79edd2fde2a1dc --- tensorflow/c/kernels/bitcast_op_test.cc | 8 +-- tensorflow/c/kernels_test.cc | 8 +-- tensorflow/core/common_runtime/device.h | 10 --- tensorflow/core/common_runtime/executor.cc | 44 +----------- .../core/common_runtime/gpu/gpu_device.cc | 17 ----- .../core/common_runtime/gpu/gpu_device.h | 9 --- .../common_runtime/gpu/gpu_event_mgr_test.cc | 1 - .../core/common_runtime/renamed_device.h | 10 --- tensorflow/core/framework/device_base.h | 5 -- tensorflow/core/framework/op_kernel.cc | 20 +----- tensorflow/core/framework/op_kernel.h | 48 ------------- tensorflow/core/framework/op_kernel_test.cc | 72 ++----------------- .../kernels/data/single_threaded_executor.cc | 1 - 13 files changed, 13 insertions(+), 240 deletions(-) diff --git a/tensorflow/c/kernels/bitcast_op_test.cc b/tensorflow/c/kernels/bitcast_op_test.cc index 7da27e99d1f..33028ea6bd9 100644 --- a/tensorflow/c/kernels/bitcast_op_test.cc +++ b/tensorflow/c/kernels/bitcast_op_test.cc @@ -27,14 +27,10 @@ namespace { class DummyDevice : public DeviceBase { public: - DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} - bool RequiresRecordingAccessedTensors() const override { return save_; } + explicit DummyDevice(Env* env) : DeviceBase(env) {} Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { return cpu_allocator(); } - - private: - bool save_; }; void TestBitcastOp(Tensor* input_tensor, DataType out_type, @@ -61,7 +57,7 @@ void TestBitcastOp(Tensor* input_tensor, DataType out_type, ASSERT_TRUE(status.ok()) << status.ToString(); OpKernelContext::Params params; - DummyDevice dummy_device(nullptr, false); + DummyDevice dummy_device(nullptr); params.device = &dummy_device; params.op_kernel = kernel.get(); gtl::InlinedVector inputs; diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index 80e90e7cdf9..423302741de 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -155,14 +155,10 @@ TEST(TestKernel, TestRegisterKernelBuilder) { class DummyDevice : public DeviceBase { public: - DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} - bool RequiresRecordingAccessedTensors() const override { return save_; } + explicit DummyDevice(Env* env) : DeviceBase(env) {} Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { return cpu_allocator(); } - - private: - bool save_; }; TEST(TestKernel, TestInputAndOutputCount) { @@ -223,7 +219,7 @@ TEST(TestKernel, TestInputAndOutputCount) { { OpKernelContext::Params p; - DummyDevice dummy_device(nullptr, false); + DummyDevice dummy_device(nullptr); p.device = &dummy_device; p.step_id = 43; diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h index 9b7dadadc3a..13877933ce6 100644 --- a/tensorflow/core/common_runtime/device.h +++ b/tensorflow/core/common_runtime/device.h @@ -92,16 +92,6 @@ class Device : public DeviceBase { op_kernel->ComputeAsync(context, std::move(done)); } - // Takes ownership of the references in tensors. If necessary, a - // device may override this method to keep a reference to the - // accessed tensors until the async computation has completed. - virtual void ConsumeListOfAccessedTensors( - DeviceContext* context, const TensorReferenceVector& tensors) { - for (const auto& ref : tensors) { - ref.Unref(); - } - } - // Blocks until all operations queued on the device at the time of // the call have completed. Returns any error pending on the device // at completion. diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 995c88cf9a2..e6aa8fc3ed2 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -121,12 +121,6 @@ void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) { stats->SetMemory(ctx); } -void SetReferencedTensors(NodeExecStatsInterface* stats, - const TensorReferenceVector& tensors) { - if (!stats) return; - stats->SetReferencedTensors(tensors); -} - } // namespace nodestats class ExecutorImpl; @@ -403,9 +397,6 @@ class ExecutorImpl : public Executor { LocalExecutorParams params_; GraphView gview_; - // A cached value of params_ - bool device_record_tensor_accesses_ = false; - // Root nodes (with no in edges) that should form the initial ready queue std::vector root_nodes_; @@ -639,11 +630,6 @@ Status ExecutorImpl::Initialize(const Graph& graph) { ControlFlowInfo cf_info; TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph, &cf_info)); - // Cache this value so we make this virtual function call once, rather - // that O(# steps * # nodes per step) times. - device_record_tensor_accesses_ = - params_.device->RequiresRecordingAccessedTensors(); - for (auto& it : cf_info.unique_frame_names) { EnsureFrameInfo(it)->nodes = new std::vector; } @@ -1444,8 +1430,6 @@ class ExecutorState { Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params, EntryVector* outputs, - TensorReferenceVector* accessed_tensors, - DeviceContext** device_context, NodeExecStatsInterface* stats); void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params, const TaggedNode& tagged_node, Entry* first_input, @@ -1751,8 +1735,6 @@ bool MightTrace(const NodeItem& item, Status ExecutorState::ProcessSync(const NodeItem& item, OpKernelContext::Params* params, EntryVector* outputs, - TensorReferenceVector* accessed_tensors, - DeviceContext** device_context, NodeExecStatsInterface* stats) { Status s; OpKernelContext ctx(params, item.num_outputs); @@ -1785,11 +1767,6 @@ Status ExecutorState::ProcessSync(const NodeItem& item, nodestats::SetOpEnd(stats); s = ProcessOutputs(item, &ctx, outputs, stats); } - if (TF_PREDICT_FALSE(impl_->device_record_tensor_accesses_) && s.ok()) { - // Get the list of all tensors accessed during the execution - ctx.retrieve_accessed_tensors(accessed_tensors); - *device_context = ctx.op_device_context(); - } nodestats::SetMemory(stats, &ctx); return s; } @@ -1833,15 +1810,6 @@ void ExecutorState::ProcessAsync(const NodeItem& item, PropagateOutputs(state->tagged_node, state->item, &outputs, &ready); } outputs.clear(); - if (TF_PREDICT_FALSE(impl_->device_record_tensor_accesses_) && s.ok()) { - // Get the list of all tensors accessed during the execution - TensorReferenceVector accessed; - state->ctx.retrieve_accessed_tensors(&accessed); - nodestats::SetReferencedTensors(stats, accessed); - // callee takes ownership of the vector - device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(), - accessed); - } const bool completed = NodeDone(s, &ready, stats, nullptr); delete state; if (completed) ScheduleFinish(); @@ -1905,7 +1873,6 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { params.device = device; } params.log_memory = log_memory_; - params.record_tensor_accesses = impl_->device_record_tensor_accesses_; params.rendezvous = rendezvous_; params.collective_executor = collective_executor_; params.session_state = session_state_; @@ -1988,7 +1955,6 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { Entry* first_input = input_tensors + item.input_start; outputs.clear(); - TensorReferenceVector accessed_tensors; // Only execute this node if it is not dead or it is a send/recv // transfer node. For transfer nodes, we need to propagate the "dead" // bit even when the node is dead. @@ -2028,15 +1994,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { ProcessAsync(item, params, tagged_node, first_input, stats); launched_asynchronously = true; } else { - DeviceContext* device_context = nullptr; - s = ProcessSync(item, ¶ms, &outputs, &accessed_tensors, - &device_context, stats); - if (!accessed_tensors.empty()) { - nodestats::SetReferencedTensors(stats, accessed_tensors); - // device_context is set above in `ProcessSync()`. - device->ConsumeListOfAccessedTensors(device_context, - accessed_tensors); - } + s = ProcessSync(item, ¶ms, &outputs, stats); } } diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index d72a99f3ca7..dcc40c3d3de 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -462,13 +462,6 @@ Status BaseGPUDevice::Init(const SessionOptions& options) { return Status::OK(); } -bool BaseGPUDevice::RequiresRecordingAccessedTensors() const { - // Since there is only one stream, we release the tensor reference - // at the end of the kernel launch, instead of at the end of the kernel - // execution. - return false; -} - string BaseGPUDevice::ComputeOpKernelDebugString(const OpKernel& op_kernel, const int& stream_id) { return strings::StrCat(op_kernel.name(), " op ", op_kernel.type_string(), @@ -541,16 +534,6 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { } } -void BaseGPUDevice::ConsumeListOfAccessedTensors( - DeviceContext* device_context, const TensorReferenceVector& tensor_refs) { - GPUDeviceContext* gpu_device_context = device_context_; - if (device_context != nullptr) { - gpu_device_context = static_cast(device_context); - } - se::Stream* stream = gpu_device_context->stream(); - em_->ThenDeleteTensors(stream, tensor_refs); -} - // Based on the semantics of Device::Sync this call should wait for // all streams not just the current one. Status BaseGPUDevice::Sync() { return GPUUtil::SyncAll(this); } diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h index eec2c099279..3646c59cec1 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.h +++ b/tensorflow/core/common_runtime/gpu/gpu_device.h @@ -62,15 +62,6 @@ class BaseGPUDevice : public LocalDevice { // Initialize the device and return the status of initialization. Status Init(const SessionOptions& options); - // GPU devices require the Op Compute method to save a reference to - // any temporary tensors that are allocated until the Op execution - // completes. - bool RequiresRecordingAccessedTensors() const override; - - void ConsumeListOfAccessedTensors( - DeviceContext* device_context, - const TensorReferenceVector& tensor_refs) override; - void Compute(OpKernel* op_kernel, OpKernelContext* context) override; Status Sync() override; diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc index 680aec1ab29..7d67a3a34fc 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc @@ -426,7 +426,6 @@ class EMBenchmarkHelper { params->step_id = 1; params->device = gpu_helper_->gpu(); params->log_memory = false; - params->record_tensor_accesses = false; params->rendezvous = nullptr; params->collective_executor = nullptr; params->session_state = nullptr; // ??? diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h index c2ddd3a8dd1..cbec750e86c 100644 --- a/tensorflow/core/common_runtime/renamed_device.h +++ b/tensorflow/core/common_runtime/renamed_device.h @@ -36,11 +36,6 @@ class RenamedDevice : public Device { ~RenamedDevice() override; - // Below are virtual methods defined on DeviceBase - bool RequiresRecordingAccessedTensors() const override { - return underlying_device_->RequiresRecordingAccessedTensors(); - } - const DeviceBase* UnderlyingDevice() const override { return underlying_device_->UnderlyingDevice(); } @@ -138,11 +133,6 @@ class RenamedDevice : public Device { underlying_device_->ComputeAsync(op_kernel, context, std::move(done)); } - void ConsumeListOfAccessedTensors( - DeviceContext* context, const TensorReferenceVector& tensors) override { - underlying_device_->ConsumeListOfAccessedTensors(context, tensors); - } - Status Sync() override { return underlying_device_->Sync(); } Status MaybeRewriteGraph(std::unique_ptr* graph) override { diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index b8890dd069b..eba64a6b41e 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -120,11 +120,6 @@ class DeviceBase { Env* env() const { return env_; } - // Override this to return true for devices that require an Op's - // compute method to save references to the temporary tensors it - // allocates until the Op execution completes - virtual bool RequiresRecordingAccessedTensors() const { return false; } - struct CpuWorkerThreads { int num_threads = 0; thread::ThreadPool* workers = nullptr; diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index cc1bb8b9c77..15d55cc19d0 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/node_properties.h" #include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/framework/tensor_reference.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/notification.h" @@ -222,7 +223,6 @@ Tensor* PersistentTensor::AccessTensor(OpKernelConstruction* context) { } Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) { - context->NotifyUseOfPersistentTensor(tensor_); return &tensor_; } @@ -331,7 +331,7 @@ OpKernelContext::OpKernelContext(Params* params) OpKernelContext::OpKernelContext(Params* params, int num_outputs) : params_(params), outputs_(num_outputs) { - if (params_->record_tensor_accesses || params_->track_allocations) { + if (params_->track_allocations) { tracking_state_ = absl::make_unique(); } @@ -393,13 +393,6 @@ void OpKernelContext::SetStatus(const Status& status) { status_.Update(status); } -void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) { - DCHECK(tracking_state_); - mutex_lock l(tracking_state_->mu); - // Keep a reference to the underlying memory around. - tracking_state_->referenced_tensors.Add(tensor); -} - Status OpKernelContext::input(StringPiece name, const Tensor** tensor) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); @@ -414,7 +407,6 @@ Status OpKernelContext::input(StringPiece name, const Tensor** tensor) { "' when non-ref input was expected"); } *tensor = (*params_->inputs)[start].tensor; - record_tensor_reference(**tensor); return Status::OK(); } @@ -449,7 +441,6 @@ const Tensor& OpKernelContext::input(int index) { CHECK_LT(index, num_inputs()) << " name: " << op_kernel().name(); CHECK(!input_is_ref(index)); const Tensor& tensor = *((*params_->inputs)[index].tensor); - record_tensor_reference(tensor); return tensor; } @@ -460,12 +451,10 @@ Tensor OpKernelContext::mutable_input(int index, bool lock_held) { // return a copy of the Ref acquired while holding the mutex if (lock_held) { Tensor& tensor = *((*params_->inputs)[index].tensor); - record_tensor_reference(tensor); return tensor; } else { tf_shared_lock l(*input_ref_mutex(index)); Tensor& tensor = *((*params_->inputs)[index].tensor); - record_tensor_reference(tensor); return tensor; } } @@ -482,7 +471,6 @@ void OpKernelContext::replace_ref_input(int index, const Tensor& tensor, mutex_lock l(*input_ref_mutex(index)); *(*params_->inputs)[index].tensor = tensor; } - record_tensor_reference(tensor); } void OpKernelContext::forward_ref_input_to_ref_output(int input_index, @@ -658,7 +646,6 @@ Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor, tf_shared_lock l(*input_ref_mutex(start)); *tensor = *(*params_->inputs)[start].tensor; } - record_tensor_reference(*tensor); return Status::OK(); } @@ -781,7 +768,6 @@ Status OpKernelContext::allocate_tensor( LogMemory::RecordTensorAllocation(params_->op_kernel->name(), params_->step_id, new_tensor); } - record_tensor_reference(new_tensor); *out_tensor = std::move(new_tensor); return Status::OK(); } @@ -969,7 +955,6 @@ void OpKernelContext::set_output(int index, const Tensor& tensor) { } else { // Input can be forwarded to output; incref on `tensor` and set output at // `index` to this tensor. - record_tensor_reference(tensor); outputs_[index] = TensorValue(new Tensor(tensor)); if (track_allocations() && tensor.TotalBytes() > 0) { DCHECK(tracking_state_); @@ -994,7 +979,6 @@ void OpKernelContext::set_output_ref(int index, mutex* mu, CHECK_GE(index, 0); CHECK_LT(index, outputs_.size()); CHECK(IsRefType(params_->op_kernel->output_type(index))); - record_tensor_reference(*tensor_for_ref); outputs_[index] = TensorValue(mu, tensor_for_ref); } diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index e274e7ee196..40891bb4b28 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -673,7 +673,6 @@ class OpKernelContext { bool track_allocations = false; bool log_memory = false; - bool record_tensor_accesses = false; // Array indexed by output number for this node const AllocatorAttributes* output_attr_array = nullptr; @@ -1221,11 +1220,6 @@ class OpKernelContext { // TODO(tucker): Add example usage. DeviceBase* device() const { return params_->device; } - // Retrieve list of referenced tensors in out_vector. Once this is - // called, it is not legal to reference any more tensors. Should - // not be called from Op kernels. - void retrieve_accessed_tensors(TensorReferenceVector* out_vector); - // Per-step container for use by white-listed internal ops. ScopedStepContainer* step_container() const { return params_->step_container; @@ -1308,13 +1302,6 @@ class OpKernelContext { private: bool record_memory_consumption_ = false; - // Internal method to add a tensor's buffer to the list of buffers - // referenced during the execution of the Op, so that GPUs may - // accurately track the memory that may not be reused until the Op - // execution completes. - void record_tensor_reference(const Tensor& tensor); - void really_record_tensor_reference(const Tensor& tensor); - // Internal common method used when allocating tensor memory Status allocate_tensor(DataType type, const TensorShape& shape, Tensor* out_tensor, @@ -1331,14 +1318,6 @@ class OpKernelContext { // called. void maybe_initialize_scope_id_set(); - // This is called by PersistentTensor::AccessTensor whenever the - // wrapped tensor is retrieved, to ensure the runtime knows that the - // Tensor is being accessed within an Op. This is necessary for - // memory safety of devices like GPUs that queue Ops for - // asynchronous execution after the Compute() method completes. - friend class PersistentTensor; - void NotifyUseOfPersistentTensor(const Tensor& tensor); - Status status_; friend class CollectiveExecutor; // for access to params_ Params* params_; // not owned @@ -1356,8 +1335,6 @@ class OpKernelContext { gtl::InlinedVector wrapped_allocators TF_GUARDED_BY(mu); - UniqueTensorReferences referenced_tensors TF_GUARDED_BY(mu); - mutable mutex stats_mu; int64 temp_memory_allocated TF_GUARDED_BY(stats_mu) = 0; @@ -1658,23 +1635,6 @@ inline bool OpKernelContext::input_is_ref(int index) const { return value.is_ref(); } -inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) { - DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(), - params_->record_tensor_accesses); - if (params_->record_tensor_accesses) { - really_record_tensor_reference(tensor); - } -} - -inline void OpKernelContext::retrieve_accessed_tensors( - TensorReferenceVector* out_vector) { - if (params_->record_tensor_accesses) { - DCHECK(tracking_state_); - mutex_lock l(tracking_state_->mu); - tracking_state_->referenced_tensors.FreezeAndReturnReferences(out_vector); - } -} - // no input if tensor == nullptr. inline bool OpKernelContext::has_input(int index) const { DCHECK_GE(index, 0); @@ -1689,17 +1649,9 @@ inline mutex* OpKernelContext::input_ref_mutex(int index) { return (*params_->inputs)[index].mutex_if_ref; } -inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) { - if (t.IsInitialized()) { - record_tensor_reference(t); - } -} - inline Tensor* OpKernelContext::mutable_output(int index) { DCHECK_GE(index, 0); DCHECK_LT(index, num_outputs()); - // No need to record_tensor_reference since the output must already - // have been set by a call that did so. return outputs_[index].tensor; } diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index 40425cf24e0..94b502f3f71 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -348,74 +348,17 @@ TEST_F(OpKernelTest, MatchSignatureFailes) { class DummyDevice : public DeviceBase { public: - DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} - bool RequiresRecordingAccessedTensors() const override { return save_; } + explicit DummyDevice(Env* env) : DeviceBase(env) {} Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { return cpu_allocator(); } - - private: - bool save_; }; -TEST_F(OpKernelTest, SaveTempFalse) { - Env* env = Env::Default(); - OpKernelContext::Params params; - params.record_tensor_accesses = false; - auto device = - absl::make_unique(env, params.record_tensor_accesses); - params.device = device.get(); - Status status; - std::unique_ptr op( - CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), - CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), - TF_GRAPH_DEF_VERSION, &status)); - EXPECT_TRUE(status.ok()); - params.op_kernel = op.get(); - auto ctx = absl::make_unique(¶ms); - - Tensor t; - TF_EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t)); - - TensorReferenceVector referenced_tensors; - ctx->retrieve_accessed_tensors(&referenced_tensors); - EXPECT_EQ(0, referenced_tensors.size()); -} - -TEST_F(OpKernelTest, SaveTempTrue) { - Env* env = Env::Default(); - OpKernelContext::Params params; - params.record_tensor_accesses = true; - auto device = - absl::make_unique(env, params.record_tensor_accesses); - params.device = device.get(); - Status status; - std::unique_ptr op( - CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), - CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), - TF_GRAPH_DEF_VERSION, &status)); - EXPECT_TRUE(status.ok()); - params.op_kernel = op.get(); - auto ctx = absl::make_unique(¶ms); - - Tensor t; - TF_EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t)); - - TensorReferenceVector referenced_tensors; - ctx->retrieve_accessed_tensors(&referenced_tensors); - EXPECT_EQ(1, referenced_tensors.size()); - for (auto& ref : referenced_tensors) { - ref.Unref(); - } -} - TEST_F(OpKernelTest, InputDtype) { Env* env = Env::Default(); OpKernelContext::Params params; - params.record_tensor_accesses = false; - auto device = - absl::make_unique(env, params.record_tensor_accesses); - params.device = device.get(); + DummyDevice device(env); + params.device = &device; Status status; std::unique_ptr op( CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), @@ -499,7 +442,6 @@ class ScopedAllocatorDevice : public DeviceBase { TEST_F(OpKernelTest, ScopedAllocationTest) { Env* env = Env::Default(); OpKernelContext::Params params; - params.record_tensor_accesses = false; auto sa_device = absl::make_unique(env); params.device = sa_device.get(); Status status; @@ -788,10 +730,8 @@ REGISTER_KERNEL_BUILDER(Name("ListOut").Device(tensorflow::DEVICE_CPU), TEST_F(OpKernelBuilderTest, OpOutputList) { Env* env = Env::Default(); OpKernelContext::Params params; - params.record_tensor_accesses = false; - auto device = - absl::make_unique(env, params.record_tensor_accesses); - params.device = device.get(); + DummyDevice device(env); + params.device = &device; Status status; std::unique_ptr op(CreateOpKernel( DEVICE_CPU, params.device, cpu_allocator(), @@ -1066,7 +1006,7 @@ void BM_InputRangeHelper(int iters, const NodeDef& node_def, const char* input_name, int expected_start, int expected_stop) { Status status; - auto device = absl::make_unique(Env::Default(), false); + auto device = absl::make_unique(Env::Default()); std::unique_ptr op(CreateOpKernel(DEVICE_CPU, device.get(), cpu_allocator(), node_def, diff --git a/tensorflow/core/kernels/data/single_threaded_executor.cc b/tensorflow/core/kernels/data/single_threaded_executor.cc index f708213bf70..770ffe61f92 100644 --- a/tensorflow/core/kernels/data/single_threaded_executor.cc +++ b/tensorflow/core/kernels/data/single_threaded_executor.cc @@ -243,7 +243,6 @@ class SingleThreadedExecutorImpl : public Executor { Device* device = params_.device; params.device = device; params.log_memory = false; // TODO(mrry): Too severe? - params.record_tensor_accesses = false; // TODO(mrry): Too severe? params.rendezvous = args.rendezvous; params.session_state = args.session_state; params.tensor_store = args.tensor_store;