Pipe ConfigProto through FLR so that it can be accessed by Ops like PartitionedCallOp.

Also pass the ConfigProto through distributed function calls both in the standard
graph registration mode and in the new eager master setup.

The PFLR stores a std::optional<ConfigProto> instead of a pointer, because it may be created with a pointer that would dangle after its creation.  At the same time, we need to know if a ConfigProto was available at creation time, which is why it's a std::optional.  In contrast, the FLR gets a pointer directly because it is given a valid pointer that will outlast it in all cases.

PiperOrigin-RevId: 272763578
This commit is contained in:
Eugene Brevdo 2019-10-03 16:12:13 -07:00 committed by TensorFlower Gardener
parent 11b69dada6
commit 90f01af49a
43 changed files with 185 additions and 118 deletions

View File

@ -447,6 +447,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
for (const auto& da : local_device_attributes) {
*base_request.add_cluster_device_attributes() = da;
}
base_request.mutable_server_def()
->mutable_default_session_config()
->MergeFrom(server_def.default_session_config());
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
LOG_AND_RETURN_IF_ERROR(

View File

@ -82,7 +82,8 @@ class CompilabilityCheckUtilTest : public ::testing::Test {
FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
OptimizerOptions opts;
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, flib_def_.get(), opts);
nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION,
flib_def_.get(), opts);
return pflr_->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
}

View File

@ -1196,11 +1196,12 @@ Status EncapsulateSubgraphsPass::Run(
std::unique_ptr<DeviceMgr> device_mgr =
absl::make_unique<StaticDeviceMgr>(std::move(devices));
OptimizerOptions opts;
const auto* config = &options.session_options->config;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(device_mgr.get(),
options.session_options->env,
TF_GRAPH_DEF_VERSION, library, opts));
new ProcessFunctionLibraryRuntime(
device_mgr.get(), options.session_options->env,
/*config=*/config, TF_GRAPH_DEF_VERSION, library,
config->graph_options().optimizer_options()));
FunctionLibraryRuntime* flr =
pflr->GetFLR("/job:localhost/replica:0/task:0/device:CPU:0");
if (flr == nullptr) {

View File

@ -513,8 +513,9 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
OptimizerOptions opts;
auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def.get(),
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
device_mgr.get(), Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, lib_def.get(), opts,
/*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
std::unique_ptr<Graph> graph_out;

View File

@ -246,8 +246,9 @@ class ExtractOutsideCompilationForFunctionTest : public ::testing::Test {
bool *has_outside_compilation) {
OptimizerOptions opts;
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, fld, opts,
/*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, fld, opts,
/*default_thread_pool=*/nullptr);
auto flr = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
return ExtractOutsideCompilationForFunction(
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,

View File

@ -1074,8 +1074,8 @@ StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
Status MarkForCompilationPassImpl::FindCompilationCandidates() {
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, env_, TF_GRAPH_DEF_VERSION,
flib_def_, opts));
new ProcessFunctionLibraryRuntime(nullptr, env_, /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, flib_def_, opts));
FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
std::vector<bool> compile_time_const_nodes(graph_->num_node_ids(), false);

View File

@ -295,7 +295,7 @@ Status PartiallyDeclusterGraph(Graph* graph,
std::vector<bool> compile_time_const_nodes(graph->num_node_ids());
OptimizerOptions opts;
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
nullptr, env, TF_GRAPH_DEF_VERSION, flib_def, opts);
nullptr, env, /*config=*/nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts);
FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*graph, nullptr,

View File

@ -148,9 +148,9 @@ xla::StatusOr<std::vector<string>> GetNodesRelatedToRefVarsSorted(
TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, Env::Default(),
TF_GRAPH_DEF_VERSION, flib_def,
OptimizerOptions{}));
new ProcessFunctionLibraryRuntime(
nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION,
flib_def, OptimizerOptions{}));
FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);

View File

@ -73,8 +73,9 @@ class XlaKernelCreatorTest : public ::testing::Test {
OptimizerOptions opts;
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts,
/*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
}

View File

@ -222,10 +222,12 @@ Status FunctionalizeControlFlowForXlaPass::Run(
DumpGraphToFile("functionalize_control_flow_before", *graph,
options.flib_def);
}
const auto* config = &options.session_options->config;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(
/*device_mgr=*/nullptr, options.session_options->env,
TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions()));
/*device_mgr=*/nullptr, options.session_options->env, config,
TF_GRAPH_DEF_VERSION, options.flib_def,
config->graph_options().optimizer_options()));
FunctionLibraryRuntime* flr =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);

View File

@ -55,8 +55,8 @@ void AnalyzeAndVerify(
ConvertGraphDefToGraph(GraphConstructorOptions(), graphdef, graph.get()));
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, flib_def,
OptimizerOptions());
nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION,
flib_def, OptimizerOptions());
FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,

View File

@ -293,8 +293,8 @@ TEST(CachedFunctionHandles, Basic) {
FunctionLibraryDefinition fld(OpRegistry::Global(), proto);
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(
/*device_mgr=*/nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, &fld,
OptimizerOptions()));
/*device_mgr=*/nullptr, Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, &fld, OptimizerOptions()));
FunctionLibraryRuntime* flr =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);

View File

@ -517,11 +517,11 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
FunctionDefLibrary{}));
local_pflr_.reset(new ProcessFunctionLibraryRuntime(
&device_mgr_, Env::Default(), options.graph_def_version,
local_flib_def_.get(), OptimizerOptions()));
&device_mgr_, Env::Default(), /*config=*/nullptr,
options.graph_def_version, local_flib_def_.get(), OptimizerOptions()));
pflr_.reset(new ProcessFunctionLibraryRuntime(
&device_mgr_, Env::Default(), options.graph_def_version, options.flib_def,
OptimizerOptions()));
&device_mgr_, Env::Default(), /*config=*/nullptr,
options.graph_def_version, options.flib_def, OptimizerOptions()));
local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
flib_runtime_ = pflr_->GetFLR(device_->name());

View File

@ -1292,7 +1292,7 @@ Status DirectSession::CreateExecutors(
? &options_.config.experimental().session_metadata()
: nullptr;
func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), options_.env, graph_def_version,
device_mgr_.get(), options_.env, &options_.config, graph_def_version,
func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
nullptr, nullptr, session_metadata));

View File

@ -80,9 +80,9 @@ EagerContext::EagerContext(
thread_pool_(NewThreadPoolFromSessionOptions(opts)),
custom_kernel_creator_(custom_kernel_creator),
pflr_(new ProcessFunctionLibraryRuntime(
device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_,
opts.config.graph_options().optimizer_options(), thread_pool_.get(),
cluster_flr, custom_kernel_creator_)),
device_mgr, opts.env, &opts.config, TF_GRAPH_DEF_VERSION,
&func_lib_def_, opts.config.graph_options().optimizer_options(),
thread_pool_.get(), cluster_flr, custom_kernel_creator_)),
log_device_placement_(opts.config.log_device_placement()),
allow_soft_placement_(opts.config.allow_soft_placement()),
num_active_steps_(0),
@ -697,9 +697,13 @@ Status EagerContext::StoreCollectiveOpsServer(
}
}
const ConfigProto* config = pflr_ ? pflr_->config() : nullptr;
pflr_.reset(new ProcessFunctionLibraryRuntime(
local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_,
{}, thread_pool_.get()));
local_unowned_device_manager_, env_, /*config=*/config,
TF_GRAPH_DEF_VERSION, &func_lib_def_,
/*optimizer_options=*/
config ? config->graph_options().optimizer_options() : OptimizerOptions(),
thread_pool_.get()));
// Memory leak!
if (server_ != nullptr) {
@ -849,9 +853,11 @@ Status EagerContext::SetMasterContextState(
entry.second->ClearError();
}
}
const auto* config = pflr_->config();
pflr_.reset(new ProcessFunctionLibraryRuntime(
local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_,
{}, thread_pool_.get(), cluster_flr, custom_kernel_creator_));
local_unowned_device_manager_, env_, config, TF_GRAPH_DEF_VERSION,
&func_lib_def_, config->graph_options().optimizer_options(),
thread_pool_.get(), cluster_flr, custom_kernel_creator_));
keep_alive_secs_ = keep_alive_secs;
sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
@ -995,9 +1001,10 @@ Status EagerContext::UpdateRemoteWorker(
}
SessionOptions options = SessionOptions();
const auto* config = pflr_->config();
pflr_.reset(new ProcessFunctionLibraryRuntime(
worker_session_device_mgr, options.env, TF_GRAPH_DEF_VERSION,
FuncLibDef(), options.config.graph_options().optimizer_options(),
worker_session_device_mgr, options.env, config, TF_GRAPH_DEF_VERSION,
FuncLibDef(), config->graph_options().optimizer_options(),
thread_pool_.get(), cluster_flr, custom_kernel_creator_));
return Status::OK();
}

View File

@ -48,8 +48,9 @@ class TestEnv {
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
OptimizerOptions opts;
pflr_ = tensorflow::MakeUnique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, &flib_def_,
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, &flib_def_, opts,
/*default_thread_pool=*/nullptr);
flr_ = pflr_->GetFLR("/job:a/replica:0/task:0/device:CPU:0");
CHECK(flr_ != nullptr);

View File

@ -45,15 +45,15 @@ class EagerProcessFunctionLibraryRuntime
: public ProcessFunctionLibraryRuntime {
public:
EagerProcessFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
thread::ThreadPool* thread_pool = nullptr,
DistributedFunctionLibraryRuntime* parent = nullptr,
const CustomKernelCreator* custom_kernel_creator = nullptr)
: ProcessFunctionLibraryRuntime(device_mgr, env, graph_def_version,
lib_def, optimizer_options, thread_pool,
parent, custom_kernel_creator) {}
: ProcessFunctionLibraryRuntime(
device_mgr, env, config, graph_def_version, lib_def,
optimizer_options, thread_pool, parent, custom_kernel_creator) {}
#if !defined(IS_MOBILE_PLATFORM)
void Run(const FunctionLibraryRuntime::Options& opts,

View File

@ -195,6 +195,7 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
const override;
Env* env() override;
const ConfigProto* const config_proto() override;
Device* device() override;
const Device* device() const override;
std::function<void(std::function<void()>)>* runner() override;
@ -277,6 +278,10 @@ bool FunctionLibraryRuntimeOverlay::IsStateful(
Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); }
const ConfigProto* const FunctionLibraryRuntimeOverlay::config_proto() {
return base_flr_->config_proto();
}
Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); }
const Device* FunctionLibraryRuntimeOverlay::device() const {
@ -317,7 +322,8 @@ Status FunctionLibraryRuntimeOverlay::Clone(
class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
public:
FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device,
FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env,
const ConfigProto* config, Device* device,
int graph_def_version,
const FunctionLibraryDefinition* lib_def,
thread::ThreadPool* default_thread_pool,
@ -361,6 +367,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
const DeviceMgr* device_mgr() const override { return device_mgr_; }
Env* env() override { return env_; }
const ConfigProto* const config_proto() override { return config_; }
int graph_def_version() const override { return graph_def_version_; }
string DebugString(Handle h) override;
@ -376,6 +383,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
const DeviceMgr* const device_mgr_;
Device* const device_;
Env* const env_;
const ConfigProto* const config_;
const int graph_def_version_;
const FunctionLibraryDefinition* const base_lib_def_;
GraphOptimizer optimizer_;
@ -442,8 +450,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
};
FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
const DeviceMgr* dmgr, Env* env, const ConfigProto* config, Device* device,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
thread::ThreadPool* default_thread_pool,
const OptimizerOptions& optimizer_options,
const CustomKernelCreator* custom_kernel_creator,
@ -452,6 +460,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
: device_mgr_(dmgr),
device_(device),
env_(env),
config_(config),
graph_def_version_(graph_def_version),
base_lib_def_(lib_def),
optimizer_(optimizer_options),
@ -1283,14 +1292,15 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c) {
}
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, Device* device,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
Device* device, int graph_def_version,
const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool,
const OptimizerOptions& optimizer_options,
const CustomKernelCreator* custom_kernel_creator,
const SessionMetadata* session_metadata,
ProcessFunctionLibraryRuntime* parent) {
return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
device_mgr, env, device, graph_def_version, lib_def, thread_pool,
device_mgr, env, config, device, graph_def_version, lib_def, thread_pool,
optimizer_options, custom_kernel_creator, session_metadata, parent));
}

View File

@ -59,11 +59,12 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c);
// typically owns the created FunctionLibraryRuntime object. The parent pointer
// is not owned by the FunctionLibraryRuntime object.
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, Device* device,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
Device* device, int graph_def_version,
const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool,
const OptimizerOptions& optimizer_options,
const CustomKernelCreator* custom_kernel_creator,
const SessionMetadata* sesson_metadata,
const SessionMetadata* session_metadata,
ProcessFunctionLibraryRuntime* parent);
// FunctionLibraryRuntime::GetFunctionBody returns a description of an

View File

@ -164,8 +164,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
OptimizerOptions opts;
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts));
device_mgr_.get(), Env::Default(), &options.config,
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts));
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1");
flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2");

View File

@ -64,8 +64,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
OptimizerOptions opts;
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, default_thread_pool));
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, default_thread_pool));
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
}

View File

@ -65,8 +65,8 @@ Status ProcessFunctionLibraryRuntime::FunctionData::DistributedInit(
}
ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
thread::ThreadPool* default_thread_pool,
DistributedFunctionLibraryRuntime* parent,
@ -74,6 +74,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const SessionMetadata* session_metadata)
: parent_(parent),
env_(env),
config_(config ? absl::make_optional(*config) : absl::nullopt),
device_mgr_(device_mgr),
lib_def_(lib_def),
default_thread_pool_(default_thread_pool),
@ -83,14 +84,16 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
session_metadata_(session_metadata) {
if (device_mgr == nullptr) {
(*flr_map_)[nullptr] = NewFunctionLibraryRuntime(
nullptr, env, nullptr, graph_def_version, lib_def_, default_thread_pool,
optimizer_options, custom_kernel_creator, session_metadata_, this);
nullptr, env, config_ ? &(*config_) : nullptr, nullptr,
graph_def_version, lib_def_, default_thread_pool, optimizer_options,
custom_kernel_creator, session_metadata_, this);
return;
}
for (Device* d : device_mgr->ListDevices()) {
(*flr_map_)[d] = NewFunctionLibraryRuntime(
device_mgr, env, d, graph_def_version, lib_def_, default_thread_pool,
optimizer_options, custom_kernel_creator, session_metadata_, this);
device_mgr, env, config_ ? &(*config_) : nullptr, d, graph_def_version,
lib_def_, default_thread_pool, optimizer_options, custom_kernel_creator,
session_metadata_, this);
}
DeviceMgr const* all_devices = device_mgr_;
@ -1330,9 +1333,9 @@ Status ProcessFunctionLibraryRuntime::Clone(
*out_lib_def = absl::make_unique<FunctionLibraryDefinition>(*lib_def_);
}
*out_pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_, env, graph_def_version, out_lib_def->get(),
optimizer_options, default_thread_pool_, parent_, custom_kernel_creator,
session_metadata_);
device_mgr_, env, config_ ? &(*config_) : nullptr, graph_def_version,
out_lib_def->get(), optimizer_options, default_thread_pool_, parent_,
custom_kernel_creator, session_metadata_);
return Status::OK();
}

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "absl/types/optional.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/framework/function.h"
@ -41,8 +42,8 @@ class ProcessFunctionLibraryRuntime {
// DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent
// (if provided) outlive this object.
ProcessFunctionLibraryRuntime(
const DeviceMgr* device_mgr, Env* env, int graph_def_version,
const FunctionLibraryDefinition* lib_def,
const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
int graph_def_version, const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
thread::ThreadPool* thread_pool = nullptr,
DistributedFunctionLibraryRuntime* parent = nullptr,
@ -159,6 +160,8 @@ class ProcessFunctionLibraryRuntime {
const DeviceSet* device_set() { return &device_set_; }
const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; }
const FunctionLibraryDefinition* GetFunctionLibraryDefinition() const {
return lib_def_;
}
@ -381,6 +384,7 @@ class ProcessFunctionLibraryRuntime {
mutable mutex mu_;
Env* const env_;
const absl::optional<const ConfigProto> config_;
const DeviceMgr* const device_mgr_;
DeviceSet device_set_;
const FunctionLibraryDefinition* lib_def_;

View File

@ -121,8 +121,9 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
OptimizerOptions opts;
cluster_flr_.reset(new TestClusterFLR(device_mgr_.get()));
proc_flr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, nullptr, cluster_flr_.get(), nullptr, session_metadata));
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, nullptr, cluster_flr_.get(),
nullptr, session_metadata));
rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
}
@ -295,8 +296,8 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) {
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr(
new ProcessFunctionLibraryRuntime(
nullptr /* device_mgr */, Env::Default(), TF_GRAPH_DEF_VERSION,
lib_def.get(), opts, nullptr, nullptr /* cluster_flr */));
nullptr /* device_mgr */, Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, lib_def.get(), opts));
FunctionLibraryRuntime* flr =
proc_flr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
EXPECT_NE(flr, nullptr);

View File

@ -128,9 +128,10 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
LOG(INFO) << "Creating " << (request->async() ? "async" : "sync")
<< " eager service context with rendezvous_id on host "
<< port::Hostname() << " " << worker_session->worker_name();
SessionOptions opts;
opts.config = request->server_def().default_session_config();
tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(),
device_mgr, false, r, GetDefaultCustomKernelCreator(),
worker_session->cluster_flr());

View File

@ -234,7 +234,6 @@ TEST_F(EagerServiceImplTest, BasicTest) {
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
EnqueueRequest remote_enqueue_request;
remote_enqueue_request.set_context_id(context_id);
EnqueueResponse remote_enqueue_response;
@ -400,8 +399,9 @@ TEST_F(EagerServiceImplTest, EagerPFLRTest) {
auto device_mgr = absl::make_unique<StaticDeviceMgr>(
DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:1"));
auto eager_pflr = absl::make_unique<EagerProcessFunctionLibraryRuntime>(
device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, &func_lib_def,
OptimizerOptions(), nullptr, eager_cluster_flr.get(), nullptr);
device_mgr.get(), Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, &func_lib_def, OptimizerOptions(), nullptr,
eager_cluster_flr.get(), nullptr);
tensorflow::FunctionDef fdef = MatMulFunction();
TF_ASSERT_OK(func_lib_def.AddFunctionDef(fdef));

View File

@ -121,13 +121,11 @@ Status GraphMgr::DecorateAndPublishGraphForDebug(
//
// "executors" are filled with one executor per device if success and
// the caller takes the ownership of returned executors.
Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef,
WorkerSession* session,
const GraphOptions& graph_options,
const DebugOptions& debug_options,
int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr,
Item* item) {
Status GraphMgr::InitItem(
const string& handle, const GraphDef& gdef, WorkerSession* session,
const GraphOptions& graph_options, const DebugOptions& debug_options,
const ConfigProto& config_proto, int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr, Item* item) {
item->session = handle;
item->collective_graph_key = collective_graph_key;
item->lib_def.reset(
@ -139,9 +137,10 @@ Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef,
// does that below.
item->proc_flr.reset(new ProcessFunctionLibraryRuntime(
device_mgr_, worker_env_->env, gdef.versions().producer(),
item->lib_def.get(), graph_options.optimizer_options(),
worker_env_->compute_pool, cluster_flr));
device_mgr_, worker_env_->env, /*config=*/&config_proto,
gdef.versions().producer(), item->lib_def.get(),
graph_options.optimizer_options(), worker_env_->compute_pool,
cluster_flr));
// Constructs the graph out of "gdef".
Graph graph(OpRegistry::Global());
@ -287,16 +286,14 @@ Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef,
return Status::OK();
}
Status GraphMgr::Register(const string& handle, const GraphDef& gdef,
WorkerSession* session,
const GraphOptions& graph_options,
const DebugOptions& debug_options,
int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr,
string* graph_handle) {
Status GraphMgr::Register(
const string& handle, const GraphDef& gdef, WorkerSession* session,
const GraphOptions& graph_options, const DebugOptions& debug_options,
const ConfigProto& config_proto, int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr, string* graph_handle) {
Item* item = new Item;
Status s = InitItem(handle, gdef, session, graph_options, debug_options,
collective_graph_key, cluster_flr, item);
config_proto, collective_graph_key, cluster_flr, item);
if (!s.ok()) {
item->Unref();
return s;

View File

@ -76,7 +76,8 @@ class GraphMgr {
// reference to cluster_flr to do cross process function calls.
Status Register(const string& handle, const GraphDef& gdef,
WorkerSession* session, const GraphOptions& graph_options,
const DebugOptions& debug_options, int64 collective_graph_key,
const DebugOptions& debug_options,
const ConfigProto& config_proto, int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr,
string* graph_handle);
@ -179,7 +180,8 @@ class GraphMgr {
Status InitItem(const string& handle, const GraphDef& gdef,
WorkerSession* session, const GraphOptions& graph_options,
const DebugOptions& debug_options, int64 collective_graph_key,
const DebugOptions& debug_options,
const ConfigProto& config_proto, int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr, Item* item);
Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options,

View File

@ -472,6 +472,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
c->req.set_session_handle(session_handle_);
c->req.set_create_worker_session_called(!should_deregister_);
c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
*c->req.mutable_config_proto() = session_opts_.config;
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
*c->req.mutable_debug_options() =
callable_opts_.run_options().debug_options();

View File

@ -80,8 +80,8 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
s = session->graph_mgr()->Register(
request->session_handle(), request->graph_def(), session.get(),
request->graph_options(), request->debug_options(),
request->collective_graph_key(), session->cluster_flr(),
response->mutable_graph_handle());
request->config_proto(), request->collective_graph_key(),
session->cluster_flr(), response->mutable_graph_handle());
}
done(s);
}

View File

@ -297,6 +297,7 @@ class IteratorContext {
allocator_getter = [device](AllocatorAttributes attrs) {
return device->GetAllocator(attrs);
};
thread::ThreadPool* thread_pool =
ctx->device()->tensorflow_device_thread_pool();
if (thread_pool) {

View File

@ -752,6 +752,9 @@ class FunctionLibraryRuntime {
// Returns the environment on which the function executes.
virtual Env* env() = 0;
// Returns the ConfigProto passed to the session used to create the function.
virtual const ConfigProto* const config_proto() = 0;
// Returns a debug string showing the definition of the function of
// 'handle'.
virtual string DebugString(Handle handle) = 0;

View File

@ -256,7 +256,7 @@ Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
// Create the function library runtime.
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env,
new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, &options.config,
graph_def.versions().producer(),
&function_library, *optimizer_opts));
FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name());

View File

@ -529,10 +529,11 @@ Status CapturedFunction::Instantiate(
// The context's runtime will be used for all subsequent calls.
FunctionLibraryRuntime* lib = ctx->flr();
FunctionLibraryRuntime::InstantiateOptions inst_opts;
// TODO(b/141576771): Propagate ctx ConfigProto into inst_opts.config_proto.
inst_opts.lib_def = metadata_->lib_def();
inst_opts.create_kernels_eagerly = true;
inst_opts.default_device_to_target = metadata_->use_default_device();
inst_opts.config_proto =
lib->config_proto() ? *lib->config_proto() : ConfigProto();
if (!metadata_->use_inter_op_parallelism()) {
inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
}

View File

@ -419,8 +419,9 @@ Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
OptimizerOptions opts;
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, thread_pool_.get(), nullptr /* cluster_flr */);
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, thread_pool_.get(),
nullptr /* cluster_flr */);
flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
if (thread_pool_ == nullptr) {
runner_ = [](std::function<void()> fn) { fn(); };

View File

@ -355,16 +355,18 @@ FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
// in its resource manager. The existing device will outlive the
// IteratorResource, because we are storing the IteratorResource
// in that device's resource manager.
*device_mgr =
absl::make_unique<StaticDeviceMgr>(RenamedDevice::NewRenamedDevice(
ctx->device()->name(), down_cast<Device*>(ctx->device()),
false /* owns_underlying */, false /* isolate_session_state */));
*flib_def = absl::make_unique<FunctionLibraryDefinition>(
*ctx->function_library()->GetFunctionLibraryDefinition());
const auto* config = ctx->function_library()->config_proto();
*pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr->get(), ctx->env(), graph_def_version_, flib_def->get(),
OptimizerOptions{} /* TODO(mrry): OptimizerOptions? */,
nullptr /* TODO(mrry): ClusterFLR */);
device_mgr->get(), ctx->env(),
/*config=*/config, graph_def_version_, flib_def->get(),
config->graph_options().optimizer_options());
return (*pflr)->GetFLR(ctx->device()->name());
}

View File

@ -325,7 +325,14 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
AttrValueMap attr_values = func_.attr();
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
// TODO(b/141576771): Propagate ctxt's ConfigProto into opts.config_proto.
const auto* config = (ctx->function_library())
? ctx->function_library()->config_proto()
: nullptr;
if (config) {
instantiate_opts.config_proto = *config;
}
instantiate_opts.target = target_device;
FunctionTarget function_target = {target_device, lib};

View File

@ -29,8 +29,8 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
device_ = device.get();
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device));
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, flib_def_.get(),
OptimizerOptions());
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions());
device_type_ = device_type;
#ifdef GOOGLE_CUDA

View File

@ -84,8 +84,8 @@ class OpsTestBase : public ::testing::Test {
flib_def_ = absl::make_unique<FunctionLibraryDefinition>(
OpRegistry::Global(), FunctionDefLibrary{});
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION,
flib_def_.get(), OptimizerOptions());
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions());
}
~OpsTestBase() override {

View File

@ -166,7 +166,12 @@ Status PartitionedCallOp::Instantiate(FunctionLibraryRuntime* lib,
std::vector<Tensor>* inputs,
FunctionLibraryRuntime::Handle* handle) {
FunctionLibraryRuntime::InstantiateOptions opts;
// TODO(b/141576771): Propagate ctxt's ConfigProto into opts.config_proto.
const auto* config = (ctx->function_library())
? ctx->function_library()->config_proto()
: nullptr;
if (config) {
opts.config_proto = *config;
}
#ifndef __ANDROID__
// Android tf library does not include grappler.

View File

@ -132,6 +132,11 @@ message RegisterGraphRequest {
// concurrently so that BufRendezvous entries will make the correct
// values accessible.
int64 collective_graph_key = 7;
// ConfigProto from the session in which this graph was created.
// Contains additional parameters beyond graph_options, including
// the name of the requested executor.
ConfigProto config_proto = 8;
}
message RegisterGraphResponse {

View File

@ -2237,8 +2237,9 @@ bool InlineAllFunctions(GraphDef* graphdef) {
tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
tensorflow::OptimizerOptions o_opts;
tensorflow::ProcessFunctionLibraryRuntime pflr(
&device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld,
o_opts, nullptr);
&device_mgr, tensorflow::Env::Default(), &options.config,
TF_GRAPH_DEF_VERSION, &fld,
options.config.graph_options().optimizer_options(), nullptr);
tensorflow::FunctionLibraryRuntime* flr;
flr = pflr.GetFLR("/job:localhost/replica:0/task:0/cpu:0");

View File

@ -124,8 +124,11 @@ def connect_to_cluster(cluster_spec_or_resolver,
job_def.tasks[0] = "localhost:{}".format(local_port)
server_def = ServerDef(
cluster=cluster_def, job_name=job_name, task_index=task_index,
protocol=protocol)
cluster=cluster_def,
job_name=job_name,
task_index=task_index,
protocol=protocol,
default_session_config=context.context().config)
context.set_server_def(server_def)