Minor modernizations, mostly more <memory>
PiperOrigin-RevId: 158793461
This commit is contained in:
parent
995f5f4f40
commit
9f10f60fbd
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/direct_session.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
@ -45,10 +46,10 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Session* CreateSession() {
|
||||
std::unique_ptr<Session> CreateSession() {
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||
return NewSession(options);
|
||||
return std::unique_ptr<Session>(NewSession(options));
|
||||
}
|
||||
|
||||
class DirectSessionMinusAXTest : public ::testing::Test {
|
||||
@ -87,7 +88,7 @@ class DirectSessionMinusAXTest : public ::testing::Test {
|
||||
|
||||
TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork) {
|
||||
Initialize({3, 2, -1, 0});
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def_));
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
@ -109,7 +110,7 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork) {
|
||||
|
||||
TEST_F(DirectSessionMinusAXTest, TestFeed) {
|
||||
Initialize({1, 2, 3, 4});
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
|
||||
TF_ASSERT_OK(session->Create(def_));
|
||||
@ -138,7 +139,7 @@ TEST_F(DirectSessionMinusAXTest, TestFeed) {
|
||||
|
||||
TEST_F(DirectSessionMinusAXTest, TestConcurrency) {
|
||||
Initialize({1, 2, 3, 4});
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def_));
|
||||
|
||||
@ -207,7 +208,7 @@ TEST_F(DirectSessionMinusAXTest, TestPerSessionThreads) {
|
||||
|
||||
TEST_F(DirectSessionMinusAXTest, TwoCreateCallsFails) {
|
||||
Initialize({1, 2, 3, 4});
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def_));
|
||||
|
||||
@ -217,7 +218,7 @@ TEST_F(DirectSessionMinusAXTest, TwoCreateCallsFails) {
|
||||
|
||||
TEST_F(DirectSessionMinusAXTest, ForgetToCreate) {
|
||||
Initialize({1, 2, 3, 4});
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
std::vector<Tensor> outputs;
|
||||
@ -261,7 +262,7 @@ TEST_F(DirectSessionMinusAXTest, InvalidDevice) {
|
||||
|
||||
TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts) {
|
||||
Initialize({3, 2, -1, 0});
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def_));
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
@ -313,7 +314,7 @@ TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
|
||||
@ -348,7 +349,7 @@ TEST(DirectSessionTest, MultipleFeedTest) {
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
|
||||
@ -436,7 +437,6 @@ TEST(DirectSessionTest, DarthKernel) {
|
||||
std::vector<Tensor> outputs;
|
||||
auto s = sess->Run({}, {y->name() + ":0"}, {}, &outputs);
|
||||
EXPECT_TRUE(errors::IsInternal(s));
|
||||
delete sess;
|
||||
}
|
||||
|
||||
// Have the Darth op in the graph placed on GPU, but don't run it.
|
||||
@ -500,7 +500,7 @@ TEST(DirectSessionTest, PartialRunTest) {
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
|
||||
@ -556,7 +556,7 @@ TEST(DirectSessionTest, PartialRunMissingFeed) {
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
|
||||
@ -589,7 +589,7 @@ TEST(DirectSessionTest, PartialRunMultiOutputFeed) {
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
|
||||
@ -638,7 +638,7 @@ TEST(DirectSessionTest, RunHandleTest) {
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
|
||||
@ -679,7 +679,7 @@ TEST(DirectSessionTest, CreateGraphFailsWhenAssigningAFedVar) {
|
||||
// a = b
|
||||
Node* assign = test::graph::Assign(&graph, a, b);
|
||||
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
|
||||
// The graph is invalid since a constant cannot be assigned to a constant.
|
||||
@ -757,30 +757,35 @@ TEST(DirectSessionTest, TimeoutSession) {
|
||||
)proto",
|
||||
&graph);
|
||||
|
||||
// Creates a session with operation_timeout_in_ms set to 100 milliseconds.
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||
options.config.set_operation_timeout_in_ms(100);
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(graph));
|
||||
{
|
||||
// Creates a session with operation_timeout_in_ms set to 100 milliseconds.
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||
options.config.set_operation_timeout_in_ms(100);
|
||||
|
||||
// Verifies that the error code is DEADLINE_EXCEEDED.
|
||||
Status s = session->Run({}, {}, {"fifo_queue_Dequeue"}, nullptr);
|
||||
ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code());
|
||||
TF_ASSERT_OK(session->Close());
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(graph));
|
||||
|
||||
// Creates a session with no operation_timeout_in_ms.
|
||||
session.reset(CreateSession());
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(graph));
|
||||
RunOptions run_options;
|
||||
run_options.set_timeout_in_ms(20);
|
||||
// Verifies that the error code is DEADLINE_EXCEEDED.
|
||||
Status s2 = session->Run(run_options, {}, {}, {"fifo_queue_Dequeue"}, nullptr,
|
||||
nullptr);
|
||||
ASSERT_EQ(error::DEADLINE_EXCEEDED, s2.code());
|
||||
TF_ASSERT_OK(session->Close());
|
||||
// Verifies that the error code is DEADLINE_EXCEEDED.
|
||||
Status s = session->Run({}, {}, {"fifo_queue_Dequeue"}, nullptr);
|
||||
ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code());
|
||||
TF_ASSERT_OK(session->Close());
|
||||
}
|
||||
|
||||
{
|
||||
// Creates a session with no operation_timeout_in_ms.
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(graph));
|
||||
RunOptions run_options;
|
||||
run_options.set_timeout_in_ms(20);
|
||||
// Verifies that the error code is DEADLINE_EXCEEDED.
|
||||
Status s2 = session->Run(run_options, {}, {}, {"fifo_queue_Dequeue"},
|
||||
nullptr, nullptr);
|
||||
ASSERT_EQ(error::DEADLINE_EXCEEDED, s2.code());
|
||||
TF_ASSERT_OK(session->Close());
|
||||
}
|
||||
}
|
||||
|
||||
// Accesses the cancellation manager for the step after the step has been
|
||||
@ -1090,7 +1095,7 @@ TEST(DirectSessionTest, TestDirectSessionPRunClose) {
|
||||
|
||||
test::graph::ToGraphDef(&g, &def);
|
||||
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
|
||||
@ -1194,8 +1199,8 @@ void FeedFetchBenchmarkHelper(int num_feeds, int iters) {
|
||||
GraphDef gd;
|
||||
g.ToGraphDef(&gd);
|
||||
SessionOptions opts;
|
||||
std::unique_ptr<Session> sess(NewSession(opts));
|
||||
TF_CHECK_OK(sess->Create(gd));
|
||||
std::unique_ptr<Session> session(NewSession(opts));
|
||||
TF_CHECK_OK(session->Create(gd));
|
||||
{
|
||||
// NOTE(mrry): Ignore the first run, which will incur the graph
|
||||
// partitioning/pruning overhead and skew the results.
|
||||
@ -1204,12 +1209,12 @@ void FeedFetchBenchmarkHelper(int num_feeds, int iters) {
|
||||
// the first run, which will impact application startup times, but
|
||||
// that is not the object of study in this benchmark.
|
||||
std::vector<Tensor> output_values;
|
||||
TF_CHECK_OK(sess->Run(inputs, outputs, {}, &output_values));
|
||||
TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
|
||||
}
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
std::vector<Tensor> output_values;
|
||||
TF_CHECK_OK(sess->Run(inputs, outputs, {}, &output_values));
|
||||
TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_tracer.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
@ -41,12 +42,12 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Session* CreateSession() {
|
||||
std::unique_ptr<Session> CreateSession() {
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 1;
|
||||
(*options.config.mutable_device_count())["GPU"] = 1;
|
||||
options.config.set_allow_soft_placement(true);
|
||||
return NewSession(options);
|
||||
return std::unique_ptr<Session>(NewSession(options));
|
||||
}
|
||||
|
||||
class GPUTracerTest : public ::testing::Test {
|
||||
@ -97,24 +98,21 @@ class GPUTracerTest : public ::testing::Test {
|
||||
};
|
||||
|
||||
TEST_F(GPUTracerTest, StartStop) {
|
||||
std::unique_ptr<GPUTracer> tracer;
|
||||
tracer.reset(CreateGPUTracer());
|
||||
std::unique_ptr<GPUTracer> tracer(CreateGPUTracer());
|
||||
if (!tracer) return;
|
||||
TF_EXPECT_OK(tracer->Start());
|
||||
TF_EXPECT_OK(tracer->Stop());
|
||||
}
|
||||
|
||||
TEST_F(GPUTracerTest, StopBeforeStart) {
|
||||
std::unique_ptr<GPUTracer> tracer;
|
||||
tracer.reset(CreateGPUTracer());
|
||||
std::unique_ptr<GPUTracer> tracer(CreateGPUTracer());
|
||||
if (!tracer) return;
|
||||
TF_EXPECT_OK(tracer->Stop());
|
||||
TF_EXPECT_OK(tracer->Stop());
|
||||
}
|
||||
|
||||
TEST_F(GPUTracerTest, CollectBeforeStart) {
|
||||
std::unique_ptr<GPUTracer> tracer;
|
||||
tracer.reset(CreateGPUTracer());
|
||||
std::unique_ptr<GPUTracer> tracer(CreateGPUTracer());
|
||||
if (!tracer) return;
|
||||
StepStats stats;
|
||||
StepStatsCollector collector(&stats);
|
||||
@ -123,8 +121,7 @@ TEST_F(GPUTracerTest, CollectBeforeStart) {
|
||||
}
|
||||
|
||||
TEST_F(GPUTracerTest, CollectBeforeStop) {
|
||||
std::unique_ptr<GPUTracer> tracer;
|
||||
tracer.reset(CreateGPUTracer());
|
||||
std::unique_ptr<GPUTracer> tracer(CreateGPUTracer());
|
||||
if (!tracer) return;
|
||||
TF_EXPECT_OK(tracer->Start());
|
||||
StepStats stats;
|
||||
@ -135,10 +132,8 @@ TEST_F(GPUTracerTest, CollectBeforeStop) {
|
||||
}
|
||||
|
||||
TEST_F(GPUTracerTest, StartTwoTracers) {
|
||||
std::unique_ptr<GPUTracer> tracer1;
|
||||
tracer1.reset(CreateGPUTracer());
|
||||
std::unique_ptr<GPUTracer> tracer2;
|
||||
tracer2.reset(CreateGPUTracer());
|
||||
std::unique_ptr<GPUTracer> tracer1(CreateGPUTracer());
|
||||
std::unique_ptr<GPUTracer> tracer2(CreateGPUTracer());
|
||||
if (!tracer1 || !tracer2) return;
|
||||
|
||||
TF_EXPECT_OK(tracer1->Start());
|
||||
@ -151,12 +146,11 @@ TEST_F(GPUTracerTest, StartTwoTracers) {
|
||||
|
||||
TEST_F(GPUTracerTest, RunWithTracer) {
|
||||
// On non-GPU platforms, we may not support GPUTracer.
|
||||
std::unique_ptr<GPUTracer> tracer;
|
||||
tracer.reset(CreateGPUTracer());
|
||||
std::unique_ptr<GPUTracer> tracer(CreateGPUTracer());
|
||||
if (!tracer) return;
|
||||
|
||||
Initialize({3, 2, -1, 0});
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def_));
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
@ -179,12 +173,11 @@ TEST_F(GPUTracerTest, RunWithTracer) {
|
||||
}
|
||||
|
||||
TEST_F(GPUTracerTest, TraceToStepStatsCollector) {
|
||||
std::unique_ptr<GPUTracer> tracer;
|
||||
tracer.reset(CreateGPUTracer());
|
||||
std::unique_ptr<GPUTracer> tracer(CreateGPUTracer());
|
||||
if (!tracer) return;
|
||||
|
||||
Initialize({3, 2, -1, 0});
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def_));
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
@ -209,7 +202,7 @@ TEST_F(GPUTracerTest, TraceToStepStatsCollector) {
|
||||
|
||||
TEST_F(GPUTracerTest, RunWithTraceOption) {
|
||||
Initialize({3, 2, -1, 0});
|
||||
std::unique_ptr<Session> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def_));
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/debug/debug_graph_utils.h"
|
||||
@ -29,14 +30,15 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
DirectSession* CreateSession() {
|
||||
std::unique_ptr<DirectSession> CreateSession() {
|
||||
SessionOptions options;
|
||||
// Turn off graph optimizer so we can observe intermediate node states.
|
||||
options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_opt_level(OptimizerOptions_Level_L0);
|
||||
|
||||
return dynamic_cast<DirectSession*>(NewSession(options));
|
||||
return std::unique_ptr<DirectSession>(
|
||||
dynamic_cast<DirectSession*>(NewSession(options)));
|
||||
}
|
||||
|
||||
class SessionDebugMinusAXTest : public ::testing::Test {
|
||||
@ -85,7 +87,7 @@ class SessionDebugMinusAXTest : public ::testing::Test {
|
||||
|
||||
TEST_F(SessionDebugMinusAXTest, RunSimpleNetwork) {
|
||||
Initialize({3, 2, -1, 0});
|
||||
std::unique_ptr<DirectSession> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
|
||||
DebugGateway debug_gateway(session.get());
|
||||
@ -220,7 +222,7 @@ TEST_F(SessionDebugMinusAXTest, RunSimpleNetwork) {
|
||||
TEST_F(SessionDebugMinusAXTest, RunSimpleNetworkWithTwoDebugNodesInserted) {
|
||||
// Tensor contains one count of NaN
|
||||
Initialize({3, std::numeric_limits<float>::quiet_NaN(), -1, 0});
|
||||
std::unique_ptr<DirectSession> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
|
||||
DebugGateway debug_gateway(session.get());
|
||||
@ -350,7 +352,7 @@ TEST_F(SessionDebugMinusAXTest,
|
||||
// Test concurrent Run() calls on a graph with different debug watches.
|
||||
|
||||
Initialize({3, 2, -1, 0});
|
||||
std::unique_ptr<DirectSession> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
TF_ASSERT_OK(session->Create(def_));
|
||||
|
||||
@ -537,7 +539,7 @@ class SessionDebugOutputSlotWithoutOngoingEdgeTest : public ::testing::Test {
|
||||
TEST_F(SessionDebugOutputSlotWithoutOngoingEdgeTest,
|
||||
WatchSlotWithoutOutgoingEdge) {
|
||||
Initialize();
|
||||
std::unique_ptr<DirectSession> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
|
||||
DebugGateway debug_gateway(session.get());
|
||||
@ -662,7 +664,7 @@ class SessionDebugVariableTest : public ::testing::Test {
|
||||
|
||||
TEST_F(SessionDebugVariableTest, WatchUninitializedVariableWithDebugOps) {
|
||||
Initialize();
|
||||
std::unique_ptr<DirectSession> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
|
||||
DebugGateway debug_gateway(session.get());
|
||||
@ -741,7 +743,7 @@ TEST_F(SessionDebugVariableTest, WatchUninitializedVariableWithDebugOps) {
|
||||
TEST_F(SessionDebugVariableTest, VariableAssignWithDebugOps) {
|
||||
// Tensor contains one count of NaN
|
||||
Initialize();
|
||||
std::unique_ptr<DirectSession> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
|
||||
DebugGateway debug_gateway(session.get());
|
||||
@ -917,7 +919,7 @@ class SessionDebugGPUSwitchTest : public ::testing::Test {
|
||||
// Test for debug-watching tensors marked as HOST_MEMORY on GPU.
|
||||
TEST_F(SessionDebugGPUSwitchTest, RunSwitchWithHostMemoryDebugOp) {
|
||||
Initialize();
|
||||
std::unique_ptr<DirectSession> session(CreateSession());
|
||||
auto session = CreateSession();
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
|
||||
DebugGateway debug_gateway(session.get());
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/debug/debug_io_utils.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
|
||||
@ -37,8 +39,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/port.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
static SessionOptions Devices(int num_cpus, int num_gpus) {
|
||||
SessionOptions Devices(int num_cpus, int num_gpus) {
|
||||
SessionOptions result;
|
||||
(*result.config.mutable_device_count())["CPU"] = num_cpus;
|
||||
(*result.config.mutable_device_count())["GPU"] = num_gpus;
|
||||
@ -67,13 +70,13 @@ void CreateGraphDef(GraphDef* graph_def, string node_names[3]) {
|
||||
|
||||
// Asserts that "val" is a single float tensor. The only float is
|
||||
// "expected_val".
|
||||
static void IsSingleFloatValue(const Tensor& val, float expected_val) {
|
||||
void IsSingleFloatValue(const Tensor& val, float expected_val) {
|
||||
ASSERT_EQ(val.dtype(), DT_FLOAT);
|
||||
ASSERT_EQ(val.NumElements(), 1);
|
||||
ASSERT_EQ(val.flat<float>()(0), expected_val);
|
||||
}
|
||||
|
||||
static SessionOptions Options(const string& target, int placement_period) {
|
||||
SessionOptions Options(const string& target, int placement_period) {
|
||||
SessionOptions options;
|
||||
// NOTE(mrry): GrpcSession requires a grpc:// scheme prefix in the target
|
||||
// string.
|
||||
@ -85,8 +88,8 @@ static SessionOptions Options(const string& target, int placement_period) {
|
||||
return options;
|
||||
}
|
||||
|
||||
static Session* NewRemote(const SessionOptions& options) {
|
||||
return CHECK_NOTNULL(NewSession(options));
|
||||
std::unique_ptr<Session> NewRemote(const SessionOptions& options) {
|
||||
return std::unique_ptr<Session>(CHECK_NOTNULL(NewSession(options)));
|
||||
}
|
||||
|
||||
class GrpcSessionDebugTest : public ::testing::Test {
|
||||
@ -149,9 +152,7 @@ TEST_F(GrpcSessionDebugTest, FileDebugURL) {
|
||||
std::unique_ptr<test::TestCluster> cluster;
|
||||
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
|
||||
|
||||
std::unique_ptr<Session> session(
|
||||
NewRemote(Options(cluster->targets()[0], 1)));
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
auto session = NewRemote(Options(cluster->targets()[0], 1));
|
||||
TF_CHECK_OK(session->Create(graph));
|
||||
|
||||
// Iteration 0: No watch.
|
||||
@ -220,9 +221,7 @@ void SetDevice(GraphDef* graph, const string& name, const string& dev) {
|
||||
TEST_F(GrpcSessionDebugTest, MultiDevices_String) {
|
||||
std::unique_ptr<test::TestCluster> cluster;
|
||||
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 1), 2, &cluster));
|
||||
std::unique_ptr<Session> session(
|
||||
NewRemote(Options(cluster->targets()[0], 1000)));
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
auto session = NewRemote(Options(cluster->targets()[0], 1000));
|
||||
|
||||
// b = a
|
||||
Graph graph(OpRegistry::Global());
|
||||
@ -289,4 +288,5 @@ TEST_F(GrpcSessionDebugTest, MultiDevices_String) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/example/example_parser_configuration.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
@ -30,10 +32,11 @@ namespace {
|
||||
void ReadFileToStringOrDie(Env* env, const string& filename, string* output) {
|
||||
TF_CHECK_OK(ReadFileToString(env, filename, output));
|
||||
}
|
||||
Session* CreateSession() {
|
||||
|
||||
std::unique_ptr<Session> CreateSession() {
|
||||
SessionOptions options;
|
||||
(*options.config.mutable_device_count())["CPU"] = 2;
|
||||
return NewSession(options);
|
||||
return std::unique_ptr<Session>(NewSession(options));
|
||||
}
|
||||
|
||||
class ExtractExampleParserConfigurationTest : public ::testing::Test {
|
||||
@ -45,19 +48,19 @@ class ExtractExampleParserConfigurationTest : public ::testing::Test {
|
||||
"core/example/testdata/parse_example_graph_def.pbtxt");
|
||||
ReadFileToStringOrDie(Env::Default(), filename, &proto_string);
|
||||
protobuf::TextFormat::ParseFromString(proto_string, &graph_def_);
|
||||
session_.reset(CreateSession());
|
||||
session_ = CreateSession();
|
||||
TF_CHECK_OK(session_->Create(graph_def_));
|
||||
}
|
||||
|
||||
NodeDef* parse_example_node() {
|
||||
for (int i = 0; i < graph_def_.node_size(); ++i) {
|
||||
auto mutable_node = graph_def_.mutable_node(i);
|
||||
if (mutable_node->name() == "ParseExample/ParseExample") {
|
||||
return mutable_node;
|
||||
for (auto& node : *graph_def_.mutable_node()) {
|
||||
if (node.name() == "ParseExample/ParseExample") {
|
||||
return &node;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
GraphDef graph_def_;
|
||||
std::unique_ptr<Session> session_;
|
||||
};
|
||||
|
@ -13,23 +13,24 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
namespace f = test::function;
|
||||
typedef FunctionDefHelper FDH;
|
||||
using FDH = FunctionDefHelper;
|
||||
|
||||
class ArrayGradTest : public ::testing::Test {};
|
||||
|
||||
Session* NewSession() {
|
||||
std::unique_ptr<Session> NewSession() {
|
||||
SessionOptions opts;
|
||||
(*opts.config.mutable_device_count())["CPU"] = 1;
|
||||
return NewSession(opts);
|
||||
return std::unique_ptr<Session>(NewSession(opts));
|
||||
}
|
||||
|
||||
std::vector<Tensor> PackGrad(const Tensor& x0, const Tensor& x1,
|
||||
@ -56,11 +57,10 @@ std::vector<Tensor> PackGrad(const Tensor& x0, const Tensor& x1,
|
||||
{"dx:0", "dx:1"}, {}, &out));
|
||||
CHECK_EQ(out.size(), 2);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
return out;
|
||||
}
|
||||
|
||||
TEST_F(ArrayGradTest, PackGrad) {
|
||||
TEST(ArrayGradTest, PackGrad) {
|
||||
Tensor x0(DT_FLOAT, {2, 3});
|
||||
x0.flat<float>().setZero();
|
||||
Tensor x1(DT_FLOAT, {2, 3});
|
||||
@ -98,11 +98,10 @@ std::vector<Tensor> UnpackGrad(const Tensor& x, const Tensor& dy0,
|
||||
{"dx:0"}, {}, &out));
|
||||
CHECK_EQ(out.size(), 1);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
return out;
|
||||
}
|
||||
|
||||
TEST_F(ArrayGradTest, UnpackGrad) {
|
||||
TEST(ArrayGradTest, UnpackGrad) {
|
||||
Tensor x(DT_FLOAT, {2, 2, 3});
|
||||
x.flat<float>().setZero();
|
||||
Tensor dy0(DT_FLOAT, {2, 3});
|
||||
@ -136,7 +135,6 @@ std::vector<Tensor> ConcatGrad(int dim, const Tensor& x0, const Tensor& x1,
|
||||
{"dx:0", "dx:1", "dx:2"}, {}, &out));
|
||||
CHECK_EQ(out.size(), 3);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -161,11 +159,10 @@ std::vector<Tensor> ConcatGradV2(int dim, const Tensor& x0, const Tensor& x1,
|
||||
{"dx:0", "dx:1", "dx:2"}, {}, &out));
|
||||
CHECK_EQ(out.size(), 3);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
return out;
|
||||
}
|
||||
|
||||
TEST_F(ArrayGradTest, ConcatGrad) {
|
||||
TEST(ArrayGradTest, ConcatGrad) {
|
||||
Tensor x0(DT_FLOAT, {2, 3, 5});
|
||||
x0.flat<float>().setZero();
|
||||
Tensor x1(DT_FLOAT, {2, 1, 5});
|
||||
@ -238,11 +235,10 @@ std::vector<Tensor> SplitGrad(int dim, const Tensor& x, const Tensor& dy0,
|
||||
{"dx:0", "dx:1"}, {}, &out));
|
||||
CHECK_EQ(out.size(), 2);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
return out;
|
||||
}
|
||||
|
||||
TEST_F(ArrayGradTest, SplitGrad) {
|
||||
TEST(ArrayGradTest, SplitGrad) {
|
||||
Tensor x(DT_FLOAT, {2, 4, 5});
|
||||
x.flat<float>().setZero();
|
||||
Tensor dy0(DT_FLOAT, {2, 2, 5});
|
||||
@ -279,11 +275,10 @@ std::vector<Tensor> ReshapeGrad(const Tensor& x, const Tensor& s,
|
||||
{"dx:0", "dx:1"}, {}, &out));
|
||||
CHECK_EQ(out.size(), 2);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
return out;
|
||||
}
|
||||
|
||||
TEST_F(ArrayGradTest, ReshapeGrad) {
|
||||
TEST(ArrayGradTest, ReshapeGrad) {
|
||||
Tensor x(DT_FLOAT, {2, 4, 5});
|
||||
x.flat<float>().setZero();
|
||||
auto s = test::AsTensor<int32>({8, 5});
|
||||
@ -319,11 +314,10 @@ std::vector<Tensor> ExpandDimsGrad(const Tensor& x, const Tensor& s,
|
||||
{"dx:0", "dx:1"}, {}, &out));
|
||||
CHECK_EQ(out.size(), 2);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
return out;
|
||||
}
|
||||
|
||||
TEST_F(ArrayGradTest, ExpandDimsGrad) {
|
||||
TEST(ArrayGradTest, ExpandDimsGrad) {
|
||||
Tensor x(DT_FLOAT, {2, 4, 5});
|
||||
x.flat<float>().setZero();
|
||||
auto s = test::AsTensor<int32>({1});
|
||||
@ -356,11 +350,10 @@ std::vector<Tensor> SqueezeGrad(const Tensor& x, const Tensor& dy) {
|
||||
TF_CHECK_OK(sess->Run({{"x:0", x}, {"dy:0", dy}}, {"dx:0"}, {}, &out));
|
||||
CHECK_EQ(out.size(), 1);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
return out;
|
||||
}
|
||||
|
||||
TEST_F(ArrayGradTest, SqueezeGrad) {
|
||||
TEST(ArrayGradTest, SqueezeGrad) {
|
||||
Tensor x(DT_FLOAT, {2, 1, 3});
|
||||
x.flat<float>().setZero();
|
||||
Tensor dy(DT_FLOAT, {2, 3});
|
||||
@ -389,11 +382,10 @@ std::vector<Tensor> TransposeGrad(const Tensor& x, const Tensor& p,
|
||||
{"dx:0", "dx:1"}, {}, &out));
|
||||
CHECK_EQ(out.size(), 2);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
return out;
|
||||
}
|
||||
|
||||
TEST_F(ArrayGradTest, TransposeGrad) {
|
||||
TEST(ArrayGradTest, TransposeGrad) {
|
||||
Tensor x(DT_FLOAT, {2, 4, 5});
|
||||
x.flat<float>().setZero();
|
||||
auto p = test::AsTensor<int32>({2, 0, 1});
|
||||
@ -428,11 +420,10 @@ std::vector<Tensor> ReverseGrad(const Tensor& x, const Tensor& dims,
|
||||
{"dx:0", "dx:1"}, {}, &out));
|
||||
CHECK_EQ(out.size(), 2);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
return out;
|
||||
}
|
||||
|
||||
TEST_F(ArrayGradTest, ReverseGrad) {
|
||||
TEST(ArrayGradTest, ReverseGrad) {
|
||||
Tensor x(DT_FLOAT, {2, 3});
|
||||
x.flat<float>().setZero();
|
||||
auto dims = test::AsTensor<bool>({false, true});
|
||||
@ -465,11 +456,10 @@ std::vector<Tensor> ReverseV2Grad(const Tensor& x, const Tensor& axis,
|
||||
{"dx:0", "dx:1"}, {}, &out));
|
||||
CHECK_EQ(out.size(), 2);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
return out;
|
||||
}
|
||||
|
||||
TEST_F(ArrayGradTest, ReverseV2Grad) {
|
||||
TEST(ArrayGradTest, ReverseV2Grad) {
|
||||
Tensor x(DT_FLOAT, {2, 3});
|
||||
x.flat<float>().setZero();
|
||||
auto axis = test::AsTensor<int32>({1});
|
||||
@ -502,11 +492,10 @@ std::vector<Tensor> SliceGrad(const Tensor& x, const Tensor& b, const Tensor& s,
|
||||
{"dx:0", "dx:1", "dx:2"}, {}, &out));
|
||||
CHECK_EQ(out.size(), 3);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
return out;
|
||||
}
|
||||
|
||||
TEST_F(ArrayGradTest, SliceGrad) {
|
||||
TEST(ArrayGradTest, SliceGrad) {
|
||||
Tensor x(DT_FLOAT, {2, 3, 4});
|
||||
x.flat<float>().setZero();
|
||||
auto begin = test::AsTensor<int32>({1, 1, 1});
|
||||
@ -564,7 +553,6 @@ std::vector<Tensor> StridedSliceGrad(const Tensor& x, const Tensor& begin,
|
||||
{"dx:0", "dx:1", "dx:2", "dx:3"}, {}, &out));
|
||||
CHECK_EQ(out.size(), 4);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -611,11 +599,10 @@ std::vector<Tensor> StridedSliceGradGrad(
|
||||
{"dx:0", "dx:1", "dx:2", "dx:3", "dx:4"}, {}, &out));
|
||||
CHECK_EQ(out.size(), 5);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
return out;
|
||||
}
|
||||
|
||||
TEST_F(ArrayGradTest, StridedSliceGrad) {
|
||||
TEST(ArrayGradTest, StridedSliceGrad) {
|
||||
Tensor x(DT_FLOAT, {2, 3, 4});
|
||||
x.flat<float>().setZero();
|
||||
Tensor x_shape = test::AsTensor<int32>({2, 3, 4}, {3});
|
||||
@ -730,4 +717,5 @@ TEST_F(ArrayGradTest, StridedSliceGrad) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
@ -21,17 +23,16 @@ limitations under the License.
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
namespace f = test::function;
|
||||
typedef FunctionDefHelper FDH;
|
||||
using FDH = FunctionDefHelper;
|
||||
|
||||
namespace {
|
||||
Session* NewSession() {
|
||||
std::unique_ptr<Session> NewSession() {
|
||||
SessionOptions opts;
|
||||
(*opts.config.mutable_device_count())["CPU"] = 1;
|
||||
return NewSession(opts);
|
||||
return std::unique_ptr<Session>(NewSession(opts));
|
||||
}
|
||||
} // end namespace
|
||||
|
||||
class MathGradTest : public ::testing::Test {
|
||||
protected:
|
||||
@ -85,7 +86,6 @@ class MathGradTest : public ::testing::Test {
|
||||
*y = outputs[0];
|
||||
}
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
return s;
|
||||
}
|
||||
|
||||
@ -148,7 +148,6 @@ class MathGradTest : public ::testing::Test {
|
||||
sess->Run({{"x:0", x}, {"y:0", y}}, {"d:0", "d:1"}, {}, &outputs));
|
||||
CHECK_EQ(outputs.size(), 2);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
*dx = outputs[0];
|
||||
*dy = outputs[1];
|
||||
}
|
||||
@ -204,7 +203,6 @@ class MathGradTest : public ::testing::Test {
|
||||
sess->Run({{"x:0", x}, {"i:0", idx}}, {"d:0", "d:1"}, {}, &outputs));
|
||||
CHECK_EQ(outputs.size(), 2);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
*dx = outputs[0];
|
||||
*di = outputs[1];
|
||||
}
|
||||
@ -227,7 +225,6 @@ class MathGradTest : public ::testing::Test {
|
||||
TF_CHECK_OK(sess->Run({{"x:0", x}, {"y:0", y}}, {"z:0"}, {}, &outputs));
|
||||
CHECK_EQ(outputs.size(), 1);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
return outputs[0];
|
||||
}
|
||||
|
||||
@ -295,7 +292,6 @@ class MathGradTest : public ::testing::Test {
|
||||
sess->Run({{"x:0", x}, {"y:0", y}}, {"d:0", "d:1"}, {}, &outputs));
|
||||
CHECK_EQ(outputs.size(), 2);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
*dx = outputs[0];
|
||||
*dy = outputs[1];
|
||||
}
|
||||
@ -359,14 +355,13 @@ class MathGradTest : public ::testing::Test {
|
||||
{"d:0", "d:1", "d:2"}, {}, &outputs));
|
||||
CHECK_EQ(outputs.size(), 3);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
delete sess;
|
||||
*dc = outputs[0];
|
||||
*dx = outputs[1];
|
||||
*dy = outputs[2];
|
||||
}
|
||||
};
|
||||
|
||||
static void HasError(const Status& s, const string& substr) {
|
||||
void HasError(const Status& s, const string& substr) {
|
||||
EXPECT_TRUE(StringPiece(s.ToString()).contains(substr))
|
||||
<< s << ", expected substring " << substr;
|
||||
}
|
||||
@ -1126,4 +1121,5 @@ TEST_F(MathGradTest, Max_dim0_dim1_Dups) {
|
||||
di, test::AsTensor<int32>({0, 0}, TensorShape({2})));
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user