This change splits many (but not all) of the function-related targets into separate cc_library targets. The main changes are: * Move "graph/graph_constructor.{h,cc}" to "common_runtime/graph_constructor.{h,cc}" and leave a forwarding header. This code depends on common_runtime and is built as part of it, so it makes sense to move it across. The "graph_constructor" library includes "shape_refiner.{h,cc}", "graph_runner.{h,cc}", and "eval_const_tensor.{h,cc}" because of a circular dependency between these modules. * Split "function.{h,cc}" into "function_body.{h,cc}", "function_utils.{h,cc}", and "inline_function_utils.{h,cc}" (plus the original, slimmed-down module). This enables other targets in common_runtime to depend on just the function utilities they need, without the whole runtime, which breaks some cycles. * New fine-grained targets for "constant_folding", "function_optimization_registry", and "graph_optimizer". PiperOrigin-RevId: 308651243 Change-Id: Iac59c1db4ebdd16609f89d6caee6b7e6ba7ff0a1
1770 lines
64 KiB
C++
1770 lines
64 KiB
C++
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/c/c_api.h"
|
|
#include "tensorflow/c/c_api_internal.h"
|
|
#include "tensorflow/c/c_test_util.h"
|
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
|
#include "tensorflow/core/framework/function.pb.h"
|
|
#include "tensorflow/core/framework/op_def.pb.h"
|
|
#include "tensorflow/core/lib/hash/hash.h"
|
|
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
|
#include "tensorflow/core/platform/logging.h"
|
|
#include "tensorflow/core/platform/status.h"
|
|
#include "tensorflow/core/platform/str_util.h"
|
|
#include "tensorflow/core/platform/strcat.h"
|
|
#include "tensorflow/core/platform/test.h"
|
|
|
|
namespace tensorflow {
|
|
namespace {
|
|
|
|
// Specification for expected input/output and its type.
|
|
// DataType value of DT_INVALID signifies that we don't want to
|
|
// check the data type.
|
|
typedef std::pair<string, DataType> IOSpec;
|
|
|
|
std::vector<IOSpec> M(const std::initializer_list<string>& names) {
|
|
std::vector<IOSpec> v;
|
|
for (const string& name : names) {
|
|
v.push_back(IOSpec(name, DT_INVALID));
|
|
}
|
|
return v;
|
|
}
|
|
|
|
// Specification for an expected edge.
|
|
// src is either:
|
|
// - input name (as it appears in FunctionDef)
|
|
// - name of output tensor (in nested "add:z:0" format)
|
|
// dst is either:
|
|
// - output name (as it appears in FunctionDef)
|
|
// - <name_of_node>:<index_of_this_input_into_node> (this looks the same as
|
|
// output tensor naming, but it the index is actually an input index)
|
|
struct EdgeSpec : public std::pair<string, string> {
|
|
typedef std::pair<string, string> Base;
|
|
|
|
// Inherit the set of constructors
|
|
using Base::pair;
|
|
|
|
string ToString() const { return strings::StrCat(first, "->", second); }
|
|
};
|
|
|
|
class CApiFunctionTest : public ::testing::Test {
|
|
protected:
|
|
CApiFunctionTest()
|
|
: s_(TF_NewStatus()),
|
|
func_graph_(TF_NewGraph()),
|
|
host_graph_(TF_NewGraph()),
|
|
func_(nullptr) {}
|
|
|
|
void SetUp() override {}
|
|
|
|
~CApiFunctionTest() override {
|
|
TF_DeleteFunction(func_);
|
|
TF_DeleteGraph(host_graph_);
|
|
TF_DeleteGraph(func_graph_);
|
|
TF_DeleteStatus(s_);
|
|
}
|
|
|
|
void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
|
|
TF_Operation* output, int32_t expected_result) {
|
|
Run(inputs, {{output, 0}}, {expected_result});
|
|
}
|
|
|
|
// Run the host graph, which now contains a function and check that
|
|
// outputs are as expected.
|
|
// 'T' stands for 'tensor' since the outputs are tensors, not scalars.
|
|
void RunT(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
|
|
std::initializer_list<TF_Output> outputs,
|
|
const std::vector<std::vector<int32_t>>& expected_results) {
|
|
// Create a session for this graph
|
|
CSession csession(host_graph_, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
// Run
|
|
csession.SetInputs(inputs);
|
|
csession.SetOutputs(outputs);
|
|
csession.Run(s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
// Check results
|
|
for (int i = 0; i < expected_results.size(); ++i) {
|
|
TF_Tensor* out = csession.output_tensor(i);
|
|
ASSERT_TRUE(out != nullptr);
|
|
EXPECT_EQ(TF_INT32, TF_TensorType(out));
|
|
EXPECT_EQ(1, TF_NumDims(out));
|
|
CompareInt32Tensor(expected_results[i], out);
|
|
}
|
|
}
|
|
|
|
// Run the host graph, which now contains a function and check that
|
|
// outputs are as expected.
|
|
void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
|
|
std::initializer_list<TF_Output> outputs,
|
|
const std::vector<int32_t>& expected_results) {
|
|
// Create a session for this graph.
|
|
CSession csession(host_graph_, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
csession.SetInputs(inputs);
|
|
csession.SetOutputs(outputs);
|
|
csession.Run(s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
for (int i = 0; i < expected_results.size(); ++i) {
|
|
TF_Tensor* out = csession.output_tensor(i);
|
|
ASSERT_TRUE(out != nullptr);
|
|
EXPECT_EQ(TF_INT32, TF_TensorType(out));
|
|
EXPECT_EQ(0, TF_NumDims(out)); // scalar
|
|
ASSERT_EQ(sizeof(int32_t), TF_TensorByteSize(out));
|
|
int32_t* output_contents = static_cast<int32_t*>(TF_TensorData(out));
|
|
EXPECT_EQ(expected_results[i], *output_contents);
|
|
}
|
|
}
|
|
|
|
void CompareInt32Tensor(const std::vector<int32_t>& expected, TF_Tensor* t) {
|
|
int32_t* data = static_cast<int32_t*>(TF_TensorData(t));
|
|
size_t size = TF_TensorByteSize(t);
|
|
ASSERT_EQ(expected.size() * sizeof(int32_t), size);
|
|
for (int i = 0; i < expected.size(); ++i) {
|
|
ASSERT_EQ(expected[i], data[i]) << "Different data at index " << i;
|
|
}
|
|
}
|
|
|
|
std::vector<TF_Output> ToOutput(const std::vector<TF_Operation*> ops) {
|
|
std::vector<TF_Output> out;
|
|
for (auto op : ops) {
|
|
out.push_back({op, 0});
|
|
}
|
|
return out;
|
|
}
|
|
|
|
void Define(int num_opers, const std::vector<TF_Operation*>& opers,
|
|
const std::vector<TF_Operation*>& inputs,
|
|
const std::vector<TF_Operation*>& outputs,
|
|
const std::vector<string>& output_names,
|
|
bool expect_failure = false) {
|
|
DefineT(num_opers, opers, ToOutput(inputs), ToOutput(outputs), output_names,
|
|
expect_failure);
|
|
}
|
|
|
|
// Caller must delete[] the returned value
|
|
static const char** ToArray(const std::vector<string>& strs) {
|
|
const char** ptr = nullptr;
|
|
if (!strs.empty()) {
|
|
ptr = new const char*[strs.size()];
|
|
for (size_t i = 0; i < strs.size(); ++i) {
|
|
ptr[i] = strs[i].c_str();
|
|
}
|
|
}
|
|
return ptr;
|
|
}
|
|
|
|
// An explicit `num_opers` is needed so that we can distinguish between the
|
|
// case of no operations specified (-1) and the case of an empty set of
|
|
// operations specified (0).
|
|
void DefineT(int num_opers, const std::vector<TF_Operation*>& opers,
|
|
const std::vector<TF_Output>& inputs,
|
|
const std::vector<TF_Output>& outputs,
|
|
const std::vector<string>& output_names,
|
|
bool expect_failure = false) {
|
|
ASSERT_EQ(func_, nullptr);
|
|
const char** output_names_ptr = ToArray(output_names);
|
|
func_ = TF_GraphToFunction(func_graph_, func_name_, false, num_opers,
|
|
num_opers == -1 ? nullptr : opers.data(),
|
|
inputs.size(), inputs.data(), outputs.size(),
|
|
outputs.data(), output_names_ptr,
|
|
/*opts=*/nullptr, /*description=*/nullptr, s_);
|
|
delete[] output_names_ptr;
|
|
if (expect_failure) {
|
|
ASSERT_EQ(func_, nullptr);
|
|
return;
|
|
}
|
|
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
ASSERT_NE(func_, nullptr);
|
|
ASSERT_EQ(std::string(func_name_), std::string(TF_FunctionName(func_)));
|
|
TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
}
|
|
|
|
TF_Operation* Use(const std::vector<TF_Operation*>& inputs) {
|
|
return UseT(ToOutput(inputs));
|
|
}
|
|
|
|
TF_Operation* UseT(const std::vector<TF_Output>& inputs) {
|
|
TF_Operation* op;
|
|
UseHelper(inputs, &op);
|
|
return op;
|
|
}
|
|
|
|
// All the *Helper methods are used as a workaround for the restrictions that
|
|
// one cannot call ASSERT_* methods in non-void-returning functions (when
|
|
// exceptions are disabled during compilation)
|
|
void UseHelper(const std::vector<TF_Output>& inputs, TF_Operation** op) {
|
|
TF_OperationDescription* desc =
|
|
TF_NewOperation(host_graph_, func_name_, func_node_name_);
|
|
for (auto input : inputs) {
|
|
TF_AddInput(desc, input);
|
|
}
|
|
// Set device to CPU because some ops inside the function might not be
|
|
// available on GPU.
|
|
TF_SetDevice(desc, "/cpu:0");
|
|
*op = TF_FinishOperation(desc, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
ASSERT_NE(*op, nullptr);
|
|
}
|
|
|
|
FunctionDef fdef() {
|
|
tensorflow::FunctionDef fdef;
|
|
EXPECT_TRUE(GetFunctionDef(func_, &fdef));
|
|
return fdef;
|
|
}
|
|
|
|
// logging utility
|
|
template <class Container>
|
|
string ToString(const Container& v) {
|
|
std::stringstream ss;
|
|
ss << "{";
|
|
size_t i = 0;
|
|
for (const auto& e : v) {
|
|
if (i != 0) {
|
|
ss << ", ";
|
|
}
|
|
ss << e.ToString();
|
|
++i;
|
|
}
|
|
ss << "}";
|
|
return ss.str();
|
|
}
|
|
|
|
void VerifyFDefNodes(const tensorflow::FunctionDef& fdef,
|
|
const std::unordered_set<string>& nodes) {
|
|
ASSERT_EQ(nodes.size(), fdef.node_def_size())
|
|
<< "Got unexpected number of nodes. Expected: ["
|
|
<< absl::StrJoin(nodes, ", ")
|
|
<< "] Actual nodes in fdef: " << fdef.DebugString();
|
|
for (const NodeDef& node_def : fdef.node_def()) {
|
|
ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end())
|
|
<< "Got unexpected node: " << node_def.name()
|
|
<< " in fdef: " << fdef.DebugString();
|
|
}
|
|
}
|
|
|
|
void VerifyFDefInputs(const tensorflow::FunctionDef& fdef,
|
|
const std::vector<IOSpec>& inputs) {
|
|
const OpDef& signature = fdef.signature();
|
|
ASSERT_EQ(inputs.size(), signature.input_arg_size());
|
|
for (int i = 0; i < inputs.size(); ++i) {
|
|
const OpDef::ArgDef& arg = signature.input_arg(i);
|
|
const IOSpec& in = inputs[i];
|
|
if (in.second != DT_INVALID) {
|
|
ASSERT_EQ(arg.type(), in.second)
|
|
<< "Got unexpected type for input " << i
|
|
<< ". fdef: " << fdef.DebugString();
|
|
}
|
|
ASSERT_EQ(arg.name(), in.first) << "Got unexpected name for input " << i
|
|
<< ". fdef: " << fdef.DebugString();
|
|
}
|
|
}
|
|
|
|
void VerifyFDefOutputs(const tensorflow::FunctionDef& fdef,
|
|
const std::vector<IOSpec>& outputs) {
|
|
const OpDef& signature = fdef.signature();
|
|
ASSERT_EQ(outputs.size(), signature.output_arg_size());
|
|
for (int i = 0; i < outputs.size(); ++i) {
|
|
const OpDef::ArgDef& arg = signature.output_arg(i);
|
|
const IOSpec& out = outputs[i];
|
|
if (out.second != DT_INVALID) {
|
|
ASSERT_EQ(arg.type(), out.second)
|
|
<< "Got unexpected type for output " << i
|
|
<< ". fdef: " << fdef.DebugString();
|
|
}
|
|
ASSERT_EQ(arg.name(), out.first) << "Got unexpected name for output " << i
|
|
<< ". fdef: " << fdef.DebugString();
|
|
}
|
|
}
|
|
|
|
void VerifyFDefEdges(
|
|
const tensorflow::FunctionDef& fdef,
|
|
const std::vector<EdgeSpec>& e_edges, // expected edges
|
|
const std::vector<EdgeSpec>& c_edges, // expected ctrl edges
|
|
bool is_exact_edges = true) {
|
|
// Build a set of edges from fdef
|
|
std::set<EdgeSpec> a_edges; // actual edges
|
|
// Get edges from inputs to body nodes and between body nodes
|
|
for (const NodeDef& node_def : fdef.node_def()) {
|
|
for (int i = 0; i < node_def.input_size(); ++i) {
|
|
const string& in = node_def.input(i);
|
|
const auto& v =
|
|
a_edges.insert({in, strings::StrCat(node_def.name(), ":", i)});
|
|
ASSERT_TRUE(v.second) << "Duplicate edge " << in << " -> "
|
|
<< strings::StrCat(node_def.name(), ":", i)
|
|
<< ". fdef: " << fdef.DebugString();
|
|
}
|
|
}
|
|
// Get edges from body nodes to outputs and from inputs to outputs
|
|
for (const OpDef::ArgDef& arg : fdef.signature().output_arg()) {
|
|
const auto& iter = fdef.ret().find(arg.name());
|
|
if (iter != fdef.ret().end()) {
|
|
const auto& v = a_edges.insert({iter->second, arg.name()});
|
|
ASSERT_TRUE(v.second) << "Duplicate edge " << iter->second << " -> "
|
|
<< arg.name() << ". fdef: " << fdef.DebugString();
|
|
} else {
|
|
const auto& v = a_edges.insert({arg.name(), arg.name()});
|
|
ASSERT_TRUE(v.second) << "Duplicate edge " << arg.name() << " -> "
|
|
<< arg.name() << ". fdef: " << fdef.DebugString();
|
|
}
|
|
}
|
|
|
|
// Verify edges
|
|
for (const EdgeSpec& e : e_edges) {
|
|
ASSERT_TRUE(a_edges.find(e) != a_edges.end())
|
|
<< "Failed to find expected edge " << e.ToString()
|
|
<< " in fdef: " << fdef.DebugString();
|
|
}
|
|
for (const EdgeSpec& e : c_edges) {
|
|
ASSERT_TRUE(a_edges.find(e) != a_edges.end())
|
|
<< "Failed to find expected control edge " << e.ToString()
|
|
<< " in fdef: " << fdef.DebugString();
|
|
}
|
|
|
|
// If caller specified all edges, check that we have seen all
|
|
if (is_exact_edges) {
|
|
ASSERT_EQ(e_edges.size() + c_edges.size(), a_edges.size())
|
|
<< "Expected edges: " << ToString(e_edges)
|
|
<< " Expected Control edges: " << ToString(c_edges)
|
|
<< " Actual edges: " << ToString(a_edges)
|
|
<< " in fdef: " << fdef.DebugString();
|
|
}
|
|
}
|
|
|
|
void VerifyFDef(const std::unordered_set<string>& nodes,
|
|
const std::vector<IOSpec>& inputs,
|
|
const std::vector<IOSpec>& outputs,
|
|
const std::vector<EdgeSpec>& e_edges, // expected edges
|
|
const std::vector<EdgeSpec>& c_edges, // expected ctrl edges
|
|
bool is_exact_edges = true) {
|
|
tensorflow::FunctionDef fdef;
|
|
ASSERT_TRUE(GetFunctionDef(func_, &fdef));
|
|
VerifyFDefNodes(fdef, nodes);
|
|
VerifyFDefInputs(fdef, inputs);
|
|
VerifyFDefOutputs(fdef, outputs);
|
|
VerifyFDefEdges(fdef, e_edges, c_edges, is_exact_edges);
|
|
}
|
|
|
|
// Serialize func_ to fdef and import it back
|
|
void Reincarnate() {
|
|
// func_ -> fdef
|
|
tensorflow::FunctionDef fdef;
|
|
ASSERT_TRUE(GetFunctionDef(func_, &fdef));
|
|
TF_DeleteFunction(func_);
|
|
|
|
// fdef -> func_
|
|
string buf;
|
|
ASSERT_TRUE(fdef.SerializeToString(&buf));
|
|
func_ = TF_FunctionImportFunctionDef(buf.data(), buf.size(), s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
}
|
|
|
|
void GetAttr(const char* attr_name, AttrValue* out_attr) {
|
|
TF_Buffer* attr_buf = TF_NewBuffer();
|
|
TF_FunctionGetAttrValueProto(func_, attr_name, attr_buf, s_);
|
|
ASSERT_TRUE(out_attr->ParseFromArray(attr_buf->data, attr_buf->length));
|
|
TF_DeleteBuffer(attr_buf);
|
|
}
|
|
|
|
const char* func_name_ = "MyFunc";
|
|
const char* func_node_name_ = "MyFunc_0";
|
|
TF_Status* s_;
|
|
TF_Graph* func_graph_;
|
|
TF_Graph* host_graph_;
|
|
TF_Function* func_;
|
|
|
|
// Workaround for not being able to initialize empty map using {}
|
|
std::unordered_set<string> empty_;
|
|
};
|
|
|
|
TEST_F(CApiFunctionTest, OneOp_ZeroInputs_OneOutput) {
|
|
/*
|
|
* constant
|
|
* |
|
|
* v
|
|
*/
|
|
// Define
|
|
TF_Operation* c = ScalarConst(10, func_graph_, s_, "scalar10");
|
|
Define(-1, {}, {}, {c}, {});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* func_op = Use({});
|
|
Run({}, func_op, 10);
|
|
VerifyFDef({"scalar10_0"}, {}, {{"scalar10", DT_INT32}},
|
|
{{"scalar10_0:output:0", "scalar10"}}, {});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, OneOp_OneInput_OneOutput) {
|
|
/*
|
|
* |
|
|
* v
|
|
* negate
|
|
* |
|
|
* v
|
|
*/
|
|
// Define
|
|
TF_Operation* feed = Placeholder(func_graph_, s_);
|
|
TF_Operation* neg = Neg(feed, func_graph_, s_);
|
|
Define(-1, {}, {feed}, {neg}, {});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({func_feed});
|
|
Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
|
|
VerifyFDef({"neg_0"}, {{"feed", DT_INT32}}, {{"neg", DT_INT32}},
|
|
{{"feed", "neg_0:0"}, {"neg_0:y:0", "neg"}}, {});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, OneOutput_OutputNames) {
|
|
/*
|
|
* |
|
|
* v
|
|
* negate
|
|
* |
|
|
* v
|
|
*/
|
|
// Define
|
|
TF_Operation* feed = Placeholder(func_graph_, s_);
|
|
TF_Operation* neg = Neg(feed, func_graph_, s_);
|
|
Define(-1, {}, {feed}, {neg}, {"negated_num"});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({func_feed});
|
|
Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
|
|
VerifyFDef({"neg"}, {{"feed", DT_INT32}}, {{"negated_num", DT_INT32}},
|
|
{{"feed", "neg:0"}, {"neg:y:0", "negated_num"}}, {});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, OutputNames_SameNameAsInput) {
|
|
/*
|
|
* |
|
|
* v
|
|
* negate
|
|
* |
|
|
* v
|
|
*/
|
|
// Define
|
|
TF_Operation* feed = Placeholder(func_graph_, s_, "negation");
|
|
TF_Operation* neg = Neg(feed, func_graph_, s_, "neg");
|
|
Define(-1, {}, {feed}, {neg}, {"negation"});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({func_feed});
|
|
Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
|
|
VerifyFDef({"neg"}, {{"negation_0", DT_INT32}}, {{"negation", DT_INT32}},
|
|
{{"negation_0", "neg:0"}, {"neg:y:0", "negation"}}, {});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, ZeroOps_Identity) {
|
|
/*
|
|
* |
|
|
* |
|
|
* |
|
|
* v
|
|
*/
|
|
// Define
|
|
TF_Operation* feed = Placeholder(func_graph_, s_);
|
|
Define(-1, {}, {feed}, {feed}, {});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({func_feed});
|
|
Run({{func_feed, Int32Tensor(3)}}, func_op, 3);
|
|
VerifyFDef(empty_, {{"feed_0", DT_INT32}}, {{"feed", DT_INT32}},
|
|
{{"feed_0", "feed"}}, {});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, ZeroOps_Permutation) {
|
|
/*
|
|
* | |
|
|
* \ /
|
|
* \/
|
|
* x
|
|
* /\
|
|
* / \
|
|
* | |
|
|
* v v
|
|
*/
|
|
// Define
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
Define(-1, {}, {feed1, feed2}, {feed2, feed1}, {});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* two = ScalarConst(2, host_graph_, s_);
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({two, func_feed});
|
|
Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {3, 2});
|
|
VerifyFDef(empty_, M({{"feed1_0"}, {"feed2_0"}}), M({{"feed2"}, {"feed1"}}),
|
|
{{"feed1_0", "feed1"}, {"feed2_0", "feed2"}}, {});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, ZeroOps_Permutation_OutputNames) {
|
|
/*
|
|
* | |
|
|
* \ /
|
|
* \/
|
|
* x
|
|
* /\
|
|
* / \
|
|
* | |
|
|
* v v
|
|
*/
|
|
// Define
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
Define(-1, {}, {feed1, feed2}, {feed2, feed1}, {"first", "second"});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* two = ScalarConst(2, host_graph_, s_);
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({two, func_feed});
|
|
Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {3, 2});
|
|
VerifyFDef(empty_, M({{"feed1"}, {"feed2"}}), M({{"first"}, {"second"}}),
|
|
{{"feed1", "second"}, {"feed2", "first"}}, {});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, OneOp_TwoInputs_OneOutput) {
|
|
/*
|
|
* | |
|
|
* v v
|
|
* add
|
|
* |
|
|
* v
|
|
*/
|
|
// Define
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
|
|
Define(-1, {}, {feed1, feed2}, {add}, {});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* two = ScalarConst(2, host_graph_, s_);
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({two, func_feed});
|
|
Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
|
|
VerifyFDef(
|
|
{"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
|
|
{{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, {});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, OneOp_TwoInputs_ZeroOutputs) {
|
|
/*
|
|
* | |
|
|
* v v
|
|
* add
|
|
*
|
|
* (output ignored)
|
|
*/
|
|
// Define
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
Add(feed1, feed2, func_graph_, s_);
|
|
Define(-1, {}, {feed1, feed2}, {}, {});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* two = ScalarConst(2, host_graph_, s_);
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
Use({two, func_feed});
|
|
VerifyFDef({"add"}, M({{"feed1"}, {"feed2"}}), {},
|
|
{{"feed1", "add:0"}, {"feed2", "add:1"}}, {});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_OneOutput) {
|
|
/*
|
|
* | | |
|
|
* v v /
|
|
* add1 /
|
|
* | |
|
|
* v v
|
|
* add2
|
|
* |
|
|
* v
|
|
*/
|
|
// Define
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
|
|
TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
|
|
TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
|
|
Define(-1, {}, {feed1, feed2, feed3}, {add2}, {});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
|
|
TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({two, ten, func_feed});
|
|
Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 10 + 3);
|
|
VerifyFDef({"add1", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
|
|
M({{"add2"}}),
|
|
{{"feed1", "add1:0"},
|
|
{"feed2", "add1:1"},
|
|
{"add1:sum:0", "add2_0:0"},
|
|
{"feed3", "add2_0:1"},
|
|
{"add2_0:sum:0", "add2"}},
|
|
{});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, OneOp_TwoInputs_TwoDuplicateOutputs) {
|
|
/*
|
|
* | |
|
|
* v v
|
|
* add
|
|
* |
|
|
* +-+-+
|
|
* | |
|
|
* v v
|
|
*/
|
|
// Define
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
|
|
Define(-1, {}, {feed1, feed2}, {add, add}, {});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* two = ScalarConst(2, host_graph_, s_);
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({two, func_feed});
|
|
Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {5, 5});
|
|
VerifyFDef({"add_1"}, M({{"feed1"}, {"feed2"}}), M({{"add"}, {"add_0"}}),
|
|
{{"feed1", "add_1:0"},
|
|
{"feed2", "add_1:1"},
|
|
{"add_1:sum:0", "add"},
|
|
{"add_1:sum:0", "add_0"}},
|
|
{});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, TwoDuplicateOutputs_OutputNames) {
|
|
/*
|
|
* | |
|
|
* v v
|
|
* add
|
|
* |
|
|
* +-+-+
|
|
* | |
|
|
* v v
|
|
*/
|
|
// Define
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
|
|
Define(-1, {}, {feed1, feed2}, {add, add}, {"out1", "out2"});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* two = ScalarConst(2, host_graph_, s_);
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({two, func_feed});
|
|
Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {5, 5});
|
|
VerifyFDef({"add"}, M({{"feed1"}, {"feed2"}}), M({{"out1"}, {"out2"}}),
|
|
{{"feed1", "add:0"},
|
|
{"feed2", "add:1"},
|
|
{"add:sum:0", "out1"},
|
|
{"add:sum:0", "out2"}},
|
|
{});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_TwoOutputs) {
|
|
/*
|
|
* | | |
|
|
* v v /
|
|
* add /
|
|
* | |
|
|
* +-+ |
|
|
* | | |
|
|
* | v v
|
|
* | add
|
|
* | |
|
|
* v v
|
|
*/
|
|
// Define
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
|
|
TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
|
|
TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
|
|
Define(-1, {}, {feed1, feed2, feed3}, {add1, add2}, {});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
|
|
TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({two, ten, func_feed});
|
|
Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {12, 15});
|
|
VerifyFDef({"add1_0", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
|
|
M({{"add1"}, {"add2"}}),
|
|
{{"feed1", "add1_0:0"},
|
|
{"feed2", "add1_0:1"},
|
|
{"add1_0:sum:0", "add2_0:0"},
|
|
{"feed3", "add2_0:1"},
|
|
{"add1_0:sum:0", "add1"},
|
|
{"add2_0:sum:0", "add2"}},
|
|
{});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, FromSubsetOfOps) {
|
|
/*
|
|
* | | |
|
|
* v v /
|
|
* add /
|
|
* | |
|
|
* +---+--+---+
|
|
* Ops used | | | |
|
|
* for func | v v |
|
|
* | | add |
|
|
* +-------> | | |
|
|
* | v |
|
|
* | |
|
|
* +----------+
|
|
*/
|
|
// Define
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
|
|
TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
|
|
TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
|
|
Define(1, {add2}, {add1, feed3}, {add2}, {});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({two, func_feed});
|
|
Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
|
|
VerifyFDef(
|
|
{"add2_0"}, M({{"add1"}, {"feed3"}}), M({{"add2"}}),
|
|
{{"add1", "add2_0:0"}, {"feed3", "add2_0:1"}, {"add2_0:sum:0", "add2"}},
|
|
{});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, UsingOneOutputOfSplit) {
|
|
/*
|
|
* feed
|
|
* |
|
|
* +---------+---+
|
|
* | const0 | |
|
|
* | | | |
|
|
* | v / |
|
|
* | split |
|
|
* | | | | |
|
|
* | v | v |
|
|
* | | |
|
|
* +------+------+
|
|
* |
|
|
* v
|
|
*
|
|
* Only the second output from split is used as function output
|
|
*/
|
|
// Define
|
|
TF_Operation* feed = Placeholder(func_graph_, s_);
|
|
TF_Operation* split = Split3(feed, func_graph_, s_);
|
|
DefineT(-1, {}, {{feed, 0}}, {{split, 1}}, {});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({func_feed});
|
|
RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}}, {{func_op, 0}},
|
|
{{3, 4}});
|
|
VerifyFDef({"split3_const0", "split3_0"}, M({{"feed"}}), M({{"split3"}}),
|
|
{{"split3_const0:output:0", "split3_0:0"},
|
|
{"feed", "split3_0:1"},
|
|
{"split3_0:output:1", "split3"}},
|
|
{});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplit) {
|
|
/*
|
|
* feed
|
|
* |
|
|
* +---------+---+
|
|
* | const0 | |
|
|
* | | | |
|
|
* | v / |
|
|
* | split |
|
|
* | | | | |
|
|
* | | v | |
|
|
* | | | |
|
|
* +---+-----+---+
|
|
* | |
|
|
* v v
|
|
*
|
|
* Second output from split is not used as function output
|
|
*/
|
|
// Define
|
|
TF_Operation* feed = Placeholder(func_graph_, s_);
|
|
TF_Operation* split = Split3(feed, func_graph_, s_);
|
|
DefineT(-1, {}, {{feed, 0}}, {{split, 0}, {split, 2}}, {});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({func_feed});
|
|
RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}},
|
|
{{func_op, 0}, {func_op, 1}}, {{1, 2}, {5, 6}});
|
|
VerifyFDef({"split3_const0", "split3_1"}, M({{"feed"}}),
|
|
M({{"split3"}, {"split3_0"}}),
|
|
{{"split3_const0:output:0", "split3_1:0"},
|
|
{"feed", "split3_1:1"},
|
|
{"split3_1:output:0", "split3"},
|
|
{"split3_1:output:2", "split3_0"}},
|
|
{});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplitAsInputs) {
|
|
/*
|
|
* |
|
|
* v
|
|
* split
|
|
* | | |
|
|
* | v |
|
|
* | |
|
|
* +---+-----+---+
|
|
* | | | |
|
|
* | v v |
|
|
* | add |
|
|
* | | |
|
|
* | | |
|
|
* +------+------+
|
|
* |
|
|
* v
|
|
*/
|
|
// Define
|
|
TF_Operation* feed = Placeholder(func_graph_, s_);
|
|
TF_Operation* split = Split3(feed, func_graph_, s_);
|
|
TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
DefineT(1, {add}, {{split, 0}, {split, 2}}, {{add, 0}}, {});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({two, func_feed});
|
|
Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
|
|
VerifyFDef(
|
|
{"add_0"}, M({{"split3"}, {"split3_0"}}), M({{"add"}}),
|
|
{{"split3", "add_0:0"}, {"split3_0", "add_0:1"}, {"add_0:sum:0", "add"}},
|
|
{});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, NodesUsedInInputsMustHaveSingleOutput) {
|
|
/*
|
|
* |
|
|
* v
|
|
* split
|
|
* | | |
|
|
* | v |
|
|
* | |
|
|
* input --->| |<--- input
|
|
* | |
|
|
* v v
|
|
* add
|
|
* |
|
|
* |
|
|
* v
|
|
*/
|
|
// Define
|
|
TF_Tensor* tensor_123 = Int32Tensor({1, 2, 3});
|
|
TF_Operation* c = Const(tensor_123, func_graph_, s_, "const_array");
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
TF_Operation* split = Split3(c, func_graph_, s_);
|
|
TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
DefineT(-1, {}, {{split, 0}, {split, 2}}, {{add, 0}}, {}, true);
|
|
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
|
|
EXPECT_EQ(string("When `num_opers` is set to -1, nodes referenced in "
|
|
"`inputs` must have a single output. Node split3 has "
|
|
"3 outputs. Encountered while creating function 'MyFunc'"),
|
|
string(TF_Message(s_)));
|
|
|
|
TF_DeleteTensor(tensor_123);
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, FunctionWithWhileLoop) {
|
|
// Inputs to the while loop and the function as a whole
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
|
|
// Outputs of the while loop corresponding to the two inputs above
|
|
// The first one will the function's output
|
|
std::vector<TF_Output> outputs;
|
|
|
|
// Add while loop to func_graph_
|
|
{
|
|
// The inputs to the while loop
|
|
std::vector<TF_Output> inputs = {{feed1, 0}, {feed2, 0}};
|
|
std::unique_ptr<TF_WhileParams> params(new TF_WhileParams(
|
|
TF_NewWhile(func_graph_, &inputs[0], inputs.size(), s_)));
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
params->name = "test_loop";
|
|
|
|
// Initialize outputs so we can easily detect errors/bugs
|
|
outputs.resize(2, {nullptr, -1});
|
|
|
|
// Create loop: while (input1 < input2) input1 += input2 + 1
|
|
TF_Operation* less_than = LessThan(
|
|
params->cond_inputs[0], params->cond_inputs[1], params->cond_graph, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
params->cond_output = {less_than, 0};
|
|
|
|
TF_Operation* add1 = Add(params->body_inputs[0], params->body_inputs[1],
|
|
params->body_graph, s_, "add1");
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
TF_Operation* one = ScalarConst(1, params->body_graph, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
TF_Operation* add2 = Add(add1, one, params->body_graph, s_, "add2");
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
params->body_outputs[0] = {add2, 0};
|
|
params->body_outputs[1] = params->body_inputs[1];
|
|
|
|
// Finalize while loop
|
|
TF_FinishWhile(params.get(), s_, &outputs[0]);
|
|
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
}
|
|
|
|
// Define function, use it in graph, and run
|
|
DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {outputs[0]}, {});
|
|
TF_Operation* five = ScalarConst(5, host_graph_, s_, "five");
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({func_feed, five});
|
|
Run({{func_feed, Int32Tensor(2)}}, func_op, 2 /*+=*/ + 5 + 1);
|
|
|
|
// Verify input, output, and subset of edges in fdef.
|
|
// The subset of edges we verify is a chain between feed1 and output to
|
|
// make sure that the correct output is picked.
|
|
tensorflow::FunctionDef fdef;
|
|
ASSERT_TRUE(GetFunctionDef(func_, &fdef));
|
|
VerifyFDefInputs(fdef, M({{"feed1"}, {"feed2"}}));
|
|
VerifyFDefOutputs(fdef, M({{"test_loop_exit"}}));
|
|
VerifyFDefEdges(fdef,
|
|
{{"feed1", "test_loop/Enter:0"},
|
|
{"test_loop/Enter:output:0", "test_loop/Merge:0"},
|
|
{"test_loop/Merge:output:0", "test_loop/Switch:0"},
|
|
{"test_loop/Switch:output_false:0", "test_loop/Exit:0"},
|
|
{"test_loop/Exit:output:0", "test_loop_exit"}},
|
|
{}, false);
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, ControlDependency) {
|
|
/*
|
|
* | | scalar
|
|
* | | .
|
|
* v v . <---- control dependency
|
|
* add < -
|
|
* |
|
|
* v
|
|
*/
|
|
// Define
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
TF_Operation* five = ScalarConst(5, func_graph_, s_);
|
|
TF_Operation* add =
|
|
AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_);
|
|
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
Define(-1, {}, {feed1, feed2}, {add}, {});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* two = ScalarConst(2, host_graph_, s_);
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({two, func_feed});
|
|
Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
|
|
VerifyFDef(
|
|
{"add_0", "scalar"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
|
|
{{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}},
|
|
{{"^scalar", "add_0:2"}});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody) {
|
|
/*
|
|
* | | scalar
|
|
* | | .
|
|
* v v . <---- control dependency
|
|
* add < -
|
|
* |
|
|
* v
|
|
*/
|
|
// Define
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
TF_Operation* five = ScalarConst(5, func_graph_, s_);
|
|
TF_Operation* add =
|
|
AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_);
|
|
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
Define(1, {add}, {feed1, feed2}, {add}, {}, true);
|
|
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
|
|
EXPECT_EQ(string("The source of control edge [id=3 scalar:-1 -> add:-1] "
|
|
"is not in the body. Encountered while creating "
|
|
"function 'MyFunc'"),
|
|
string(TF_Message(s_)));
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody_FromInputNode) {
|
|
/*
|
|
* | |.
|
|
* | | .
|
|
* | | .
|
|
* v v . <---- control dependency
|
|
* add < -
|
|
* |
|
|
* v
|
|
*/
|
|
// Define
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
TF_Operation* add =
|
|
AddWithCtrlDependency(feed1, feed2, func_graph_, feed1, s_);
|
|
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
Define(-1, {}, {feed1, feed2}, {add}, {});
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* two = ScalarConst(2, host_graph_, s_);
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({two, func_feed});
|
|
Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
|
|
VerifyFDef(
|
|
{"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
|
|
{{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}},
|
|
{{"^feed1", "add_0:2"}});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, DuplicateInputsAreNotAllowed) {
|
|
/*
|
|
* feed
|
|
* |
|
|
* +++
|
|
* | |
|
|
* +---+-+---+
|
|
* | | | |
|
|
* | v v |
|
|
* | add |
|
|
* | | |
|
|
* | | |
|
|
* +----+----+
|
|
* |
|
|
* v
|
|
*/
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* add = Add(feed1, feed1, func_graph_, s_);
|
|
Define(-1, {}, {feed1, feed1}, {add}, {}, true);
|
|
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
|
|
EXPECT_EQ(
|
|
string("TF_Output feed1:0 appears more than once in the input list"),
|
|
string(TF_Message(s_)));
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, DuplicateOutputNamesAreNotAllowed) {
|
|
/*
|
|
* | | |
|
|
* v v /
|
|
* add /
|
|
* | |
|
|
* +-+ |
|
|
* | | |
|
|
* | v v
|
|
* | add
|
|
* | |
|
|
* v v
|
|
*/
|
|
// Define
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
|
|
TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
|
|
TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
|
|
Define(-1, {}, {feed1, feed2, feed3}, {add1, add2}, {"my_out", "my_out"},
|
|
true);
|
|
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
|
|
EXPECT_EQ(string("Cannot have duplicate output names. Name 'my_out' "
|
|
"appears more than once in 'output_names' array."),
|
|
string(TF_Message(s_)));
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, InvalidInputTensor_HighIndex) {
|
|
/*
|
|
* | |
|
|
* v v
|
|
* add
|
|
* |
|
|
* v
|
|
*/
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
|
|
DefineT(-1, {}, {{feed1, 0}, {feed2, 2}}, {{add, 0}}, {}, true);
|
|
EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_));
|
|
EXPECT_EQ(string("Node 'feed2' (type: 'Placeholder', num of outputs: 1) does "
|
|
"not have output 2\n\tEncountered while processing "
|
|
"input 1 into function 'MyFunc'"),
|
|
string(TF_Message(s_)));
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, InvalidInputTensor_BadNodePtr) {
|
|
/*
|
|
* | |
|
|
* v v
|
|
* add
|
|
* |
|
|
* v
|
|
*/
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
|
|
DefineT(-1, {}, {{feed1, 0}, {nullptr, 0}}, {{add, 0}}, {}, true);
|
|
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
|
|
EXPECT_EQ(string("Node is null\n\tEncountered while processing input 1 "
|
|
"into function 'MyFunc'"),
|
|
string(TF_Message(s_)));
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, InvalidOutputTensor_HighIndex) {
|
|
/*
|
|
* | |
|
|
* v v
|
|
* add
|
|
* |
|
|
* v
|
|
*/
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
|
|
DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{add, 3}}, {}, true);
|
|
EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_));
|
|
EXPECT_EQ(string("Node 'add' (type: 'AddN', num of outputs: 1) does "
|
|
"not have output 3\n\tEncountered while processing "
|
|
"output 0 from function 'MyFunc'"),
|
|
string(TF_Message(s_)));
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, InvalidOutputTensor_BadNodePtr) {
|
|
/*
|
|
* | |
|
|
* v v
|
|
* add
|
|
* |
|
|
* v
|
|
*/
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
Add(feed1, feed2, func_graph_, s_);
|
|
DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{nullptr, 3}}, {}, true);
|
|
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
|
|
EXPECT_EQ(string("Node is null\n\tEncountered while processing output 0 "
|
|
"from function 'MyFunc'"),
|
|
string(TF_Message(s_)));
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, NodeMissingInput) {
|
|
/*
|
|
* input---> | | <----missing input
|
|
* v v
|
|
* body----> add
|
|
* |
|
|
* v
|
|
*/
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
|
|
DefineT(1, {add}, {{feed1, 0}}, {{add, 0}}, {}, true);
|
|
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
|
|
EXPECT_EQ(string("Input 1, 'feed2:0', of node 'add' in function 'MyFunc' "
|
|
"is not available. You might need to include it in inputs "
|
|
"or include its source node in the body"),
|
|
string(TF_Message(s_)));
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, OutputOpNotInBody) {
|
|
/*
|
|
* | |
|
|
* v v
|
|
* add scalar (scalar not included in body)
|
|
* | |
|
|
* v v (function has two outputs)
|
|
*/
|
|
// Define
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
TF_Operation* scalar = ScalarConst(2, func_graph_, s_);
|
|
TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
|
|
Define(1, {add}, {feed1, feed2}, {add, scalar}, {}, true);
|
|
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
|
|
EXPECT_EQ(string("TF_Output scalar:0 is neither in the function body nor "
|
|
"among function inputs. Encountered while creating "
|
|
"function 'MyFunc'"),
|
|
string(TF_Message(s_)));
|
|
}
|
|
|
|
void DefineFunction(const char* name, TF_Function** func,
|
|
const char* description = nullptr,
|
|
bool append_hash = false) {
|
|
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
|
|
TF_NewGraph(), TF_DeleteGraph);
|
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
|
|
TF_DeleteStatus);
|
|
|
|
TF_Operation* feed = Placeholder(func_graph.get(), s.get());
|
|
TF_Operation* neg = Neg(feed, func_graph.get(), s.get());
|
|
|
|
TF_Output inputs[] = {{feed, 0}};
|
|
TF_Output outputs[] = {{neg, 0}};
|
|
*func = TF_GraphToFunction(func_graph.get(), name, append_hash, -1,
|
|
/*opers=*/nullptr, 1, inputs, 1, outputs,
|
|
/*output_names=*/nullptr,
|
|
/*opts=*/nullptr, description, s.get());
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
|
ASSERT_NE(*func, nullptr);
|
|
}
|
|
|
|
REGISTER_OP("CustomOp")
|
|
.Output("output: float32")
|
|
.Attr("index: int")
|
|
.SetShapeFn(tensorflow::shape_inference::UnknownShape);
|
|
|
|
void NodeWithPlaceholderAttrHelper(TF_Graph* graph, TF_Status* s,
|
|
const char* name, const char* placeholder,
|
|
TF_Operation** op) {
|
|
TF_OperationDescription* desc = TF_NewOperation(graph, "CustomOp", name);
|
|
TF_SetAttrPlaceholder(desc, "index", placeholder);
|
|
*op = TF_FinishOperation(desc, s);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
|
ASSERT_NE(*op, nullptr);
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
|
|
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
|
|
TF_NewGraph(), TF_DeleteGraph);
|
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
|
|
TF_DeleteStatus);
|
|
|
|
TF_Operation *node1, *node2, *node3;
|
|
NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node1", "v1",
|
|
&node1);
|
|
NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node2", "v1",
|
|
&node2);
|
|
NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2",
|
|
&node3);
|
|
|
|
TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}};
|
|
func_ = TF_GraphToFunction(
|
|
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
|
|
/*opers=*/nullptr, 0, nullptr, 3, outputs,
|
|
/*output_names=*/nullptr,
|
|
/*opts=*/nullptr, /*description=*/nullptr, s.get());
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
|
ASSERT_NE(func_, nullptr);
|
|
|
|
// Verify that FunctionDef has 2 attributes, "v1" and "v2".
|
|
ASSERT_EQ(func_->fdef.signature().attr().size(), 2);
|
|
EXPECT_EQ(func_->fdef.signature().attr(0).name(), "v1");
|
|
EXPECT_EQ(func_->fdef.signature().attr(0).type(), "int");
|
|
EXPECT_EQ(func_->fdef.signature().attr(1).name(), "v2");
|
|
EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int");
|
|
}
|
|
|
|
void NodeWithAttrHelper(TF_Graph* graph, TF_Status* s, const char* name,
|
|
const char* attr_name, const char* attr_value,
|
|
TF_Operation** op) {
|
|
TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
|
|
TF_SetAttrType(desc, "dtype", TF_INT32);
|
|
TF_SetAttrString(desc, attr_name, attr_value, strlen(attr_value));
|
|
*op = TF_FinishOperation(desc, s);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
|
ASSERT_NE(*op, nullptr);
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) {
|
|
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
|
|
TF_NewGraph(), TF_DeleteGraph);
|
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
|
|
TF_DeleteStatus);
|
|
|
|
TF_Operation* node;
|
|
NodeWithAttrHelper(func_graph.get(), s.get(), "node", "_test_attr", "value",
|
|
&node);
|
|
|
|
TF_Output inputs[] = {{node, 0}};
|
|
func_ = TF_GraphToFunction(
|
|
func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
|
|
/*opers=*/nullptr, 1, inputs, 0, nullptr,
|
|
/*output_names=*/nullptr,
|
|
/*opts=*/nullptr, /*description=*/nullptr, s.get());
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
|
ASSERT_NE(func_, nullptr);
|
|
|
|
// Verify that FunctionDef ArgDef has attributes.
|
|
ASSERT_EQ(func_->fdef.arg_attr_size(), 1);
|
|
auto arg_attrs = func_->fdef.arg_attr().find(0);
|
|
ASSERT_NE(arg_attrs, func_->fdef.arg_attr().end());
|
|
auto iter = arg_attrs->second.attr().find("_test_attr");
|
|
ASSERT_NE(iter, arg_attrs->second.attr().end());
|
|
EXPECT_EQ(iter->second.s(), "value");
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, SetGradientAndRun) {
|
|
// Define the function and its grad
|
|
DefineFunction(func_name_, &func_);
|
|
TF_Function* grad_func;
|
|
DefineFunction("MyGrad", &grad_func);
|
|
|
|
// Add func and its gradient to host graph
|
|
TF_GraphCopyFunction(host_graph_, func_, grad_func, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
// Verify that function and its grad are in host graph's GraphDef
|
|
GraphDef gdef;
|
|
GetGraphDef(host_graph_, &gdef);
|
|
std::vector<string> func_names = GetFuncNames(gdef);
|
|
ASSERT_EQ(2, func_names.size());
|
|
ASSERT_EQ(func_name_, func_names[0]);
|
|
ASSERT_EQ("MyGrad", func_names[1]);
|
|
std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
|
|
ASSERT_EQ(1, grads.size());
|
|
ASSERT_EQ(func_name_, grads[0].first);
|
|
ASSERT_EQ("MyGrad", grads[0].second);
|
|
|
|
// These calls must be noops
|
|
TF_GraphCopyFunction(host_graph_, func_, grad_func, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
// Delete the gradient func.
|
|
// It is safe to delete after adding a copy to host graph.
|
|
TF_DeleteFunction(grad_func);
|
|
|
|
// Check that GraphDef did not change
|
|
GraphDef gdef2;
|
|
GetGraphDef(host_graph_, &gdef2);
|
|
ASSERT_EQ(gdef.DebugString(), gdef2.DebugString());
|
|
|
|
// Use and run func
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({func_feed});
|
|
Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, SameGradForTwoFunctions) {
|
|
// Define the functions
|
|
TF_Function* func1;
|
|
TF_Function* func2;
|
|
TF_Function* grad_func;
|
|
DefineFunction("FooFunc1", &func1);
|
|
DefineFunction("FooFunc2", &func2);
|
|
DefineFunction("MyGrad", &grad_func);
|
|
|
|
// Make grad_func be a gradient of func1 and func2
|
|
TF_GraphCopyFunction(host_graph_, func1, grad_func, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
TF_GraphCopyFunction(host_graph_, func2, grad_func, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
// Verify that functions and their gradients are in host graph's GraphDef
|
|
GraphDef gdef;
|
|
GetGraphDef(host_graph_, &gdef);
|
|
std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
|
|
ASSERT_EQ(2, grads.size());
|
|
ASSERT_EQ("FooFunc1", grads[0].first);
|
|
ASSERT_EQ("MyGrad", grads[0].second);
|
|
ASSERT_EQ("FooFunc2", grads[1].first);
|
|
ASSERT_EQ("MyGrad", grads[1].second);
|
|
|
|
TF_DeleteFunction(func1);
|
|
TF_DeleteFunction(func2);
|
|
TF_DeleteFunction(grad_func);
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, AddFunctionsThenMakeOneGradientOfAnother) {
|
|
// Define the functions
|
|
TF_Function* func;
|
|
TF_Function* grad_func;
|
|
DefineFunction("FooFunc", &func);
|
|
DefineFunction("MyGrad", &grad_func);
|
|
|
|
// Add functions individually
|
|
TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
TF_GraphCopyFunction(host_graph_, grad_func, nullptr, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
// Check that functions are added but not linked
|
|
GraphDef gdef;
|
|
GetGraphDef(host_graph_, &gdef);
|
|
std::vector<string> func_names = GetFuncNames(gdef);
|
|
ASSERT_EQ(2, func_names.size());
|
|
ASSERT_EQ("FooFunc", func_names[0]);
|
|
ASSERT_EQ("MyGrad", func_names[1]);
|
|
ASSERT_EQ(0, GetGradDefs(gdef).size());
|
|
|
|
// Make grad_func a gradient of func
|
|
TF_GraphCopyFunction(host_graph_, func, grad_func, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
// Verify that function and its grad are linked
|
|
gdef.Clear();
|
|
GetGraphDef(host_graph_, &gdef);
|
|
std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
|
|
ASSERT_EQ(1, grads.size());
|
|
ASSERT_EQ("FooFunc", grads[0].first);
|
|
ASSERT_EQ("MyGrad", grads[0].second);
|
|
|
|
TF_DeleteFunction(func);
|
|
TF_DeleteFunction(grad_func);
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, GradientErrorCases) {
|
|
// Define the function
|
|
DefineFunction(func_name_, &func_);
|
|
TF_Function* grad_func1;
|
|
TF_Function* grad_func2;
|
|
DefineFunction("MyGrad1", &grad_func1);
|
|
DefineFunction("MyGrad2", &grad_func2);
|
|
|
|
// func cannot be null
|
|
TF_GraphCopyFunction(host_graph_, nullptr, func_, s_);
|
|
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
|
|
EXPECT_EQ(string("'func' argument to TF_GraphCopyFunction cannot be null"),
|
|
string(TF_Message(s_)));
|
|
|
|
// Cannot change gradient
|
|
TF_GraphCopyFunction(host_graph_, func_, grad_func1, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
TF_GraphCopyFunction(host_graph_, func_, grad_func2, s_);
|
|
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
|
|
EXPECT_EQ(string("Cannot assign gradient function 'MyGrad2' to 'MyFunc' "
|
|
"because it already has gradient function 'MyGrad1'"),
|
|
string(TF_Message(s_)));
|
|
|
|
TF_DeleteFunction(grad_func1);
|
|
TF_DeleteFunction(grad_func2);
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, ImportFunctionDef) {
|
|
/*
|
|
* Using a fairly complex function with output names
|
|
*
|
|
* | | |
|
|
* v v /
|
|
* add /
|
|
* | |
|
|
* +------+ |
|
|
* | | |
|
|
* | v v
|
|
* | add
|
|
* | |
|
|
* v v
|
|
* internal_out final_out
|
|
*/
|
|
// Define
|
|
TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
|
|
TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
|
|
TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
|
|
TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
|
|
TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
|
|
Define(-1, {}, {feed1, feed2, feed3}, {add1, add2},
|
|
{"internal_out", "final_out"});
|
|
|
|
// Save func_ to FunctionDef and import it back
|
|
Reincarnate();
|
|
|
|
// Use, run, and verify
|
|
TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
|
|
TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
|
|
TF_Operation* func_feed = Placeholder(host_graph_, s_);
|
|
TF_Operation* func_op = Use({two, ten, func_feed});
|
|
Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {12, 15});
|
|
VerifyFDef({"add1", "add2"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
|
|
M({{"internal_out"}, {"final_out"}}),
|
|
{{"feed1", "add1:0"},
|
|
{"feed2", "add1:1"},
|
|
{"add1:sum:0", "add2:0"},
|
|
{"feed3", "add2:1"},
|
|
{"add1:sum:0", "internal_out"},
|
|
{"add2:sum:0", "final_out"}},
|
|
{});
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, ImportFunctionDef_InvalidProto) {
|
|
// Invalid protobuf data (protos cannot start with 4 bytes of zeros)
|
|
char proto[] = {0x0, 0x0, 0x0, 0x0};
|
|
func_ = TF_FunctionImportFunctionDef(proto, 4, s_);
|
|
EXPECT_TRUE(func_ == nullptr);
|
|
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
|
|
EXPECT_EQ(string("Invalid FunctionDef given to TF_FunctionImportFunctionDef"),
|
|
string(TF_Message(s_)));
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, Attribute) {
|
|
DefineFunction(func_name_, &func_);
|
|
|
|
// Get non existent attribute
|
|
TF_Buffer* attr_buf = TF_NewBuffer();
|
|
TF_FunctionGetAttrValueProto(func_, "foo_attr", attr_buf, s_);
|
|
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
|
|
EXPECT_EQ(string("Function 'MyFunc' has no attr named 'foo_attr'."),
|
|
string(TF_Message(s_)));
|
|
TF_DeleteBuffer(attr_buf);
|
|
|
|
// Set attr
|
|
tensorflow::AttrValue attr;
|
|
attr.set_s("test_attr_value");
|
|
string bytes;
|
|
attr.SerializeToString(&bytes);
|
|
TF_FunctionSetAttrValueProto(func_, "test_attr_name", bytes.data(),
|
|
bytes.size(), s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
// Get attr
|
|
AttrValue read_attr;
|
|
GetAttr("test_attr_name", &read_attr);
|
|
ASSERT_EQ(attr.DebugString(), read_attr.DebugString());
|
|
|
|
// Retrieve the same attr after save/restore
|
|
Reincarnate();
|
|
AttrValue read_attr2;
|
|
GetAttr("test_attr_name", &read_attr2);
|
|
ASSERT_EQ(attr.DebugString(), read_attr2.DebugString());
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, Description) {
|
|
DefineFunction(func_name_, &func_, "Return something");
|
|
tensorflow::FunctionDef fdef;
|
|
ASSERT_TRUE(GetFunctionDef(func_, &fdef));
|
|
ASSERT_EQ(string("Return something"), fdef.signature().description());
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, Name) {
|
|
DefineFunction("long_func_name", &func_, "Return something",
|
|
/*append_hash=*/false);
|
|
tensorflow::FunctionDef fdef;
|
|
ASSERT_TRUE(GetFunctionDef(func_, &fdef));
|
|
ASSERT_EQ(string("long_func_name"), fdef.signature().name());
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, AppendHash) {
|
|
DefineFunction("func_name_base", &func_, "Return something",
|
|
/*append_hash=*/true);
|
|
tensorflow::FunctionDef fdef;
|
|
ASSERT_TRUE(GetFunctionDef(func_, &fdef));
|
|
#if (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
|
|
ASSERT_EQ(string("func_name_base_ZpgUD4x8oqk"), fdef.signature().name());
|
|
#else
|
|
ASSERT_EQ(string("func_name_base_qaJ8jA8UmGY"), fdef.signature().name());
|
|
#endif
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, GetOpDef) {
|
|
DefineFunction(func_name_, &func_);
|
|
TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
// Test we can retrieve function OpDef from graph
|
|
TF_Buffer* buffer = TF_NewBuffer();
|
|
TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
// Sanity check returned OpDef
|
|
string data(static_cast<const char*>(buffer->data), buffer->length);
|
|
OpDef op_def;
|
|
op_def.ParseFromString(data);
|
|
EXPECT_EQ(op_def.name(), func_name_);
|
|
EXPECT_EQ(op_def.input_arg_size(), 1);
|
|
EXPECT_EQ(op_def.output_arg_size(), 1);
|
|
EXPECT_FALSE(op_def.is_stateful());
|
|
|
|
TF_DeleteBuffer(buffer);
|
|
}
|
|
|
|
void DefineStatefulFunction(const char* name, TF_Function** func) {
|
|
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
|
|
TF_NewGraph(), TF_DeleteGraph);
|
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
|
|
TF_DeleteStatus);
|
|
|
|
TF_Tensor* tensor_shape = Int32Tensor({37, 1});
|
|
TF_Operation* shape = Const(tensor_shape, func_graph.get(), s.get(), "shape");
|
|
TF_Operation* random =
|
|
RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get());
|
|
|
|
TF_Output outputs[] = {{random, 0}};
|
|
*func = TF_GraphToFunction(func_graph.get(), name,
|
|
/*append_hash_to_fn_name=*/false, -1,
|
|
/*opers=*/nullptr, 0, nullptr, 1, outputs,
|
|
/*output_names=*/nullptr,
|
|
/*opts=*/nullptr, "", s.get());
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
|
|
ASSERT_NE(*func, nullptr);
|
|
TF_DeleteTensor(tensor_shape);
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, StatefulOpDef) {
|
|
DefineStatefulFunction(func_name_, &func_);
|
|
TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
// Test we can retrieve function OpDef from graph
|
|
TF_Buffer* buffer = TF_NewBuffer();
|
|
TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
// Sanity check returned OpDef
|
|
string data(static_cast<const char*>(buffer->data), buffer->length);
|
|
OpDef op_def;
|
|
op_def.ParseFromString(data);
|
|
EXPECT_EQ(op_def.name(), func_name_);
|
|
EXPECT_EQ(op_def.input_arg_size(), 0);
|
|
EXPECT_EQ(op_def.output_arg_size(), 1);
|
|
EXPECT_TRUE(op_def.is_stateful());
|
|
|
|
TF_DeleteBuffer(buffer);
|
|
}
|
|
|
|
void AssertEqual(TF_Function* f1, TF_Function* f2) {
|
|
string s1, s2;
|
|
tensorflow::FunctionDef fdef1, fdef2;
|
|
ASSERT_TRUE(GetFunctionDef(f1, &fdef1));
|
|
ASSERT_TRUE(GetFunctionDef(f2, &fdef2));
|
|
SerializeToStringDeterministic(fdef1, &s1);
|
|
SerializeToStringDeterministic(fdef2, &s2);
|
|
ASSERT_EQ(s1, s2);
|
|
}
|
|
|
|
string GetName(TF_Function* func) {
|
|
tensorflow::FunctionDef fdef;
|
|
GetFunctionDef(func, &fdef);
|
|
return fdef.signature().name();
|
|
}
|
|
|
|
TEST_F(CApiFunctionTest, GetFunctionsFromGraph) {
|
|
TF_Function* funcs[2];
|
|
|
|
// Get functions from empty graph
|
|
EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 0);
|
|
TF_GraphGetFunctions(host_graph_, nullptr, 0, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
// Define a function and add it to host_graph_
|
|
TF_Function* func0;
|
|
DefineFunction("FooFunc0", &func0);
|
|
TF_GraphCopyFunction(host_graph_, func0, nullptr, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
// Get this function from host_graph_
|
|
EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 1);
|
|
EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 0, s_), 0);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 1, s_), 1);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
AssertEqual(func0, funcs[0]);
|
|
TF_DeleteFunction(funcs[0]);
|
|
EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 2, s_), 1);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
AssertEqual(func0, funcs[0]);
|
|
TF_DeleteFunction(funcs[0]);
|
|
|
|
// Define a second function
|
|
TF_Function* func1;
|
|
DefineFunction("FooFunc1", &func1);
|
|
TF_GraphCopyFunction(host_graph_, func1, nullptr, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
// Get both function from host_graph_
|
|
EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 2);
|
|
EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 0, s_), 0);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 2, s_), 2);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
if (GetName(funcs[0]) == GetName(func0)) {
|
|
AssertEqual(func0, funcs[0]);
|
|
AssertEqual(func1, funcs[1]);
|
|
} else {
|
|
AssertEqual(func0, funcs[1]);
|
|
AssertEqual(func1, funcs[0]);
|
|
}
|
|
|
|
TF_DeleteFunction(funcs[0]);
|
|
TF_DeleteFunction(funcs[1]);
|
|
|
|
TF_DeleteFunction(func0);
|
|
TF_DeleteFunction(func1);
|
|
}
|
|
|
|
// This test only works when the TF build includes XLA compiler. One way to set
|
|
// this up is via bazel build option "--define with_xla_support=true".
|
|
//
|
|
// FIXME: generalize the macro name TENSORFLOW_EAGER_USE_XLA to
|
|
// something like TENSORFLOW_CAPI_USE_XLA.
|
|
#ifdef TENSORFLOW_EAGER_USE_XLA
|
|
TEST_F(CApiFunctionTest, StatelessIf_XLA) {
|
|
TF_Function* func;
|
|
const std::string funcName = "BranchFunc";
|
|
DefineFunction(funcName.c_str(), &func);
|
|
TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
TF_Operation* feed = Placeholder(host_graph_, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
TF_Operation* true_cond = ScalarConst(true, host_graph_, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
TF_OperationDescription* desc =
|
|
TF_NewOperation(host_graph_, "StatelessIf", "IfNode");
|
|
TF_AddInput(desc, {true_cond, 0});
|
|
TF_Output inputs[] = {{feed, 0}};
|
|
TF_AddInputList(desc, inputs, TF_ARRAYSIZE(inputs));
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
TF_SetAttrType(desc, "Tcond", TF_BOOL);
|
|
TF_DataType inputType = TF_INT32;
|
|
TF_SetAttrTypeList(desc, "Tin", &inputType, 1);
|
|
TF_SetAttrTypeList(desc, "Tout", &inputType, 1);
|
|
TF_SetAttrFuncName(desc, "then_branch", funcName.data(), funcName.size());
|
|
TF_SetAttrFuncName(desc, "else_branch", funcName.data(), funcName.size());
|
|
TF_SetDevice(desc, "/device:XLA_CPU:0");
|
|
auto op = TF_FinishOperation(desc, s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
ASSERT_NE(op, nullptr);
|
|
|
|
// Create a session for this graph.
|
|
CSession csession(host_graph_, s_, /*use_XLA*/ true);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
// Run the graph.
|
|
csession.SetInputs({{feed, Int32Tensor(17)}});
|
|
csession.SetOutputs({op});
|
|
csession.Run(s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
TF_Tensor* out = csession.output_tensor(0);
|
|
ASSERT_TRUE(out != nullptr);
|
|
EXPECT_EQ(TF_INT32, TF_TensorType(out));
|
|
EXPECT_EQ(0, TF_NumDims(out)); // scalar
|
|
ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
|
|
int32* output_contents = static_cast<int32*>(TF_TensorData(out));
|
|
EXPECT_EQ(-17, *output_contents);
|
|
|
|
// Clean up
|
|
csession.CloseAndDelete(s_);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
|
|
|
|
TF_DeleteFunction(func);
|
|
}
|
|
#endif // TENSORFLOW_EAGER_USE_XLA
|
|
|
|
} // namespace
|
|
} // namespace tensorflow
|