Avoid partially creating/updating cluster when some workers fail during update.
PiperOrigin-RevId: 309397022 Change-Id: I0c20db46a5c0bc629a662a5854d4b72e39b82322
This commit is contained in:
parent
228aa4119b
commit
9596f52784
tensorflow/c/eager
@ -342,7 +342,7 @@ tf_cuda_cc_test(
|
|||||||
|
|
||||||
tf_cuda_cc_test(
|
tf_cuda_cc_test(
|
||||||
name = "c_api_remote_test",
|
name = "c_api_remote_test",
|
||||||
size = "small",
|
size = "medium",
|
||||||
srcs = [
|
srcs = [
|
||||||
"c_api_remote_test.cc",
|
"c_api_remote_test.cc",
|
||||||
],
|
],
|
||||||
@ -364,6 +364,8 @@ tf_cuda_cc_test(
|
|||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||||
|
"//tensorflow/core/platform:env",
|
||||||
|
"@com_google_absl//absl/debugging:leak_check",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -500,6 +500,17 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
grpc_server->master_env()->worker_cache->GetEagerClientCache(
|
||||||
&remote_eager_workers));
|
&remote_eager_workers));
|
||||||
|
|
||||||
|
// For cluster update, use a status group to aggregate statuses from
|
||||||
|
// * adding and removing remote devices
|
||||||
|
// * creating remote contexts on newly added workers
|
||||||
|
// * updating remote contexts on existing workers
|
||||||
|
// * updating the master context
|
||||||
|
// Note that we should not return immediately on errors in the middle of these
|
||||||
|
// updates to prevent cluster from having inconsistent context views.
|
||||||
|
//
|
||||||
|
// Unused if `reset_context` is True.
|
||||||
|
tensorflow::StatusGroup sg;
|
||||||
|
|
||||||
// When updating an existing context, populate the following lists with:
|
// When updating an existing context, populate the following lists with:
|
||||||
// * added_workers: set(remote_workers) - set(curr_remote_workers)
|
// * added_workers: set(remote_workers) - set(curr_remote_workers)
|
||||||
// * removed_workers: set(curr_remote_workers) - set(remote_workers)
|
// * removed_workers: set(curr_remote_workers) - set(remote_workers)
|
||||||
@ -535,7 +546,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
|
DifferentiateWorkerLists(&curr_remote_workers, &remote_workers,
|
||||||
&added_workers, &removed_workers,
|
&added_workers, &removed_workers,
|
||||||
&existing_workers);
|
&existing_workers);
|
||||||
LOG_AND_RETURN_IF_ERROR(GetReplacedFromExistingWorkers(
|
sg.Update(GetReplacedFromExistingWorkers(
|
||||||
&existing_workers, context_id, context->GetContextViewId(), server_def,
|
&existing_workers, context_id, context->GetContextViewId(), server_def,
|
||||||
remote_eager_workers.get(), &replaced_workers));
|
remote_eager_workers.get(), &replaced_workers));
|
||||||
if (VLOG_IS_ON(1)) {
|
if (VLOG_IS_ON(1)) {
|
||||||
@ -559,11 +570,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
existing_workers.end());
|
existing_workers.end());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOG_AND_RETURN_IF_ERROR(
|
sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
|
||||||
RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr));
|
sg.Update(AddRemoteDevicesToMgr(added_workers,
|
||||||
LOG_AND_RETURN_IF_ERROR(AddRemoteDevicesToMgr(
|
grpc_server->master_env()->worker_cache,
|
||||||
added_workers, grpc_server->master_env()->worker_cache,
|
remote_device_mgr));
|
||||||
remote_device_mgr));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
|
std::vector<tensorflow::DeviceAttributes> cluster_device_attributes;
|
||||||
@ -584,7 +594,6 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize remote eager workers.
|
// Initialize remote eager workers.
|
||||||
// TODO(b/138847548) Create remote eager contexts in async mode by default.
|
|
||||||
if (reset_context) {
|
if (reset_context) {
|
||||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||||
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
|
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||||
@ -596,7 +605,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
// existing workers to also have the updated context_view_id, so
|
// existing workers to also have the updated context_view_id, so
|
||||||
// we must set their context_view_id to the existing master's
|
// we must set their context_view_id to the existing master's
|
||||||
// context_view_id + 1.
|
// context_view_id + 1.
|
||||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
sg.Update(CreateRemoteContexts(
|
||||||
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs,
|
||||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||||
@ -606,10 +615,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
VLOG(1) << "Updating cluster with existing worker " << w;
|
VLOG(1) << "Updating cluster with existing worker " << w;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOG_AND_RETURN_IF_ERROR(UpdateRemoteContexts(
|
sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers,
|
||||||
ctx, existing_workers, added_workers, removed_workers, context_id,
|
removed_workers, context_id,
|
||||||
context_view_id + 1, server_def, remote_eager_workers.get(),
|
context_view_id + 1, server_def,
|
||||||
base_request));
|
remote_eager_workers.get(), base_request));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -645,13 +654,13 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
// GrpcServer cannot be destroyed after it is started.
|
// GrpcServer cannot be destroyed after it is started.
|
||||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||||
} else {
|
} else {
|
||||||
LOG_AND_RETURN_IF_ERROR(
|
sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession(
|
||||||
grpc_server->worker_env()->session_mgr->UpdateSession(
|
session_name, server_def, base_request.cluster_device_attributes(),
|
||||||
session_name, server_def, base_request.cluster_device_attributes(),
|
/*isolate_session_state=*/true));
|
||||||
/*isolate_session_state=*/true));
|
sg.Update(context->UpdateRemoteMaster(context_id,
|
||||||
LOG_AND_RETURN_IF_ERROR(
|
std::move(remote_eager_workers),
|
||||||
context->UpdateRemoteMaster(context_id, std::move(remote_eager_workers),
|
added_workers, removed_workers));
|
||||||
added_workers, removed_workers));
|
LOG_AND_RETURN_IF_ERROR(sg.as_summary_status());
|
||||||
}
|
}
|
||||||
#undef LOG_AND_RETURN_IF_ERROR
|
#undef LOG_AND_RETURN_IF_ERROR
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "absl/debugging/leak_check.h"
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
@ -21,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||||
#include "tensorflow/core/platform/casts.h"
|
#include "tensorflow/core/platform/casts.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/protobuf/cluster.pb.h"
|
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||||
@ -527,4 +529,124 @@ TEST(CAPI, RemoteExecuteChangeServerDefAsync) {
|
|||||||
TestRemoteExecuteChangeServerDef(true);
|
TestRemoteExecuteChangeServerDef(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestRemoteExecuteUpdateServerDef(bool async) {
|
||||||
|
// TODO(b/136478427): Skip heap checker for leaked gRPC server instances.
|
||||||
|
absl::LeakCheckDisabler disabler;
|
||||||
|
|
||||||
|
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||||
|
// This server def has the task index set to 0.
|
||||||
|
string serialized = server_def.SerializeAsString();
|
||||||
|
|
||||||
|
server_def.set_task_index(1);
|
||||||
|
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server->Start().ok());
|
||||||
|
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||||
|
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
const char local_device_name[] =
|
||||||
|
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
|
const char remote_device_name[] =
|
||||||
|
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||||
|
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||||
|
|
||||||
|
TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
|
||||||
|
status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||||
|
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
|
||||||
|
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||||
|
worker_server.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, RemoteExecuteUpdateServerDef) {
|
||||||
|
TestRemoteExecuteUpdateServerDef(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, RemoteExecuteUpdateServerDefAsync) {
|
||||||
|
TestRemoteExecuteUpdateServerDef(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestRemoteExecuteUpdateServerDefWithFailures(bool async) {
|
||||||
|
// TODO(b/136478427): Skip heap checker for leaked gRPC server instances.
|
||||||
|
absl::LeakCheckDisabler disabler;
|
||||||
|
// Fail fast on GetStatus requests so we can get errors instead of timeout
|
||||||
|
// when updating cluster with non-exsitent worker
|
||||||
|
tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
|
||||||
|
|
||||||
|
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||||
|
// This server def has the task index set to 0.
|
||||||
|
string serialized = server_def.SerializeAsString();
|
||||||
|
|
||||||
|
server_def.set_task_index(1);
|
||||||
|
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server->Start().ok());
|
||||||
|
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||||
|
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
const char local_device_name[] =
|
||||||
|
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
|
const char remote_device_name[] =
|
||||||
|
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||||
|
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||||
|
|
||||||
|
// Adding a non-existent remote worker to cluster def. This should cause the
|
||||||
|
// UpdateServerDef call to fail.
|
||||||
|
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||||
|
tensorflow::JobDef* job_def = cluster_def->mutable_job(0);
|
||||||
|
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||||
|
job_def->mutable_tasks()->insert(
|
||||||
|
{2, tensorflow::strings::StrCat("localhost:", port)});
|
||||||
|
string serialized_update = server_def.SerializeAsString();
|
||||||
|
TFE_ContextUpdateServerDef(ctx, 0, serialized_update.data(),
|
||||||
|
serialized_update.size(), status);
|
||||||
|
EXPECT_NE(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
// Even after the prevoiusly failed cluster update, another update and op
|
||||||
|
// execution should work fine as long as the provided server_def is valid.
|
||||||
|
TFE_ContextUpdateServerDef(ctx, 0, serialized.data(), serialized.size(),
|
||||||
|
status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name);
|
||||||
|
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
|
||||||
|
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||||
|
worker_server.release();
|
||||||
|
tensorflow::unsetenv("GRPC_FAIL_FAST");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, RemoteExecuteUpdateServerDefWithFailures) {
|
||||||
|
TestRemoteExecuteUpdateServerDefWithFailures(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) {
|
||||||
|
TestRemoteExecuteUpdateServerDefWithFailures(true);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user