Pipe ConfigProto through FLR so that it can be accessed by Ops like PartitionedCallOp.
Also pass the ConfigProto through distributed function calls both in the standard graph registration mode and in the new eager master setup. The PFLR stores a std::optional<ConfigProto> instead of a pointer, because it may be created with a pointer that would dangle after its creation. At the same time, we need to know if a ConfigProto was available at creation time, which is why it's a std::optional. In contrast, the FLR gets a pointer directly because it is given a valid pointer that will outlast it in all cases. PiperOrigin-RevId: 272763578
This commit is contained in:
parent
11b69dada6
commit
90f01af49a
@ -447,6 +447,9 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
for (const auto& da : local_device_attributes) {
|
||||
*base_request.add_cluster_device_attributes() = da;
|
||||
}
|
||||
base_request.mutable_server_def()
|
||||
->mutable_default_session_config()
|
||||
->MergeFrom(server_def.default_session_config());
|
||||
|
||||
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
|
@ -82,7 +82,8 @@ class CompilabilityCheckUtilTest : public ::testing::Test {
|
||||
FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
|
||||
OptimizerOptions opts;
|
||||
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, flib_def_.get(), opts);
|
||||
nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION,
|
||||
flib_def_.get(), opts);
|
||||
|
||||
return pflr_->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
}
|
||||
|
@ -1196,11 +1196,12 @@ Status EncapsulateSubgraphsPass::Run(
|
||||
|
||||
std::unique_ptr<DeviceMgr> device_mgr =
|
||||
absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
OptimizerOptions opts;
|
||||
const auto* config = &options.session_options->config;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
new ProcessFunctionLibraryRuntime(device_mgr.get(),
|
||||
options.session_options->env,
|
||||
TF_GRAPH_DEF_VERSION, library, opts));
|
||||
new ProcessFunctionLibraryRuntime(
|
||||
device_mgr.get(), options.session_options->env,
|
||||
/*config=*/config, TF_GRAPH_DEF_VERSION, library,
|
||||
config->graph_options().optimizer_options()));
|
||||
FunctionLibraryRuntime* flr =
|
||||
pflr->GetFLR("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
if (flr == nullptr) {
|
||||
|
@ -513,8 +513,9 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
|
||||
OptimizerOptions opts;
|
||||
auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def.get(),
|
||||
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
||||
device_mgr.get(), Env::Default(), /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, lib_def.get(), opts,
|
||||
/*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
||||
auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
|
||||
|
||||
std::unique_ptr<Graph> graph_out;
|
||||
|
@ -246,8 +246,9 @@ class ExtractOutsideCompilationForFunctionTest : public ::testing::Test {
|
||||
bool *has_outside_compilation) {
|
||||
OptimizerOptions opts;
|
||||
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, fld, opts,
|
||||
/*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
||||
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, fld, opts,
|
||||
/*default_thread_pool=*/nullptr);
|
||||
auto flr = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
|
||||
return ExtractOutsideCompilationForFunction(
|
||||
xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name,
|
||||
|
@ -1074,8 +1074,8 @@ StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
|
||||
Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
OptimizerOptions opts;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
new ProcessFunctionLibraryRuntime(nullptr, env_, TF_GRAPH_DEF_VERSION,
|
||||
flib_def_, opts));
|
||||
new ProcessFunctionLibraryRuntime(nullptr, env_, /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, flib_def_, opts));
|
||||
FunctionLibraryRuntime* lib_runtime =
|
||||
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
std::vector<bool> compile_time_const_nodes(graph_->num_node_ids(), false);
|
||||
|
@ -295,7 +295,7 @@ Status PartiallyDeclusterGraph(Graph* graph,
|
||||
std::vector<bool> compile_time_const_nodes(graph->num_node_ids());
|
||||
OptimizerOptions opts;
|
||||
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
nullptr, env, TF_GRAPH_DEF_VERSION, flib_def, opts);
|
||||
nullptr, env, /*config=*/nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts);
|
||||
FunctionLibraryRuntime* lib_runtime =
|
||||
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*graph, nullptr,
|
||||
|
@ -148,9 +148,9 @@ xla::StatusOr<std::vector<string>> GetNodesRelatedToRefVarsSorted(
|
||||
TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
|
||||
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
new ProcessFunctionLibraryRuntime(nullptr, Env::Default(),
|
||||
TF_GRAPH_DEF_VERSION, flib_def,
|
||||
OptimizerOptions{}));
|
||||
new ProcessFunctionLibraryRuntime(
|
||||
nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION,
|
||||
flib_def, OptimizerOptions{}));
|
||||
FunctionLibraryRuntime* lib_runtime =
|
||||
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
|
||||
|
@ -73,8 +73,9 @@ class XlaKernelCreatorTest : public ::testing::Test {
|
||||
OptimizerOptions opts;
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
||||
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts,
|
||||
/*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
||||
flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
|
||||
}
|
||||
|
||||
|
@ -222,10 +222,12 @@ Status FunctionalizeControlFlowForXlaPass::Run(
|
||||
DumpGraphToFile("functionalize_control_flow_before", *graph,
|
||||
options.flib_def);
|
||||
}
|
||||
const auto* config = &options.session_options->config;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
new ProcessFunctionLibraryRuntime(
|
||||
/*device_mgr=*/nullptr, options.session_options->env,
|
||||
TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions()));
|
||||
/*device_mgr=*/nullptr, options.session_options->env, config,
|
||||
TF_GRAPH_DEF_VERSION, options.flib_def,
|
||||
config->graph_options().optimizer_options()));
|
||||
FunctionLibraryRuntime* flr =
|
||||
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
|
||||
|
@ -55,8 +55,8 @@ void AnalyzeAndVerify(
|
||||
ConvertGraphDefToGraph(GraphConstructorOptions(), graphdef, graph.get()));
|
||||
|
||||
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, flib_def,
|
||||
OptimizerOptions());
|
||||
nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION,
|
||||
flib_def, OptimizerOptions());
|
||||
FunctionLibraryRuntime* lib_runtime =
|
||||
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
|
||||
|
@ -293,8 +293,8 @@ TEST(CachedFunctionHandles, Basic) {
|
||||
FunctionLibraryDefinition fld(OpRegistry::Global(), proto);
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
new ProcessFunctionLibraryRuntime(
|
||||
/*device_mgr=*/nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, &fld,
|
||||
OptimizerOptions()));
|
||||
/*device_mgr=*/nullptr, Env::Default(), /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, &fld, OptimizerOptions()));
|
||||
FunctionLibraryRuntime* flr =
|
||||
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
|
||||
|
@ -517,11 +517,11 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
|
||||
local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
|
||||
FunctionDefLibrary{}));
|
||||
local_pflr_.reset(new ProcessFunctionLibraryRuntime(
|
||||
&device_mgr_, Env::Default(), options.graph_def_version,
|
||||
local_flib_def_.get(), OptimizerOptions()));
|
||||
&device_mgr_, Env::Default(), /*config=*/nullptr,
|
||||
options.graph_def_version, local_flib_def_.get(), OptimizerOptions()));
|
||||
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
||||
&device_mgr_, Env::Default(), options.graph_def_version, options.flib_def,
|
||||
OptimizerOptions()));
|
||||
&device_mgr_, Env::Default(), /*config=*/nullptr,
|
||||
options.graph_def_version, options.flib_def, OptimizerOptions()));
|
||||
|
||||
local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
|
||||
flib_runtime_ = pflr_->GetFLR(device_->name());
|
||||
|
@ -1292,7 +1292,7 @@ Status DirectSession::CreateExecutors(
|
||||
? &options_.config.experimental().session_metadata()
|
||||
: nullptr;
|
||||
func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
|
||||
device_mgr_.get(), options_.env, graph_def_version,
|
||||
device_mgr_.get(), options_.env, &options_.config, graph_def_version,
|
||||
func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
|
||||
nullptr, nullptr, session_metadata));
|
||||
|
||||
|
@ -80,9 +80,9 @@ EagerContext::EagerContext(
|
||||
thread_pool_(NewThreadPoolFromSessionOptions(opts)),
|
||||
custom_kernel_creator_(custom_kernel_creator),
|
||||
pflr_(new ProcessFunctionLibraryRuntime(
|
||||
device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_,
|
||||
opts.config.graph_options().optimizer_options(), thread_pool_.get(),
|
||||
cluster_flr, custom_kernel_creator_)),
|
||||
device_mgr, opts.env, &opts.config, TF_GRAPH_DEF_VERSION,
|
||||
&func_lib_def_, opts.config.graph_options().optimizer_options(),
|
||||
thread_pool_.get(), cluster_flr, custom_kernel_creator_)),
|
||||
log_device_placement_(opts.config.log_device_placement()),
|
||||
allow_soft_placement_(opts.config.allow_soft_placement()),
|
||||
num_active_steps_(0),
|
||||
@ -697,9 +697,13 @@ Status EagerContext::StoreCollectiveOpsServer(
|
||||
}
|
||||
}
|
||||
|
||||
const ConfigProto* config = pflr_ ? pflr_->config() : nullptr;
|
||||
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
||||
local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_,
|
||||
{}, thread_pool_.get()));
|
||||
local_unowned_device_manager_, env_, /*config=*/config,
|
||||
TF_GRAPH_DEF_VERSION, &func_lib_def_,
|
||||
/*optimizer_options=*/
|
||||
config ? config->graph_options().optimizer_options() : OptimizerOptions(),
|
||||
thread_pool_.get()));
|
||||
|
||||
// Memory leak!
|
||||
if (server_ != nullptr) {
|
||||
@ -849,9 +853,11 @@ Status EagerContext::SetMasterContextState(
|
||||
entry.second->ClearError();
|
||||
}
|
||||
}
|
||||
const auto* config = pflr_->config();
|
||||
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
||||
local_unowned_device_manager_, env_, TF_GRAPH_DEF_VERSION, &func_lib_def_,
|
||||
{}, thread_pool_.get(), cluster_flr, custom_kernel_creator_));
|
||||
local_unowned_device_manager_, env_, config, TF_GRAPH_DEF_VERSION,
|
||||
&func_lib_def_, config->graph_options().optimizer_options(),
|
||||
thread_pool_.get(), cluster_flr, custom_kernel_creator_));
|
||||
|
||||
keep_alive_secs_ = keep_alive_secs;
|
||||
sleep_for_secs_ = std::max(1, keep_alive_secs_ / 2);
|
||||
@ -995,9 +1001,10 @@ Status EagerContext::UpdateRemoteWorker(
|
||||
}
|
||||
|
||||
SessionOptions options = SessionOptions();
|
||||
const auto* config = pflr_->config();
|
||||
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
||||
worker_session_device_mgr, options.env, TF_GRAPH_DEF_VERSION,
|
||||
FuncLibDef(), options.config.graph_options().optimizer_options(),
|
||||
worker_session_device_mgr, options.env, config, TF_GRAPH_DEF_VERSION,
|
||||
FuncLibDef(), config->graph_options().optimizer_options(),
|
||||
thread_pool_.get(), cluster_flr, custom_kernel_creator_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -48,8 +48,9 @@ class TestEnv {
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
OptimizerOptions opts;
|
||||
pflr_ = tensorflow::MakeUnique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, &flib_def_,
|
||||
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
|
||||
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, &flib_def_, opts,
|
||||
/*default_thread_pool=*/nullptr);
|
||||
|
||||
flr_ = pflr_->GetFLR("/job:a/replica:0/task:0/device:CPU:0");
|
||||
CHECK(flr_ != nullptr);
|
||||
|
@ -45,15 +45,15 @@ class EagerProcessFunctionLibraryRuntime
|
||||
: public ProcessFunctionLibraryRuntime {
|
||||
public:
|
||||
EagerProcessFunctionLibraryRuntime(
|
||||
const DeviceMgr* device_mgr, Env* env, int graph_def_version,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
|
||||
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
thread::ThreadPool* thread_pool = nullptr,
|
||||
DistributedFunctionLibraryRuntime* parent = nullptr,
|
||||
const CustomKernelCreator* custom_kernel_creator = nullptr)
|
||||
: ProcessFunctionLibraryRuntime(device_mgr, env, graph_def_version,
|
||||
lib_def, optimizer_options, thread_pool,
|
||||
parent, custom_kernel_creator) {}
|
||||
: ProcessFunctionLibraryRuntime(
|
||||
device_mgr, env, config, graph_def_version, lib_def,
|
||||
optimizer_options, thread_pool, parent, custom_kernel_creator) {}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
void Run(const FunctionLibraryRuntime::Options& opts,
|
||||
|
@ -195,6 +195,7 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
|
||||
const override;
|
||||
|
||||
Env* env() override;
|
||||
const ConfigProto* const config_proto() override;
|
||||
Device* device() override;
|
||||
const Device* device() const override;
|
||||
std::function<void(std::function<void()>)>* runner() override;
|
||||
@ -277,6 +278,10 @@ bool FunctionLibraryRuntimeOverlay::IsStateful(
|
||||
|
||||
Env* FunctionLibraryRuntimeOverlay::env() { return base_flr_->env(); }
|
||||
|
||||
const ConfigProto* const FunctionLibraryRuntimeOverlay::config_proto() {
|
||||
return base_flr_->config_proto();
|
||||
}
|
||||
|
||||
Device* FunctionLibraryRuntimeOverlay::device() { return base_flr_->device(); }
|
||||
|
||||
const Device* FunctionLibraryRuntimeOverlay::device() const {
|
||||
@ -317,7 +322,8 @@ Status FunctionLibraryRuntimeOverlay::Clone(
|
||||
|
||||
class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
||||
public:
|
||||
FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device,
|
||||
FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env,
|
||||
const ConfigProto* config, Device* device,
|
||||
int graph_def_version,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
thread::ThreadPool* default_thread_pool,
|
||||
@ -361,6 +367,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
||||
|
||||
const DeviceMgr* device_mgr() const override { return device_mgr_; }
|
||||
Env* env() override { return env_; }
|
||||
const ConfigProto* const config_proto() override { return config_; }
|
||||
int graph_def_version() const override { return graph_def_version_; }
|
||||
|
||||
string DebugString(Handle h) override;
|
||||
@ -376,6 +383,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
||||
const DeviceMgr* const device_mgr_;
|
||||
Device* const device_;
|
||||
Env* const env_;
|
||||
const ConfigProto* const config_;
|
||||
const int graph_def_version_;
|
||||
const FunctionLibraryDefinition* const base_lib_def_;
|
||||
GraphOptimizer optimizer_;
|
||||
@ -442,8 +450,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
||||
};
|
||||
|
||||
FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
|
||||
const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
const DeviceMgr* dmgr, Env* env, const ConfigProto* config, Device* device,
|
||||
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||
thread::ThreadPool* default_thread_pool,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
const CustomKernelCreator* custom_kernel_creator,
|
||||
@ -452,6 +460,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
|
||||
: device_mgr_(dmgr),
|
||||
device_(device),
|
||||
env_(env),
|
||||
config_(config),
|
||||
graph_def_version_(graph_def_version),
|
||||
base_lib_def_(lib_def),
|
||||
optimizer_(optimizer_options),
|
||||
@ -1283,14 +1292,15 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c) {
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
||||
const DeviceMgr* device_mgr, Env* env, Device* device,
|
||||
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||
thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
|
||||
const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
|
||||
Device* device, int graph_def_version,
|
||||
const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
const CustomKernelCreator* custom_kernel_creator,
|
||||
const SessionMetadata* session_metadata,
|
||||
ProcessFunctionLibraryRuntime* parent) {
|
||||
return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl(
|
||||
device_mgr, env, device, graph_def_version, lib_def, thread_pool,
|
||||
device_mgr, env, config, device, graph_def_version, lib_def, thread_pool,
|
||||
optimizer_options, custom_kernel_creator, session_metadata, parent));
|
||||
}
|
||||
|
||||
|
@ -59,11 +59,12 @@ void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c);
|
||||
// typically owns the created FunctionLibraryRuntime object. The parent pointer
|
||||
// is not owned by the FunctionLibraryRuntime object.
|
||||
std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
|
||||
const DeviceMgr* device_mgr, Env* env, Device* device,
|
||||
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||
thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
|
||||
const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
|
||||
Device* device, int graph_def_version,
|
||||
const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
const CustomKernelCreator* custom_kernel_creator,
|
||||
const SessionMetadata* sesson_metadata,
|
||||
const SessionMetadata* session_metadata,
|
||||
ProcessFunctionLibraryRuntime* parent);
|
||||
|
||||
// FunctionLibraryRuntime::GetFunctionBody returns a description of an
|
||||
|
@ -164,8 +164,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
OptimizerOptions opts;
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||
opts));
|
||||
device_mgr_.get(), Env::Default(), &options.config,
|
||||
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts));
|
||||
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
|
||||
flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1");
|
||||
flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2");
|
||||
|
@ -64,8 +64,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
OptimizerOptions opts;
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
|
||||
pflr_.reset(new ProcessFunctionLibraryRuntime(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||
opts, default_thread_pool));
|
||||
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, default_thread_pool));
|
||||
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
|
||||
}
|
||||
|
||||
|
@ -65,8 +65,8 @@ Status ProcessFunctionLibraryRuntime::FunctionData::DistributedInit(
|
||||
}
|
||||
|
||||
ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
|
||||
const DeviceMgr* device_mgr, Env* env, int graph_def_version,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
|
||||
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
thread::ThreadPool* default_thread_pool,
|
||||
DistributedFunctionLibraryRuntime* parent,
|
||||
@ -74,6 +74,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
|
||||
const SessionMetadata* session_metadata)
|
||||
: parent_(parent),
|
||||
env_(env),
|
||||
config_(config ? absl::make_optional(*config) : absl::nullopt),
|
||||
device_mgr_(device_mgr),
|
||||
lib_def_(lib_def),
|
||||
default_thread_pool_(default_thread_pool),
|
||||
@ -83,14 +84,16 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
|
||||
session_metadata_(session_metadata) {
|
||||
if (device_mgr == nullptr) {
|
||||
(*flr_map_)[nullptr] = NewFunctionLibraryRuntime(
|
||||
nullptr, env, nullptr, graph_def_version, lib_def_, default_thread_pool,
|
||||
optimizer_options, custom_kernel_creator, session_metadata_, this);
|
||||
nullptr, env, config_ ? &(*config_) : nullptr, nullptr,
|
||||
graph_def_version, lib_def_, default_thread_pool, optimizer_options,
|
||||
custom_kernel_creator, session_metadata_, this);
|
||||
return;
|
||||
}
|
||||
for (Device* d : device_mgr->ListDevices()) {
|
||||
(*flr_map_)[d] = NewFunctionLibraryRuntime(
|
||||
device_mgr, env, d, graph_def_version, lib_def_, default_thread_pool,
|
||||
optimizer_options, custom_kernel_creator, session_metadata_, this);
|
||||
device_mgr, env, config_ ? &(*config_) : nullptr, d, graph_def_version,
|
||||
lib_def_, default_thread_pool, optimizer_options, custom_kernel_creator,
|
||||
session_metadata_, this);
|
||||
}
|
||||
|
||||
DeviceMgr const* all_devices = device_mgr_;
|
||||
@ -1330,9 +1333,9 @@ Status ProcessFunctionLibraryRuntime::Clone(
|
||||
*out_lib_def = absl::make_unique<FunctionLibraryDefinition>(*lib_def_);
|
||||
}
|
||||
*out_pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr_, env, graph_def_version, out_lib_def->get(),
|
||||
optimizer_options, default_thread_pool_, parent_, custom_kernel_creator,
|
||||
session_metadata_);
|
||||
device_mgr_, env, config_ ? &(*config_) : nullptr, graph_def_version,
|
||||
out_lib_def->get(), optimizer_options, default_thread_pool_, parent_,
|
||||
custom_kernel_creator, session_metadata_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -41,8 +42,8 @@ class ProcessFunctionLibraryRuntime {
|
||||
// DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent
|
||||
// (if provided) outlive this object.
|
||||
ProcessFunctionLibraryRuntime(
|
||||
const DeviceMgr* device_mgr, Env* env, int graph_def_version,
|
||||
const FunctionLibraryDefinition* lib_def,
|
||||
const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
|
||||
int graph_def_version, const FunctionLibraryDefinition* lib_def,
|
||||
const OptimizerOptions& optimizer_options,
|
||||
thread::ThreadPool* thread_pool = nullptr,
|
||||
DistributedFunctionLibraryRuntime* parent = nullptr,
|
||||
@ -159,6 +160,8 @@ class ProcessFunctionLibraryRuntime {
|
||||
|
||||
const DeviceSet* device_set() { return &device_set_; }
|
||||
|
||||
const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; }
|
||||
|
||||
const FunctionLibraryDefinition* GetFunctionLibraryDefinition() const {
|
||||
return lib_def_;
|
||||
}
|
||||
@ -381,6 +384,7 @@ class ProcessFunctionLibraryRuntime {
|
||||
mutable mutex mu_;
|
||||
|
||||
Env* const env_;
|
||||
const absl::optional<const ConfigProto> config_;
|
||||
const DeviceMgr* const device_mgr_;
|
||||
DeviceSet device_set_;
|
||||
const FunctionLibraryDefinition* lib_def_;
|
||||
|
@ -121,8 +121,9 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
OptimizerOptions opts;
|
||||
cluster_flr_.reset(new TestClusterFLR(device_mgr_.get()));
|
||||
proc_flr_.reset(new ProcessFunctionLibraryRuntime(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||
opts, nullptr, cluster_flr_.get(), nullptr, session_metadata));
|
||||
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, nullptr, cluster_flr_.get(),
|
||||
nullptr, session_metadata));
|
||||
rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
|
||||
}
|
||||
|
||||
@ -295,8 +296,8 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) {
|
||||
OptimizerOptions opts;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr(
|
||||
new ProcessFunctionLibraryRuntime(
|
||||
nullptr /* device_mgr */, Env::Default(), TF_GRAPH_DEF_VERSION,
|
||||
lib_def.get(), opts, nullptr, nullptr /* cluster_flr */));
|
||||
nullptr /* device_mgr */, Env::Default(), /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, lib_def.get(), opts));
|
||||
FunctionLibraryRuntime* flr =
|
||||
proc_flr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
EXPECT_NE(flr, nullptr);
|
||||
|
@ -128,9 +128,10 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
|
||||
LOG(INFO) << "Creating " << (request->async() ? "async" : "sync")
|
||||
<< " eager service context with rendezvous_id on host "
|
||||
<< port::Hostname() << " " << worker_session->worker_name();
|
||||
SessionOptions opts;
|
||||
opts.config = request->server_def().default_session_config();
|
||||
tensorflow::EagerContext* ctx = new tensorflow::EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
opts, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(),
|
||||
device_mgr, false, r, GetDefaultCustomKernelCreator(),
|
||||
worker_session->cluster_flr());
|
||||
|
@ -234,7 +234,6 @@ TEST_F(EagerServiceImplTest, BasicTest) {
|
||||
|
||||
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
|
||||
|
||||
|
||||
EnqueueRequest remote_enqueue_request;
|
||||
remote_enqueue_request.set_context_id(context_id);
|
||||
EnqueueResponse remote_enqueue_response;
|
||||
@ -400,8 +399,9 @@ TEST_F(EagerServiceImplTest, EagerPFLRTest) {
|
||||
auto device_mgr = absl::make_unique<StaticDeviceMgr>(
|
||||
DeviceFactory::NewDevice("CPU", {}, "/job:localhost/replica:0/task:1"));
|
||||
auto eager_pflr = absl::make_unique<EagerProcessFunctionLibraryRuntime>(
|
||||
device_mgr.get(), Env::Default(), TF_GRAPH_DEF_VERSION, &func_lib_def,
|
||||
OptimizerOptions(), nullptr, eager_cluster_flr.get(), nullptr);
|
||||
device_mgr.get(), Env::Default(), /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, &func_lib_def, OptimizerOptions(), nullptr,
|
||||
eager_cluster_flr.get(), nullptr);
|
||||
|
||||
tensorflow::FunctionDef fdef = MatMulFunction();
|
||||
TF_ASSERT_OK(func_lib_def.AddFunctionDef(fdef));
|
||||
|
@ -121,13 +121,11 @@ Status GraphMgr::DecorateAndPublishGraphForDebug(
|
||||
//
|
||||
// "executors" are filled with one executor per device if success and
|
||||
// the caller takes the ownership of returned executors.
|
||||
Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef,
|
||||
WorkerSession* session,
|
||||
const GraphOptions& graph_options,
|
||||
const DebugOptions& debug_options,
|
||||
int64 collective_graph_key,
|
||||
DistributedFunctionLibraryRuntime* cluster_flr,
|
||||
Item* item) {
|
||||
Status GraphMgr::InitItem(
|
||||
const string& handle, const GraphDef& gdef, WorkerSession* session,
|
||||
const GraphOptions& graph_options, const DebugOptions& debug_options,
|
||||
const ConfigProto& config_proto, int64 collective_graph_key,
|
||||
DistributedFunctionLibraryRuntime* cluster_flr, Item* item) {
|
||||
item->session = handle;
|
||||
item->collective_graph_key = collective_graph_key;
|
||||
item->lib_def.reset(
|
||||
@ -139,9 +137,10 @@ Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef,
|
||||
// does that below.
|
||||
|
||||
item->proc_flr.reset(new ProcessFunctionLibraryRuntime(
|
||||
device_mgr_, worker_env_->env, gdef.versions().producer(),
|
||||
item->lib_def.get(), graph_options.optimizer_options(),
|
||||
worker_env_->compute_pool, cluster_flr));
|
||||
device_mgr_, worker_env_->env, /*config=*/&config_proto,
|
||||
gdef.versions().producer(), item->lib_def.get(),
|
||||
graph_options.optimizer_options(), worker_env_->compute_pool,
|
||||
cluster_flr));
|
||||
|
||||
// Constructs the graph out of "gdef".
|
||||
Graph graph(OpRegistry::Global());
|
||||
@ -287,16 +286,14 @@ Status GraphMgr::InitItem(const string& handle, const GraphDef& gdef,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphMgr::Register(const string& handle, const GraphDef& gdef,
|
||||
WorkerSession* session,
|
||||
const GraphOptions& graph_options,
|
||||
const DebugOptions& debug_options,
|
||||
int64 collective_graph_key,
|
||||
DistributedFunctionLibraryRuntime* cluster_flr,
|
||||
string* graph_handle) {
|
||||
Status GraphMgr::Register(
|
||||
const string& handle, const GraphDef& gdef, WorkerSession* session,
|
||||
const GraphOptions& graph_options, const DebugOptions& debug_options,
|
||||
const ConfigProto& config_proto, int64 collective_graph_key,
|
||||
DistributedFunctionLibraryRuntime* cluster_flr, string* graph_handle) {
|
||||
Item* item = new Item;
|
||||
Status s = InitItem(handle, gdef, session, graph_options, debug_options,
|
||||
collective_graph_key, cluster_flr, item);
|
||||
config_proto, collective_graph_key, cluster_flr, item);
|
||||
if (!s.ok()) {
|
||||
item->Unref();
|
||||
return s;
|
||||
|
@ -76,7 +76,8 @@ class GraphMgr {
|
||||
// reference to cluster_flr to do cross process function calls.
|
||||
Status Register(const string& handle, const GraphDef& gdef,
|
||||
WorkerSession* session, const GraphOptions& graph_options,
|
||||
const DebugOptions& debug_options, int64 collective_graph_key,
|
||||
const DebugOptions& debug_options,
|
||||
const ConfigProto& config_proto, int64 collective_graph_key,
|
||||
DistributedFunctionLibraryRuntime* cluster_flr,
|
||||
string* graph_handle);
|
||||
|
||||
@ -179,7 +180,8 @@ class GraphMgr {
|
||||
|
||||
Status InitItem(const string& handle, const GraphDef& gdef,
|
||||
WorkerSession* session, const GraphOptions& graph_options,
|
||||
const DebugOptions& debug_options, int64 collective_graph_key,
|
||||
const DebugOptions& debug_options,
|
||||
const ConfigProto& config_proto, int64 collective_graph_key,
|
||||
DistributedFunctionLibraryRuntime* cluster_flr, Item* item);
|
||||
|
||||
Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options,
|
||||
|
@ -472,6 +472,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
|
||||
c->req.set_session_handle(session_handle_);
|
||||
c->req.set_create_worker_session_called(!should_deregister_);
|
||||
c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
|
||||
*c->req.mutable_config_proto() = session_opts_.config;
|
||||
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
|
||||
*c->req.mutable_debug_options() =
|
||||
callable_opts_.run_options().debug_options();
|
||||
|
@ -80,8 +80,8 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
|
||||
s = session->graph_mgr()->Register(
|
||||
request->session_handle(), request->graph_def(), session.get(),
|
||||
request->graph_options(), request->debug_options(),
|
||||
request->collective_graph_key(), session->cluster_flr(),
|
||||
response->mutable_graph_handle());
|
||||
request->config_proto(), request->collective_graph_key(),
|
||||
session->cluster_flr(), response->mutable_graph_handle());
|
||||
}
|
||||
done(s);
|
||||
}
|
||||
|
@ -297,6 +297,7 @@ class IteratorContext {
|
||||
allocator_getter = [device](AllocatorAttributes attrs) {
|
||||
return device->GetAllocator(attrs);
|
||||
};
|
||||
|
||||
thread::ThreadPool* thread_pool =
|
||||
ctx->device()->tensorflow_device_thread_pool();
|
||||
if (thread_pool) {
|
||||
|
@ -752,6 +752,9 @@ class FunctionLibraryRuntime {
|
||||
// Returns the environment on which the function executes.
|
||||
virtual Env* env() = 0;
|
||||
|
||||
// Returns the ConfigProto passed to the session used to create the function.
|
||||
virtual const ConfigProto* const config_proto() = 0;
|
||||
|
||||
// Returns a debug string showing the definition of the function of
|
||||
// 'handle'.
|
||||
virtual string DebugString(Handle handle) = 0;
|
||||
|
@ -256,7 +256,7 @@ Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
|
||||
|
||||
// Create the function library runtime.
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env,
|
||||
new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, &options.config,
|
||||
graph_def.versions().producer(),
|
||||
&function_library, *optimizer_opts));
|
||||
FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name());
|
||||
|
@ -529,10 +529,11 @@ Status CapturedFunction::Instantiate(
|
||||
// The context's runtime will be used for all subsequent calls.
|
||||
FunctionLibraryRuntime* lib = ctx->flr();
|
||||
FunctionLibraryRuntime::InstantiateOptions inst_opts;
|
||||
// TODO(b/141576771): Propagate ctx ConfigProto into inst_opts.config_proto.
|
||||
inst_opts.lib_def = metadata_->lib_def();
|
||||
inst_opts.create_kernels_eagerly = true;
|
||||
inst_opts.default_device_to_target = metadata_->use_default_device();
|
||||
inst_opts.config_proto =
|
||||
lib->config_proto() ? *lib->config_proto() : ConfigProto();
|
||||
if (!metadata_->use_inter_op_parallelism()) {
|
||||
inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
|
||||
}
|
||||
|
@ -419,8 +419,9 @@ Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
|
||||
|
||||
OptimizerOptions opts;
|
||||
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||
opts, thread_pool_.get(), nullptr /* cluster_flr */);
|
||||
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, thread_pool_.get(),
|
||||
nullptr /* cluster_flr */);
|
||||
flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
|
||||
if (thread_pool_ == nullptr) {
|
||||
runner_ = [](std::function<void()> fn) { fn(); };
|
||||
|
@ -355,16 +355,18 @@ FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
|
||||
// in its resource manager. The existing device will outlive the
|
||||
// IteratorResource, because we are storing the IteratorResource
|
||||
// in that device's resource manager.
|
||||
|
||||
*device_mgr =
|
||||
absl::make_unique<StaticDeviceMgr>(RenamedDevice::NewRenamedDevice(
|
||||
ctx->device()->name(), down_cast<Device*>(ctx->device()),
|
||||
false /* owns_underlying */, false /* isolate_session_state */));
|
||||
*flib_def = absl::make_unique<FunctionLibraryDefinition>(
|
||||
*ctx->function_library()->GetFunctionLibraryDefinition());
|
||||
const auto* config = ctx->function_library()->config_proto();
|
||||
*pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr->get(), ctx->env(), graph_def_version_, flib_def->get(),
|
||||
OptimizerOptions{} /* TODO(mrry): OptimizerOptions? */,
|
||||
nullptr /* TODO(mrry): ClusterFLR */);
|
||||
device_mgr->get(), ctx->env(),
|
||||
/*config=*/config, graph_def_version_, flib_def->get(),
|
||||
config->graph_options().optimizer_options());
|
||||
|
||||
return (*pflr)->GetFLR(ctx->device()->name());
|
||||
}
|
||||
|
@ -325,7 +325,14 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
AttrValueMap attr_values = func_.attr();
|
||||
|
||||
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
|
||||
// TODO(b/141576771): Propagate ctxt's ConfigProto into opts.config_proto.
|
||||
|
||||
const auto* config = (ctx->function_library())
|
||||
? ctx->function_library()->config_proto()
|
||||
: nullptr;
|
||||
if (config) {
|
||||
instantiate_opts.config_proto = *config;
|
||||
}
|
||||
|
||||
instantiate_opts.target = target_device;
|
||||
|
||||
FunctionTarget function_target = {target_device, lib};
|
||||
|
@ -29,8 +29,8 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
|
||||
device_ = device.get();
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(device));
|
||||
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, flib_def_.get(),
|
||||
OptimizerOptions());
|
||||
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions());
|
||||
|
||||
device_type_ = device_type;
|
||||
#ifdef GOOGLE_CUDA
|
||||
|
@ -84,8 +84,8 @@ class OpsTestBase : public ::testing::Test {
|
||||
flib_def_ = absl::make_unique<FunctionLibraryDefinition>(
|
||||
OpRegistry::Global(), FunctionDefLibrary{});
|
||||
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION,
|
||||
flib_def_.get(), OptimizerOptions());
|
||||
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, flib_def_.get(), OptimizerOptions());
|
||||
}
|
||||
|
||||
~OpsTestBase() override {
|
||||
|
@ -166,7 +166,12 @@ Status PartitionedCallOp::Instantiate(FunctionLibraryRuntime* lib,
|
||||
std::vector<Tensor>* inputs,
|
||||
FunctionLibraryRuntime::Handle* handle) {
|
||||
FunctionLibraryRuntime::InstantiateOptions opts;
|
||||
// TODO(b/141576771): Propagate ctxt's ConfigProto into opts.config_proto.
|
||||
const auto* config = (ctx->function_library())
|
||||
? ctx->function_library()->config_proto()
|
||||
: nullptr;
|
||||
if (config) {
|
||||
opts.config_proto = *config;
|
||||
}
|
||||
|
||||
#ifndef __ANDROID__
|
||||
// Android tf library does not include grappler.
|
||||
|
@ -132,6 +132,11 @@ message RegisterGraphRequest {
|
||||
// concurrently so that BufRendezvous entries will make the correct
|
||||
// values accessible.
|
||||
int64 collective_graph_key = 7;
|
||||
|
||||
// ConfigProto from the session in which this graph was created.
|
||||
// Contains additional parameters beyond graph_options, including
|
||||
// the name of the requested executor.
|
||||
ConfigProto config_proto = 8;
|
||||
}
|
||||
|
||||
message RegisterGraphResponse {
|
||||
|
@ -2237,8 +2237,9 @@ bool InlineAllFunctions(GraphDef* graphdef) {
|
||||
tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
|
||||
tensorflow::OptimizerOptions o_opts;
|
||||
tensorflow::ProcessFunctionLibraryRuntime pflr(
|
||||
&device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld,
|
||||
o_opts, nullptr);
|
||||
&device_mgr, tensorflow::Env::Default(), &options.config,
|
||||
TF_GRAPH_DEF_VERSION, &fld,
|
||||
options.config.graph_options().optimizer_options(), nullptr);
|
||||
tensorflow::FunctionLibraryRuntime* flr;
|
||||
flr = pflr.GetFLR("/job:localhost/replica:0/task:0/cpu:0");
|
||||
|
||||
|
@ -124,8 +124,11 @@ def connect_to_cluster(cluster_spec_or_resolver,
|
||||
job_def.tasks[0] = "localhost:{}".format(local_port)
|
||||
|
||||
server_def = ServerDef(
|
||||
cluster=cluster_def, job_name=job_name, task_index=task_index,
|
||||
protocol=protocol)
|
||||
cluster=cluster_def,
|
||||
job_name=job_name,
|
||||
task_index=task_index,
|
||||
protocol=protocol,
|
||||
default_session_config=context.context().config)
|
||||
|
||||
context.set_server_def(server_def)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user