Minor modernizations, mostly more <memory>

PiperOrigin-RevId: 158793461
This commit is contained in:
A. Unique TensorFlower 2017-06-12 18:12:30 -07:00 committed by TensorFlower Gardener
parent 995f5f4f40
commit 9f10f60fbd
7 changed files with 122 additions and 135 deletions

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/direct_session.h"
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
@ -45,10 +46,10 @@ limitations under the License.
namespace tensorflow {
namespace {
Session* CreateSession() {
std::unique_ptr<Session> CreateSession() {
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 2;
return NewSession(options);
return std::unique_ptr<Session>(NewSession(options));
}
class DirectSessionMinusAXTest : public ::testing::Test {
@ -87,7 +88,7 @@ class DirectSessionMinusAXTest : public ::testing::Test {
TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork) {
Initialize({3, 2, -1, 0});
std::unique_ptr<Session> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def_));
std::vector<std::pair<string, Tensor>> inputs;
@ -109,7 +110,7 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork) {
TEST_F(DirectSessionMinusAXTest, TestFeed) {
Initialize({1, 2, 3, 4});
std::unique_ptr<Session> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def_));
@ -138,7 +139,7 @@ TEST_F(DirectSessionMinusAXTest, TestFeed) {
TEST_F(DirectSessionMinusAXTest, TestConcurrency) {
Initialize({1, 2, 3, 4});
std::unique_ptr<Session> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def_));
@ -207,7 +208,7 @@ TEST_F(DirectSessionMinusAXTest, TestPerSessionThreads) {
TEST_F(DirectSessionMinusAXTest, TwoCreateCallsFails) {
Initialize({1, 2, 3, 4});
std::unique_ptr<Session> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def_));
@ -217,7 +218,7 @@ TEST_F(DirectSessionMinusAXTest, TwoCreateCallsFails) {
TEST_F(DirectSessionMinusAXTest, ForgetToCreate) {
Initialize({1, 2, 3, 4});
std::unique_ptr<Session> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
std::vector<std::pair<string, Tensor>> inputs;
std::vector<Tensor> outputs;
@ -261,7 +262,7 @@ TEST_F(DirectSessionMinusAXTest, InvalidDevice) {
TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts) {
Initialize({3, 2, -1, 0});
std::unique_ptr<Session> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def_));
std::vector<std::pair<string, Tensor>> inputs;
@ -313,7 +314,7 @@ TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
test::graph::ToGraphDef(&g, &def);
std::unique_ptr<Session> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def));
@ -348,7 +349,7 @@ TEST(DirectSessionTest, MultipleFeedTest) {
test::graph::ToGraphDef(&g, &def);
std::unique_ptr<Session> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def));
@ -436,7 +437,6 @@ TEST(DirectSessionTest, DarthKernel) {
std::vector<Tensor> outputs;
auto s = sess->Run({}, {y->name() + ":0"}, {}, &outputs);
EXPECT_TRUE(errors::IsInternal(s));
delete sess;
}
// Have the Darth op in the graph placed on GPU, but don't run it.
@ -500,7 +500,7 @@ TEST(DirectSessionTest, PartialRunTest) {
test::graph::ToGraphDef(&g, &def);
std::unique_ptr<Session> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def));
@ -556,7 +556,7 @@ TEST(DirectSessionTest, PartialRunMissingFeed) {
test::graph::ToGraphDef(&g, &def);
std::unique_ptr<Session> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def));
@ -589,7 +589,7 @@ TEST(DirectSessionTest, PartialRunMultiOutputFeed) {
test::graph::ToGraphDef(&g, &def);
std::unique_ptr<Session> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def));
@ -638,7 +638,7 @@ TEST(DirectSessionTest, RunHandleTest) {
test::graph::ToGraphDef(&g, &def);
std::unique_ptr<Session> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def));
@ -679,7 +679,7 @@ TEST(DirectSessionTest, CreateGraphFailsWhenAssigningAFedVar) {
// a = b
Node* assign = test::graph::Assign(&graph, a, b);
std::unique_ptr<Session> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
// The graph is invalid since a constant cannot be assigned to a constant.
@ -757,30 +757,35 @@ TEST(DirectSessionTest, TimeoutSession) {
)proto",
&graph);
// Creates a session with operation_timeout_in_ms set to 100 milliseconds.
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 2;
options.config.set_operation_timeout_in_ms(100);
std::unique_ptr<Session> session(NewSession(options));
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(graph));
{
// Creates a session with operation_timeout_in_ms set to 100 milliseconds.
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 2;
options.config.set_operation_timeout_in_ms(100);
// Verifies that the error code is DEADLINE_EXCEEDED.
Status s = session->Run({}, {}, {"fifo_queue_Dequeue"}, nullptr);
ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code());
TF_ASSERT_OK(session->Close());
std::unique_ptr<Session> session(NewSession(options));
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(graph));
// Creates a session with no operation_timeout_in_ms.
session.reset(CreateSession());
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(graph));
RunOptions run_options;
run_options.set_timeout_in_ms(20);
// Verifies that the error code is DEADLINE_EXCEEDED.
Status s2 = session->Run(run_options, {}, {}, {"fifo_queue_Dequeue"}, nullptr,
nullptr);
ASSERT_EQ(error::DEADLINE_EXCEEDED, s2.code());
TF_ASSERT_OK(session->Close());
// Verifies that the error code is DEADLINE_EXCEEDED.
Status s = session->Run({}, {}, {"fifo_queue_Dequeue"}, nullptr);
ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code());
TF_ASSERT_OK(session->Close());
}
{
// Creates a session with no operation_timeout_in_ms.
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(graph));
RunOptions run_options;
run_options.set_timeout_in_ms(20);
// Verifies that the error code is DEADLINE_EXCEEDED.
Status s2 = session->Run(run_options, {}, {}, {"fifo_queue_Dequeue"},
nullptr, nullptr);
ASSERT_EQ(error::DEADLINE_EXCEEDED, s2.code());
TF_ASSERT_OK(session->Close());
}
}
// Accesses the cancellation manager for the step after the step has been
@ -1090,7 +1095,7 @@ TEST(DirectSessionTest, TestDirectSessionPRunClose) {
test::graph::ToGraphDef(&g, &def);
std::unique_ptr<Session> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def));
@ -1194,8 +1199,8 @@ void FeedFetchBenchmarkHelper(int num_feeds, int iters) {
GraphDef gd;
g.ToGraphDef(&gd);
SessionOptions opts;
std::unique_ptr<Session> sess(NewSession(opts));
TF_CHECK_OK(sess->Create(gd));
std::unique_ptr<Session> session(NewSession(opts));
TF_CHECK_OK(session->Create(gd));
{
// NOTE(mrry): Ignore the first run, which will incur the graph
// partitioning/pruning overhead and skew the results.
@ -1204,12 +1209,12 @@ void FeedFetchBenchmarkHelper(int num_feeds, int iters) {
// the first run, which will impact application startup times, but
// that is not the object of study in this benchmark.
std::vector<Tensor> output_values;
TF_CHECK_OK(sess->Run(inputs, outputs, {}, &output_values));
TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
}
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
std::vector<Tensor> output_values;
TF_CHECK_OK(sess->Run(inputs, outputs, {}, &output_values));
TF_CHECK_OK(session->Run(inputs, outputs, {}, &output_values));
}
testing::StopTiming();
}

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_tracer.h"
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
@ -41,12 +42,12 @@ limitations under the License.
namespace tensorflow {
namespace {
Session* CreateSession() {
std::unique_ptr<Session> CreateSession() {
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 1;
(*options.config.mutable_device_count())["GPU"] = 1;
options.config.set_allow_soft_placement(true);
return NewSession(options);
return std::unique_ptr<Session>(NewSession(options));
}
class GPUTracerTest : public ::testing::Test {
@ -97,24 +98,21 @@ class GPUTracerTest : public ::testing::Test {
};
TEST_F(GPUTracerTest, StartStop) {
std::unique_ptr<GPUTracer> tracer;
tracer.reset(CreateGPUTracer());
std::unique_ptr<GPUTracer> tracer(CreateGPUTracer());
if (!tracer) return;
TF_EXPECT_OK(tracer->Start());
TF_EXPECT_OK(tracer->Stop());
}
TEST_F(GPUTracerTest, StopBeforeStart) {
std::unique_ptr<GPUTracer> tracer;
tracer.reset(CreateGPUTracer());
std::unique_ptr<GPUTracer> tracer(CreateGPUTracer());
if (!tracer) return;
TF_EXPECT_OK(tracer->Stop());
TF_EXPECT_OK(tracer->Stop());
}
TEST_F(GPUTracerTest, CollectBeforeStart) {
std::unique_ptr<GPUTracer> tracer;
tracer.reset(CreateGPUTracer());
std::unique_ptr<GPUTracer> tracer(CreateGPUTracer());
if (!tracer) return;
StepStats stats;
StepStatsCollector collector(&stats);
@ -123,8 +121,7 @@ TEST_F(GPUTracerTest, CollectBeforeStart) {
}
TEST_F(GPUTracerTest, CollectBeforeStop) {
std::unique_ptr<GPUTracer> tracer;
tracer.reset(CreateGPUTracer());
std::unique_ptr<GPUTracer> tracer(CreateGPUTracer());
if (!tracer) return;
TF_EXPECT_OK(tracer->Start());
StepStats stats;
@ -135,10 +132,8 @@ TEST_F(GPUTracerTest, CollectBeforeStop) {
}
TEST_F(GPUTracerTest, StartTwoTracers) {
std::unique_ptr<GPUTracer> tracer1;
tracer1.reset(CreateGPUTracer());
std::unique_ptr<GPUTracer> tracer2;
tracer2.reset(CreateGPUTracer());
std::unique_ptr<GPUTracer> tracer1(CreateGPUTracer());
std::unique_ptr<GPUTracer> tracer2(CreateGPUTracer());
if (!tracer1 || !tracer2) return;
TF_EXPECT_OK(tracer1->Start());
@ -151,12 +146,11 @@ TEST_F(GPUTracerTest, StartTwoTracers) {
TEST_F(GPUTracerTest, RunWithTracer) {
// On non-GPU platforms, we may not support GPUTracer.
std::unique_ptr<GPUTracer> tracer;
tracer.reset(CreateGPUTracer());
std::unique_ptr<GPUTracer> tracer(CreateGPUTracer());
if (!tracer) return;
Initialize({3, 2, -1, 0});
std::unique_ptr<Session> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def_));
std::vector<std::pair<string, Tensor>> inputs;
@ -179,12 +173,11 @@ TEST_F(GPUTracerTest, RunWithTracer) {
}
TEST_F(GPUTracerTest, TraceToStepStatsCollector) {
std::unique_ptr<GPUTracer> tracer;
tracer.reset(CreateGPUTracer());
std::unique_ptr<GPUTracer> tracer(CreateGPUTracer());
if (!tracer) return;
Initialize({3, 2, -1, 0});
std::unique_ptr<Session> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def_));
std::vector<std::pair<string, Tensor>> inputs;
@ -209,7 +202,7 @@ TEST_F(GPUTracerTest, TraceToStepStatsCollector) {
TEST_F(GPUTracerTest, RunWithTraceOption) {
Initialize({3, 2, -1, 0});
std::unique_ptr<Session> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def_));
std::vector<std::pair<string, Tensor>> inputs;

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <algorithm>
#include <cstdlib>
#include <memory>
#include <unordered_map>
#include "tensorflow/core/debug/debug_graph_utils.h"
@ -29,14 +30,15 @@ limitations under the License.
namespace tensorflow {
namespace {
DirectSession* CreateSession() {
std::unique_ptr<DirectSession> CreateSession() {
SessionOptions options;
// Turn off graph optimizer so we can observe intermediate node states.
options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_opt_level(OptimizerOptions_Level_L0);
return dynamic_cast<DirectSession*>(NewSession(options));
return std::unique_ptr<DirectSession>(
dynamic_cast<DirectSession*>(NewSession(options)));
}
class SessionDebugMinusAXTest : public ::testing::Test {
@ -85,7 +87,7 @@ class SessionDebugMinusAXTest : public ::testing::Test {
TEST_F(SessionDebugMinusAXTest, RunSimpleNetwork) {
Initialize({3, 2, -1, 0});
std::unique_ptr<DirectSession> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
DebugGateway debug_gateway(session.get());
@ -220,7 +222,7 @@ TEST_F(SessionDebugMinusAXTest, RunSimpleNetwork) {
TEST_F(SessionDebugMinusAXTest, RunSimpleNetworkWithTwoDebugNodesInserted) {
// Tensor contains one count of NaN
Initialize({3, std::numeric_limits<float>::quiet_NaN(), -1, 0});
std::unique_ptr<DirectSession> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
DebugGateway debug_gateway(session.get());
@ -350,7 +352,7 @@ TEST_F(SessionDebugMinusAXTest,
// Test concurrent Run() calls on a graph with different debug watches.
Initialize({3, 2, -1, 0});
std::unique_ptr<DirectSession> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def_));
@ -537,7 +539,7 @@ class SessionDebugOutputSlotWithoutOngoingEdgeTest : public ::testing::Test {
TEST_F(SessionDebugOutputSlotWithoutOngoingEdgeTest,
WatchSlotWithoutOutgoingEdge) {
Initialize();
std::unique_ptr<DirectSession> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
DebugGateway debug_gateway(session.get());
@ -662,7 +664,7 @@ class SessionDebugVariableTest : public ::testing::Test {
TEST_F(SessionDebugVariableTest, WatchUninitializedVariableWithDebugOps) {
Initialize();
std::unique_ptr<DirectSession> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
DebugGateway debug_gateway(session.get());
@ -741,7 +743,7 @@ TEST_F(SessionDebugVariableTest, WatchUninitializedVariableWithDebugOps) {
TEST_F(SessionDebugVariableTest, VariableAssignWithDebugOps) {
// Tensor contains one count of NaN
Initialize();
std::unique_ptr<DirectSession> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
DebugGateway debug_gateway(session.get());
@ -917,7 +919,7 @@ class SessionDebugGPUSwitchTest : public ::testing::Test {
// Test for debug-watching tensors marked as HOST_MEMORY on GPU.
TEST_F(SessionDebugGPUSwitchTest, RunSwitchWithHostMemoryDebugOp) {
Initialize();
std::unique_ptr<DirectSession> session(CreateSession());
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
DebugGateway debug_gateway(session.get());

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
#include <memory>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/debug/debug_io_utils.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
@ -37,8 +39,9 @@ limitations under the License.
#include "tensorflow/core/util/port.h"
namespace tensorflow {
namespace {
static SessionOptions Devices(int num_cpus, int num_gpus) {
SessionOptions Devices(int num_cpus, int num_gpus) {
SessionOptions result;
(*result.config.mutable_device_count())["CPU"] = num_cpus;
(*result.config.mutable_device_count())["GPU"] = num_gpus;
@ -67,13 +70,13 @@ void CreateGraphDef(GraphDef* graph_def, string node_names[3]) {
// Asserts that "val" is a single float tensor. The only float is
// "expected_val".
static void IsSingleFloatValue(const Tensor& val, float expected_val) {
void IsSingleFloatValue(const Tensor& val, float expected_val) {
ASSERT_EQ(val.dtype(), DT_FLOAT);
ASSERT_EQ(val.NumElements(), 1);
ASSERT_EQ(val.flat<float>()(0), expected_val);
}
static SessionOptions Options(const string& target, int placement_period) {
SessionOptions Options(const string& target, int placement_period) {
SessionOptions options;
// NOTE(mrry): GrpcSession requires a grpc:// scheme prefix in the target
// string.
@ -85,8 +88,8 @@ static SessionOptions Options(const string& target, int placement_period) {
return options;
}
static Session* NewRemote(const SessionOptions& options) {
return CHECK_NOTNULL(NewSession(options));
std::unique_ptr<Session> NewRemote(const SessionOptions& options) {
return std::unique_ptr<Session>(CHECK_NOTNULL(NewSession(options)));
}
class GrpcSessionDebugTest : public ::testing::Test {
@ -149,9 +152,7 @@ TEST_F(GrpcSessionDebugTest, FileDebugURL) {
std::unique_ptr<test::TestCluster> cluster;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
std::unique_ptr<Session> session(
NewRemote(Options(cluster->targets()[0], 1)));
ASSERT_TRUE(session != nullptr);
auto session = NewRemote(Options(cluster->targets()[0], 1));
TF_CHECK_OK(session->Create(graph));
// Iteration 0: No watch.
@ -220,9 +221,7 @@ void SetDevice(GraphDef* graph, const string& name, const string& dev) {
TEST_F(GrpcSessionDebugTest, MultiDevices_String) {
std::unique_ptr<test::TestCluster> cluster;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 1), 2, &cluster));
std::unique_ptr<Session> session(
NewRemote(Options(cluster->targets()[0], 1000)));
ASSERT_TRUE(session != nullptr);
auto session = NewRemote(Options(cluster->targets()[0], 1000));
// b = a
Graph graph(OpRegistry::Global());
@ -289,4 +288,5 @@ TEST_F(GrpcSessionDebugTest, MultiDevices_String) {
}
}
} // namespace
} // namespace tensorflow

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/example/example_parser_configuration.h"
#include <memory>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@ -30,10 +32,11 @@ namespace {
void ReadFileToStringOrDie(Env* env, const string& filename, string* output) {
TF_CHECK_OK(ReadFileToString(env, filename, output));
}
Session* CreateSession() {
std::unique_ptr<Session> CreateSession() {
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 2;
return NewSession(options);
return std::unique_ptr<Session>(NewSession(options));
}
class ExtractExampleParserConfigurationTest : public ::testing::Test {
@ -45,19 +48,19 @@ class ExtractExampleParserConfigurationTest : public ::testing::Test {
"core/example/testdata/parse_example_graph_def.pbtxt");
ReadFileToStringOrDie(Env::Default(), filename, &proto_string);
protobuf::TextFormat::ParseFromString(proto_string, &graph_def_);
session_.reset(CreateSession());
session_ = CreateSession();
TF_CHECK_OK(session_->Create(graph_def_));
}
NodeDef* parse_example_node() {
for (int i = 0; i < graph_def_.node_size(); ++i) {
auto mutable_node = graph_def_.mutable_node(i);
if (mutable_node->name() == "ParseExample/ParseExample") {
return mutable_node;
for (auto& node : *graph_def_.mutable_node()) {
if (node.name() == "ParseExample/ParseExample") {
return &node;
}
}
return nullptr;
}
GraphDef graph_def_;
std::unique_ptr<Session> session_;
};

View File

@ -13,23 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
namespace {
namespace f = test::function;
typedef FunctionDefHelper FDH;
using FDH = FunctionDefHelper;
class ArrayGradTest : public ::testing::Test {};
Session* NewSession() {
std::unique_ptr<Session> NewSession() {
SessionOptions opts;
(*opts.config.mutable_device_count())["CPU"] = 1;
return NewSession(opts);
return std::unique_ptr<Session>(NewSession(opts));
}
std::vector<Tensor> PackGrad(const Tensor& x0, const Tensor& x1,
@ -56,11 +57,10 @@ std::vector<Tensor> PackGrad(const Tensor& x0, const Tensor& x1,
{"dx:0", "dx:1"}, {}, &out));
CHECK_EQ(out.size(), 2);
TF_CHECK_OK(sess->Close());
delete sess;
return out;
}
TEST_F(ArrayGradTest, PackGrad) {
TEST(ArrayGradTest, PackGrad) {
Tensor x0(DT_FLOAT, {2, 3});
x0.flat<float>().setZero();
Tensor x1(DT_FLOAT, {2, 3});
@ -98,11 +98,10 @@ std::vector<Tensor> UnpackGrad(const Tensor& x, const Tensor& dy0,
{"dx:0"}, {}, &out));
CHECK_EQ(out.size(), 1);
TF_CHECK_OK(sess->Close());
delete sess;
return out;
}
TEST_F(ArrayGradTest, UnpackGrad) {
TEST(ArrayGradTest, UnpackGrad) {
Tensor x(DT_FLOAT, {2, 2, 3});
x.flat<float>().setZero();
Tensor dy0(DT_FLOAT, {2, 3});
@ -136,7 +135,6 @@ std::vector<Tensor> ConcatGrad(int dim, const Tensor& x0, const Tensor& x1,
{"dx:0", "dx:1", "dx:2"}, {}, &out));
CHECK_EQ(out.size(), 3);
TF_CHECK_OK(sess->Close());
delete sess;
return out;
}
@ -161,11 +159,10 @@ std::vector<Tensor> ConcatGradV2(int dim, const Tensor& x0, const Tensor& x1,
{"dx:0", "dx:1", "dx:2"}, {}, &out));
CHECK_EQ(out.size(), 3);
TF_CHECK_OK(sess->Close());
delete sess;
return out;
}
TEST_F(ArrayGradTest, ConcatGrad) {
TEST(ArrayGradTest, ConcatGrad) {
Tensor x0(DT_FLOAT, {2, 3, 5});
x0.flat<float>().setZero();
Tensor x1(DT_FLOAT, {2, 1, 5});
@ -238,11 +235,10 @@ std::vector<Tensor> SplitGrad(int dim, const Tensor& x, const Tensor& dy0,
{"dx:0", "dx:1"}, {}, &out));
CHECK_EQ(out.size(), 2);
TF_CHECK_OK(sess->Close());
delete sess;
return out;
}
TEST_F(ArrayGradTest, SplitGrad) {
TEST(ArrayGradTest, SplitGrad) {
Tensor x(DT_FLOAT, {2, 4, 5});
x.flat<float>().setZero();
Tensor dy0(DT_FLOAT, {2, 2, 5});
@ -279,11 +275,10 @@ std::vector<Tensor> ReshapeGrad(const Tensor& x, const Tensor& s,
{"dx:0", "dx:1"}, {}, &out));
CHECK_EQ(out.size(), 2);
TF_CHECK_OK(sess->Close());
delete sess;
return out;
}
TEST_F(ArrayGradTest, ReshapeGrad) {
TEST(ArrayGradTest, ReshapeGrad) {
Tensor x(DT_FLOAT, {2, 4, 5});
x.flat<float>().setZero();
auto s = test::AsTensor<int32>({8, 5});
@ -319,11 +314,10 @@ std::vector<Tensor> ExpandDimsGrad(const Tensor& x, const Tensor& s,
{"dx:0", "dx:1"}, {}, &out));
CHECK_EQ(out.size(), 2);
TF_CHECK_OK(sess->Close());
delete sess;
return out;
}
TEST_F(ArrayGradTest, ExpandDimsGrad) {
TEST(ArrayGradTest, ExpandDimsGrad) {
Tensor x(DT_FLOAT, {2, 4, 5});
x.flat<float>().setZero();
auto s = test::AsTensor<int32>({1});
@ -356,11 +350,10 @@ std::vector<Tensor> SqueezeGrad(const Tensor& x, const Tensor& dy) {
TF_CHECK_OK(sess->Run({{"x:0", x}, {"dy:0", dy}}, {"dx:0"}, {}, &out));
CHECK_EQ(out.size(), 1);
TF_CHECK_OK(sess->Close());
delete sess;
return out;
}
TEST_F(ArrayGradTest, SqueezeGrad) {
TEST(ArrayGradTest, SqueezeGrad) {
Tensor x(DT_FLOAT, {2, 1, 3});
x.flat<float>().setZero();
Tensor dy(DT_FLOAT, {2, 3});
@ -389,11 +382,10 @@ std::vector<Tensor> TransposeGrad(const Tensor& x, const Tensor& p,
{"dx:0", "dx:1"}, {}, &out));
CHECK_EQ(out.size(), 2);
TF_CHECK_OK(sess->Close());
delete sess;
return out;
}
TEST_F(ArrayGradTest, TransposeGrad) {
TEST(ArrayGradTest, TransposeGrad) {
Tensor x(DT_FLOAT, {2, 4, 5});
x.flat<float>().setZero();
auto p = test::AsTensor<int32>({2, 0, 1});
@ -428,11 +420,10 @@ std::vector<Tensor> ReverseGrad(const Tensor& x, const Tensor& dims,
{"dx:0", "dx:1"}, {}, &out));
CHECK_EQ(out.size(), 2);
TF_CHECK_OK(sess->Close());
delete sess;
return out;
}
TEST_F(ArrayGradTest, ReverseGrad) {
TEST(ArrayGradTest, ReverseGrad) {
Tensor x(DT_FLOAT, {2, 3});
x.flat<float>().setZero();
auto dims = test::AsTensor<bool>({false, true});
@ -465,11 +456,10 @@ std::vector<Tensor> ReverseV2Grad(const Tensor& x, const Tensor& axis,
{"dx:0", "dx:1"}, {}, &out));
CHECK_EQ(out.size(), 2);
TF_CHECK_OK(sess->Close());
delete sess;
return out;
}
TEST_F(ArrayGradTest, ReverseV2Grad) {
TEST(ArrayGradTest, ReverseV2Grad) {
Tensor x(DT_FLOAT, {2, 3});
x.flat<float>().setZero();
auto axis = test::AsTensor<int32>({1});
@ -502,11 +492,10 @@ std::vector<Tensor> SliceGrad(const Tensor& x, const Tensor& b, const Tensor& s,
{"dx:0", "dx:1", "dx:2"}, {}, &out));
CHECK_EQ(out.size(), 3);
TF_CHECK_OK(sess->Close());
delete sess;
return out;
}
TEST_F(ArrayGradTest, SliceGrad) {
TEST(ArrayGradTest, SliceGrad) {
Tensor x(DT_FLOAT, {2, 3, 4});
x.flat<float>().setZero();
auto begin = test::AsTensor<int32>({1, 1, 1});
@ -564,7 +553,6 @@ std::vector<Tensor> StridedSliceGrad(const Tensor& x, const Tensor& begin,
{"dx:0", "dx:1", "dx:2", "dx:3"}, {}, &out));
CHECK_EQ(out.size(), 4);
TF_CHECK_OK(sess->Close());
delete sess;
return out;
}
@ -611,11 +599,10 @@ std::vector<Tensor> StridedSliceGradGrad(
{"dx:0", "dx:1", "dx:2", "dx:3", "dx:4"}, {}, &out));
CHECK_EQ(out.size(), 5);
TF_CHECK_OK(sess->Close());
delete sess;
return out;
}
TEST_F(ArrayGradTest, StridedSliceGrad) {
TEST(ArrayGradTest, StridedSliceGrad) {
Tensor x(DT_FLOAT, {2, 3, 4});
x.flat<float>().setZero();
Tensor x_shape = test::AsTensor<int32>({2, 3, 4}, {3});
@ -730,4 +717,5 @@ TEST_F(ArrayGradTest, StridedSliceGrad) {
}
}
} // namespace
} // namespace tensorflow

View File

@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@ -21,17 +23,16 @@ limitations under the License.
#include "tensorflow/core/public/session.h"
namespace tensorflow {
namespace {
namespace f = test::function;
typedef FunctionDefHelper FDH;
using FDH = FunctionDefHelper;
namespace {
Session* NewSession() {
std::unique_ptr<Session> NewSession() {
SessionOptions opts;
(*opts.config.mutable_device_count())["CPU"] = 1;
return NewSession(opts);
return std::unique_ptr<Session>(NewSession(opts));
}
} // end namespace
class MathGradTest : public ::testing::Test {
protected:
@ -85,7 +86,6 @@ class MathGradTest : public ::testing::Test {
*y = outputs[0];
}
TF_CHECK_OK(sess->Close());
delete sess;
return s;
}
@ -148,7 +148,6 @@ class MathGradTest : public ::testing::Test {
sess->Run({{"x:0", x}, {"y:0", y}}, {"d:0", "d:1"}, {}, &outputs));
CHECK_EQ(outputs.size(), 2);
TF_CHECK_OK(sess->Close());
delete sess;
*dx = outputs[0];
*dy = outputs[1];
}
@ -204,7 +203,6 @@ class MathGradTest : public ::testing::Test {
sess->Run({{"x:0", x}, {"i:0", idx}}, {"d:0", "d:1"}, {}, &outputs));
CHECK_EQ(outputs.size(), 2);
TF_CHECK_OK(sess->Close());
delete sess;
*dx = outputs[0];
*di = outputs[1];
}
@ -227,7 +225,6 @@ class MathGradTest : public ::testing::Test {
TF_CHECK_OK(sess->Run({{"x:0", x}, {"y:0", y}}, {"z:0"}, {}, &outputs));
CHECK_EQ(outputs.size(), 1);
TF_CHECK_OK(sess->Close());
delete sess;
return outputs[0];
}
@ -295,7 +292,6 @@ class MathGradTest : public ::testing::Test {
sess->Run({{"x:0", x}, {"y:0", y}}, {"d:0", "d:1"}, {}, &outputs));
CHECK_EQ(outputs.size(), 2);
TF_CHECK_OK(sess->Close());
delete sess;
*dx = outputs[0];
*dy = outputs[1];
}
@ -359,14 +355,13 @@ class MathGradTest : public ::testing::Test {
{"d:0", "d:1", "d:2"}, {}, &outputs));
CHECK_EQ(outputs.size(), 3);
TF_CHECK_OK(sess->Close());
delete sess;
*dc = outputs[0];
*dx = outputs[1];
*dy = outputs[2];
}
};
static void HasError(const Status& s, const string& substr) {
void HasError(const Status& s, const string& substr) {
EXPECT_TRUE(StringPiece(s.ToString()).contains(substr))
<< s << ", expected substring " << substr;
}
@ -1126,4 +1121,5 @@ TEST_F(MathGradTest, Max_dim0_dim1_Dups) {
di, test::AsTensor<int32>({0, 0}, TensorShape({2})));
}
} // end namespace tensorflow
} // namespace
} // namespace tensorflow