443 lines
15 KiB
C++
443 lines
15 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/c/c_api.h"
|
|
#include "tensorflow/c/c_test_util.h"
|
|
#include "tensorflow/core/platform/logging.h"
|
|
#include "tensorflow/core/platform/strcat.h"
|
|
#include "tensorflow/core/platform/test.h"
|
|
|
|
using tensorflow::GraphDef;
|
|
|
|
namespace {
|
|
|
|
class CApiWhileLoopTest : public ::testing::Test {
|
|
protected:
|
|
CApiWhileLoopTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {}
|
|
|
|
~CApiWhileLoopTest() override {
|
|
TF_DeleteGraph(graph_);
|
|
TF_DeleteStatus(s_);
|
|
}
|
|
|
|
void Init(int ninputs) {
|
|
DCHECK(inputs_.empty());
|
|
DCHECK_GT(ninputs, 0);
|
|
|
|
for (int i = 0; i < ninputs; ++i) {
|
|
TF_Operation* placeholder = Placeholder(
|
|
graph_, s_, ::tensorflow::strings::StrCat("p", i).c_str());
|
|
DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
inputs_.push_back({placeholder, 0});
|
|
}
|
|
|
|
original_graph_description_ = GraphDebugString();
|
|
|
|
params_.reset(new TF_WhileParams(
|
|
TF_NewWhile(graph_, &inputs_[0], inputs_.size(), s_)));
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
ASSERT_EQ(original_graph_description_, GraphDebugString())
|
|
<< "TF_NewWhile() altered graph";
|
|
|
|
params_->name = "test_loop";
|
|
|
|
// Initialize outputs_ so we can easily detect errors/bugs
|
|
outputs_.resize(ninputs, {nullptr, -1});
|
|
}
|
|
|
|
void ExpectOK() {
|
|
TF_FinishWhile(params_.get(), s_, &outputs_[0]);
|
|
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
}
|
|
|
|
void ExpectError(TF_Code expected_code, const string& expected_msg) {
|
|
TF_FinishWhile(params_.get(), s_, &outputs_[0]);
|
|
EXPECT_EQ(expected_code, TF_GetCode(s_));
|
|
EXPECT_EQ(expected_msg, TF_Message(s_));
|
|
// TODO(skyewm): this assert is currently broken. Fix or remove guarantee.
|
|
// ASSERT_EQ(original_graph_description_, GraphDebugString()) <<
|
|
// "TF_FinishWhile() altered graph on error";
|
|
}
|
|
|
|
void Run(std::initializer_list<int> input_values) {
|
|
Run(outputs_, input_values);
|
|
}
|
|
|
|
void Run(const std::vector<TF_Output>& run_outputs,
|
|
std::initializer_list<int> input_values) {
|
|
DCHECK_EQ(inputs_.size(), input_values.size());
|
|
std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs(inputs_.size());
|
|
int i = 0;
|
|
for (int v : input_values) {
|
|
inputs[i] = {inputs_[i].oper, Int32Tensor(v)};
|
|
++i;
|
|
}
|
|
// TODO(skyewm): use std::make_unique or absl::make_unique when possible.
|
|
csession_.reset(new CSession(graph_, s_));
|
|
csession_->SetInputs(inputs);
|
|
csession_->SetOutputs(run_outputs);
|
|
csession_->Run(s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
}
|
|
|
|
void ExpectOutputValue(int idx, int expected_value) {
|
|
TF_Tensor* out = csession_->output_tensor(idx);
|
|
ASSERT_TRUE(out != nullptr);
|
|
EXPECT_EQ(TF_INT32, TF_TensorType(out));
|
|
EXPECT_EQ(0, TF_NumDims(out));
|
|
ASSERT_EQ(sizeof(int32_t), TF_TensorByteSize(out));
|
|
int32_t* data = static_cast<int32_t*>(TF_TensorData(out));
|
|
EXPECT_EQ(expected_value, *data);
|
|
}
|
|
|
|
// Create a valid conditional graph. Useful for testing unrelated errors.
|
|
void CreateCondGraph() {
|
|
TF_Operation* one = ScalarConst(1, params_->cond_graph, s_);
|
|
TF_Operation* less_than =
|
|
LessThan(params_->cond_inputs[0], {one, 0}, params_->cond_graph, s_);
|
|
DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
params_->cond_output = {less_than, 0};
|
|
}
|
|
|
|
string GraphDebugString() const {
|
|
TF_Buffer* buf = TF_NewBuffer();
|
|
TF_GraphToGraphDef(graph_, buf, s_);
|
|
DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
GraphDef def;
|
|
bool success = def.ParseFromArray(buf->data, buf->length);
|
|
DCHECK(success);
|
|
TF_DeleteBuffer(buf);
|
|
return def.DebugString();
|
|
}
|
|
|
|
TF_Status* s_;
|
|
TF_Graph* graph_;
|
|
std::vector<TF_Output> inputs_; // The inputs to the while loop
|
|
std::vector<TF_Output> outputs_; // The final outputs of the while loop
|
|
std::unique_ptr<TF_WhileParams> params_;
|
|
std::unique_ptr<CSession> csession_;
|
|
|
|
private:
|
|
// Used to verify that errors don't change graph_
|
|
string original_graph_description_;
|
|
};
|
|
|
|
TEST_F(CApiWhileLoopTest, BasicLoop) {
|
|
Init(2);
|
|
|
|
// Validate TF_WhileParams returned by TF_NewWhile()
|
|
EXPECT_TRUE(params_->body_graph != nullptr);
|
|
EXPECT_TRUE(params_->cond_graph != nullptr);
|
|
|
|
EXPECT_EQ(params_->ninputs, 2);
|
|
|
|
ASSERT_TRUE(params_->cond_inputs != nullptr);
|
|
ASSERT_TRUE(params_->cond_inputs[0].oper != nullptr);
|
|
EXPECT_TRUE(params_->cond_inputs[1].oper != nullptr);
|
|
|
|
ASSERT_TRUE(params_->body_inputs != nullptr);
|
|
EXPECT_TRUE(params_->body_inputs[0].oper != nullptr);
|
|
EXPECT_TRUE(params_->body_inputs[1].oper != nullptr);
|
|
|
|
ASSERT_TRUE(params_->body_outputs != nullptr);
|
|
|
|
// Create loop: while (input1 < input2) input1 += input2 + 1
|
|
TF_Operation* less_than =
|
|
LessThan(params_->cond_inputs[0], params_->cond_inputs[1],
|
|
params_->cond_graph, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
params_->cond_output = {less_than, 0};
|
|
|
|
TF_Operation* add1 = Add(params_->body_inputs[0], params_->body_inputs[1],
|
|
params_->body_graph, s_, "add1");
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
TF_Operation* one = ScalarConst(1, params_->body_graph, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
TF_Operation* add2 = Add(add1, one, params_->body_graph, s_, "add2");
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
params_->body_outputs[0] = {add2, 0};
|
|
params_->body_outputs[1] = params_->body_inputs[1];
|
|
|
|
// Finalize while loop
|
|
ExpectOK();
|
|
|
|
// Validate while loop outputs returned by TF_FinishWhile()
|
|
EXPECT_TRUE(outputs_[0].oper != nullptr);
|
|
EXPECT_GE(outputs_[0].index, 0);
|
|
EXPECT_TRUE(outputs_[1].oper != nullptr);
|
|
EXPECT_GE(outputs_[1].index, 0);
|
|
|
|
// Check that cond and body inputs are not present
|
|
for (int i = 0; i < params_->ninputs; ++i) {
|
|
string cond_name =
|
|
::tensorflow::strings::StrCat(params_->name, "/cond/cond_input", i);
|
|
string body_name =
|
|
::tensorflow::strings::StrCat(params_->name, "/body/body_input", i);
|
|
EXPECT_TRUE(TF_GraphOperationByName(graph_, cond_name.c_str()) == nullptr);
|
|
EXPECT_TRUE(TF_GraphOperationByName(graph_, body_name.c_str()) == nullptr);
|
|
}
|
|
|
|
// Run the graph
|
|
Run({-9, 2});
|
|
ExpectOutputValue(0, 3);
|
|
ExpectOutputValue(1, 2);
|
|
}
|
|
|
|
TEST_F(CApiWhileLoopTest, NestedLoop) {
|
|
Init(2);
|
|
// Create nested loop:
|
|
// while (input1 < 6) {
|
|
// inner_input1 = input1
|
|
// while (inner_input1 < 3) {
|
|
// input2 += 1
|
|
// inner_input1 += 2
|
|
// }
|
|
// input1 += input2
|
|
// }
|
|
//
|
|
// Expected execution with initial values input1 = input2 = 0:
|
|
//
|
|
// outer inner inner_
|
|
// step# step# input1 input2 input1
|
|
// ------------------------------------
|
|
// 0 0 0 0 0
|
|
// 0 1 0 1 2
|
|
// 0 2 0 2 4
|
|
// 0 - 2 2 -
|
|
// 1 0 2 2 2
|
|
// 1 1 2 3 4
|
|
// 1 - 5 3 -
|
|
// 2 0 5 3 5
|
|
// 2 - 8 3 -
|
|
|
|
// Create outer cond graph
|
|
TF_Operation* six = ScalarConst(6, params_->cond_graph, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
TF_Operation* less_than =
|
|
LessThan(params_->cond_inputs[0], {six, 0}, params_->cond_graph, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
params_->cond_output = {less_than, 0};
|
|
|
|
// Create outer body graph
|
|
// Init inner graph
|
|
TF_Output inner_inputs[] = {params_->body_inputs[0], params_->body_inputs[1]};
|
|
TF_WhileParams inner_params =
|
|
TF_NewWhile(params_->body_graph, inner_inputs, 2, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
inner_params.name = "inner_loop";
|
|
|
|
// Create inner cond graph
|
|
TF_Operation* three = ScalarConst(3, inner_params.cond_graph, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
TF_Operation* inner_less_than = LessThan(
|
|
inner_params.cond_inputs[0], {three, 0}, inner_params.cond_graph, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
inner_params.cond_output = {inner_less_than, 0};
|
|
|
|
// Create inner body graph
|
|
TF_Operation* one = ScalarConst(1, inner_params.body_graph, s_, "one");
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
TF_Operation* two = ScalarConst(2, inner_params.body_graph, s_, "two");
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
TF_Operation* input2_add =
|
|
Add(inner_params.body_inputs[1].oper, one, inner_params.body_graph, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
inner_params.body_outputs[1] = {input2_add, 0};
|
|
|
|
TF_Operation* inner_input1_add = Add(inner_params.body_inputs[0].oper, two,
|
|
inner_params.body_graph, s_, "add2");
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
inner_params.body_outputs[0] = {inner_input1_add, 0};
|
|
|
|
// Finalize inner graph
|
|
TF_Output inner_outputs[2] = {{nullptr, -1}};
|
|
TF_FinishWhile(&inner_params, s_, inner_outputs);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
TF_Operation* input1_add =
|
|
Add(params_->body_inputs[0], inner_outputs[1], params_->body_graph, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
params_->body_outputs[0] = {input1_add, 0};
|
|
|
|
params_->body_outputs[1] = inner_outputs[1];
|
|
|
|
// Finalize outer graph
|
|
ExpectOK();
|
|
|
|
// Check for a few expected nodes
|
|
const char* node_name = "test_loop/cond/scalar";
|
|
EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
|
|
node_name = "test_loop/body/add";
|
|
EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
|
|
node_name = "test_loop/body/inner_loop/body/one";
|
|
EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
|
|
node_name = "test_loop/body/inner_loop/cond/less_than";
|
|
EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
|
|
|
|
// Run the graph
|
|
Run({0, 0});
|
|
ExpectOutputValue(0, 8);
|
|
ExpectOutputValue(1, 3);
|
|
}
|
|
|
|
TEST_F(CApiWhileLoopTest, UnsetCondOutput) {
|
|
Init(1);
|
|
params_->body_outputs[0] = params_->body_inputs[0];
|
|
ExpectError(TF_INVALID_ARGUMENT,
|
|
"TF_WhileParams `cond_output` field isn't set");
|
|
}
|
|
|
|
TEST_F(CApiWhileLoopTest, WrongCondOutputType) {
|
|
Init(1);
|
|
params_->cond_output = params_->cond_inputs[0];
|
|
params_->body_outputs[0] = params_->body_inputs[0];
|
|
ExpectError(TF_INVALID_ARGUMENT,
|
|
"BuildWhileLoop: 'cond' argument must return a boolean output, "
|
|
"got int32");
|
|
}
|
|
|
|
TEST_F(CApiWhileLoopTest, InvalidCondOutputNode) {
|
|
Init(1);
|
|
// Try to reuse node from parent graph
|
|
params_->cond_output = inputs_[0];
|
|
params_->body_outputs[0] = params_->body_inputs[0];
|
|
// TODO(skyewm): this error message could be more informative. Add explicit
|
|
// checks for this case in the while loop implementation?
|
|
ExpectError(TF_INVALID_ARGUMENT,
|
|
"Requested return tensor 'p0:0' not found in graph def");
|
|
}
|
|
|
|
TEST_F(CApiWhileLoopTest, InvalidCondOutputIndex) {
|
|
Init(1);
|
|
CreateCondGraph();
|
|
params_->cond_output.index = 100;
|
|
params_->body_outputs[0] = params_->body_inputs[0];
|
|
ExpectError(TF_INVALID_ARGUMENT,
|
|
"Invalid return output 100 of node 'less_than', which has 1 "
|
|
"output(s)");
|
|
}
|
|
|
|
// TODO(skyewm): test bad cond output shape
|
|
|
|
TEST_F(CApiWhileLoopTest, UnsetBodyOutput) {
|
|
Init(1);
|
|
CreateCondGraph();
|
|
ExpectError(TF_INVALID_ARGUMENT,
|
|
"TF_WhileParams `body_outputs[0]` field isn't set");
|
|
}
|
|
|
|
// TODO(skyewm): enable this when it works (currently doesn't error)
|
|
// TEST_F(CApiWhileLoopTest, WrongBodyOutputType) {
|
|
// Init(1);
|
|
// CreateCondGraph();
|
|
// TF_Operation* double_scalar =
|
|
// ScalarConst(1.0, params_->body_graph, s_, "double_scalar");
|
|
// params_->body_outputs[0] = {double_scalar, 0};
|
|
// ExpectError(TF_INVALID_ARGUMENT, "bad body output type");
|
|
// }
|
|
|
|
TEST_F(CApiWhileLoopTest, InvalidBodyOutputNode) {
|
|
Init(1);
|
|
CreateCondGraph();
|
|
// Try to reuse node from parent graph
|
|
params_->body_outputs[0] = inputs_[0];
|
|
// TODO(skyewm): this error message could be more informative. Add explicit
|
|
// checks for this case in the while loop implementation?
|
|
ExpectError(TF_INVALID_ARGUMENT,
|
|
"Requested return tensor 'p0:0' not found in graph def");
|
|
}
|
|
|
|
// TODO(skyewm): enable this when it works (currently segfaults!)
|
|
// TEST_F(CApiWhileLoopTest, InvalidBodyOutputIndex) {
|
|
// Init(1);
|
|
// CreateCondGraph();
|
|
// params_->body_outputs[0] = params_->body_inputs[0];
|
|
// params_->body_outputs[0].index = 100;
|
|
// ExpectError(TF_INVALID_ARGUMENT,
|
|
// "Invalid return output 100 of node 'less_than', which has 1 "
|
|
// "output(s)");
|
|
// }
|
|
|
|
// TODO(skyewm): test bad body output shape
|
|
|
|
TEST_F(CApiWhileLoopTest, NullName) {
|
|
Init(1);
|
|
CreateCondGraph();
|
|
params_->body_outputs[0] = params_->body_inputs[0];
|
|
params_->name = nullptr;
|
|
ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `name` field is null");
|
|
}
|
|
|
|
TEST_F(CApiWhileLoopTest, WrongGraph) {
|
|
Init(1);
|
|
CreateCondGraph();
|
|
// Set body output to output from outer graph
|
|
params_->body_outputs[0] = inputs_[0];
|
|
// TODO(skyewm): improve error message
|
|
ExpectError(TF_INVALID_ARGUMENT,
|
|
"Requested return tensor 'p0:0' not found in graph def");
|
|
}
|
|
|
|
TEST_F(CApiWhileLoopTest, BadTypes) {
|
|
Init(1);
|
|
CreateCondGraph();
|
|
// Op that has a float input + output
|
|
TF_OperationDescription* desc = TF_NewOperation(
|
|
params_->body_graph, "FakeQuantWithMinMaxArgs", "float_op");
|
|
TF_AddInput(desc, params_->body_inputs[0]);
|
|
TF_FinishOperation(desc, s_);
|
|
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
|
|
string msg(TF_Message(s_));
|
|
EXPECT_NE(msg.find("Input 'inputs' passed int32 expected float while "
|
|
"building NodeDef 'float_op'"),
|
|
msg.npos);
|
|
TF_AbortWhile(params_.get());
|
|
}
|
|
|
|
// This is a basic test to make sure the C++ gradient code can handle while
|
|
// loops created by the C API (which calls the C++ API under the hood). There
|
|
// are more while loop gradient tests in cc/framework/while_gradients_test.cc.
|
|
TEST_F(CApiWhileLoopTest, Gradients) {
|
|
Init(1);
|
|
|
|
// Create loop: while (i < 10) i += 1
|
|
TF_Operation* ten = ScalarConst(10, params_->cond_graph, s_);
|
|
TF_Operation* less_than =
|
|
LessThan(params_->cond_inputs[0], {ten, 0}, params_->cond_graph, s_);
|
|
DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
params_->cond_output = {less_than, 0};
|
|
|
|
TF_Operation* one = ScalarConst(1, params_->body_graph, s_);
|
|
TF_Operation* add =
|
|
Add(params_->body_inputs[0], {one, 0}, params_->body_graph, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
params_->body_outputs[0] = {add, 0};
|
|
|
|
ExpectOK();
|
|
|
|
// Create backprop graph
|
|
TF_Output grad_output;
|
|
TF_AddGradients(graph_, outputs_.data(), outputs_.size(), inputs_.data(), 1,
|
|
nullptr, s_, &grad_output);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
// Run gradient
|
|
Run({grad_output}, {0});
|
|
ExpectOutputValue(0, 1);
|
|
}
|
|
|
|
} // namespace
|