Migrate FunctionDefHelper::Define() to NodeDef functions. For callers

of Define() where that doesn't work, switch to Create() instead.
Change: 142718606
This commit is contained in:
A. Unique TensorFlower 2016-12-21 17:28:44 -08:00 committed by TensorFlower Gardener
parent e14f7841de
commit a439d9975e
8 changed files with 273 additions and 261 deletions

View File

@ -359,38 +359,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
delete g;
}
TEST_F(FunctionLibraryRuntimeTest, ManySwapsOld) {
auto func = FDH::Define( // Creates a FunctionDef using FunctionDef::Nodes
// Name
"ManySwapsFirst",
// Args
{"x: float", "y: float"},
// Return values
{"o: float"},
// attr def
{},
// Nodes
{{{"a0", "b0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}},
{{"a1", "b1"}, "Swap", {"a0", "b0"}, {{"T", DT_FLOAT}}},
{{"a2", "b2"}, "Swap", {"a1", "b1"}, {{"T", DT_FLOAT}}},
{{"a3", "b3"}, "Swap", {"a2", "b2"}, {{"T", DT_FLOAT}}},
{{"a4", "b4"}, "Swap", {"a3", "b3"}, {{"T", DT_FLOAT}}},
{{"a5", "b5"}, "Swap", {"a4", "b4"}, {{"T", DT_FLOAT}}},
{{"o"}, "Identity", {"a5"}, {{"T", DT_FLOAT}}}});
Init({test::function::Swap(), func});
Graph* g = GetFuncBody("ManySwapsFirst", {});
ASSERT_TRUE(g != nullptr);
OptimizeGraph(lib_, &g);
const char* e0 = R"P(
(n3:float, n2:float) -> (n3:float) {
}
)P";
EXPECT_EQ(e0, DebugString(g));
delete g;
}
// Like the above test, but using NodeDefs in the FunctionDef.
TEST_F(FunctionLibraryRuntimeTest, DISABLED_ManySwapsNodeDef) {
TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) {
auto func = FDH::Create( // Creates a FunctionDef using NodeDefs
// Name
"ManySwapsNodeDef",
@ -423,7 +392,7 @@ TEST_F(FunctionLibraryRuntimeTest, DISABLED_ManySwapsNodeDef) {
}
TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
auto func = FDH::Define(
auto func = FDH::Create(
// Name
"ManySwapsFirst",
// Args
@ -438,11 +407,12 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
// y2 depends on the 2nd swap. The 2nd swap has data dependency
// on the 1st swap. The optimization should maintain the control
// dependencies.
{{{"a0", "b0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}, {"x2"}},
{{"a1", "b1"}, "Swap", {"a0", "b0"}, {{"T", DT_FLOAT}}},
{{{"a0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}, {"x2"}},
{{"a1"}, "Swap", {"a0:o0:0", "a0:o1:0"}, {{"T", DT_FLOAT}}},
{{"x2"}, "Mul", {"x", "x"}, {{"T", DT_FLOAT}}},
{{"y2"}, "Mul", {"y", "y"}, {{"T", DT_FLOAT}}, {"a1"}},
{{"o"}, "Add", {"x2", "y2"}, {{"T", DT_FLOAT}}}});
{{"o"}, "Add", {"x2:z:0", "y2:z:0"}, {{"T", DT_FLOAT}}}},
{{"o", "o:z:0"}});
Init({test::function::Swap(), func});
Graph* g = GetFuncBody("ManySwapsFirst", {});
ASSERT_TRUE(g != nullptr);
@ -608,7 +578,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
auto grad =
FDH::Define("TestGrad", {"x:float", "y:float"}, {"dx:float", "dy:float"},
{}, {FDH::Const<float>("dz", 1),
{{"grad"},
{{"grad0", "grad1"},
"SymbolicGradient",
{"x", "y", "dz"},
{
@ -616,8 +586,8 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
{"Tin", DataTypeSlice{T, T, T}},
{"Tout", DataTypeSlice{T, T}},
}},
{{"dx"}, "Identity", {"grad:0"}, {{"T", DT_FLOAT}}},
{{"dy"}, "Identity", {"grad:1"}, {{"T", DT_FLOAT}}}});
{{"dx"}, "Identity", {"grad0"}, {{"T", DT_FLOAT}}},
{{"dy"}, "Identity", {"grad1"}, {{"T", DT_FLOAT}}}});
Init({test, grad});
@ -660,19 +630,19 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
OptimizeGraph(lib_, &g);
const char* e2 = R"P(
(n4:float, n3:float) -> (n25:float, n23:float) {
n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]()
n7 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
n8 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
n7 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
n19 = Shape[T=float, out_type=int32](n3)
n8 = Add[T=float](n4, n3)
n9 = Add[T=float](n4, n3)
n20 = Shape[T=float, out_type=int32](n4)
n9 = Rank[T=float](n8)
n14 = Shape[T=float, out_type=int32](n8)
n10 = Rank[T=float](n9)
n14 = Shape[T=float, out_type=int32](n9)
n21 = BroadcastGradientArgs[T=int32](n20, n19)
n10 = Range[Tidx=int32](n7, n9, n11)
n12 = Shape[T=int32, out_type=int32](n10)
n13 = Fill[T=int32](n12, n11)
n15 = DynamicStitch[N=2, T=int32](n10, n10, n14, n13)
n11 = Range[Tidx=int32](n8, n10, n7)
n12 = Shape[T=int32, out_type=int32](n11)
n13 = Fill[T=int32](n12, n7)
n15 = DynamicStitch[N=2, T=int32](n11, n11, n14, n13)
n16 = Reshape[T=float, Tshape=int32](n2, n15)
n17 = Div[T=int32](n14, n15)
n18 = Tile[T=float, Tmultiples=int32](n16, n17)
@ -842,33 +812,38 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
}
TEST(OptimizationTest, RemoveListArrayConverter) {
auto func = FDH::Define(
auto func = FDH::Create(
// Name
"Test",
// Args
{"i: float"},
// Return values
// Return signature
{"o: float"},
// Attrs
{},
// Nodes
{FDH::Const("zero", 0),
{{"s"}, "Split", {"zero", "i"}, {{"num_split", 4}, {"T", DT_FLOAT}}},
{{"s"},
"Split",
{"zero:output:0", "i"},
{{"num_split", 4}, {"T", DT_FLOAT}}},
{{"a"},
"_ArrayToList",
{"s"},
{"s:output"},
{{"N", 4},
{"T", DT_FLOAT},
{"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}},
{{"l"}, "Mul", {"a:0", "a:1"}, {{"T", DT_FLOAT}}},
{{"r"}, "Mul", {"a:2", "a:3"}, {{"T", DT_FLOAT}}},
{{"l"}, "Mul", {"a:output:0", "a:output:1"}, {{"T", DT_FLOAT}}},
{{"r"}, "Mul", {"a:output:2", "a:output:3"}, {{"T", DT_FLOAT}}},
{{"x"},
"_ListToArray",
{"l", "r"},
{"l:z", "r:z"},
{{"N", 2},
{"T", DT_FLOAT},
{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
{{"o"}, "AddN", {"x"}, {{"N", 2}, {"T", DT_FLOAT}}}});
{{"o"}, "AddN", {"x:output"}, {{"N", 2}, {"T", DT_FLOAT}}}},
// Return values
{{"o", "o:sum"}});
const char* e0 = R"P(
(n0:float) -> (n7:float) {
@ -916,7 +891,7 @@ TEST(OptimizationTest, RemoveListArrayConverter) {
}
TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
auto func = FDH::Define(
auto func = FDH::Create(
// Name
"Test",
// Args
@ -935,10 +910,11 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
{"dummy"}},
{{"o"},
"AddN",
{"x"},
{"x:output"},
{{"N", 2}, {"T", DT_FLOAT}},
// Control dep
{"x"}}});
{"x"}}},
{{"o", "o:sum"}});
const char* e0 = R"P(
(n0:float) -> (n3:float) {

View File

@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include "tensorflow/core/framework/function.pb_text.h"
@ -1110,35 +1110,17 @@ FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
return ret;
}
FunctionDef::Node FunctionDefHelper::Node::ToProto() const {
FunctionDef::Node n;
for (const string& r : this->ret) {
n.add_ret(r);
}
n.set_op(this->op);
for (const string& a : arg) {
n.add_arg(a);
}
for (const auto& a : this->attr) {
n.mutable_attr()->insert({a.first, a.second.proto});
}
for (const string& d : dep) {
n.add_dep(d);
}
return n;
}
NodeDef FunctionDefHelper::Node::ToNodeDef() const {
NodeDef n;
n.set_op(this->op);
n.set_name(this->ret[0]);
for (const string& a : arg) {
n.add_input(a);
}
for (const auto& a : this->attr) {
n.mutable_attr()->insert({a.first, a.second.proto});
}
for (const string& d : dep) {
for (const string& a : this->arg) {
n.add_input(a);
}
for (const string& d : this->dep) {
n.add_input(strings::StrCat("^", d));
}
return n;
@ -1189,8 +1171,56 @@ FunctionDef FunctionDefHelper::Define(const string& name,
OpRegistrationData op_reg_data;
TF_CHECK_OK(b.Finalize(&op_reg_data));
fdef.mutable_signature()->Swap(&op_reg_data.op_def);
for (const auto& n : node_def) {
*(fdef.add_node()) = n.ToProto();
// Mapping from legacy output names to NodeDef outputs.
std::unordered_map<string, string> ret_index;
for (const auto& a : fdef.signature().input_arg()) {
ret_index[a.name()] = a.name();
}
// For looking up OpDefs
auto* op_def_registry = OpRegistry::Global();
// Function body
for (const auto& src : node_def) {
NodeDef* n = fdef.add_node_def();
n->set_op(src.op);
n->set_name(src.ret[0]);
for (const auto& a : src.attr) {
n->mutable_attr()->insert({a.first, a.second.proto});
}
for (const string& a : src.arg) {
const auto iter = ret_index.find(a);
CHECK(iter != ret_index.end()) << "Node input '" << a << "' in '"
<< src.ret[0] << "' of " << name;
n->add_input(iter->second);
}
for (const string& d : src.dep) {
n->add_input(strings::StrCat("^", d));
}
// Add the outputs of this node to ret_index.
const OpDef* op_def = nullptr;
TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
CHECK(op_def != nullptr) << n->op();
NameRangeMap output_names;
TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
for (const auto& o : output_names) {
CHECK_LE(o.second.second, src.ret.size())
<< "Missing ret for output '" << o.first << "' in '" << src.ret[0]
<< "' of " << name;
for (int i = o.second.first; i < o.second.second; ++i) {
ret_index[src.ret[i]] =
strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
}
}
}
// Returns
for (const auto& r : fdef.signature().output_arg()) {
const auto iter = ret_index.find(r.name());
CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
fdef.mutable_ret()->insert({r.name(), iter->second});
}
return fdef;
}

View File

@ -108,7 +108,6 @@ class FunctionDefHelper {
std::vector<std::pair<string, AttrValueWrapper>> attr;
std::vector<string> dep;
FunctionDef::Node ToProto() const;
NodeDef ToNodeDef() const;
};

View File

@ -71,7 +71,8 @@ TEST(TFunc, SquarePlusOneOld) {
SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
a = Square[T=$T](x)
o = One[T=$T]()
y = Add[T=$T](a, o)
y = Add[T=$T](a:y:0, o:y:0)
return y = y:z:0
}
)P";
EXPECT_EQ(DebugString(fdef), e);
@ -91,7 +92,7 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
EXPECT_EQ(DebugString(result.gdef), e2);
}
TEST(TFunc, DISABLED_SquarePlusOneNodeDef) {
TEST(TFunc, SquarePlusOneNodeDef) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
// Name
"SquarePlusOne",
@ -137,7 +138,7 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
EXPECT_EQ(DebugString(result.gdef), e2);
}
TEST(TFunc, DISABLED_ControlDepNodeDef) {
TEST(TFunc, ControlDepNodeDef) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
// Name
"ControlDep",
@ -206,6 +207,7 @@ TEST(TFunc, MissingTypeAttrOld) {
const char* e = R"P(
BackCompat() -> (y:float) {
y = HasDefaultType()
return y = y:out:0
}
)P";
EXPECT_EQ(DebugString(fdef), e);
@ -224,7 +226,7 @@ BackCompat() -> (y:float) {
EXPECT_EQ(DebugString(result.gdef), e2);
}
TEST(TFunc, DISABLED_MissingTypeAttrNodeDef) {
TEST(TFunc, MissingTypeAttrNodeDef) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
// Name
"BackCompat",
@ -262,7 +264,7 @@ BackCompat() -> (y:float) {
EXPECT_EQ(DebugString(result.gdef), e2);
}
TEST(TFunc, DISABLED_NTimesTNodeDef) {
TEST(TFunc, NTimesTNodeDef) {
// Note that the equivalent FunctionDef using FunctionDef::Node requires
// using a _ListToArray to package up the two inputs to AddN as a single
// N*T edge.
@ -320,7 +322,7 @@ y: N tensors, each of type U;
)doc");
TEST(TFunc, AddSquared) {
auto fdef = FDH::Define(
auto fdef = FDH::Create(
// Name
"AddSquared",
// Args
@ -339,12 +341,14 @@ TEST(TFunc, AddSquared) {
{"U", "$T"},
{"N", "$N"}}},
// y = AddN<N=$N,T=$T>(a)
{{"y"}, "AddN", {"a"}, {{"N", "$N"}, {"T", "$T"}}}});
{{"y"}, "AddN", {"a:y"}, {{"N", "$N"}, {"T", "$T"}}}},
{{"y", "y:sum"}});
const char* e = R"P(
AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) {
a = Map[N=$N, T=$T, U=$T, func=Square[T=$T]](x)
y = AddN[N=$N, T=$T](a)
y = AddN[N=$N, T=$T](a:y)
return y = y:sum
}
)P";
EXPECT_EQ(DebugString(fdef), e);
@ -413,8 +417,9 @@ TEST(TFunc, XTimesTwo) {
auto expect = R"P(
XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
scale = Cast[DstT=$T, SrcT=int64](two)
y = Mul[T=$T](x, scale)
scale = Cast[DstT=$T, SrcT=int64](two:output:0)
y = Mul[T=$T](x, scale:y:0)
return y = y:z:0
}
)P";
EXPECT_EQ(expect, DebugString(test::function::XTimesTwo()));
@ -424,7 +429,8 @@ TEST(TFunc, WXPlusB) {
auto expect = R"P(
WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x)
y = Add[T=$T](mm, b)
y = Add[T=$T](mm:product:0, b)
return y = y:z:0
}
)P";
EXPECT_EQ(expect, DebugString(test::function::WXPlusB()));
@ -432,7 +438,7 @@ WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
TEST(TFunc, Body_TypeList) {
const Tensor kZero = test::AsScalar<int32>(0);
auto fdef = FDH::Define(
auto fdef = FDH::Create(
// Name
"Test",
// Args
@ -443,32 +449,30 @@ TEST(TFunc, Body_TypeList) {
{},
// Nodes
{{{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT32}}},
{{"s"}, "Split", {"zero", "i"}, {{"num_split", 4}, {"T", DT_FLOAT}}},
{{"lst"},
"_ArrayToList",
{"s"},
{{"N", 4},
{"T", DT_FLOAT},
{"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}},
{{"l"}, "Mul", {"lst:0", "lst:1"}, {{"T", DT_FLOAT}}},
{{"r"}, "Mul", {"lst:2", "lst:3"}, {{"T", DT_FLOAT}}},
{{"s"},
"Split",
{"zero:output:0", "i"},
{{"num_split", 4}, {"T", DT_FLOAT}}},
{{"l"}, "Mul", {"s:output:0", "s:output:1"}, {{"T", DT_FLOAT}}},
{{"r"}, "Mul", {"s:output:2", "s:output:3"}, {{"T", DT_FLOAT}}},
{{"x"},
"_ListToArray",
{"l", "r"},
{"l:z", "r:z"},
{{"N", 2},
{"T", DT_FLOAT},
{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
{{"o"}, "AddN", {"x"}, {{"N", 2}, {"T", DT_FLOAT}}}});
{{"o"}, "AddN", {"x:output"}, {{"N", 2}, {"T", DT_FLOAT}}}},
{{"o", "o:sum:0"}});
const char* e = R"P(
Test(i:float) -> (o:float) {
zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
s = Split[T=float, num_split=4](zero, i)
lst = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](s)
l = Mul[T=float](lst:0, lst:1)
r = Mul[T=float](lst:2, lst:3)
x = _ListToArray[N=2, T=float, Tin={float, float}](l, r)
o = AddN[N=2, T=float](x)
s = Split[T=float, num_split=4](zero:output:0, i)
l = Mul[T=float](s:output:0, s:output:1)
r = Mul[T=float](s:output:2, s:output:3)
x = _ListToArray[N=2, T=float, Tin={float, float}](l:z, r:z)
o = AddN[N=2, T=float](x:output)
return o = o:sum:0
}
)P";
EXPECT_EQ(DebugString(fdef), e);
@ -476,14 +480,13 @@ Test(i:float) -> (o:float) {
InstantiationResult result;
TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
const char* e2 = R"P(
(n0:float) -> (n7:float) {
(n0:float) -> (n6:float) {
n1 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
n2 = Split[T=float, num_split=4](n1, n0)
n3 = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](n2, n2:1, n2:2, n2:3)
n4 = Mul[T=float](n3, n3:1)
n5 = Mul[T=float](n3:2, n3:3)
n6 = _ListToArray[N=2, T=float, Tin={float, float}](n4, n5)
n7 = AddN[N=2, T=float](n6, n6:1)
n3 = Mul[T=float](n2, n2:1)
n4 = Mul[T=float](n2:2, n2:3)
n5 = _ListToArray[N=2, T=float, Tin={float, float}](n3, n4)
n6 = AddN[N=2, T=float](n5, n5:1)
}
)P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
@ -540,7 +543,8 @@ TEST(TFunc, Body_Array_List_Converter) {
const char* e = R"P(
MySelect(x:float) -> (z:float) {
y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x)
z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y, y)
z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y:output:0, y:output:0)
return z = z:output:0
}
)P";
EXPECT_EQ(DebugString(fdef), e);
@ -608,16 +612,6 @@ TEST(InstantiateErrors, DupArgs) {
"Duplicated arg name");
}
TEST(InstantiateErrors, Dup_Arg_Node_Name) {
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
{
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
});
InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Duplicated ret name");
}
TEST(InstantiateErrors, Dup_Node_Names) {
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
{
@ -629,44 +623,25 @@ TEST(InstantiateErrors, Dup_Node_Names) {
"Duplicated ret name");
}
TEST(InstantiateErrors, Node_Signature_Mismatch_NoOp) {
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
{
{{"y", "z"}, "NoOp", {}, {{"T", DT_FLOAT}}},
});
InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Expect one ret name");
}
TEST(InstantiateErrors, Node_Signature_Mismatch) {
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
{
{{"y", "z"}, "One", {}, {{"T", DT_FLOAT}}},
});
InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Malformed function node (#ret)");
}
TEST(InstantiateErrors, Node_Arg_Notfound) {
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
auto fdef = FDH::Create("test", {"x:float"}, {}, {},
{
{{"y"}, "Add", {"x", "z"}, {{"T", DT_FLOAT}}},
});
},
{});
InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"arg[1] is not found");
"input z is not found");
}
TEST(InstantiateErrors, Node_Arg_Mismatch) {
TEST(InstantiateErrors, Node_Arg_TypeMismatch) {
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
{
{{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}},
});
InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Invalid arg(0) for function arg");
"input x[0] expected type int32 != float, the type of x[0]");
}
TEST(InstantiateErrors, Node_Arg_ControlMissing) {
@ -677,20 +652,55 @@ TEST(InstantiateErrors, Node_Arg_ControlMissing) {
});
InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"dep[0] is not found");
"input[2] == '^z', is not found.");
}
TEST(InstantiateErrors, FuncRet_Missing) {
auto fdef = FDH::Define("test", {}, {"y: float"}, {},
auto fdef = FDH::Create("test", {}, {"y: float"}, {},
{
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
});
},
{});
InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"ret is not found");
"Return y missing");
}
TEST(InstantiateErrors, FuncRet_Mismatch) {
TEST(InstantiateErrors, FuncRet_NotFound) {
auto fdef = FDH::Create("test", {}, {"y: float"}, {},
{
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
},
{{"y", "z"}});
InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Return y -> z is not found");
}
TEST(InstantiateErrors, FuncRet_NameMismatch) {
auto fdef = FDH::Create("test", {}, {"y: float"}, {},
{
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
},
{{"z", "x:y:0"}});
InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Return y missing");
}
// TODO(josh11b): Make this an error.
// TEST(InstantiateErrors, FuncRet_Extra) {
// auto fdef = FDH::Create("test", {}, {"y: float"}, {},
// {
// {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
// },
// {{"y", "x:y:0"}, {"z", "x:y:0"}});
// InstantiationResult result;
// HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
// "ret is not found");
// }
TEST(InstantiateErrors, FuncRet_TypeMismatch) {
auto fdef = FDH::Define("test", {}, {"y: float"}, {},
{
{{"y"}, "One", {}, {{"T", DT_DOUBLE}}},
@ -701,7 +711,7 @@ TEST(InstantiateErrors, FuncRet_Mismatch) {
}
TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) {
auto fdef = FDH::Define(
auto fdef = FDH::Create(
// Name
"MySelect",
// Args
@ -719,14 +729,15 @@ TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) {
{"cond", FDH::FunctionRef("MyCond2")},
{"then_branch", FDH::FunctionRef("MyThen2")},
{"else_branch", FDH::FunctionRef("MyElse2")}}},
});
},
{{"y", "y:output"}});
InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"type attr not found: out_types");
}
TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) {
auto fdef = FDH::Define(
auto fdef = FDH::Create(
// Name
"MySelect",
// Args
@ -745,14 +756,15 @@ TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) {
{"cond", FDH::FunctionRef("MyCond2")},
{"then_branch", FDH::FunctionRef("MyThen2")},
{"else_branch", FDH::FunctionRef("MyElse2")}}},
});
},
{{"y", "y:output"}});
InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Invalid ret types");
}
TEST(InstantiateErrors, TypeList_Missing_Arg) {
auto fdef = FDH::Define(
auto fdef = FDH::Create(
// Name
"MySelect",
// Args
@ -771,13 +783,14 @@ TEST(InstantiateErrors, TypeList_Missing_Arg) {
{"cond", FDH::FunctionRef("MyCond2")},
{"then_branch", FDH::FunctionRef("MyThen2")},
{"else_branch", FDH::FunctionRef("MyElse2")}}},
});
},
{{"y", "y:output"}});
InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"arg[1] is not found");
"input unknown is not found");
}
TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputs) {
TEST(InstantiateErrors, NodeDef_TooManyInputs) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
// Name
"TooManyInputs",
@ -798,7 +811,7 @@ TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputs) {
"Expected input[2] == 'x' to be a control input.");
}
TEST(InstantiateErrors, DISABLED_NodeDef_TooFewInputs) {
TEST(InstantiateErrors, NodeDef_TooFewInputs) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
// Name
"TooFewInputs",
@ -819,7 +832,7 @@ TEST(InstantiateErrors, DISABLED_NodeDef_TooFewInputs) {
"Attempt to access beyond input size: 2 >= 2");
}
TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputsFromArray1) {
TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray1) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
// Name
"TooManyInputsFromArray",
@ -847,7 +860,7 @@ TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputsFromArray1) {
"Expected input[1] == 'y' to be a control input.");
}
TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputsFromArray2) {
TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray2) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
// Name
"TooManyInputsFromArray",
@ -875,7 +888,7 @@ TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputsFromArray2) {
"Input a:output too long for inputs");
}
TEST(InstantiateErrors, DISABLED_NodeDef_TypeMismatch) {
TEST(InstantiateErrors, NodeDef_TypeMismatch) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
// Name
"TypeMismatch",
@ -968,8 +981,9 @@ TEST(FunctionLibraryDefinitionTest, Find) {
auto expect = R"P(
XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
scale = Cast[DstT=$T, SrcT=int64](two)
y = Mul[T=$T](x, scale)
scale = Cast[DstT=$T, SrcT=int64](two:output:0)
y = Mul[T=$T](x, scale:y:0)
return y = y:z:0
}
)P";
auto found = lib_def.Find("XTimesTwo");

View File

@ -91,7 +91,7 @@ FunctionDef XTimesTwo() {
}
FunctionDef XTimesFour() {
return FDH::Define(
return FDH::Create(
// Name
"XTimesFour",
// Args
@ -103,12 +103,13 @@ FunctionDef XTimesFour() {
// Nodes
{
{{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}},
{{"y"}, "XTimesTwo", {"x2"}, {{"T", "$T"}}},
});
{{"y"}, "XTimesTwo", {"x2:y:0"}, {{"T", "$T"}}},
},
{{"y", "y:y:0"}});
}
FunctionDef XTimes16() {
return FDH::Define(
return FDH::Create(
// Name
"XTimes16",
// Args
@ -120,8 +121,9 @@ FunctionDef XTimes16() {
// Nodes
{
{{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}},
{{"y"}, "XTimesFour", {"x4"}, {{"T", "$T"}}},
});
{{"y"}, "XTimesFour", {"x4:y:0"}, {{"T", "$T"}}},
},
{{"y", "y:y:0"}});
}
FunctionDef WXPlusB() {

View File

@ -90,7 +90,8 @@ REGISTER_OP_GRADIENT("Identity", IdentityGrad);
Status PackGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
*g = FDH::Define(
*g = FDH::Create(
"_",
// Arg defs
{"x: N*T", "dy: T"},
// Ret val defs
@ -105,7 +106,8 @@ Status PackGrad(const AttrSlice& attrs, FunctionDef* g) {
{"dy"},
{{"T", "$T"}, {"num", "$N"}, {"axis", "$axis"}}
},
});
},
{{"dx", "dx:output"}});
// clang-format on
VLOG(1) << "PackGrad " << DebugString(*g);
return Status::OK();
@ -147,9 +149,9 @@ Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g,
std::vector<string> offset_i;
std::vector<string> dx_i;
for (int i = 0; i < N; ++i) {
shape_i.push_back(strings::StrCat("shape_lst:", i));
offset_i.push_back(strings::StrCat("offset_lst:", i));
dx_i.push_back(strings::StrCat("dx_", i));
shape_i.push_back(strings::StrCat("shapes:output:", i));
offset_i.push_back(strings::StrCat("offset:offset:", i));
dx_i.push_back(strings::StrCat("dx_", i, ":output:0"));
}
DataTypeVector dtype_list(N, T);
@ -160,19 +162,7 @@ Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g,
// which is the same as dx[i]'s offset within dy.
std::vector<FDH::Node> nodes{
{{"shapes"}, "ShapeN", {"x"}, {{"T", "$T"}, {"N", "$N"}}},
{{"shape_lst"},
"_ArrayToList",
{"shapes"},
{{"T", DT_INT32},
{"N", "$N"},
{"out_types", DataTypeVector(N, DT_INT32)}}},
{{"offset"}, "ConcatOffset", {"dim", "shapes"}, {{"N", "$N"}}},
{{"offset_lst"},
"_ArrayToList",
{"offset"},
{{"T", DT_INT32},
{"N", "$N"},
{"out_types", DataTypeVector(N, DT_INT32)}}},
{{"offset"}, "ConcatOffset", {"dim", "shapes:output"}, {{"N", "$N"}}},
{{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}},
{{"dx"},
"_ListToArray",
@ -182,34 +172,40 @@ Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g,
// For each dx[i], we take a slice of dy. The offset and size of the
// slice is given by offset[i] and shape[i].
for (int i = 0; i < N; ++i) {
nodes.push_back({{dx_i[i]},
nodes.push_back({{strings::StrCat("dx_", i)},
"Slice",
{"dy", offset_i[i], shape_i[i]},
{{"T", "$T"}, {"Index", DT_INT32}}});
}
if (dim_is_last_arg) {
// clang-format off
*g = FDH::Define(
*g = FDH::Create(
"_",
// Arg defs
{"x: N*T", "dim: int32", "dy: T"},
// Ret val defs
// Return signature
{"dx: N*T", "d_dim: int32"},
// Attr defs
{"T: type", "N: int"},
// Nodes
nodes);
nodes,
// Return values
{{"dx", "dx:output"}, {"d_dim", "d_dim:y:0"}});
// clang-format on
} else {
// clang-format off
*g = FDH::Define(
*g = FDH::Create(
"_",
// Arg defs
{"dim: int32", "x: N*T", "dy: T"},
// Ret val defs
// Return signature
{"d_dim: int32", "dx: N*T"},
// Attr defs
{"T: type", "N: int"},
// Nodes
nodes);
nodes,
// Return values
{{"dx", "dx:output"}, {"d_dim", "d_dim:y:0"}});
// clang-format on
}
VLOG(1) << "ConcatGrad " << DebugString(*g);
@ -399,15 +395,9 @@ Status SliceGrad(const AttrSlice& attrs, FunctionDef* g) {
{{"xs_b"}, "Sub", {"xs", "begin"}, {{"T", DT_INT32}}},
{{"xs_b_s"}, "Sub", {"xs_b", "size"}, {{"T", DT_INT32}}},
{{"a1"}, "ExpandDims", {"xs_b_s", "one"}, {{"T", DT_INT32}}},
{{"b_and_a"},
"_ListToArray",
{"b1", "a1"},
{{"T", DT_INT32},
{"N", 2},
{"Tin", DataTypeVector{DT_INT32, DT_INT32}}}},
{{"paddings"},
"Concat",
{"one", "b_and_a"},
{"one", "b1", "a1"},
{{"N", 2}, {"T", DT_INT32}}},
// dx = Pad(dy, paddings)
{{"dx"}, "Pad", {"dy", "paddings"}, {{"T", "$T"}}},

View File

@ -314,6 +314,7 @@ Status GradForBinaryCwise(FunctionDef* g, std::vector<FDH::Node> body) {
};
nodes.insert(nodes.end(), body.begin(), body.end());
std::vector<FDH::Node> reshapes = {
{{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "sy"}},
{{"sum_gx"}, "Sum", {"gx", "rx"}},
{{"dx"}, "Reshape", {"sum_gx", "sx"}},
{{"sum_gy"}, "Sum", {"gy", "ry"}},
@ -323,12 +324,11 @@ Status GradForBinaryCwise(FunctionDef* g, std::vector<FDH::Node> body) {
// clang-format on
for (auto& n : nodes) {
if (n.attr.empty()) {
// "BroadcastGradientArgs" doesn't need any attrs.
if (n.attr.empty() && n.op != "BroadcastGradientArgs") {
n.attr = {{"T", "$T"}};
}
}
// "BroadcastGradientArgs" doesn't need any attrs.
nodes.push_back({{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "sy"}});
*g = FDH::Define(
// Arg defs
{"x: T", "y: T", "dz: T"},
@ -518,18 +518,14 @@ Status GradForReductionOp(FunctionDef* g, std::vector<FDH::Node> body) {
FDH::Const("zero", 0),
FDH::Const("one", 1),
// stitch_idx0 = Range(0, x_rank, 1)
{{"stitch_idx1"}, "Identity", {"i"}, {{"T", DT_INT32}}},
{{"stitch_idx"}, "_ListToArray", {"stitch_idx0", "stitch_idx1"},
{{"Tin", DataTypeSlice{DT_INT32, DT_INT32}},
{"T", DT_INT32}, {"N", 2}}},
{{"stitch_val0"}, "Identity", {"x_shape"}, {{"T", DT_INT32}}},
{{"stitch_val1"}, "Fill", {"i_shape", "one"}, {{"T", DT_INT32}}},
{{"stitch_val"}, "_ListToArray", {"stitch_val0", "stitch_val1"},
{{"Tin", DataTypeSlice{DT_INT32, DT_INT32}},
{"T", DT_INT32}, {"N", 2}}},
{{"y_shape"}, "DynamicStitch", {"stitch_idx", "stitch_val"},
{{"stitch_val1"}, "Fill", {"i_shape:output:0", "one:output:0"},
{{"T", DT_INT32}}},
{{"y_shape"}, "DynamicStitch",
{"stitch_idx0:output:0", "i",
"x_shape:output:0", "stitch_val1:output:0"},
{{"N", 2}, {"T", DT_INT32}}},
{{"tile_scaling"}, "Div", {"x_shape", "y_shape"}, {{"T", DT_INT32}}},
{{"tile_scaling"}, "Div", {"x_shape:output:0", "y_shape:merged:0"},
{{"T", DT_INT32}}},
{{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}}
};
// clang-format on
@ -540,41 +536,46 @@ Status GradForReductionOp(FunctionDef* g, std::vector<FDH::Node> body) {
}
}
// "Range" doesn't need any attr.
nodes.push_back({{"stitch_idx0"}, "Range", {"zero", "x_rank", "one"}, {}});
*g = FDH::Define(
// Arg defs
nodes.push_back({{"stitch_idx0"},
"Range",
{"zero:output:0", "x_rank:output:0", "one:output:0"},
{}});
*g = FDH::Create("_",
// Input defs
{"x:T", "i:int32", "dy:T"},
// Ret val defs
{"dx:T", "di:int32"},
// Attr defs
{{"T: {half, float, double}"}},
// Nodes
nodes);
nodes,
// Return values
{{"dx", "dx:output:0"}, {"di", "di:y:0"}});
return Status::OK();
}
Status SumGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForReductionOp(g, {
{{"dy_reshaped"}, "Reshape", {"dy", "y_shape"}},
{{"dx"}, "Tile", {"dy_reshaped", "tile_scaling"}},
{{"dy_reshaped"}, "Reshape", {"dy", "y_shape:merged:0"}},
{{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
});
// clang-format on
return Status::OK();
}
REGISTER_OP_GRADIENT("Sum", SumGrad);
Status MeanGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
return GradForReductionOp(g, {
{{"factor"}, "Prod", {"tile_scaling", "zero"}, {{"T", DT_INT32}}},
{{"factor_T"}, "Cast", {"factor"}, {{"SrcT", DT_INT32}, {"DstT", "$T"}}},
{{"dy_scaled"}, "Div", {"dy", "factor_T"}},
{{"dy_reshaped"}, "Reshape", {"dy_scaled", "y_shape"}},
{{"dx"}, "Tile", {"dy_reshaped", "tile_scaling"}},
{{"factor"}, "Prod", {"tile_scaling:z:0", "zero:output:0"},
{{"T", DT_INT32}}},
{{"factor_T"}, "Cast", {"factor:output:0"},
{{"SrcT", DT_INT32}, {"DstT", "$T"}}},
{{"dy_scaled"}, "Div", {"dy", "factor_T:y:0"}},
{{"dy_reshaped"}, "Reshape", {"dy_scaled:z:0", "y_shape:merged:0"}},
{{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
});
// clang-format on
return Status::OK();
}
REGISTER_OP_GRADIENT("Mean", MeanGrad);

View File

@ -66,7 +66,7 @@ class MathGradTest : public ::testing::Test {
{"Tin", DataTypeSlice{T, T}},
{"Tout", DataTypeSlice{T}},
}},
{{"dx"}, "Identity", {"grad:0"}, {{"T", T}}},
{{"dx"}, "Identity", {"grad"}, {{"T", T}}},
});
// Each test case will feed in "x:0" and expects to get "dx:0".
auto gdef = test::function::GDef(
@ -120,7 +120,7 @@ class MathGradTest : public ::testing::Test {
{
FDH::Const("one", 1),
{{"dz"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}},
{{"grad"},
{{"grad0", "grad1"},
"SymbolicGradient",
{"x", "y", "dz"},
{
@ -128,8 +128,8 @@ class MathGradTest : public ::testing::Test {
{"Tin", DataTypeSlice{T, T, T}},
{"Tout", DataTypeSlice{T, T}},
}},
{{"dx"}, "Identity", {"grad:0"}, {{"T", T}}},
{{"dy"}, "Identity", {"grad:1"}, {{"T", T}}},
{{"dx"}, "Identity", {"grad0"}, {{"T", T}}},
{{"dy"}, "Identity", {"grad1"}, {{"T", T}}},
});
// Each test case will feed in "x:0" and "y:0" and expects to get "d0" and
// "d:0".
@ -177,7 +177,7 @@ class MathGradTest : public ::testing::Test {
{
FDH::Const("one", 1),
{{"dy"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}},
{{"grad"},
{{"grad0", "grad1"},
"SymbolicGradient",
{"x", "i", "dy"},
{
@ -185,8 +185,8 @@ class MathGradTest : public ::testing::Test {
{"Tin", DataTypeSlice{T, DT_INT32, T}},
{"Tout", DataTypeSlice{T, DT_INT32}},
}},
{{"dx"}, "Identity", {"grad:0"}, {{"T", T}}},
{{"di"}, "Identity", {"grad:1"}, {{"T", DT_INT32}}},
{{"dx"}, "Identity", {"grad0"}, {{"T", T}}},
{{"di"}, "Identity", {"grad1"}, {{"T", DT_INT32}}},
});
// Each test case will feed in "x:0" and expects to get "dx:0".
auto gdef = test::function::GDef(
@ -267,7 +267,7 @@ class MathGradTest : public ::testing::Test {
{
FDH::Const("one", 1),
{{"dz"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}},
{{"grad"},
{{"grad0", "grad1"},
"SymbolicGradient",
{"x", "y", "dz"},
{
@ -275,8 +275,8 @@ class MathGradTest : public ::testing::Test {
{"Tin", DataTypeSlice{T, T, T}},
{"Tout", DataTypeSlice{T, T}},
}},
{{"dx"}, "Identity", {"grad:0"}, {{"T", T}}},
{{"dy"}, "Identity", {"grad:1"}, {{"T", T}}},
{{"dx"}, "Identity", {"grad0"}, {{"T", T}}},
{{"dy"}, "Identity", {"grad1"}, {{"T", T}}},
});
// Each test case will feed in "x:0" and "y:0" and expects to get "d0" and
// "d:0".
@ -331,7 +331,7 @@ class MathGradTest : public ::testing::Test {
auto grad = FDH::Define("TestGrad", {"c:bool", "x:float", "y:float"},
{"dc:bool", "dx:float", "dy:float"}, {},
{FDH::Const("dz", 1.f),
{{"grad"},
{{"grad0", "grad1", "grad2"},
"SymbolicGradient",
{"c", "x", "y", "dz"},
{
@ -339,9 +339,9 @@ class MathGradTest : public ::testing::Test {
{"Tin", DataTypeSlice{DT_BOOL, T, T, T}},
{"Tout", DataTypeSlice{DT_BOOL, T, T}},
}},
{{"dc"}, "Identity", {"grad:0"}, {{"T", DT_BOOL}}},
{{"dx"}, "Identity", {"grad:1"}, {{"T", T}}},
{{"dy"}, "Identity", {"grad:2"}, {{"T", T}}}});
{{"dc"}, "Identity", {"grad0"}, {{"T", DT_BOOL}}},
{{"dx"}, "Identity", {"grad1"}, {{"T", T}}},
{{"dy"}, "Identity", {"grad2"}, {{"T", T}}}});
// Each test case will feed in "x:0" and expects to get "dx:0".
auto gdef = test::function::GDef(
{