141 lines
5.4 KiB
C++
141 lines
5.4 KiB
C++
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "absl/strings/str_cat.h"
|
|
#include "tensorflow/c/eager/c_api.h"
|
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
|
#include "tensorflow/c/eager/c_api_internal.h"
|
|
#include "tensorflow/c/eager/c_api_remote_test_util.h"
|
|
#include "tensorflow/c/eager/c_api_test_util.h"
|
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
|
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
|
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
|
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
|
#include "tensorflow/core/framework/function.h"
|
|
#include "tensorflow/core/graph/graph.h"
|
|
#include "tensorflow/core/platform/casts.h"
|
|
#include "tensorflow/core/platform/errors.h"
|
|
#include "tensorflow/core/platform/protobuf.h"
|
|
#include "tensorflow/core/platform/test.h"
|
|
#include "tensorflow/core/protobuf/cluster.pb.h"
|
|
#include "tensorflow/core/protobuf/config.pb.h"
|
|
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
|
|
|
namespace {
|
|
|
|
using ::tensorflow::string;
|
|
|
|
void TestRemoteExecute(bool async) {
|
|
tensorflow::ServerDef server_def = GetServerDef(2);
|
|
|
|
// This server def has the task index set to 0.
|
|
string serialized = server_def.SerializeAsString();
|
|
|
|
server_def.set_task_index(1);
|
|
|
|
std::unique_ptr<tensorflow::GrpcServer> worker_server;
|
|
ASSERT_TRUE(tensorflow::GrpcServer::Create(
|
|
server_def, tensorflow::Env::Default(), &worker_server)
|
|
.ok());
|
|
ASSERT_TRUE(worker_server->Start().ok());
|
|
|
|
TF_Status* status = TF_NewStatus();
|
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
|
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
|
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
|
|
TFE_DEVICE_PLACEMENT_EXPLICIT);
|
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
TFE_DeleteContextOptions(opts);
|
|
|
|
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
|
|
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
|
|
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
|
|
const char remote_device_name[] =
|
|
"/job:localhost/replica:0/task:1/device:CPU:0";
|
|
auto* h0_task1 =
|
|
TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
auto* h1_task1 =
|
|
TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
|
|
TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1);
|
|
TFE_OpSetDevice(matmul, remote_device_name, status);
|
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
|
|
TFE_TensorHandle* retvals[1];
|
|
int num_retvals = 1;
|
|
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
|
|
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
float product[4] = {0};
|
|
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
|
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
|
TF_DeleteTensor(t);
|
|
EXPECT_EQ(7, product[0]);
|
|
EXPECT_EQ(10, product[1]);
|
|
EXPECT_EQ(15, product[2]);
|
|
EXPECT_EQ(22, product[3]);
|
|
|
|
TFE_DeleteTensorHandle(h0_task0);
|
|
TFE_DeleteTensorHandle(h1_task0);
|
|
TFE_DeleteTensorHandle(h0_task1);
|
|
TFE_DeleteTensorHandle(h1_task1);
|
|
TFE_DeleteTensorHandle(retvals[0]);
|
|
|
|
TFE_DeleteOp(matmul);
|
|
|
|
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
|
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
TFE_DeleteExecutor(executor);
|
|
TFE_DeleteContext(ctx);
|
|
|
|
TF_DeleteStatus(status);
|
|
|
|
// TODO(b/136478427): Figure out how to correctly shut the server down.
|
|
worker_server.release();
|
|
}
|
|
|
|
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
|
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
|
|
|
void TestRemoteExecuteSilentCopiesOp(bool async, bool remote,
|
|
bool remote_func_outputs = false) {
|
|
return TestRemoteExecuteSilentCopies(async, remote, /*func=*/false,
|
|
/*heavy_load_on_streaming_rpc=*/false,
|
|
remote_func_outputs);
|
|
}
|
|
|
|
TEST(CAPI, RemoteExecuteSilentCopies) {
|
|
TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/true);
|
|
}
|
|
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
|
TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/true);
|
|
}
|
|
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
|
|
TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/false);
|
|
}
|
|
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
|
|
TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/false);
|
|
}
|
|
|
|
} // namespace
|