Fix c_api_remote_test tsan flakiness.
PiperOrigin-RevId: 313813747 Change-Id: I428eaa271adcb0cca3236edd0c52232dda9719a6
This commit is contained in:
parent
dae0485fed
commit
58f1e31019
@ -370,7 +370,6 @@ tf_cuda_cc_test(
|
|||||||
extra_copts = tfe_xla_copts(),
|
extra_copts = tfe_xla_copts(),
|
||||||
tags = [
|
tags = [
|
||||||
"noasan", # leaks gRPC server instances
|
"noasan", # leaks gRPC server instances
|
||||||
"notsan", # b/157098283
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":c_api",
|
":c_api",
|
||||||
@ -392,6 +391,36 @@ tf_cuda_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cuda_cc_test(
|
||||||
|
name = "c_api_distributed_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = [
|
||||||
|
"c_api_distributed_test.cc",
|
||||||
|
],
|
||||||
|
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||||
|
args = ["--heap_check=local"],
|
||||||
|
extra_copts = tfe_xla_copts(),
|
||||||
|
tags = ["noasan"], # leaks gRPC server instances
|
||||||
|
deps = [
|
||||||
|
":c_api",
|
||||||
|
":c_api_experimental",
|
||||||
|
":c_api_internal",
|
||||||
|
":c_api_test_util",
|
||||||
|
":tfe_tensorhandle_internal",
|
||||||
|
"//tensorflow/c:c_test_util",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:graph",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/common_runtime:function_optimization_registry",
|
||||||
|
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||||
|
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cuda_cc_test(
|
tf_cuda_cc_test(
|
||||||
name = "c_api_cluster_test",
|
name = "c_api_cluster_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -30,26 +30,6 @@ namespace {
|
|||||||
|
|
||||||
using ::tensorflow::string;
|
using ::tensorflow::string;
|
||||||
|
|
||||||
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
|
|
||||||
tensorflow::ServerDef server_def;
|
|
||||||
server_def.set_protocol("grpc");
|
|
||||||
server_def.set_job_name(job_name);
|
|
||||||
server_def.set_task_index(0);
|
|
||||||
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
|
||||||
tensorflow::JobDef* job_def = cluster_def->add_job();
|
|
||||||
job_def->set_name(job_name);
|
|
||||||
for (int i = 0; i < num_tasks; i++) {
|
|
||||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
|
||||||
job_def->mutable_tasks()->insert(
|
|
||||||
{i, tensorflow::strings::StrCat("localhost", ":", port)});
|
|
||||||
}
|
|
||||||
return server_def;
|
|
||||||
}
|
|
||||||
|
|
||||||
tensorflow::ServerDef GetServerDef(int num_tasks) {
|
|
||||||
return GetServerDef("localhost", num_tasks);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
|
void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index) {
|
||||||
tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
|
tensorflow::JobDef* job_def = server_def->mutable_cluster()->mutable_job(0);
|
||||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||||
|
506
tensorflow/c/eager/c_api_distributed_test.cc
Normal file
506
tensorflow/c/eager/c_api_distributed_test.cc
Normal file
@ -0,0 +1,506 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||||
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||||
|
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||||
|
#include "tensorflow/core/framework/function.h"
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
#include "tensorflow/core/platform/casts.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||||
|
#include "tensorflow/core/protobuf/config.pb.h"
|
||||||
|
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::tensorflow::string;
|
||||||
|
|
||||||
|
// Add the values of three variables on three different tasks.
|
||||||
|
string AddVariablesFunction() {
|
||||||
|
tensorflow::FunctionDef def;
|
||||||
|
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||||
|
" signature {"
|
||||||
|
" name: 'AddVariablesFunction'"
|
||||||
|
" input_arg {"
|
||||||
|
" name: 'var'"
|
||||||
|
" type: DT_RESOURCE"
|
||||||
|
" }"
|
||||||
|
" output_arg {"
|
||||||
|
" name: 'sum'"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'read0'"
|
||||||
|
" op: 'ReadVariableOp'"
|
||||||
|
" input: 'var'"
|
||||||
|
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'dtype'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'read1'"
|
||||||
|
" op: 'ReadVariableOp'"
|
||||||
|
" input: 'var'"
|
||||||
|
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'dtype'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'read2'"
|
||||||
|
" op: 'ReadVariableOp'"
|
||||||
|
" input: 'var'"
|
||||||
|
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'dtype'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'add1'"
|
||||||
|
" op: 'Add'"
|
||||||
|
" input: 'read0:value:0'"
|
||||||
|
" input: 'read1:value:0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'T'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'add2'"
|
||||||
|
" op: 'Add'"
|
||||||
|
" input: 'add1:z:0'"
|
||||||
|
" input: 'read2:value:0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'T'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" ret {"
|
||||||
|
" key: 'sum'"
|
||||||
|
" value: 'add2:z:0'"
|
||||||
|
" }",
|
||||||
|
&def));
|
||||||
|
return def.SerializeAsString();
|
||||||
|
}
|
||||||
|
|
||||||
|
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(op, var_handle, status);
|
||||||
|
TFE_TensorHandle* is_initialized[1] = {nullptr};
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
|
||||||
|
CHECK_EQ(1, num_retvals);
|
||||||
|
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
|
||||||
|
bool initialized = false;
|
||||||
|
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
|
||||||
|
EXPECT_EQ(initialized, true);
|
||||||
|
TF_DeleteTensor(t);
|
||||||
|
TFE_DeleteTensorHandle(is_initialized[0]);
|
||||||
|
TFE_DeleteOp(op);
|
||||||
|
delete status;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestFunctionWithPackedInput(const bool remote) {
|
||||||
|
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||||
|
|
||||||
|
// This server def has the task index set to 0.
|
||||||
|
string serialized = server_def.SerializeAsString();
|
||||||
|
|
||||||
|
server_def.set_task_index(1);
|
||||||
|
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server1->Start().ok());
|
||||||
|
|
||||||
|
server_def.set_task_index(2);
|
||||||
|
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server2->Start().ok());
|
||||||
|
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
|
||||||
|
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
|
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||||
|
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||||
|
|
||||||
|
// Create one variable per task.
|
||||||
|
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
|
||||||
|
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
|
||||||
|
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
|
||||||
|
|
||||||
|
// Add a sync point in order to make sure that variables have been initialized
|
||||||
|
// before the function execution starts.
|
||||||
|
// TODO(b/155789951): Remove once b/155789951 is fixed.
|
||||||
|
VarIsInitialized(ctx, h1);
|
||||||
|
VarIsInitialized(ctx, h2);
|
||||||
|
|
||||||
|
// Pack 3 variable handles into one TFE_TensorHandle.
|
||||||
|
int num_replicas = 3;
|
||||||
|
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
|
||||||
|
TFE_TensorHandle* packed_handle =
|
||||||
|
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
|
||||||
|
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
|
||||||
|
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
|
||||||
|
|
||||||
|
const string composite_device_name =
|
||||||
|
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
|
||||||
|
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
|
||||||
|
composite_device_name);
|
||||||
|
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
|
||||||
|
composite_device_name);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
// Register and run a function which returns the sum of 3 variables.
|
||||||
|
const string function_def = AddVariablesFunction();
|
||||||
|
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||||
|
status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(func, packed_handle, status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
if (remote) {
|
||||||
|
TFE_OpSetDevice(func, task1_name, status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(func, &retvals[0], &num_retvals, status);
|
||||||
|
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
ASSERT_EQ(1, num_retvals);
|
||||||
|
TFE_DeleteOp(func);
|
||||||
|
TFE_DeleteTensorHandle(packed_handle);
|
||||||
|
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
TFE_DeleteTensorHandle(retvals[0]);
|
||||||
|
float sum = 0;
|
||||||
|
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
|
||||||
|
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
|
||||||
|
TF_DeleteTensor(t);
|
||||||
|
EXPECT_EQ(sum, 6.0);
|
||||||
|
|
||||||
|
TFE_DeleteTensorHandle(h0);
|
||||||
|
TFE_DeleteTensorHandle(h1);
|
||||||
|
TFE_DeleteTensorHandle(h2);
|
||||||
|
|
||||||
|
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||||
|
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
TFE_DeleteExecutor(executor);
|
||||||
|
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
|
||||||
|
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||||
|
worker_server1.release();
|
||||||
|
worker_server2.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TestLocalFunctionWithPackedInput) {
|
||||||
|
TestFunctionWithPackedInput(/*remote=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
|
||||||
|
TestFunctionWithPackedInput(/*remote=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
|
string VariableAddFunction() {
|
||||||
|
tensorflow::FunctionDef def;
|
||||||
|
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||||
|
" signature {"
|
||||||
|
" name: 'VariableAddFunction'"
|
||||||
|
" input_arg {"
|
||||||
|
" name: 'var0'"
|
||||||
|
" type: DT_RESOURCE"
|
||||||
|
" }"
|
||||||
|
" output_arg {"
|
||||||
|
" name: 'var0_value'"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'read0'"
|
||||||
|
" op: 'ReadVariableOp'"
|
||||||
|
" input: 'var0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'dtype'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'add'"
|
||||||
|
" op: 'Add'"
|
||||||
|
" input: 'read0:value:0'"
|
||||||
|
" input: 'read0:value:0'"
|
||||||
|
" device: '/job:localhost/task:1/device:CPU:0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'T'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" node_def {"
|
||||||
|
" name: 'identity'"
|
||||||
|
" op: 'Identity'"
|
||||||
|
" input: 'add:z:0'"
|
||||||
|
" device: '/job:localhost/task:0/device:CPU:0'"
|
||||||
|
" attr {"
|
||||||
|
" key: 'T'"
|
||||||
|
" value {"
|
||||||
|
" type: DT_FLOAT"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" }"
|
||||||
|
" ret {"
|
||||||
|
" key: 'var0_value'"
|
||||||
|
" value: 'identity:output:0'"
|
||||||
|
" }",
|
||||||
|
&def));
|
||||||
|
return def.SerializeAsString();
|
||||||
|
}
|
||||||
|
|
||||||
|
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
|
||||||
|
public:
|
||||||
|
FunctionErrorInjectionPass(string error_node, string error_device)
|
||||||
|
: error_node_(error_node), error_device_(error_device) {}
|
||||||
|
tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
|
||||||
|
const tensorflow::ConfigProto& config_proto,
|
||||||
|
std::unique_ptr<tensorflow::Graph>* graph,
|
||||||
|
tensorflow::FunctionLibraryDefinition* flib_def,
|
||||||
|
std::vector<std::string>* control_ret_node_names,
|
||||||
|
bool* control_rets_updated) override {
|
||||||
|
// Inject failure to function instantiation if finding a node that contains
|
||||||
|
// the given node name (error_node_) and requested device (error_device_).
|
||||||
|
for (const auto node : graph->get()->nodes()) {
|
||||||
|
if (node->name().find(error_node_) != string::npos &&
|
||||||
|
node->requested_device() == error_device_) {
|
||||||
|
return tensorflow::errors::Internal("Injected graph pass error.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tensorflow::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const string error_node_;
|
||||||
|
const string error_device_;
|
||||||
|
};
|
||||||
|
|
||||||
|
void TestDistributedFunctionCancellation(bool inject_error) {
|
||||||
|
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||||
|
// This server def has the task index set to 0.
|
||||||
|
string serialized = server_def.SerializeAsString();
|
||||||
|
|
||||||
|
server_def.set_task_index(1);
|
||||||
|
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server1->Start().ok());
|
||||||
|
server_def.set_task_index(2);
|
||||||
|
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server2->Start().ok());
|
||||||
|
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||||
|
|
||||||
|
if (inject_error) {
|
||||||
|
// Inject a function optimization pass failure when it sees the 'read0' op
|
||||||
|
// having a requested device `dev2_name`. During execution:
|
||||||
|
// * task:0 processes the main function `VariableAddFunction` and places
|
||||||
|
// the read0 op on task:2
|
||||||
|
// * task:0 partitions the main function with a subgraph containing read0
|
||||||
|
// sent to task:2
|
||||||
|
// * task:2 graph pass reports an error when it sees read0 with dev2_name
|
||||||
|
tensorflow::function_optimization_registration::
|
||||||
|
FunctionOptimizationPassRegistration register_test_pass(
|
||||||
|
std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
|
||||||
|
EXPECT_NE(var_handle, nullptr);
|
||||||
|
|
||||||
|
const string function_def = VariableAddFunction();
|
||||||
|
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||||
|
status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
|
||||||
|
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
TFE_OpAddInput(func, var_handle, status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(func, &retvals[0], &num_retvals, status);
|
||||||
|
|
||||||
|
if (inject_error) {
|
||||||
|
ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
} else {
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
ASSERT_EQ(1, num_retvals);
|
||||||
|
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteTensorHandle(retvals[0]);
|
||||||
|
float sum = 0;
|
||||||
|
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
|
||||||
|
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
|
||||||
|
TF_DeleteTensor(t);
|
||||||
|
ASSERT_EQ(sum, 4.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_DeleteOp(func);
|
||||||
|
TFE_DeleteTensorHandle(var_handle);
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
|
||||||
|
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||||
|
worker_server1.release();
|
||||||
|
worker_server2.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, DistributedFunctionNoError) {
|
||||||
|
TestDistributedFunctionCancellation(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, DistributedFunctionCancelledOnError) {
|
||||||
|
TestDistributedFunctionCancellation(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
||||||
|
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||||
|
|
||||||
|
// This server def has the task index set to 0.
|
||||||
|
string serialized = server_def.SerializeAsString();
|
||||||
|
|
||||||
|
server_def.set_task_index(1);
|
||||||
|
|
||||||
|
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
||||||
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||||
|
server_def, tensorflow::Env::Default(), &worker_server)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server->Start().ok());
|
||||||
|
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||||
|
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
|
||||||
|
TFE_DEVICE_PLACEMENT_EXPLICIT);
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
// Use large matrices so that RPCs don't return before we get a chance
|
||||||
|
// to call TFE_DeleteContext.
|
||||||
|
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
|
||||||
|
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
|
||||||
|
const char remote_device_name[] =
|
||||||
|
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||||
|
auto* h0_task1 =
|
||||||
|
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
auto* h1_task1 =
|
||||||
|
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
|
||||||
|
TFE_OpSetDevice(matmul, remote_device_name, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
TFE_TensorHandle* retvals[1];
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
|
||||||
|
TFE_DeleteTensorHandle(h0_task0);
|
||||||
|
TFE_DeleteTensorHandle(h1_task0);
|
||||||
|
TFE_DeleteTensorHandle(h0_task1);
|
||||||
|
TFE_DeleteTensorHandle(h1_task1);
|
||||||
|
TFE_DeleteTensorHandle(retvals[0]);
|
||||||
|
|
||||||
|
TFE_DeleteOp(matmul);
|
||||||
|
|
||||||
|
TFE_DeleteContext(ctx);
|
||||||
|
|
||||||
|
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||||
|
worker_server.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
|
||||||
|
TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
|
||||||
|
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
|
||||||
|
}
|
||||||
|
} // namespace
|
@ -35,26 +35,6 @@ namespace {
|
|||||||
|
|
||||||
using ::tensorflow::string;
|
using ::tensorflow::string;
|
||||||
|
|
||||||
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
|
|
||||||
tensorflow::ServerDef server_def;
|
|
||||||
server_def.set_protocol("grpc");
|
|
||||||
server_def.set_job_name(job_name);
|
|
||||||
server_def.set_task_index(0);
|
|
||||||
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
|
||||||
tensorflow::JobDef* job_def = cluster_def->add_job();
|
|
||||||
job_def->set_name(job_name);
|
|
||||||
for (int i = 0; i < num_tasks; i++) {
|
|
||||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
|
||||||
job_def->mutable_tasks()->insert(
|
|
||||||
{i, tensorflow::strings::StrCat("localhost:", port)});
|
|
||||||
}
|
|
||||||
return server_def;
|
|
||||||
}
|
|
||||||
|
|
||||||
tensorflow::ServerDef GetServerDef(int num_tasks) {
|
|
||||||
return GetServerDef("localhost", num_tasks);
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestRemoteExecute(bool async) {
|
void TestRemoteExecute(bool async) {
|
||||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||||
|
|
||||||
@ -356,472 +336,4 @@ TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
|
|||||||
/*heavy_load_on_streaming_rpc=*/true);
|
/*heavy_load_on_streaming_rpc=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the values of three variables on three different tasks.
|
|
||||||
string AddVariablesFunction() {
|
|
||||||
tensorflow::FunctionDef def;
|
|
||||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
|
||||||
" signature {"
|
|
||||||
" name: 'AddVariablesFunction'"
|
|
||||||
" input_arg {"
|
|
||||||
" name: 'var'"
|
|
||||||
" type: DT_RESOURCE"
|
|
||||||
" }"
|
|
||||||
" output_arg {"
|
|
||||||
" name: 'sum'"
|
|
||||||
" type: DT_FLOAT"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" node_def {"
|
|
||||||
" name: 'read0'"
|
|
||||||
" op: 'ReadVariableOp'"
|
|
||||||
" input: 'var'"
|
|
||||||
" device: '/job:localhost/replica:0/task:0/device:CPU:0'"
|
|
||||||
" attr {"
|
|
||||||
" key: 'dtype'"
|
|
||||||
" value {"
|
|
||||||
" type: DT_FLOAT"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" node_def {"
|
|
||||||
" name: 'read1'"
|
|
||||||
" op: 'ReadVariableOp'"
|
|
||||||
" input: 'var'"
|
|
||||||
" device: '/job:localhost/replica:0/task:1/device:CPU:0'"
|
|
||||||
" attr {"
|
|
||||||
" key: 'dtype'"
|
|
||||||
" value {"
|
|
||||||
" type: DT_FLOAT"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" node_def {"
|
|
||||||
" name: 'read2'"
|
|
||||||
" op: 'ReadVariableOp'"
|
|
||||||
" input: 'var'"
|
|
||||||
" device: '/job:localhost/replica:0/task:2/device:CPU:0'"
|
|
||||||
" attr {"
|
|
||||||
" key: 'dtype'"
|
|
||||||
" value {"
|
|
||||||
" type: DT_FLOAT"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" node_def {"
|
|
||||||
" name: 'add1'"
|
|
||||||
" op: 'Add'"
|
|
||||||
" input: 'read0:value:0'"
|
|
||||||
" input: 'read1:value:0'"
|
|
||||||
" attr {"
|
|
||||||
" key: 'T'"
|
|
||||||
" value {"
|
|
||||||
" type: DT_FLOAT"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" node_def {"
|
|
||||||
" name: 'add2'"
|
|
||||||
" op: 'Add'"
|
|
||||||
" input: 'add1:z:0'"
|
|
||||||
" input: 'read2:value:0'"
|
|
||||||
" attr {"
|
|
||||||
" key: 'T'"
|
|
||||||
" value {"
|
|
||||||
" type: DT_FLOAT"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" ret {"
|
|
||||||
" key: 'sum'"
|
|
||||||
" value: 'add2:z:0'"
|
|
||||||
" }",
|
|
||||||
&def));
|
|
||||||
return def.SerializeAsString();
|
|
||||||
}
|
|
||||||
|
|
||||||
void VarIsInitialized(TFE_Context* ctx, TFE_TensorHandle* var_handle) {
|
|
||||||
TF_Status* status = TF_NewStatus();
|
|
||||||
TFE_Op* op = TFE_NewOp(ctx, "VarIsInitializedOp", status);
|
|
||||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
|
||||||
TFE_OpAddInput(op, var_handle, status);
|
|
||||||
TFE_TensorHandle* is_initialized[1] = {nullptr};
|
|
||||||
int num_retvals = 1;
|
|
||||||
TFE_Execute(op, &is_initialized[0], &num_retvals, status);
|
|
||||||
CHECK_EQ(1, num_retvals);
|
|
||||||
TF_Tensor* t = TFE_TensorHandleResolve(is_initialized[0], status);
|
|
||||||
bool initialized = false;
|
|
||||||
memcpy(&initialized, TF_TensorData(t), TF_TensorByteSize(t));
|
|
||||||
EXPECT_EQ(initialized, true);
|
|
||||||
TF_DeleteTensor(t);
|
|
||||||
TFE_DeleteTensorHandle(is_initialized[0]);
|
|
||||||
TFE_DeleteOp(op);
|
|
||||||
delete status;
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestFunctionWithPackedInput(const bool remote) {
|
|
||||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
|
||||||
|
|
||||||
// This server def has the task index set to 0.
|
|
||||||
string serialized = server_def.SerializeAsString();
|
|
||||||
|
|
||||||
server_def.set_task_index(1);
|
|
||||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
|
||||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
|
||||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
|
||||||
.ok());
|
|
||||||
ASSERT_TRUE(worker_server1->Start().ok());
|
|
||||||
|
|
||||||
server_def.set_task_index(2);
|
|
||||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
|
||||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
|
||||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
|
||||||
.ok());
|
|
||||||
ASSERT_TRUE(worker_server2->Start().ok());
|
|
||||||
|
|
||||||
TF_Status* status = TF_NewStatus();
|
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
|
||||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(/*enable=*/true));
|
|
||||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
|
||||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
|
||||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
|
||||||
TFE_DeleteContextOptions(opts);
|
|
||||||
|
|
||||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
|
||||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
|
||||||
|
|
||||||
const char task0_name[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
|
||||||
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
|
||||||
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
|
||||||
|
|
||||||
// Create one variable per task.
|
|
||||||
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
|
|
||||||
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
|
|
||||||
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
|
|
||||||
|
|
||||||
// Add a sync point in order to make sure that variables have been initialized
|
|
||||||
// before the function execution starts.
|
|
||||||
// TODO(b/155789951): Remove once b/155789951 is fixed.
|
|
||||||
VarIsInitialized(ctx, h1);
|
|
||||||
VarIsInitialized(ctx, h2);
|
|
||||||
|
|
||||||
// Pack 3 variable handles into one TFE_TensorHandle.
|
|
||||||
int num_replicas = 3;
|
|
||||||
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
|
|
||||||
TFE_TensorHandle* packed_handle =
|
|
||||||
TFE_CreatePackedTensorHandle(ctx, handles.data(), &num_replicas, status);
|
|
||||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
|
||||||
EXPECT_EQ(TFE_TensorHandleDataType(packed_handle), TF_RESOURCE);
|
|
||||||
EXPECT_EQ(TFE_TensorHandleNumDims(packed_handle, status), 0);
|
|
||||||
EXPECT_EQ(TFE_TensorHandleNumElements(packed_handle, status), 1);
|
|
||||||
|
|
||||||
const string composite_device_name =
|
|
||||||
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
|
|
||||||
EXPECT_EQ(TFE_TensorHandleDeviceName(packed_handle, status),
|
|
||||||
composite_device_name);
|
|
||||||
EXPECT_EQ(TFE_TensorHandleBackingDeviceName(packed_handle, status),
|
|
||||||
composite_device_name);
|
|
||||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
|
||||||
|
|
||||||
// Register and run a function which returns the sum of 3 variables.
|
|
||||||
const string function_def = AddVariablesFunction();
|
|
||||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
|
||||||
status);
|
|
||||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
|
||||||
|
|
||||||
TFE_Op* func = TFE_NewOp(ctx, "AddVariablesFunction", status);
|
|
||||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
|
||||||
TFE_OpAddInput(func, packed_handle, status);
|
|
||||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
|
||||||
if (remote) {
|
|
||||||
TFE_OpSetDevice(func, task1_name, status);
|
|
||||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
|
||||||
}
|
|
||||||
|
|
||||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
|
||||||
int num_retvals = 1;
|
|
||||||
TFE_Execute(func, &retvals[0], &num_retvals, status);
|
|
||||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
|
||||||
ASSERT_EQ(1, num_retvals);
|
|
||||||
TFE_DeleteOp(func);
|
|
||||||
TFE_DeleteTensorHandle(packed_handle);
|
|
||||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
|
||||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
|
||||||
TFE_DeleteTensorHandle(retvals[0]);
|
|
||||||
float sum = 0;
|
|
||||||
EXPECT_EQ(sizeof(sum), TF_TensorByteSize(t));
|
|
||||||
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
|
|
||||||
TF_DeleteTensor(t);
|
|
||||||
EXPECT_EQ(sum, 6.0);
|
|
||||||
|
|
||||||
TFE_DeleteTensorHandle(h0);
|
|
||||||
TFE_DeleteTensorHandle(h1);
|
|
||||||
TFE_DeleteTensorHandle(h2);
|
|
||||||
|
|
||||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
|
||||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
|
||||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
|
||||||
TFE_DeleteExecutor(executor);
|
|
||||||
TFE_ContextRemoveFunction(ctx, "AddVariablesFunction", status);
|
|
||||||
TFE_DeleteContext(ctx);
|
|
||||||
|
|
||||||
TF_DeleteStatus(status);
|
|
||||||
|
|
||||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
|
||||||
worker_server1.release();
|
|
||||||
worker_server2.release();
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(CAPI, TestLocalFunctionWithPackedInput) {
|
|
||||||
TestFunctionWithPackedInput(/*remote=*/false);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(CAPI, TestRemoteFunctionWithPackedInput) {
|
|
||||||
TestFunctionWithPackedInput(/*remote=*/true);
|
|
||||||
}
|
|
||||||
|
|
||||||
string VariableAddFunction() {
|
|
||||||
tensorflow::FunctionDef def;
|
|
||||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
|
||||||
" signature {"
|
|
||||||
" name: 'VariableAddFunction'"
|
|
||||||
" input_arg {"
|
|
||||||
" name: 'var0'"
|
|
||||||
" type: DT_RESOURCE"
|
|
||||||
" }"
|
|
||||||
" output_arg {"
|
|
||||||
" name: 'var0_value'"
|
|
||||||
" type: DT_FLOAT"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" node_def {"
|
|
||||||
" name: 'read0'"
|
|
||||||
" op: 'ReadVariableOp'"
|
|
||||||
" input: 'var0'"
|
|
||||||
" attr {"
|
|
||||||
" key: 'dtype'"
|
|
||||||
" value {"
|
|
||||||
" type: DT_FLOAT"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" node_def {"
|
|
||||||
" name: 'add'"
|
|
||||||
" op: 'Add'"
|
|
||||||
" input: 'read0:value:0'"
|
|
||||||
" input: 'read0:value:0'"
|
|
||||||
" device: '/job:localhost/task:1/device:CPU:0'"
|
|
||||||
" attr {"
|
|
||||||
" key: 'T'"
|
|
||||||
" value {"
|
|
||||||
" type: DT_FLOAT"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" node_def {"
|
|
||||||
" name: 'identity'"
|
|
||||||
" op: 'Identity'"
|
|
||||||
" input: 'add:z:0'"
|
|
||||||
" device: '/job:localhost/task:0/device:CPU:0'"
|
|
||||||
" attr {"
|
|
||||||
" key: 'T'"
|
|
||||||
" value {"
|
|
||||||
" type: DT_FLOAT"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" }"
|
|
||||||
" ret {"
|
|
||||||
" key: 'var0_value'"
|
|
||||||
" value: 'identity:output:0'"
|
|
||||||
" }",
|
|
||||||
&def));
|
|
||||||
return def.SerializeAsString();
|
|
||||||
}
|
|
||||||
|
|
||||||
class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass {
|
|
||||||
public:
|
|
||||||
FunctionErrorInjectionPass(string error_node, string error_device)
|
|
||||||
: error_node_(error_node), error_device_(error_device) {}
|
|
||||||
tensorflow::Status Run(const tensorflow::DeviceSet& device_set,
|
|
||||||
const tensorflow::ConfigProto& config_proto,
|
|
||||||
std::unique_ptr<tensorflow::Graph>* graph,
|
|
||||||
tensorflow::FunctionLibraryDefinition* flib_def,
|
|
||||||
std::vector<std::string>* control_ret_node_names,
|
|
||||||
bool* control_rets_updated) override {
|
|
||||||
// Inject failure to function instantiation if finding a node that contains
|
|
||||||
// the given node name (error_node_) and requested device (error_device_).
|
|
||||||
for (const auto node : graph->get()->nodes()) {
|
|
||||||
if (node->name().find(error_node_) != string::npos &&
|
|
||||||
node->requested_device() == error_device_) {
|
|
||||||
return tensorflow::errors::Internal("Injected graph pass error.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return tensorflow::Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
const string error_node_;
|
|
||||||
const string error_device_;
|
|
||||||
};
|
|
||||||
|
|
||||||
void TestDistributedFunctionCancellation(bool inject_error) {
|
|
||||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
|
||||||
// This server def has the task index set to 0.
|
|
||||||
string serialized = server_def.SerializeAsString();
|
|
||||||
|
|
||||||
server_def.set_task_index(1);
|
|
||||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
|
||||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
|
||||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
|
||||||
.ok());
|
|
||||||
ASSERT_TRUE(worker_server1->Start().ok());
|
|
||||||
server_def.set_task_index(2);
|
|
||||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
|
||||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
|
||||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
|
||||||
.ok());
|
|
||||||
ASSERT_TRUE(worker_server2->Start().ok());
|
|
||||||
const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
|
||||||
|
|
||||||
if (inject_error) {
|
|
||||||
// Inject a function optimization pass failure when it sees the 'read0' op
|
|
||||||
// having a requested device `dev2_name`. During execution:
|
|
||||||
// * task:0 processes the main function `VariableAddFunction` and places
|
|
||||||
// the read0 op on task:2
|
|
||||||
// * task:0 partitions the main function with a subgraph containing read0
|
|
||||||
// sent to task:2
|
|
||||||
// * task:2 graph pass reports an error when it sees read0 with dev2_name
|
|
||||||
tensorflow::function_optimization_registration::
|
|
||||||
FunctionOptimizationPassRegistration register_test_pass(
|
|
||||||
std::make_unique<FunctionErrorInjectionPass>("read0", dev2_name));
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_Status* status = TF_NewStatus();
|
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
|
||||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
|
||||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
TFE_DeleteContextOptions(opts);
|
|
||||||
|
|
||||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
|
|
||||||
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
|
|
||||||
EXPECT_NE(var_handle, nullptr);
|
|
||||||
|
|
||||||
const string function_def = VariableAddFunction();
|
|
||||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
|
||||||
status);
|
|
||||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
|
||||||
|
|
||||||
TFE_Op* func = TFE_NewOp(ctx, "VariableAddFunction", status);
|
|
||||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
|
||||||
TFE_OpAddInput(func, var_handle, status);
|
|
||||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
|
||||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
|
||||||
int num_retvals = 1;
|
|
||||||
TFE_Execute(func, &retvals[0], &num_retvals, status);
|
|
||||||
|
|
||||||
if (inject_error) {
|
|
||||||
ASSERT_EQ(TF_INTERNAL, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
} else {
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
ASSERT_EQ(1, num_retvals);
|
|
||||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
TFE_DeleteTensorHandle(retvals[0]);
|
|
||||||
float sum = 0;
|
|
||||||
ASSERT_EQ(sizeof(sum), TF_TensorByteSize(t));
|
|
||||||
memcpy(&sum, TF_TensorData(t), TF_TensorByteSize(t));
|
|
||||||
TF_DeleteTensor(t);
|
|
||||||
ASSERT_EQ(sum, 4.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
TFE_DeleteOp(func);
|
|
||||||
TFE_DeleteTensorHandle(var_handle);
|
|
||||||
TFE_DeleteContext(ctx);
|
|
||||||
TF_DeleteStatus(status);
|
|
||||||
|
|
||||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
|
||||||
worker_server1.release();
|
|
||||||
worker_server2.release();
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(CAPI, DistributedFunctionNoError) {
|
|
||||||
TestDistributedFunctionCancellation(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(CAPI, DistributedFunctionCancelledOnError) {
|
|
||||||
TestDistributedFunctionCancellation(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) {
|
|
||||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
|
||||||
|
|
||||||
// This server def has the task index set to 0.
|
|
||||||
string serialized = server_def.SerializeAsString();
|
|
||||||
|
|
||||||
server_def.set_task_index(1);
|
|
||||||
|
|
||||||
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
|
||||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
|
||||||
server_def, tensorflow::Env::Default(), &worker_server)
|
|
||||||
.ok());
|
|
||||||
ASSERT_TRUE(worker_server->Start().ok());
|
|
||||||
|
|
||||||
TF_Status* status = TF_NewStatus();
|
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
|
||||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
|
||||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
|
|
||||||
TFE_DEVICE_PLACEMENT_EXPLICIT);
|
|
||||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
TFE_DeleteContextOptions(opts);
|
|
||||||
|
|
||||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
|
|
||||||
// Use large matrices so that RPCs don't return before we get a chance
|
|
||||||
// to call TFE_DeleteContext.
|
|
||||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(ctx);
|
|
||||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(ctx);
|
|
||||||
const char remote_device_name[] =
|
|
||||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
|
||||||
auto* h0_task1 =
|
|
||||||
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
auto* h1_task1 =
|
|
||||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
|
|
||||||
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
|
|
||||||
TFE_OpSetDevice(matmul, remote_device_name, status);
|
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
|
|
||||||
TFE_TensorHandle* retvals[1];
|
|
||||||
int num_retvals = 1;
|
|
||||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
||||||
TF_DeleteStatus(status);
|
|
||||||
|
|
||||||
TFE_DeleteTensorHandle(h0_task0);
|
|
||||||
TFE_DeleteTensorHandle(h1_task0);
|
|
||||||
TFE_DeleteTensorHandle(h0_task1);
|
|
||||||
TFE_DeleteTensorHandle(h1_task1);
|
|
||||||
TFE_DeleteTensorHandle(retvals[0]);
|
|
||||||
|
|
||||||
TFE_DeleteOp(matmul);
|
|
||||||
|
|
||||||
TFE_DeleteContext(ctx);
|
|
||||||
|
|
||||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
|
||||||
worker_server.release();
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) {
|
|
||||||
TestRemoteExecuteDeleteContextWithOutstandingRPC(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) {
|
|
||||||
TestRemoteExecuteDeleteContextWithOutstandingRPC(true);
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -18,7 +18,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/strcat.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/protobuf/cluster.pb.h"
|
||||||
|
|
||||||
using tensorflow::string;
|
using tensorflow::string;
|
||||||
|
|
||||||
@ -296,3 +298,23 @@ bool GetDeviceName(TFE_Context* ctx, string* device_name,
|
|||||||
TF_DeleteDeviceList(devices);
|
TF_DeleteDeviceList(devices);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
|
||||||
|
tensorflow::ServerDef server_def;
|
||||||
|
server_def.set_protocol("grpc");
|
||||||
|
server_def.set_job_name(job_name);
|
||||||
|
server_def.set_task_index(0);
|
||||||
|
tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster();
|
||||||
|
tensorflow::JobDef* job_def = cluster_def->add_job();
|
||||||
|
job_def->set_name(job_name);
|
||||||
|
for (int i = 0; i < num_tasks; i++) {
|
||||||
|
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||||
|
job_def->mutable_tasks()->insert(
|
||||||
|
{i, tensorflow::strings::StrCat("localhost:", port)});
|
||||||
|
}
|
||||||
|
return server_def;
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorflow::ServerDef GetServerDef(int num_tasks) {
|
||||||
|
return GetServerDef("localhost", num_tasks);
|
||||||
|
}
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||||
|
|
||||||
// Return a tensor handle containing a float scalar
|
// Return a tensor handle containing a float scalar
|
||||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value);
|
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value);
|
||||||
@ -72,4 +73,11 @@ TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input,
|
|||||||
bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name,
|
bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name,
|
||||||
const char* device_type);
|
const char* device_type);
|
||||||
|
|
||||||
|
// Create a ServerDef with the given `job_name` and add `num_tasks` tasks in it.
|
||||||
|
tensorflow::ServerDef GetServerDef(const tensorflow::string& job_name,
|
||||||
|
int num_tasks);
|
||||||
|
|
||||||
|
// Create a ServerDef with job name "localhost" and add `num_tasks` tasks in it.
|
||||||
|
tensorflow::ServerDef GetServerDef(int num_tasks);
|
||||||
|
|
||||||
#endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
|
#endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
|
||||||
|
@ -181,7 +181,7 @@ class GrpcEagerClient : public EagerClient {
|
|||||||
StatusCallback done) override {
|
StatusCallback done) override {
|
||||||
StatusCallback done_wrapped = callback_wrapper(std::move(done));
|
StatusCallback done_wrapped = callback_wrapper(std::move(done));
|
||||||
if (EnableStreaming()) {
|
if (EnableStreaming()) {
|
||||||
tf_shared_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
auto it = enqueue_dispatchers_.find(request->context_id());
|
auto it = enqueue_dispatchers_.find(request->context_id());
|
||||||
if (it == enqueue_dispatchers_.end()) {
|
if (it == enqueue_dispatchers_.end()) {
|
||||||
auto it_and_bool = enqueue_dispatchers_.emplace(
|
auto it_and_bool = enqueue_dispatchers_.emplace(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user