Separate function tests from c_api_remote_test
PiperOrigin-RevId: 326133427 Change-Id: Ia611cd513ef84d393f0f32ba9410c037f339fe19
This commit is contained in:
parent
85cfe48f62
commit
497d94e9ac
@ -508,6 +508,27 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api_remote_test_util",
|
||||
testonly = 1,
|
||||
srcs = ["c_api_remote_test_util.cc"],
|
||||
hdrs = ["c_api_remote_test_util.h"],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_internal",
|
||||
":c_api_test_util",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "c_api_remote_test",
|
||||
size = "small",
|
||||
@ -524,6 +545,7 @@ tf_cuda_cc_test(
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_internal",
|
||||
":c_api_remote_test_util",
|
||||
":c_api_test_util",
|
||||
":tfe_tensorhandle_internal",
|
||||
"//tensorflow/c:c_test_util",
|
||||
@ -540,6 +562,25 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "c_api_remote_function_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"c_api_remote_function_test.cc",
|
||||
],
|
||||
# TODO(b/136478427): Figure out how to correctly shut the server down
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
tags = [
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
":c_api_remote_test_util",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "c_api_distributed_test",
|
||||
size = "small",
|
||||
|
56
tensorflow/c/eager/c_api_remote_function_test.cc
Normal file
56
tensorflow/c/eager/c_api_remote_function_test.cc
Normal file
@ -0,0 +1,56 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/c_api_remote_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
void TestRemoteExecuteSilentCopiesFunc(bool async, bool remote,
|
||||
bool heavy_load_on_streaming_rpc,
|
||||
bool remote_func_outputs = false) {
|
||||
return TestRemoteExecuteSilentCopies(async, remote, /*func=*/true,
|
||||
heavy_load_on_streaming_rpc,
|
||||
remote_func_outputs);
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
// TODO(b/162618595): Enable this test once we remove the check of remote
|
||||
// outputs in ProcessFunctionLibraryRuntime.
|
||||
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false,
|
||||
/*remote_func_outputs=*/true);
|
||||
}
|
||||
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false,
|
||||
/*remote_func_outputs=*/true);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
|
||||
// A remote input may be not ready when we start running a function. Test that
|
||||
// the function execution should wait until the remote input is ready.
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/true);
|
||||
}
|
||||
|
||||
} // namespace
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_remote_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
@ -116,242 +117,24 @@ void TestRemoteExecute(bool async) {
|
||||
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
||||
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
||||
|
||||
string MatMulFunction(const string& matmul_device) {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
absl::StrCat(" signature {"
|
||||
" name: 'MatMulFunction'"
|
||||
" input_arg {"
|
||||
" name: 'a'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" input_arg {"
|
||||
" name: 'b'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'm'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'matmul'"
|
||||
" op: 'MatMul'"
|
||||
" input: 'a'"
|
||||
" input: 'b'"
|
||||
" device: '",
|
||||
matmul_device, "'",
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'm'"
|
||||
" value: 'matmul:product'"
|
||||
" }"),
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one
|
||||
// which creates a remote remote input, to simulate a scenario that the remote
|
||||
// input is not ready when we start running an op or a function.
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
|
||||
bool heavy_load_on_streaming_rpc,
|
||||
bool remote_func_outputs = false) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server1->Start().ok());
|
||||
|
||||
server_def.set_task_index(2);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server2->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
|
||||
std::vector<TFE_TensorHandle*> handles_task0;
|
||||
if (heavy_load_on_streaming_rpc) {
|
||||
// Send 50 tensor copy requests to simulate that there have been some RPC
|
||||
// requests been enqueued.
|
||||
for (int i = 0; i < 50; ++i) {
|
||||
handles_task0.push_back(TestMatrixTensorHandle(ctx));
|
||||
}
|
||||
}
|
||||
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
std::vector<TFE_TensorHandle*> handles_task2;
|
||||
for (auto* h_task0 : handles_task0) {
|
||||
handles_task2.push_back(
|
||||
TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status));
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
}
|
||||
|
||||
auto* h1_task2 =
|
||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* matmul = nullptr;
|
||||
if (func) {
|
||||
const string matmul_device = remote_func_outputs ? task2_name : "";
|
||||
string function_def = MatMulFunction(matmul_device);
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, h0_task0, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, h1_task2, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
} else {
|
||||
// Handles are on task0 (local), and task2, but op is on task1.
|
||||
matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||
}
|
||||
if (remote) {
|
||||
TFE_OpSetDevice(matmul, task1_name, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
} else if (!async) {
|
||||
// Set the local device to CPU to easily validate mirroring
|
||||
string cpu_device_name;
|
||||
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
|
||||
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
auto remote_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// TODO(gjn): Add support for waiting on async local mirrors
|
||||
if (!remote && !async && !remote_func_outputs) {
|
||||
auto remote_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
|
||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(retval_task0);
|
||||
float product[4] = {0};
|
||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(7, product[0]);
|
||||
EXPECT_EQ(10, product[1]);
|
||||
EXPECT_EQ(15, product[2]);
|
||||
EXPECT_EQ(22, product[3]);
|
||||
|
||||
TFE_DeleteTensorHandle(h0_task0);
|
||||
TFE_DeleteTensorHandle(h1_task0);
|
||||
TFE_DeleteTensorHandle(h1_task2);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
for (auto* h : handles_task0) {
|
||||
TFE_DeleteTensorHandle(h);
|
||||
}
|
||||
for (auto* h : handles_task2) {
|
||||
TFE_DeleteTensorHandle(h);
|
||||
}
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
if (func) {
|
||||
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
|
||||
}
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server1.release();
|
||||
worker_server2.release();
|
||||
void TestRemoteExecuteSilentCopiesOp(bool async, bool remote,
|
||||
bool remote_func_outputs = false) {
|
||||
return TestRemoteExecuteSilentCopies(async, remote, /*func=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false,
|
||||
remote_func_outputs);
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteSilentCopies) {
|
||||
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/true,
|
||||
/*func=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/true);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/true);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
|
||||
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/false,
|
||||
/*func=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
|
||||
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false,
|
||||
/*func=*/false,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
// TODO(b/162618595): Enable this test once we remove the check of remote
|
||||
// outputs in ProcessFunctionLibraryRuntime.
|
||||
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
|
||||
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/false,
|
||||
/*func=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false,
|
||||
/*remote_func_outputs=*/true);
|
||||
}
|
||||
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
|
||||
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false,
|
||||
/*remote_func_outputs=*/true);
|
||||
}
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
|
||||
// A remote input may be not ready when we start running a function. Test that
|
||||
// the function execution should wait until the remote input is ready.
|
||||
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/true);
|
||||
TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/false);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
215
tensorflow/c/eager/c_api_remote_test_util.cc
Normal file
215
tensorflow/c/eager/c_api_remote_test_util.cc
Normal file
@ -0,0 +1,215 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/c_api_remote_test_util.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
|
||||
using ::tensorflow::string;
|
||||
|
||||
string MatMulFunction(const string& matmul_device) {
|
||||
tensorflow::FunctionDef def;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
absl::StrCat(" signature {"
|
||||
" name: 'MatMulFunction'"
|
||||
" input_arg {"
|
||||
" name: 'a'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" input_arg {"
|
||||
" name: 'b'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" output_arg {"
|
||||
" name: 'm'"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" node_def {"
|
||||
" name: 'matmul'"
|
||||
" op: 'MatMul'"
|
||||
" input: 'a'"
|
||||
" input: 'b'"
|
||||
" device: '",
|
||||
matmul_device, "'",
|
||||
" attr {"
|
||||
" key: 'T'"
|
||||
" value {"
|
||||
" type: DT_FLOAT"
|
||||
" }"
|
||||
" }"
|
||||
" }"
|
||||
" ret {"
|
||||
" key: 'm'"
|
||||
" value: 'matmul:product'"
|
||||
" }"),
|
||||
&def));
|
||||
return def.SerializeAsString();
|
||||
}
|
||||
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
|
||||
bool heavy_load_on_streaming_rpc,
|
||||
bool remote_func_outputs) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(3);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server1)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server1->Start().ok());
|
||||
|
||||
server_def.set_task_index(2);
|
||||
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
|
||||
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
||||
server_def, tensorflow::Env::Default(), &worker_server2)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server2->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
|
||||
std::vector<TFE_TensorHandle*> handles_task0;
|
||||
if (heavy_load_on_streaming_rpc) {
|
||||
// Send 50 tensor copy requests to simulate that there have been some RPC
|
||||
// requests been enqueued.
|
||||
for (int i = 0; i < 50; ++i) {
|
||||
handles_task0.push_back(TestMatrixTensorHandle(ctx));
|
||||
}
|
||||
}
|
||||
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||
|
||||
std::vector<TFE_TensorHandle*> handles_task2;
|
||||
for (auto* h_task0 : handles_task0) {
|
||||
handles_task2.push_back(
|
||||
TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status));
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
}
|
||||
|
||||
auto* h1_task2 =
|
||||
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TFE_Op* matmul = nullptr;
|
||||
if (func) {
|
||||
const string matmul_device = remote_func_outputs ? task2_name : "";
|
||||
string function_def = MatMulFunction(matmul_device);
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
status);
|
||||
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, h0_task0, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_OpAddInput(matmul, h1_task2, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
} else {
|
||||
// Handles are on task0 (local), and task2, but op is on task1.
|
||||
matmul = MatMulOp(ctx, h0_task0, h1_task2);
|
||||
}
|
||||
if (remote) {
|
||||
TFE_OpSetDevice(matmul, task1_name, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
} else if (!async) {
|
||||
// Set the local device to CPU to easily validate mirroring
|
||||
string cpu_device_name;
|
||||
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
|
||||
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
auto remote_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
// TODO(gjn): Add support for waiting on async local mirrors
|
||||
if (!remote && !async && !remote_func_outputs) {
|
||||
auto remote_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
|
||||
// The input handles should never change since they have been mirrored.
|
||||
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
|
||||
}
|
||||
|
||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(retval_task0);
|
||||
float product[4] = {0};
|
||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(7, product[0]);
|
||||
EXPECT_EQ(10, product[1]);
|
||||
EXPECT_EQ(15, product[2]);
|
||||
EXPECT_EQ(22, product[3]);
|
||||
|
||||
TFE_DeleteTensorHandle(h0_task0);
|
||||
TFE_DeleteTensorHandle(h1_task0);
|
||||
TFE_DeleteTensorHandle(h1_task2);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
for (auto* h : handles_task0) {
|
||||
TFE_DeleteTensorHandle(h);
|
||||
}
|
||||
for (auto* h : handles_task2) {
|
||||
TFE_DeleteTensorHandle(h);
|
||||
}
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
if (func) {
|
||||
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
|
||||
}
|
||||
TFE_DeleteContext(ctx);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
||||
worker_server1.release();
|
||||
worker_server2.release();
|
||||
}
|
26
tensorflow/c/eager/c_api_remote_test_util.h
Normal file
26
tensorflow/c/eager/c_api_remote_test_util.h
Normal file
@ -0,0 +1,26 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
|
||||
#define TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
|
||||
|
||||
// Run a function containing a MatMul op and check its output.
|
||||
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one
|
||||
// which creates a remote remote input, to simulate a scenario that the remote
|
||||
// input is not ready when we start running an op or a function.
|
||||
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
|
||||
bool heavy_load_on_streaming_rpc,
|
||||
bool remote_func_outputs = false);
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
|
Loading…
Reference in New Issue
Block a user