Remove unused tensor-reference-recording feature from the executor.
This change also removes the `Device::RequiresRecordingAccessedTensors()` and `Device::ConsumeListOfAccessedTensors()` methods. Some device objects (historically, GPUs with experimental multi-stream support) required the ability to record which tensors were used during kernel execution. This support has bit-rotted since it was introduced, and causes runtime overhead for most devices that do not use the feature. PiperOrigin-RevId: 300443774 Change-Id: Ia44ff65dee57f4d9f971f0079f79edd2fde2a1dc
This commit is contained in:
parent
aedb53e371
commit
06e20a2fe2
@ -27,14 +27,10 @@ namespace {
|
||||
|
||||
class DummyDevice : public DeviceBase {
|
||||
public:
|
||||
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
|
||||
bool RequiresRecordingAccessedTensors() const override { return save_; }
|
||||
explicit DummyDevice(Env* env) : DeviceBase(env) {}
|
||||
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
|
||||
return cpu_allocator();
|
||||
}
|
||||
|
||||
private:
|
||||
bool save_;
|
||||
};
|
||||
|
||||
void TestBitcastOp(Tensor* input_tensor, DataType out_type,
|
||||
@ -61,7 +57,7 @@ void TestBitcastOp(Tensor* input_tensor, DataType out_type,
|
||||
ASSERT_TRUE(status.ok()) << status.ToString();
|
||||
|
||||
OpKernelContext::Params params;
|
||||
DummyDevice dummy_device(nullptr, false);
|
||||
DummyDevice dummy_device(nullptr);
|
||||
params.device = &dummy_device;
|
||||
params.op_kernel = kernel.get();
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
|
@ -155,14 +155,10 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
|
||||
|
||||
class DummyDevice : public DeviceBase {
|
||||
public:
|
||||
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
|
||||
bool RequiresRecordingAccessedTensors() const override { return save_; }
|
||||
explicit DummyDevice(Env* env) : DeviceBase(env) {}
|
||||
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
|
||||
return cpu_allocator();
|
||||
}
|
||||
|
||||
private:
|
||||
bool save_;
|
||||
};
|
||||
|
||||
TEST(TestKernel, TestInputAndOutputCount) {
|
||||
@ -223,7 +219,7 @@ TEST(TestKernel, TestInputAndOutputCount) {
|
||||
|
||||
{
|
||||
OpKernelContext::Params p;
|
||||
DummyDevice dummy_device(nullptr, false);
|
||||
DummyDevice dummy_device(nullptr);
|
||||
p.device = &dummy_device;
|
||||
p.step_id = 43;
|
||||
|
||||
|
@ -92,16 +92,6 @@ class Device : public DeviceBase {
|
||||
op_kernel->ComputeAsync(context, std::move(done));
|
||||
}
|
||||
|
||||
// Takes ownership of the references in tensors. If necessary, a
|
||||
// device may override this method to keep a reference to the
|
||||
// accessed tensors until the async computation has completed.
|
||||
virtual void ConsumeListOfAccessedTensors(
|
||||
DeviceContext* context, const TensorReferenceVector& tensors) {
|
||||
for (const auto& ref : tensors) {
|
||||
ref.Unref();
|
||||
}
|
||||
}
|
||||
|
||||
// Blocks until all operations queued on the device at the time of
|
||||
// the call have completed. Returns any error pending on the device
|
||||
// at completion.
|
||||
|
@ -121,12 +121,6 @@ void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
|
||||
stats->SetMemory(ctx);
|
||||
}
|
||||
|
||||
void SetReferencedTensors(NodeExecStatsInterface* stats,
|
||||
const TensorReferenceVector& tensors) {
|
||||
if (!stats) return;
|
||||
stats->SetReferencedTensors(tensors);
|
||||
}
|
||||
|
||||
} // namespace nodestats
|
||||
|
||||
class ExecutorImpl;
|
||||
@ -403,9 +397,6 @@ class ExecutorImpl : public Executor {
|
||||
LocalExecutorParams params_;
|
||||
GraphView gview_;
|
||||
|
||||
// A cached value of params_
|
||||
bool device_record_tensor_accesses_ = false;
|
||||
|
||||
// Root nodes (with no in edges) that should form the initial ready queue
|
||||
std::vector<const NodeItem*> root_nodes_;
|
||||
|
||||
@ -639,11 +630,6 @@ Status ExecutorImpl::Initialize(const Graph& graph) {
|
||||
ControlFlowInfo cf_info;
|
||||
TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph, &cf_info));
|
||||
|
||||
// Cache this value so we make this virtual function call once, rather
|
||||
// that O(# steps * # nodes per step) times.
|
||||
device_record_tensor_accesses_ =
|
||||
params_.device->RequiresRecordingAccessedTensors();
|
||||
|
||||
for (auto& it : cf_info.unique_frame_names) {
|
||||
EnsureFrameInfo(it)->nodes = new std::vector<const NodeItem*>;
|
||||
}
|
||||
@ -1444,8 +1430,6 @@ class ExecutorState {
|
||||
|
||||
Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params,
|
||||
EntryVector* outputs,
|
||||
TensorReferenceVector* accessed_tensors,
|
||||
DeviceContext** device_context,
|
||||
NodeExecStatsInterface* stats);
|
||||
void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params,
|
||||
const TaggedNode& tagged_node, Entry* first_input,
|
||||
@ -1751,8 +1735,6 @@ bool MightTrace(const NodeItem& item,
|
||||
Status ExecutorState::ProcessSync(const NodeItem& item,
|
||||
OpKernelContext::Params* params,
|
||||
EntryVector* outputs,
|
||||
TensorReferenceVector* accessed_tensors,
|
||||
DeviceContext** device_context,
|
||||
NodeExecStatsInterface* stats) {
|
||||
Status s;
|
||||
OpKernelContext ctx(params, item.num_outputs);
|
||||
@ -1785,11 +1767,6 @@ Status ExecutorState::ProcessSync(const NodeItem& item,
|
||||
nodestats::SetOpEnd(stats);
|
||||
s = ProcessOutputs(item, &ctx, outputs, stats);
|
||||
}
|
||||
if (TF_PREDICT_FALSE(impl_->device_record_tensor_accesses_) && s.ok()) {
|
||||
// Get the list of all tensors accessed during the execution
|
||||
ctx.retrieve_accessed_tensors(accessed_tensors);
|
||||
*device_context = ctx.op_device_context();
|
||||
}
|
||||
nodestats::SetMemory(stats, &ctx);
|
||||
return s;
|
||||
}
|
||||
@ -1833,15 +1810,6 @@ void ExecutorState::ProcessAsync(const NodeItem& item,
|
||||
PropagateOutputs(state->tagged_node, state->item, &outputs, &ready);
|
||||
}
|
||||
outputs.clear();
|
||||
if (TF_PREDICT_FALSE(impl_->device_record_tensor_accesses_) && s.ok()) {
|
||||
// Get the list of all tensors accessed during the execution
|
||||
TensorReferenceVector accessed;
|
||||
state->ctx.retrieve_accessed_tensors(&accessed);
|
||||
nodestats::SetReferencedTensors(stats, accessed);
|
||||
// callee takes ownership of the vector
|
||||
device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
|
||||
accessed);
|
||||
}
|
||||
const bool completed = NodeDone(s, &ready, stats, nullptr);
|
||||
delete state;
|
||||
if (completed) ScheduleFinish();
|
||||
@ -1905,7 +1873,6 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
|
||||
params.device = device;
|
||||
}
|
||||
params.log_memory = log_memory_;
|
||||
params.record_tensor_accesses = impl_->device_record_tensor_accesses_;
|
||||
params.rendezvous = rendezvous_;
|
||||
params.collective_executor = collective_executor_;
|
||||
params.session_state = session_state_;
|
||||
@ -1988,7 +1955,6 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
|
||||
Entry* first_input = input_tensors + item.input_start;
|
||||
outputs.clear();
|
||||
|
||||
TensorReferenceVector accessed_tensors;
|
||||
// Only execute this node if it is not dead or it is a send/recv
|
||||
// transfer node. For transfer nodes, we need to propagate the "dead"
|
||||
// bit even when the node is dead.
|
||||
@ -2028,15 +1994,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
|
||||
ProcessAsync(item, params, tagged_node, first_input, stats);
|
||||
launched_asynchronously = true;
|
||||
} else {
|
||||
DeviceContext* device_context = nullptr;
|
||||
s = ProcessSync(item, ¶ms, &outputs, &accessed_tensors,
|
||||
&device_context, stats);
|
||||
if (!accessed_tensors.empty()) {
|
||||
nodestats::SetReferencedTensors(stats, accessed_tensors);
|
||||
// device_context is set above in `ProcessSync()`.
|
||||
device->ConsumeListOfAccessedTensors(device_context,
|
||||
accessed_tensors);
|
||||
}
|
||||
s = ProcessSync(item, ¶ms, &outputs, stats);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -462,13 +462,6 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool BaseGPUDevice::RequiresRecordingAccessedTensors() const {
|
||||
// Since there is only one stream, we release the tensor reference
|
||||
// at the end of the kernel launch, instead of at the end of the kernel
|
||||
// execution.
|
||||
return false;
|
||||
}
|
||||
|
||||
string BaseGPUDevice::ComputeOpKernelDebugString(const OpKernel& op_kernel,
|
||||
const int& stream_id) {
|
||||
return strings::StrCat(op_kernel.name(), " op ", op_kernel.type_string(),
|
||||
@ -541,16 +534,6 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
|
||||
}
|
||||
}
|
||||
|
||||
void BaseGPUDevice::ConsumeListOfAccessedTensors(
|
||||
DeviceContext* device_context, const TensorReferenceVector& tensor_refs) {
|
||||
GPUDeviceContext* gpu_device_context = device_context_;
|
||||
if (device_context != nullptr) {
|
||||
gpu_device_context = static_cast<GPUDeviceContext*>(device_context);
|
||||
}
|
||||
se::Stream* stream = gpu_device_context->stream();
|
||||
em_->ThenDeleteTensors(stream, tensor_refs);
|
||||
}
|
||||
|
||||
// Based on the semantics of Device::Sync this call should wait for
|
||||
// all streams not just the current one.
|
||||
Status BaseGPUDevice::Sync() { return GPUUtil::SyncAll(this); }
|
||||
|
@ -62,15 +62,6 @@ class BaseGPUDevice : public LocalDevice {
|
||||
// Initialize the device and return the status of initialization.
|
||||
Status Init(const SessionOptions& options);
|
||||
|
||||
// GPU devices require the Op Compute method to save a reference to
|
||||
// any temporary tensors that are allocated until the Op execution
|
||||
// completes.
|
||||
bool RequiresRecordingAccessedTensors() const override;
|
||||
|
||||
void ConsumeListOfAccessedTensors(
|
||||
DeviceContext* device_context,
|
||||
const TensorReferenceVector& tensor_refs) override;
|
||||
|
||||
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
|
||||
|
||||
Status Sync() override;
|
||||
|
@ -426,7 +426,6 @@ class EMBenchmarkHelper {
|
||||
params->step_id = 1;
|
||||
params->device = gpu_helper_->gpu();
|
||||
params->log_memory = false;
|
||||
params->record_tensor_accesses = false;
|
||||
params->rendezvous = nullptr;
|
||||
params->collective_executor = nullptr;
|
||||
params->session_state = nullptr; // ???
|
||||
|
@ -36,11 +36,6 @@ class RenamedDevice : public Device {
|
||||
|
||||
~RenamedDevice() override;
|
||||
|
||||
// Below are virtual methods defined on DeviceBase
|
||||
bool RequiresRecordingAccessedTensors() const override {
|
||||
return underlying_device_->RequiresRecordingAccessedTensors();
|
||||
}
|
||||
|
||||
const DeviceBase* UnderlyingDevice() const override {
|
||||
return underlying_device_->UnderlyingDevice();
|
||||
}
|
||||
@ -138,11 +133,6 @@ class RenamedDevice : public Device {
|
||||
underlying_device_->ComputeAsync(op_kernel, context, std::move(done));
|
||||
}
|
||||
|
||||
void ConsumeListOfAccessedTensors(
|
||||
DeviceContext* context, const TensorReferenceVector& tensors) override {
|
||||
underlying_device_->ConsumeListOfAccessedTensors(context, tensors);
|
||||
}
|
||||
|
||||
Status Sync() override { return underlying_device_->Sync(); }
|
||||
|
||||
Status MaybeRewriteGraph(std::unique_ptr<Graph>* graph) override {
|
||||
|
@ -120,11 +120,6 @@ class DeviceBase {
|
||||
|
||||
Env* env() const { return env_; }
|
||||
|
||||
// Override this to return true for devices that require an Op's
|
||||
// compute method to save references to the temporary tensors it
|
||||
// allocates until the Op execution completes
|
||||
virtual bool RequiresRecordingAccessedTensors() const { return false; }
|
||||
|
||||
struct CpuWorkerThreads {
|
||||
int num_threads = 0;
|
||||
thread::ThreadPool* workers = nullptr;
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/node_properties.h"
|
||||
#include "tensorflow/core/framework/op_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor_reference.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
@ -222,7 +223,6 @@ Tensor* PersistentTensor::AccessTensor(OpKernelConstruction* context) {
|
||||
}
|
||||
|
||||
Tensor* PersistentTensor::AccessTensor(OpKernelContext* context) {
|
||||
context->NotifyUseOfPersistentTensor(tensor_);
|
||||
return &tensor_;
|
||||
}
|
||||
|
||||
@ -331,7 +331,7 @@ OpKernelContext::OpKernelContext(Params* params)
|
||||
|
||||
OpKernelContext::OpKernelContext(Params* params, int num_outputs)
|
||||
: params_(params), outputs_(num_outputs) {
|
||||
if (params_->record_tensor_accesses || params_->track_allocations) {
|
||||
if (params_->track_allocations) {
|
||||
tracking_state_ = absl::make_unique<TrackingState>();
|
||||
}
|
||||
|
||||
@ -393,13 +393,6 @@ void OpKernelContext::SetStatus(const Status& status) {
|
||||
status_.Update(status);
|
||||
}
|
||||
|
||||
void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) {
|
||||
DCHECK(tracking_state_);
|
||||
mutex_lock l(tracking_state_->mu);
|
||||
// Keep a reference to the underlying memory around.
|
||||
tracking_state_->referenced_tensors.Add(tensor);
|
||||
}
|
||||
|
||||
Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
|
||||
int start, stop;
|
||||
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
|
||||
@ -414,7 +407,6 @@ Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
|
||||
"' when non-ref input was expected");
|
||||
}
|
||||
*tensor = (*params_->inputs)[start].tensor;
|
||||
record_tensor_reference(**tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -449,7 +441,6 @@ const Tensor& OpKernelContext::input(int index) {
|
||||
CHECK_LT(index, num_inputs()) << " name: " << op_kernel().name();
|
||||
CHECK(!input_is_ref(index));
|
||||
const Tensor& tensor = *((*params_->inputs)[index].tensor);
|
||||
record_tensor_reference(tensor);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
@ -460,12 +451,10 @@ Tensor OpKernelContext::mutable_input(int index, bool lock_held) {
|
||||
// return a copy of the Ref acquired while holding the mutex
|
||||
if (lock_held) {
|
||||
Tensor& tensor = *((*params_->inputs)[index].tensor);
|
||||
record_tensor_reference(tensor);
|
||||
return tensor;
|
||||
} else {
|
||||
tf_shared_lock l(*input_ref_mutex(index));
|
||||
Tensor& tensor = *((*params_->inputs)[index].tensor);
|
||||
record_tensor_reference(tensor);
|
||||
return tensor;
|
||||
}
|
||||
}
|
||||
@ -482,7 +471,6 @@ void OpKernelContext::replace_ref_input(int index, const Tensor& tensor,
|
||||
mutex_lock l(*input_ref_mutex(index));
|
||||
*(*params_->inputs)[index].tensor = tensor;
|
||||
}
|
||||
record_tensor_reference(tensor);
|
||||
}
|
||||
|
||||
void OpKernelContext::forward_ref_input_to_ref_output(int input_index,
|
||||
@ -658,7 +646,6 @@ Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
|
||||
tf_shared_lock l(*input_ref_mutex(start));
|
||||
*tensor = *(*params_->inputs)[start].tensor;
|
||||
}
|
||||
record_tensor_reference(*tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -781,7 +768,6 @@ Status OpKernelContext::allocate_tensor(
|
||||
LogMemory::RecordTensorAllocation(params_->op_kernel->name(),
|
||||
params_->step_id, new_tensor);
|
||||
}
|
||||
record_tensor_reference(new_tensor);
|
||||
*out_tensor = std::move(new_tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
@ -969,7 +955,6 @@ void OpKernelContext::set_output(int index, const Tensor& tensor) {
|
||||
} else {
|
||||
// Input can be forwarded to output; incref on `tensor` and set output at
|
||||
// `index` to this tensor.
|
||||
record_tensor_reference(tensor);
|
||||
outputs_[index] = TensorValue(new Tensor(tensor));
|
||||
if (track_allocations() && tensor.TotalBytes() > 0) {
|
||||
DCHECK(tracking_state_);
|
||||
@ -994,7 +979,6 @@ void OpKernelContext::set_output_ref(int index, mutex* mu,
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, outputs_.size());
|
||||
CHECK(IsRefType(params_->op_kernel->output_type(index)));
|
||||
record_tensor_reference(*tensor_for_ref);
|
||||
outputs_[index] = TensorValue(mu, tensor_for_ref);
|
||||
}
|
||||
|
||||
|
@ -673,7 +673,6 @@ class OpKernelContext {
|
||||
|
||||
bool track_allocations = false;
|
||||
bool log_memory = false;
|
||||
bool record_tensor_accesses = false;
|
||||
|
||||
// Array indexed by output number for this node
|
||||
const AllocatorAttributes* output_attr_array = nullptr;
|
||||
@ -1221,11 +1220,6 @@ class OpKernelContext {
|
||||
// TODO(tucker): Add example usage.
|
||||
DeviceBase* device() const { return params_->device; }
|
||||
|
||||
// Retrieve list of referenced tensors in out_vector. Once this is
|
||||
// called, it is not legal to reference any more tensors. Should
|
||||
// not be called from Op kernels.
|
||||
void retrieve_accessed_tensors(TensorReferenceVector* out_vector);
|
||||
|
||||
// Per-step container for use by white-listed internal ops.
|
||||
ScopedStepContainer* step_container() const {
|
||||
return params_->step_container;
|
||||
@ -1308,13 +1302,6 @@ class OpKernelContext {
|
||||
private:
|
||||
bool record_memory_consumption_ = false;
|
||||
|
||||
// Internal method to add a tensor's buffer to the list of buffers
|
||||
// referenced during the execution of the Op, so that GPUs may
|
||||
// accurately track the memory that may not be reused until the Op
|
||||
// execution completes.
|
||||
void record_tensor_reference(const Tensor& tensor);
|
||||
void really_record_tensor_reference(const Tensor& tensor);
|
||||
|
||||
// Internal common method used when allocating tensor memory
|
||||
Status allocate_tensor(DataType type, const TensorShape& shape,
|
||||
Tensor* out_tensor,
|
||||
@ -1331,14 +1318,6 @@ class OpKernelContext {
|
||||
// called.
|
||||
void maybe_initialize_scope_id_set();
|
||||
|
||||
// This is called by PersistentTensor::AccessTensor whenever the
|
||||
// wrapped tensor is retrieved, to ensure the runtime knows that the
|
||||
// Tensor is being accessed within an Op. This is necessary for
|
||||
// memory safety of devices like GPUs that queue Ops for
|
||||
// asynchronous execution after the Compute() method completes.
|
||||
friend class PersistentTensor;
|
||||
void NotifyUseOfPersistentTensor(const Tensor& tensor);
|
||||
|
||||
Status status_;
|
||||
friend class CollectiveExecutor; // for access to params_
|
||||
Params* params_; // not owned
|
||||
@ -1356,8 +1335,6 @@ class OpKernelContext {
|
||||
gtl::InlinedVector<WrappedAllocator, 4> wrapped_allocators
|
||||
TF_GUARDED_BY(mu);
|
||||
|
||||
UniqueTensorReferences referenced_tensors TF_GUARDED_BY(mu);
|
||||
|
||||
mutable mutex stats_mu;
|
||||
int64 temp_memory_allocated TF_GUARDED_BY(stats_mu) = 0;
|
||||
|
||||
@ -1658,23 +1635,6 @@ inline bool OpKernelContext::input_is_ref(int index) const {
|
||||
return value.is_ref();
|
||||
}
|
||||
|
||||
inline void OpKernelContext::record_tensor_reference(const Tensor& tensor) {
|
||||
DCHECK_EQ(params_->device->RequiresRecordingAccessedTensors(),
|
||||
params_->record_tensor_accesses);
|
||||
if (params_->record_tensor_accesses) {
|
||||
really_record_tensor_reference(tensor);
|
||||
}
|
||||
}
|
||||
|
||||
inline void OpKernelContext::retrieve_accessed_tensors(
|
||||
TensorReferenceVector* out_vector) {
|
||||
if (params_->record_tensor_accesses) {
|
||||
DCHECK(tracking_state_);
|
||||
mutex_lock l(tracking_state_->mu);
|
||||
tracking_state_->referenced_tensors.FreezeAndReturnReferences(out_vector);
|
||||
}
|
||||
}
|
||||
|
||||
// no input if tensor == nullptr.
|
||||
inline bool OpKernelContext::has_input(int index) const {
|
||||
DCHECK_GE(index, 0);
|
||||
@ -1689,17 +1649,9 @@ inline mutex* OpKernelContext::input_ref_mutex(int index) {
|
||||
return (*params_->inputs)[index].mutex_if_ref;
|
||||
}
|
||||
|
||||
inline void OpKernelContext::NotifyUseOfPersistentTensor(const Tensor& t) {
|
||||
if (t.IsInitialized()) {
|
||||
record_tensor_reference(t);
|
||||
}
|
||||
}
|
||||
|
||||
inline Tensor* OpKernelContext::mutable_output(int index) {
|
||||
DCHECK_GE(index, 0);
|
||||
DCHECK_LT(index, num_outputs());
|
||||
// No need to record_tensor_reference since the output must already
|
||||
// have been set by a call that did so.
|
||||
return outputs_[index].tensor;
|
||||
}
|
||||
|
||||
|
@ -348,74 +348,17 @@ TEST_F(OpKernelTest, MatchSignatureFailes) {
|
||||
|
||||
class DummyDevice : public DeviceBase {
|
||||
public:
|
||||
DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {}
|
||||
bool RequiresRecordingAccessedTensors() const override { return save_; }
|
||||
explicit DummyDevice(Env* env) : DeviceBase(env) {}
|
||||
Allocator* GetAllocator(AllocatorAttributes /*attr*/) override {
|
||||
return cpu_allocator();
|
||||
}
|
||||
|
||||
private:
|
||||
bool save_;
|
||||
};
|
||||
|
||||
TEST_F(OpKernelTest, SaveTempFalse) {
|
||||
Env* env = Env::Default();
|
||||
OpKernelContext::Params params;
|
||||
params.record_tensor_accesses = false;
|
||||
auto device =
|
||||
absl::make_unique<DummyDevice>(env, params.record_tensor_accesses);
|
||||
params.device = device.get();
|
||||
Status status;
|
||||
std::unique_ptr<OpKernel> op(
|
||||
CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
|
||||
CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
|
||||
TF_GRAPH_DEF_VERSION, &status));
|
||||
EXPECT_TRUE(status.ok());
|
||||
params.op_kernel = op.get();
|
||||
auto ctx = absl::make_unique<OpKernelContext>(¶ms);
|
||||
|
||||
Tensor t;
|
||||
TF_EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
|
||||
|
||||
TensorReferenceVector referenced_tensors;
|
||||
ctx->retrieve_accessed_tensors(&referenced_tensors);
|
||||
EXPECT_EQ(0, referenced_tensors.size());
|
||||
}
|
||||
|
||||
TEST_F(OpKernelTest, SaveTempTrue) {
|
||||
Env* env = Env::Default();
|
||||
OpKernelContext::Params params;
|
||||
params.record_tensor_accesses = true;
|
||||
auto device =
|
||||
absl::make_unique<DummyDevice>(env, params.record_tensor_accesses);
|
||||
params.device = device.get();
|
||||
Status status;
|
||||
std::unique_ptr<OpKernel> op(
|
||||
CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
|
||||
CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}),
|
||||
TF_GRAPH_DEF_VERSION, &status));
|
||||
EXPECT_TRUE(status.ok());
|
||||
params.op_kernel = op.get();
|
||||
auto ctx = absl::make_unique<OpKernelContext>(¶ms);
|
||||
|
||||
Tensor t;
|
||||
TF_EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t));
|
||||
|
||||
TensorReferenceVector referenced_tensors;
|
||||
ctx->retrieve_accessed_tensors(&referenced_tensors);
|
||||
EXPECT_EQ(1, referenced_tensors.size());
|
||||
for (auto& ref : referenced_tensors) {
|
||||
ref.Unref();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpKernelTest, InputDtype) {
|
||||
Env* env = Env::Default();
|
||||
OpKernelContext::Params params;
|
||||
params.record_tensor_accesses = false;
|
||||
auto device =
|
||||
absl::make_unique<DummyDevice>(env, params.record_tensor_accesses);
|
||||
params.device = device.get();
|
||||
DummyDevice device(env);
|
||||
params.device = &device;
|
||||
Status status;
|
||||
std::unique_ptr<OpKernel> op(
|
||||
CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(),
|
||||
@ -499,7 +442,6 @@ class ScopedAllocatorDevice : public DeviceBase {
|
||||
TEST_F(OpKernelTest, ScopedAllocationTest) {
|
||||
Env* env = Env::Default();
|
||||
OpKernelContext::Params params;
|
||||
params.record_tensor_accesses = false;
|
||||
auto sa_device = absl::make_unique<ScopedAllocatorDevice>(env);
|
||||
params.device = sa_device.get();
|
||||
Status status;
|
||||
@ -788,10 +730,8 @@ REGISTER_KERNEL_BUILDER(Name("ListOut").Device(tensorflow::DEVICE_CPU),
|
||||
TEST_F(OpKernelBuilderTest, OpOutputList) {
|
||||
Env* env = Env::Default();
|
||||
OpKernelContext::Params params;
|
||||
params.record_tensor_accesses = false;
|
||||
auto device =
|
||||
absl::make_unique<DummyDevice>(env, params.record_tensor_accesses);
|
||||
params.device = device.get();
|
||||
DummyDevice device(env);
|
||||
params.device = &device;
|
||||
Status status;
|
||||
std::unique_ptr<OpKernel> op(CreateOpKernel(
|
||||
DEVICE_CPU, params.device, cpu_allocator(),
|
||||
@ -1066,7 +1006,7 @@ void BM_InputRangeHelper(int iters, const NodeDef& node_def,
|
||||
const char* input_name, int expected_start,
|
||||
int expected_stop) {
|
||||
Status status;
|
||||
auto device = absl::make_unique<DummyDevice>(Env::Default(), false);
|
||||
auto device = absl::make_unique<DummyDevice>(Env::Default());
|
||||
|
||||
std::unique_ptr<OpKernel> op(CreateOpKernel(DEVICE_CPU, device.get(),
|
||||
cpu_allocator(), node_def,
|
||||
|
@ -243,7 +243,6 @@ class SingleThreadedExecutorImpl : public Executor {
|
||||
Device* device = params_.device;
|
||||
params.device = device;
|
||||
params.log_memory = false; // TODO(mrry): Too severe?
|
||||
params.record_tensor_accesses = false; // TODO(mrry): Too severe?
|
||||
params.rendezvous = args.rendezvous;
|
||||
params.session_state = args.session_state;
|
||||
params.tensor_store = args.tensor_store;
|
||||
|
Loading…
Reference in New Issue
Block a user