Add resource_mgr to the op_kernel_ctx and iterator_ctx in DatasetOpsTestBase

This commit is contained in:
Fei Hu 2019-04-18 16:25:46 -07:00
parent df72d996d3
commit 7f64e72c1b
2 changed files with 4 additions and 0 deletions

View File

@ -190,6 +190,7 @@ Status DatasetOpsTestBase::CreateIteratorContext(
OpKernelContext* const op_context, OpKernelContext* const op_context,
std::unique_ptr<IteratorContext>* iterator_context) { std::unique_ptr<IteratorContext>* iterator_context) {
IteratorContext::Params params(op_context); IteratorContext::Params params(op_context);
params.resource_mgr = op_context->resource_manager();
function_handle_cache_ = absl::make_unique<FunctionHandleCache>(flr_); function_handle_cache_ = absl::make_unique<FunctionHandleCache>(flr_);
params.function_handle_cache = function_handle_cache_.get(); params.function_handle_cache = function_handle_cache_.get();
*iterator_context = absl::make_unique<IteratorContext>(params); *iterator_context = absl::make_unique<IteratorContext>(params);
@ -228,6 +229,7 @@ Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices( TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices)); options, "/job:localhost/replica:0/task:0", &devices));
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices)); device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
resource_mgr_ = absl::make_unique<ResourceMgr>("default_container");
FunctionDefLibrary proto; FunctionDefLibrary proto;
for (const auto& fdef : flib) *(proto.add_function()) = fdef; for (const auto& fdef : flib) *(proto.add_function()) = fdef;
@ -269,6 +271,7 @@ Status DatasetOpsTestBase::CreateOpKernelContext(
step_container_ = step_container_ =
absl::make_unique<ScopedStepContainer>(0, [](const string&) {}); absl::make_unique<ScopedStepContainer>(0, [](const string&) {});
params_->step_container = step_container_.get(); params_->step_container = step_container_.get();
params_->resource_manager = resource_mgr_.get();
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper; checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
slice_reader_cache_ = slice_reader_cache_ =
absl::make_unique<checkpoint::TensorSliceReaderCacheWrapper>(); absl::make_unique<checkpoint::TensorSliceReaderCacheWrapper>();

View File

@ -206,6 +206,7 @@ class DatasetOpsTestBase : public ::testing::Test {
std::function<void(std::function<void()>)> runner_; std::function<void(std::function<void()>)> runner_;
std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_; std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ResourceMgr> resource_mgr_;
std::unique_ptr<OpKernelContext::Params> params_; std::unique_ptr<OpKernelContext::Params> params_;
std::unique_ptr<checkpoint::TensorSliceReaderCacheWrapper> std::unique_ptr<checkpoint::TensorSliceReaderCacheWrapper>
slice_reader_cache_; slice_reader_cache_;