Make all function library runtime invocations use a new randomly generated step
ID. PiperOrigin-RevId: 246977125
This commit is contained in:
parent
4624a9ee5f
commit
ead79afcd6
@ -259,7 +259,6 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
|
||||
}
|
||||
auto lib = ctx->function_library();
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.step_id = ctx->step_id();
|
||||
opts.rendezvous = ctx->rendezvous();
|
||||
opts.cancellation_manager = ctx->cancellation_manager();
|
||||
opts.runner = ctx->runner();
|
||||
|
@ -498,7 +498,6 @@ class CallOp : public AsyncOpKernel {
|
||||
errors::Internal("No function library is provided."),
|
||||
done);
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.step_id = ctx->step_id();
|
||||
opts.rendezvous = ctx->rendezvous();
|
||||
opts.cancellation_manager = ctx->cancellation_manager();
|
||||
opts.step_container = ctx->step_container();
|
||||
|
@ -196,9 +196,7 @@ void ClusterFunctionLibraryRuntime::Run(
|
||||
req->set_session_handle(worker_session_->session_name);
|
||||
req->set_create_worker_session_called(create_worker_session_called_);
|
||||
req->set_graph_handle(function_data->graph_handle);
|
||||
// Borrowed from master_session.cc
|
||||
const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
|
||||
req->set_step_id(step_id);
|
||||
req->set_step_id(opts.step_id);
|
||||
int i = 0;
|
||||
for (const auto& send_key : function_data->send_keys) {
|
||||
NamedTensorProto* send = req->add_send();
|
||||
@ -212,7 +210,7 @@ void ClusterFunctionLibraryRuntime::Run(
|
||||
}
|
||||
|
||||
CleanupGraphRequest* cleanup_req = new CleanupGraphRequest;
|
||||
cleanup_req->set_step_id(step_id);
|
||||
cleanup_req->set_step_id(opts.step_id);
|
||||
|
||||
RunGraphResponse* resp = new RunGraphResponse();
|
||||
CleanupGraphResponse* cleanup_resp = new CleanupGraphResponse;
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
@ -27,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
@ -628,8 +630,13 @@ class FunctionLibraryRuntime {
|
||||
// In the cross-process scenario, runner isn't used for making the Async
|
||||
// RPC calls.
|
||||
struct Options {
|
||||
// The id of the step that is calling this function.
|
||||
int64 step_id = 0;
|
||||
// Choose a step ID that is guaranteed not to clash with any
|
||||
// Session-generated step ID. DirectSession only generates
|
||||
// non-negative step IDs (contiguous, starting from 0), and
|
||||
// MasterSession generates 56-bit random step IDs whose MSB is
|
||||
// always 0, so a negative random step ID should suffice.
|
||||
const int64 step_id = -std::abs(static_cast<int64>(random::New64()));
|
||||
|
||||
Rendezvous* rendezvous = nullptr;
|
||||
CancellationManager* cancellation_manager = nullptr;
|
||||
CollectiveExecutor* collective_executor = nullptr;
|
||||
|
@ -520,7 +520,6 @@ class BatchResource : public ResourceBase {
|
||||
return;
|
||||
}
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.step_id = last_task_context->step_id();
|
||||
opts.step_container = last_task_context->step_container();
|
||||
opts.cancellation_manager = last_task_context->cancellation_manager();
|
||||
opts.stats_collector = last_task_context->stats_collector();
|
||||
|
@ -516,7 +516,6 @@ Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime::Options f_opts;
|
||||
f_opts.step_id = InstantiatedCapturedFunction::generate_step_id();
|
||||
ScopedStepContainer step_container(
|
||||
f_opts.step_id, [this](const string& name) {
|
||||
lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
|
||||
@ -558,7 +557,6 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime::Options f_opts;
|
||||
f_opts.step_id = InstantiatedCapturedFunction::generate_step_id();
|
||||
ScopedStepContainer step_container(
|
||||
f_opts.step_id, [this](const string& name) {
|
||||
lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
|
||||
@ -599,7 +597,6 @@ Status InstantiatedCapturedFunction::RunInstantiated(
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime::Options f_opts;
|
||||
f_opts.step_id = InstantiatedCapturedFunction::generate_step_id();
|
||||
ScopedStepContainer step_container(
|
||||
f_opts.step_id, [this](const string& name) {
|
||||
lib_->device()->resource_manager()->Cleanup(name).IgnoreError();
|
||||
@ -655,7 +652,6 @@ void InstantiatedCapturedFunction::RunAsync(
|
||||
std::move(args), &captured_func_->captured_inputs(), ret_types_);
|
||||
|
||||
FunctionLibraryRuntime::Options f_opts;
|
||||
f_opts.step_id = InstantiatedCapturedFunction::generate_step_id();
|
||||
ResourceMgr* resource_mgr = lib_->device()->resource_manager();
|
||||
ScopedStepContainer* step_container = new ScopedStepContainer(
|
||||
f_opts.step_id, [resource_mgr](const string& name) {
|
||||
|
@ -88,16 +88,6 @@ class InstantiatedCapturedFunction {
|
||||
FunctionLibraryRuntime::DoneCallback done,
|
||||
const string& prefix) const;
|
||||
|
||||
// Returns a step ID for use when running an `InstantiatedCapturedFunction`.
|
||||
static int64 generate_step_id() {
|
||||
// Choose a step ID that is guaranteed not to clash with any
|
||||
// Session-generated step ID. DirectSession only generates
|
||||
// non-negative step IDs (contiguous, starting from 0), and
|
||||
// MasterSession generates 56-bit random step IDs whose MSB is
|
||||
// always 0, so a negative random step ID should suffice.
|
||||
return -std::abs(static_cast<int64>(random::New64()));
|
||||
}
|
||||
|
||||
private:
|
||||
InstantiatedCapturedFunction(
|
||||
FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle,
|
||||
|
@ -916,12 +916,6 @@ class OneShotIteratorOp : public AsyncOpKernel {
|
||||
&f_handle));
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.cancellation_manager = ctx->cancellation_manager();
|
||||
// Choose a step ID that is guaranteed not to clash with any
|
||||
// Session-generated step ID. DirectSession only generates
|
||||
// non-negative step IDs (contiguous, starting from 0), and
|
||||
// MasterSession generates 56-bit random step IDs whose MSB is
|
||||
// always 0, so a negative random step ID should suffice.
|
||||
opts.step_id = -std::abs(static_cast<int64>(random::New64()));
|
||||
ScopedStepContainer step_container(opts.step_id, [ctx](const string& name) {
|
||||
ctx->resource_manager()->Cleanup(name).IgnoreError();
|
||||
});
|
||||
|
@ -250,7 +250,6 @@ class MapDefunOp : public AsyncOpKernel {
|
||||
void SetRunOptions(OpKernelContext* ctx,
|
||||
FunctionLibraryRuntime::Options* opts,
|
||||
ComputeOptions* compute_opts, bool always_collect_stats) {
|
||||
opts->step_id = ctx->step_id();
|
||||
opts->rendezvous = ctx->rendezvous();
|
||||
if (always_collect_stats) {
|
||||
opts->stats_collector = ctx->stats_collector();
|
||||
|
@ -250,7 +250,6 @@ class SymbolicGradientOp : public AsyncOpKernel {
|
||||
ctx, lib->Instantiate(kGradientOp, AttrSlice(def()), &handle), done);
|
||||
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.step_id = ctx->step_id();
|
||||
opts.rendezvous = ctx->rendezvous();
|
||||
opts.cancellation_manager = ctx->cancellation_manager();
|
||||
opts.runner = ctx->runner();
|
||||
@ -352,7 +351,6 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
|
||||
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
opts.step_id = ctx->step_id();
|
||||
opts.runner = ctx->runner();
|
||||
opts.source_device = source_device;
|
||||
if (opts.source_device != target_device) {
|
||||
|
@ -114,7 +114,6 @@ Status SetOutputs(const OpKernel* kernel, OpKernelContext* ctx,
|
||||
|
||||
void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
|
||||
bool always_collect_stats) {
|
||||
opts->step_id = ctx->step_id();
|
||||
opts->rendezvous = ctx->rendezvous();
|
||||
opts->cancellation_manager = ctx->cancellation_manager();
|
||||
if (always_collect_stats) {
|
||||
|
@ -218,7 +218,6 @@ void PartitionedCallOp::RunFunction(FunctionLibraryRuntime::Handle handle,
|
||||
FunctionLibraryRuntime* lib,
|
||||
OpKernelContext* ctx, DoneCallback done) {
|
||||
FunctionLibraryRuntime::Options run_opts;
|
||||
run_opts.step_id = ctx->step_id();
|
||||
run_opts.step_container = ctx->step_container();
|
||||
run_opts.cancellation_manager = ctx->cancellation_manager();
|
||||
run_opts.stats_collector = ctx->stats_collector();
|
||||
|
Loading…
Reference in New Issue
Block a user