Migrate FunctionDefHelper::Define() to NodeDef functions. For callers

of Define() where that doesn't work, switch to Create() instead.
Change: 142718606
This commit is contained in:
A. Unique TensorFlower 2016-12-21 17:28:44 -08:00 committed by TensorFlower Gardener
parent e14f7841de
commit a439d9975e
8 changed files with 273 additions and 261 deletions

View File

@ -359,38 +359,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
delete g; delete g;
} }
TEST_F(FunctionLibraryRuntimeTest, ManySwapsOld) { TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) {
auto func = FDH::Define( // Creates a FunctionDef using FunctionDef::Nodes
// Name
"ManySwapsFirst",
// Args
{"x: float", "y: float"},
// Return values
{"o: float"},
// attr def
{},
// Nodes
{{{"a0", "b0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}},
{{"a1", "b1"}, "Swap", {"a0", "b0"}, {{"T", DT_FLOAT}}},
{{"a2", "b2"}, "Swap", {"a1", "b1"}, {{"T", DT_FLOAT}}},
{{"a3", "b3"}, "Swap", {"a2", "b2"}, {{"T", DT_FLOAT}}},
{{"a4", "b4"}, "Swap", {"a3", "b3"}, {{"T", DT_FLOAT}}},
{{"a5", "b5"}, "Swap", {"a4", "b4"}, {{"T", DT_FLOAT}}},
{{"o"}, "Identity", {"a5"}, {{"T", DT_FLOAT}}}});
Init({test::function::Swap(), func});
Graph* g = GetFuncBody("ManySwapsFirst", {});
ASSERT_TRUE(g != nullptr);
OptimizeGraph(lib_, &g);
const char* e0 = R"P(
(n3:float, n2:float) -> (n3:float) {
}
)P";
EXPECT_EQ(e0, DebugString(g));
delete g;
}
// Like the above test, but using NodeDefs in the FunctionDef.
TEST_F(FunctionLibraryRuntimeTest, DISABLED_ManySwapsNodeDef) {
auto func = FDH::Create( // Creates a FunctionDef using NodeDefs auto func = FDH::Create( // Creates a FunctionDef using NodeDefs
// Name // Name
"ManySwapsNodeDef", "ManySwapsNodeDef",
@ -423,7 +392,7 @@ TEST_F(FunctionLibraryRuntimeTest, DISABLED_ManySwapsNodeDef) {
} }
TEST_F(FunctionLibraryRuntimeTest, ControlDeps) { TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
auto func = FDH::Define( auto func = FDH::Create(
// Name // Name
"ManySwapsFirst", "ManySwapsFirst",
// Args // Args
@ -438,11 +407,12 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
// y2 depends on the 2nd swap. The 2nd swap has data dependency // y2 depends on the 2nd swap. The 2nd swap has data dependency
// on the 1st swap. The optimization should maintain the control // on the 1st swap. The optimization should maintain the control
// dependencies. // dependencies.
{{{"a0", "b0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}, {"x2"}}, {{{"a0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}, {"x2"}},
{{"a1", "b1"}, "Swap", {"a0", "b0"}, {{"T", DT_FLOAT}}}, {{"a1"}, "Swap", {"a0:o0:0", "a0:o1:0"}, {{"T", DT_FLOAT}}},
{{"x2"}, "Mul", {"x", "x"}, {{"T", DT_FLOAT}}}, {{"x2"}, "Mul", {"x", "x"}, {{"T", DT_FLOAT}}},
{{"y2"}, "Mul", {"y", "y"}, {{"T", DT_FLOAT}}, {"a1"}}, {{"y2"}, "Mul", {"y", "y"}, {{"T", DT_FLOAT}}, {"a1"}},
{{"o"}, "Add", {"x2", "y2"}, {{"T", DT_FLOAT}}}}); {{"o"}, "Add", {"x2:z:0", "y2:z:0"}, {{"T", DT_FLOAT}}}},
{{"o", "o:z:0"}});
Init({test::function::Swap(), func}); Init({test::function::Swap(), func});
Graph* g = GetFuncBody("ManySwapsFirst", {}); Graph* g = GetFuncBody("ManySwapsFirst", {});
ASSERT_TRUE(g != nullptr); ASSERT_TRUE(g != nullptr);
@ -608,7 +578,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
auto grad = auto grad =
FDH::Define("TestGrad", {"x:float", "y:float"}, {"dx:float", "dy:float"}, FDH::Define("TestGrad", {"x:float", "y:float"}, {"dx:float", "dy:float"},
{}, {FDH::Const<float>("dz", 1), {}, {FDH::Const<float>("dz", 1),
{{"grad"}, {{"grad0", "grad1"},
"SymbolicGradient", "SymbolicGradient",
{"x", "y", "dz"}, {"x", "y", "dz"},
{ {
@ -616,8 +586,8 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
{"Tin", DataTypeSlice{T, T, T}}, {"Tin", DataTypeSlice{T, T, T}},
{"Tout", DataTypeSlice{T, T}}, {"Tout", DataTypeSlice{T, T}},
}}, }},
{{"dx"}, "Identity", {"grad:0"}, {{"T", DT_FLOAT}}}, {{"dx"}, "Identity", {"grad0"}, {{"T", DT_FLOAT}}},
{{"dy"}, "Identity", {"grad:1"}, {{"T", DT_FLOAT}}}}); {{"dy"}, "Identity", {"grad1"}, {{"T", DT_FLOAT}}}});
Init({test, grad}); Init({test, grad});
@ -660,19 +630,19 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
OptimizeGraph(lib_, &g); OptimizeGraph(lib_, &g);
const char* e2 = R"P( const char* e2 = R"P(
(n4:float, n3:float) -> (n25:float, n23:float) { (n4:float, n3:float) -> (n25:float, n23:float) {
n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]() n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]()
n7 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() n8 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
n7 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
n19 = Shape[T=float, out_type=int32](n3) n19 = Shape[T=float, out_type=int32](n3)
n8 = Add[T=float](n4, n3) n9 = Add[T=float](n4, n3)
n20 = Shape[T=float, out_type=int32](n4) n20 = Shape[T=float, out_type=int32](n4)
n9 = Rank[T=float](n8) n10 = Rank[T=float](n9)
n14 = Shape[T=float, out_type=int32](n8) n14 = Shape[T=float, out_type=int32](n9)
n21 = BroadcastGradientArgs[T=int32](n20, n19) n21 = BroadcastGradientArgs[T=int32](n20, n19)
n10 = Range[Tidx=int32](n7, n9, n11) n11 = Range[Tidx=int32](n8, n10, n7)
n12 = Shape[T=int32, out_type=int32](n10) n12 = Shape[T=int32, out_type=int32](n11)
n13 = Fill[T=int32](n12, n11) n13 = Fill[T=int32](n12, n7)
n15 = DynamicStitch[N=2, T=int32](n10, n10, n14, n13) n15 = DynamicStitch[N=2, T=int32](n11, n11, n14, n13)
n16 = Reshape[T=float, Tshape=int32](n2, n15) n16 = Reshape[T=float, Tshape=int32](n2, n15)
n17 = Div[T=int32](n14, n15) n17 = Div[T=int32](n14, n15)
n18 = Tile[T=float, Tmultiples=int32](n16, n17) n18 = Tile[T=float, Tmultiples=int32](n16, n17)
@ -842,33 +812,38 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
} }
TEST(OptimizationTest, RemoveListArrayConverter) { TEST(OptimizationTest, RemoveListArrayConverter) {
auto func = FDH::Define( auto func = FDH::Create(
// Name // Name
"Test", "Test",
// Args // Args
{"i: float"}, {"i: float"},
// Return values // Return signature
{"o: float"}, {"o: float"},
// Attrs // Attrs
{}, {},
// Nodes // Nodes
{FDH::Const("zero", 0), {FDH::Const("zero", 0),
{{"s"}, "Split", {"zero", "i"}, {{"num_split", 4}, {"T", DT_FLOAT}}}, {{"s"},
"Split",
{"zero:output:0", "i"},
{{"num_split", 4}, {"T", DT_FLOAT}}},
{{"a"}, {{"a"},
"_ArrayToList", "_ArrayToList",
{"s"}, {"s:output"},
{{"N", 4}, {{"N", 4},
{"T", DT_FLOAT}, {"T", DT_FLOAT},
{"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}}, {"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}},
{{"l"}, "Mul", {"a:0", "a:1"}, {{"T", DT_FLOAT}}}, {{"l"}, "Mul", {"a:output:0", "a:output:1"}, {{"T", DT_FLOAT}}},
{{"r"}, "Mul", {"a:2", "a:3"}, {{"T", DT_FLOAT}}}, {{"r"}, "Mul", {"a:output:2", "a:output:3"}, {{"T", DT_FLOAT}}},
{{"x"}, {{"x"},
"_ListToArray", "_ListToArray",
{"l", "r"}, {"l:z", "r:z"},
{{"N", 2}, {{"N", 2},
{"T", DT_FLOAT}, {"T", DT_FLOAT},
{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}}, {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
{{"o"}, "AddN", {"x"}, {{"N", 2}, {"T", DT_FLOAT}}}}); {{"o"}, "AddN", {"x:output"}, {{"N", 2}, {"T", DT_FLOAT}}}},
// Return values
{{"o", "o:sum"}});
const char* e0 = R"P( const char* e0 = R"P(
(n0:float) -> (n7:float) { (n0:float) -> (n7:float) {
@ -916,7 +891,7 @@ TEST(OptimizationTest, RemoveListArrayConverter) {
} }
TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) { TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
auto func = FDH::Define( auto func = FDH::Create(
// Name // Name
"Test", "Test",
// Args // Args
@ -935,10 +910,11 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
{"dummy"}}, {"dummy"}},
{{"o"}, {{"o"},
"AddN", "AddN",
{"x"}, {"x:output"},
{{"N", 2}, {"T", DT_FLOAT}}, {{"N", 2}, {"T", DT_FLOAT}},
// Control dep // Control dep
{"x"}}}); {"x"}}},
{{"o", "o:sum"}});
const char* e0 = R"P( const char* e0 = R"P(
(n0:float) -> (n3:float) { (n0:float) -> (n3:float) {

View File

@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
#include <unordered_set> #include <unordered_map>
#include <vector> #include <vector>
#include "tensorflow/core/framework/function.pb_text.h" #include "tensorflow/core/framework/function.pb_text.h"
@ -1110,35 +1110,17 @@ FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
return ret; return ret;
} }
FunctionDef::Node FunctionDefHelper::Node::ToProto() const {
FunctionDef::Node n;
for (const string& r : this->ret) {
n.add_ret(r);
}
n.set_op(this->op);
for (const string& a : arg) {
n.add_arg(a);
}
for (const auto& a : this->attr) {
n.mutable_attr()->insert({a.first, a.second.proto});
}
for (const string& d : dep) {
n.add_dep(d);
}
return n;
}
NodeDef FunctionDefHelper::Node::ToNodeDef() const { NodeDef FunctionDefHelper::Node::ToNodeDef() const {
NodeDef n; NodeDef n;
n.set_op(this->op); n.set_op(this->op);
n.set_name(this->ret[0]); n.set_name(this->ret[0]);
for (const string& a : arg) {
n.add_input(a);
}
for (const auto& a : this->attr) { for (const auto& a : this->attr) {
n.mutable_attr()->insert({a.first, a.second.proto}); n.mutable_attr()->insert({a.first, a.second.proto});
} }
for (const string& d : dep) { for (const string& a : this->arg) {
n.add_input(a);
}
for (const string& d : this->dep) {
n.add_input(strings::StrCat("^", d)); n.add_input(strings::StrCat("^", d));
} }
return n; return n;
@ -1189,8 +1171,56 @@ FunctionDef FunctionDefHelper::Define(const string& name,
OpRegistrationData op_reg_data; OpRegistrationData op_reg_data;
TF_CHECK_OK(b.Finalize(&op_reg_data)); TF_CHECK_OK(b.Finalize(&op_reg_data));
fdef.mutable_signature()->Swap(&op_reg_data.op_def); fdef.mutable_signature()->Swap(&op_reg_data.op_def);
for (const auto& n : node_def) {
*(fdef.add_node()) = n.ToProto(); // Mapping from legacy output names to NodeDef outputs.
std::unordered_map<string, string> ret_index;
for (const auto& a : fdef.signature().input_arg()) {
ret_index[a.name()] = a.name();
}
// For looking up OpDefs
auto* op_def_registry = OpRegistry::Global();
// Function body
for (const auto& src : node_def) {
NodeDef* n = fdef.add_node_def();
n->set_op(src.op);
n->set_name(src.ret[0]);
for (const auto& a : src.attr) {
n->mutable_attr()->insert({a.first, a.second.proto});
}
for (const string& a : src.arg) {
const auto iter = ret_index.find(a);
CHECK(iter != ret_index.end()) << "Node input '" << a << "' in '"
<< src.ret[0] << "' of " << name;
n->add_input(iter->second);
}
for (const string& d : src.dep) {
n->add_input(strings::StrCat("^", d));
}
// Add the outputs of this node to ret_index.
const OpDef* op_def = nullptr;
TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
CHECK(op_def != nullptr) << n->op();
NameRangeMap output_names;
TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
for (const auto& o : output_names) {
CHECK_LE(o.second.second, src.ret.size())
<< "Missing ret for output '" << o.first << "' in '" << src.ret[0]
<< "' of " << name;
for (int i = o.second.first; i < o.second.second; ++i) {
ret_index[src.ret[i]] =
strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
}
}
}
// Returns
for (const auto& r : fdef.signature().output_arg()) {
const auto iter = ret_index.find(r.name());
CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
fdef.mutable_ret()->insert({r.name(), iter->second});
} }
return fdef; return fdef;
} }

View File

@ -108,7 +108,6 @@ class FunctionDefHelper {
std::vector<std::pair<string, AttrValueWrapper>> attr; std::vector<std::pair<string, AttrValueWrapper>> attr;
std::vector<string> dep; std::vector<string> dep;
FunctionDef::Node ToProto() const;
NodeDef ToNodeDef() const; NodeDef ToNodeDef() const;
}; };

View File

@ -71,7 +71,8 @@ TEST(TFunc, SquarePlusOneOld) {
SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) { SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
a = Square[T=$T](x) a = Square[T=$T](x)
o = One[T=$T]() o = One[T=$T]()
y = Add[T=$T](a, o) y = Add[T=$T](a:y:0, o:y:0)
return y = y:z:0
} }
)P"; )P";
EXPECT_EQ(DebugString(fdef), e); EXPECT_EQ(DebugString(fdef), e);
@ -91,7 +92,7 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
EXPECT_EQ(DebugString(result.gdef), e2); EXPECT_EQ(DebugString(result.gdef), e2);
} }
TEST(TFunc, DISABLED_SquarePlusOneNodeDef) { TEST(TFunc, SquarePlusOneNodeDef) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs. auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
// Name // Name
"SquarePlusOne", "SquarePlusOne",
@ -137,7 +138,7 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
EXPECT_EQ(DebugString(result.gdef), e2); EXPECT_EQ(DebugString(result.gdef), e2);
} }
TEST(TFunc, DISABLED_ControlDepNodeDef) { TEST(TFunc, ControlDepNodeDef) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs. auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
// Name // Name
"ControlDep", "ControlDep",
@ -206,6 +207,7 @@ TEST(TFunc, MissingTypeAttrOld) {
const char* e = R"P( const char* e = R"P(
BackCompat() -> (y:float) { BackCompat() -> (y:float) {
y = HasDefaultType() y = HasDefaultType()
return y = y:out:0
} }
)P"; )P";
EXPECT_EQ(DebugString(fdef), e); EXPECT_EQ(DebugString(fdef), e);
@ -224,7 +226,7 @@ BackCompat() -> (y:float) {
EXPECT_EQ(DebugString(result.gdef), e2); EXPECT_EQ(DebugString(result.gdef), e2);
} }
TEST(TFunc, DISABLED_MissingTypeAttrNodeDef) { TEST(TFunc, MissingTypeAttrNodeDef) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs. auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
// Name // Name
"BackCompat", "BackCompat",
@ -262,7 +264,7 @@ BackCompat() -> (y:float) {
EXPECT_EQ(DebugString(result.gdef), e2); EXPECT_EQ(DebugString(result.gdef), e2);
} }
TEST(TFunc, DISABLED_NTimesTNodeDef) { TEST(TFunc, NTimesTNodeDef) {
// Note that the equivalent FunctionDef using FunctionDef::Node requires // Note that the equivalent FunctionDef using FunctionDef::Node requires
// using a _ListToArray to package up the two inputs to AddN as a single // using a _ListToArray to package up the two inputs to AddN as a single
// N*T edge. // N*T edge.
@ -320,7 +322,7 @@ y: N tensors, each of type U;
)doc"); )doc");
TEST(TFunc, AddSquared) { TEST(TFunc, AddSquared) {
auto fdef = FDH::Define( auto fdef = FDH::Create(
// Name // Name
"AddSquared", "AddSquared",
// Args // Args
@ -339,12 +341,14 @@ TEST(TFunc, AddSquared) {
{"U", "$T"}, {"U", "$T"},
{"N", "$N"}}}, {"N", "$N"}}},
// y = AddN<N=$N,T=$T>(a) // y = AddN<N=$N,T=$T>(a)
{{"y"}, "AddN", {"a"}, {{"N", "$N"}, {"T", "$T"}}}}); {{"y"}, "AddN", {"a:y"}, {{"N", "$N"}, {"T", "$T"}}}},
{{"y", "y:sum"}});
const char* e = R"P( const char* e = R"P(
AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) { AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) {
a = Map[N=$N, T=$T, U=$T, func=Square[T=$T]](x) a = Map[N=$N, T=$T, U=$T, func=Square[T=$T]](x)
y = AddN[N=$N, T=$T](a) y = AddN[N=$N, T=$T](a:y)
return y = y:sum
} }
)P"; )P";
EXPECT_EQ(DebugString(fdef), e); EXPECT_EQ(DebugString(fdef), e);
@ -413,8 +417,9 @@ TEST(TFunc, XTimesTwo) {
auto expect = R"P( auto expect = R"P(
XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) { XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
scale = Cast[DstT=$T, SrcT=int64](two) scale = Cast[DstT=$T, SrcT=int64](two:output:0)
y = Mul[T=$T](x, scale) y = Mul[T=$T](x, scale:y:0)
return y = y:z:0
} }
)P"; )P";
EXPECT_EQ(expect, DebugString(test::function::XTimesTwo())); EXPECT_EQ(expect, DebugString(test::function::XTimesTwo()));
@ -424,7 +429,8 @@ TEST(TFunc, WXPlusB) {
auto expect = R"P( auto expect = R"P(
WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) { WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x) mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x)
y = Add[T=$T](mm, b) y = Add[T=$T](mm:product:0, b)
return y = y:z:0
} }
)P"; )P";
EXPECT_EQ(expect, DebugString(test::function::WXPlusB())); EXPECT_EQ(expect, DebugString(test::function::WXPlusB()));
@ -432,7 +438,7 @@ WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
TEST(TFunc, Body_TypeList) { TEST(TFunc, Body_TypeList) {
const Tensor kZero = test::AsScalar<int32>(0); const Tensor kZero = test::AsScalar<int32>(0);
auto fdef = FDH::Define( auto fdef = FDH::Create(
// Name // Name
"Test", "Test",
// Args // Args
@ -443,32 +449,30 @@ TEST(TFunc, Body_TypeList) {
{}, {},
// Nodes // Nodes
{{{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT32}}}, {{{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT32}}},
{{"s"}, "Split", {"zero", "i"}, {{"num_split", 4}, {"T", DT_FLOAT}}}, {{"s"},
{{"lst"}, "Split",
"_ArrayToList", {"zero:output:0", "i"},
{"s"}, {{"num_split", 4}, {"T", DT_FLOAT}}},
{{"N", 4}, {{"l"}, "Mul", {"s:output:0", "s:output:1"}, {{"T", DT_FLOAT}}},
{"T", DT_FLOAT}, {{"r"}, "Mul", {"s:output:2", "s:output:3"}, {{"T", DT_FLOAT}}},
{"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}},
{{"l"}, "Mul", {"lst:0", "lst:1"}, {{"T", DT_FLOAT}}},
{{"r"}, "Mul", {"lst:2", "lst:3"}, {{"T", DT_FLOAT}}},
{{"x"}, {{"x"},
"_ListToArray", "_ListToArray",
{"l", "r"}, {"l:z", "r:z"},
{{"N", 2}, {{"N", 2},
{"T", DT_FLOAT}, {"T", DT_FLOAT},
{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}}, {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
{{"o"}, "AddN", {"x"}, {{"N", 2}, {"T", DT_FLOAT}}}}); {{"o"}, "AddN", {"x:output"}, {{"N", 2}, {"T", DT_FLOAT}}}},
{{"o", "o:sum:0"}});
const char* e = R"P( const char* e = R"P(
Test(i:float) -> (o:float) { Test(i:float) -> (o:float) {
zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
s = Split[T=float, num_split=4](zero, i) s = Split[T=float, num_split=4](zero:output:0, i)
lst = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](s) l = Mul[T=float](s:output:0, s:output:1)
l = Mul[T=float](lst:0, lst:1) r = Mul[T=float](s:output:2, s:output:3)
r = Mul[T=float](lst:2, lst:3) x = _ListToArray[N=2, T=float, Tin={float, float}](l:z, r:z)
x = _ListToArray[N=2, T=float, Tin={float, float}](l, r) o = AddN[N=2, T=float](x:output)
o = AddN[N=2, T=float](x) return o = o:sum:0
} }
)P"; )P";
EXPECT_EQ(DebugString(fdef), e); EXPECT_EQ(DebugString(fdef), e);
@ -476,14 +480,13 @@ Test(i:float) -> (o:float) {
InstantiationResult result; InstantiationResult result;
TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result)); TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
const char* e2 = R"P( const char* e2 = R"P(
(n0:float) -> (n7:float) { (n0:float) -> (n6:float) {
n1 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() n1 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
n2 = Split[T=float, num_split=4](n1, n0) n2 = Split[T=float, num_split=4](n1, n0)
n3 = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](n2, n2:1, n2:2, n2:3) n3 = Mul[T=float](n2, n2:1)
n4 = Mul[T=float](n3, n3:1) n4 = Mul[T=float](n2:2, n2:3)
n5 = Mul[T=float](n3:2, n3:3) n5 = _ListToArray[N=2, T=float, Tin={float, float}](n3, n4)
n6 = _ListToArray[N=2, T=float, Tin={float, float}](n4, n5) n6 = AddN[N=2, T=float](n5, n5:1)
n7 = AddN[N=2, T=float](n6, n6:1)
} }
)P"; )P";
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT})); EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
@ -540,7 +543,8 @@ TEST(TFunc, Body_Array_List_Converter) {
const char* e = R"P( const char* e = R"P(
MySelect(x:float) -> (z:float) { MySelect(x:float) -> (z:float) {
y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x) y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x)
z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y, y) z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y:output:0, y:output:0)
return z = z:output:0
} }
)P"; )P";
EXPECT_EQ(DebugString(fdef), e); EXPECT_EQ(DebugString(fdef), e);
@ -608,16 +612,6 @@ TEST(InstantiateErrors, DupArgs) {
"Duplicated arg name"); "Duplicated arg name");
} }
TEST(InstantiateErrors, Dup_Arg_Node_Name) {
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
{
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
});
InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Duplicated ret name");
}
TEST(InstantiateErrors, Dup_Node_Names) { TEST(InstantiateErrors, Dup_Node_Names) {
auto fdef = FDH::Define("test", {"x:float"}, {}, {}, auto fdef = FDH::Define("test", {"x:float"}, {}, {},
{ {
@ -629,44 +623,25 @@ TEST(InstantiateErrors, Dup_Node_Names) {
"Duplicated ret name"); "Duplicated ret name");
} }
TEST(InstantiateErrors, Node_Signature_Mismatch_NoOp) {
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
{
{{"y", "z"}, "NoOp", {}, {{"T", DT_FLOAT}}},
});
InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Expect one ret name");
}
TEST(InstantiateErrors, Node_Signature_Mismatch) {
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
{
{{"y", "z"}, "One", {}, {{"T", DT_FLOAT}}},
});
InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Malformed function node (#ret)");
}
TEST(InstantiateErrors, Node_Arg_Notfound) { TEST(InstantiateErrors, Node_Arg_Notfound) {
auto fdef = FDH::Define("test", {"x:float"}, {}, {}, auto fdef = FDH::Create("test", {"x:float"}, {}, {},
{ {
{{"y"}, "Add", {"x", "z"}, {{"T", DT_FLOAT}}}, {{"y"}, "Add", {"x", "z"}, {{"T", DT_FLOAT}}},
}); },
{});
InstantiationResult result; InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"arg[1] is not found"); "input z is not found");
} }
TEST(InstantiateErrors, Node_Arg_Mismatch) { TEST(InstantiateErrors, Node_Arg_TypeMismatch) {
auto fdef = FDH::Define("test", {"x:float"}, {}, {}, auto fdef = FDH::Define("test", {"x:float"}, {}, {},
{ {
{{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}}, {{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}},
}); });
InstantiationResult result; InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Invalid arg(0) for function arg"); "input x[0] expected type int32 != float, the type of x[0]");
} }
TEST(InstantiateErrors, Node_Arg_ControlMissing) { TEST(InstantiateErrors, Node_Arg_ControlMissing) {
@ -677,20 +652,55 @@ TEST(InstantiateErrors, Node_Arg_ControlMissing) {
}); });
InstantiationResult result; InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"dep[0] is not found"); "input[2] == '^z', is not found.");
} }
TEST(InstantiateErrors, FuncRet_Missing) { TEST(InstantiateErrors, FuncRet_Missing) {
auto fdef = FDH::Define("test", {}, {"y: float"}, {}, auto fdef = FDH::Create("test", {}, {"y: float"}, {},
{ {
{{"x"}, "One", {}, {{"T", DT_FLOAT}}}, {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
}); },
{});
InstantiationResult result; InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"ret is not found"); "Return y missing");
} }
TEST(InstantiateErrors, FuncRet_Mismatch) { TEST(InstantiateErrors, FuncRet_NotFound) {
auto fdef = FDH::Create("test", {}, {"y: float"}, {},
{
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
},
{{"y", "z"}});
InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Return y -> z is not found");
}
TEST(InstantiateErrors, FuncRet_NameMismatch) {
auto fdef = FDH::Create("test", {}, {"y: float"}, {},
{
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
},
{{"z", "x:y:0"}});
InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Return y missing");
}
// TODO(josh11b): Make this an error.
// TEST(InstantiateErrors, FuncRet_Extra) {
// auto fdef = FDH::Create("test", {}, {"y: float"}, {},
// {
// {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
// },
// {{"y", "x:y:0"}, {"z", "x:y:0"}});
// InstantiationResult result;
// HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
// "ret is not found");
// }
TEST(InstantiateErrors, FuncRet_TypeMismatch) {
auto fdef = FDH::Define("test", {}, {"y: float"}, {}, auto fdef = FDH::Define("test", {}, {"y: float"}, {},
{ {
{{"y"}, "One", {}, {{"T", DT_DOUBLE}}}, {{"y"}, "One", {}, {{"T", DT_DOUBLE}}},
@ -701,7 +711,7 @@ TEST(InstantiateErrors, FuncRet_Mismatch) {
} }
TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) { TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) {
auto fdef = FDH::Define( auto fdef = FDH::Create(
// Name // Name
"MySelect", "MySelect",
// Args // Args
@ -719,14 +729,15 @@ TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) {
{"cond", FDH::FunctionRef("MyCond2")}, {"cond", FDH::FunctionRef("MyCond2")},
{"then_branch", FDH::FunctionRef("MyThen2")}, {"then_branch", FDH::FunctionRef("MyThen2")},
{"else_branch", FDH::FunctionRef("MyElse2")}}}, {"else_branch", FDH::FunctionRef("MyElse2")}}},
}); },
{{"y", "y:output"}});
InstantiationResult result; InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"type attr not found: out_types"); "type attr not found: out_types");
} }
TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) { TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) {
auto fdef = FDH::Define( auto fdef = FDH::Create(
// Name // Name
"MySelect", "MySelect",
// Args // Args
@ -745,14 +756,15 @@ TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) {
{"cond", FDH::FunctionRef("MyCond2")}, {"cond", FDH::FunctionRef("MyCond2")},
{"then_branch", FDH::FunctionRef("MyThen2")}, {"then_branch", FDH::FunctionRef("MyThen2")},
{"else_branch", FDH::FunctionRef("MyElse2")}}}, {"else_branch", FDH::FunctionRef("MyElse2")}}},
}); },
{{"y", "y:output"}});
InstantiationResult result; InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"Invalid ret types"); "Invalid ret types");
} }
TEST(InstantiateErrors, TypeList_Missing_Arg) { TEST(InstantiateErrors, TypeList_Missing_Arg) {
auto fdef = FDH::Define( auto fdef = FDH::Create(
// Name // Name
"MySelect", "MySelect",
// Args // Args
@ -771,13 +783,14 @@ TEST(InstantiateErrors, TypeList_Missing_Arg) {
{"cond", FDH::FunctionRef("MyCond2")}, {"cond", FDH::FunctionRef("MyCond2")},
{"then_branch", FDH::FunctionRef("MyThen2")}, {"then_branch", FDH::FunctionRef("MyThen2")},
{"else_branch", FDH::FunctionRef("MyElse2")}}}, {"else_branch", FDH::FunctionRef("MyElse2")}}},
}); },
{{"y", "y:output"}});
InstantiationResult result; InstantiationResult result;
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result), HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
"arg[1] is not found"); "input unknown is not found");
} }
TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputs) { TEST(InstantiateErrors, NodeDef_TooManyInputs) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs. auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
// Name // Name
"TooManyInputs", "TooManyInputs",
@ -798,7 +811,7 @@ TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputs) {
"Expected input[2] == 'x' to be a control input."); "Expected input[2] == 'x' to be a control input.");
} }
TEST(InstantiateErrors, DISABLED_NodeDef_TooFewInputs) { TEST(InstantiateErrors, NodeDef_TooFewInputs) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs. auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
// Name // Name
"TooFewInputs", "TooFewInputs",
@ -819,7 +832,7 @@ TEST(InstantiateErrors, DISABLED_NodeDef_TooFewInputs) {
"Attempt to access beyond input size: 2 >= 2"); "Attempt to access beyond input size: 2 >= 2");
} }
TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputsFromArray1) { TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray1) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs. auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
// Name // Name
"TooManyInputsFromArray", "TooManyInputsFromArray",
@ -847,7 +860,7 @@ TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputsFromArray1) {
"Expected input[1] == 'y' to be a control input."); "Expected input[1] == 'y' to be a control input.");
} }
TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputsFromArray2) { TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray2) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs. auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
// Name // Name
"TooManyInputsFromArray", "TooManyInputsFromArray",
@ -875,7 +888,7 @@ TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputsFromArray2) {
"Input a:output too long for inputs"); "Input a:output too long for inputs");
} }
TEST(InstantiateErrors, DISABLED_NodeDef_TypeMismatch) { TEST(InstantiateErrors, NodeDef_TypeMismatch) {
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs. auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
// Name // Name
"TypeMismatch", "TypeMismatch",
@ -968,8 +981,9 @@ TEST(FunctionLibraryDefinitionTest, Find) {
auto expect = R"P( auto expect = R"P(
XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) { XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]() two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
scale = Cast[DstT=$T, SrcT=int64](two) scale = Cast[DstT=$T, SrcT=int64](two:output:0)
y = Mul[T=$T](x, scale) y = Mul[T=$T](x, scale:y:0)
return y = y:z:0
} }
)P"; )P";
auto found = lib_def.Find("XTimesTwo"); auto found = lib_def.Find("XTimesTwo");

View File

@ -91,7 +91,7 @@ FunctionDef XTimesTwo() {
} }
FunctionDef XTimesFour() { FunctionDef XTimesFour() {
return FDH::Define( return FDH::Create(
// Name // Name
"XTimesFour", "XTimesFour",
// Args // Args
@ -103,12 +103,13 @@ FunctionDef XTimesFour() {
// Nodes // Nodes
{ {
{{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}}, {{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}},
{{"y"}, "XTimesTwo", {"x2"}, {{"T", "$T"}}}, {{"y"}, "XTimesTwo", {"x2:y:0"}, {{"T", "$T"}}},
}); },
{{"y", "y:y:0"}});
} }
FunctionDef XTimes16() { FunctionDef XTimes16() {
return FDH::Define( return FDH::Create(
// Name // Name
"XTimes16", "XTimes16",
// Args // Args
@ -120,8 +121,9 @@ FunctionDef XTimes16() {
// Nodes // Nodes
{ {
{{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}}, {{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}},
{{"y"}, "XTimesFour", {"x4"}, {{"T", "$T"}}}, {{"y"}, "XTimesFour", {"x4:y:0"}, {{"T", "$T"}}},
}); },
{{"y", "y:y:0"}});
} }
FunctionDef WXPlusB() { FunctionDef WXPlusB() {

View File

@ -90,7 +90,8 @@ REGISTER_OP_GRADIENT("Identity", IdentityGrad);
Status PackGrad(const AttrSlice& attrs, FunctionDef* g) { Status PackGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off // clang-format off
*g = FDH::Define( *g = FDH::Create(
"_",
// Arg defs // Arg defs
{"x: N*T", "dy: T"}, {"x: N*T", "dy: T"},
// Ret val defs // Ret val defs
@ -105,7 +106,8 @@ Status PackGrad(const AttrSlice& attrs, FunctionDef* g) {
{"dy"}, {"dy"},
{{"T", "$T"}, {"num", "$N"}, {"axis", "$axis"}} {{"T", "$T"}, {"num", "$N"}, {"axis", "$axis"}}
}, },
}); },
{{"dx", "dx:output"}});
// clang-format on // clang-format on
VLOG(1) << "PackGrad " << DebugString(*g); VLOG(1) << "PackGrad " << DebugString(*g);
return Status::OK(); return Status::OK();
@ -147,9 +149,9 @@ Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g,
std::vector<string> offset_i; std::vector<string> offset_i;
std::vector<string> dx_i; std::vector<string> dx_i;
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
shape_i.push_back(strings::StrCat("shape_lst:", i)); shape_i.push_back(strings::StrCat("shapes:output:", i));
offset_i.push_back(strings::StrCat("offset_lst:", i)); offset_i.push_back(strings::StrCat("offset:offset:", i));
dx_i.push_back(strings::StrCat("dx_", i)); dx_i.push_back(strings::StrCat("dx_", i, ":output:0"));
} }
DataTypeVector dtype_list(N, T); DataTypeVector dtype_list(N, T);
@ -160,19 +162,7 @@ Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g,
// which is the same as dx[i]'s offset within dy. // which is the same as dx[i]'s offset within dy.
std::vector<FDH::Node> nodes{ std::vector<FDH::Node> nodes{
{{"shapes"}, "ShapeN", {"x"}, {{"T", "$T"}, {"N", "$N"}}}, {{"shapes"}, "ShapeN", {"x"}, {{"T", "$T"}, {"N", "$N"}}},
{{"shape_lst"}, {{"offset"}, "ConcatOffset", {"dim", "shapes:output"}, {{"N", "$N"}}},
"_ArrayToList",
{"shapes"},
{{"T", DT_INT32},
{"N", "$N"},
{"out_types", DataTypeVector(N, DT_INT32)}}},
{{"offset"}, "ConcatOffset", {"dim", "shapes"}, {{"N", "$N"}}},
{{"offset_lst"},
"_ArrayToList",
{"offset"},
{{"T", DT_INT32},
{"N", "$N"},
{"out_types", DataTypeVector(N, DT_INT32)}}},
{{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}}, {{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}},
{{"dx"}, {{"dx"},
"_ListToArray", "_ListToArray",
@ -182,34 +172,40 @@ Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g,
// For each dx[i], we take a slice of dy. The offset and size of the // For each dx[i], we take a slice of dy. The offset and size of the
// slice is given by offset[i] and shape[i]. // slice is given by offset[i] and shape[i].
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
nodes.push_back({{dx_i[i]}, nodes.push_back({{strings::StrCat("dx_", i)},
"Slice", "Slice",
{"dy", offset_i[i], shape_i[i]}, {"dy", offset_i[i], shape_i[i]},
{{"T", "$T"}, {"Index", DT_INT32}}}); {{"T", "$T"}, {"Index", DT_INT32}}});
} }
if (dim_is_last_arg) { if (dim_is_last_arg) {
// clang-format off // clang-format off
*g = FDH::Define( *g = FDH::Create(
"_",
// Arg defs // Arg defs
{"x: N*T", "dim: int32", "dy: T"}, {"x: N*T", "dim: int32", "dy: T"},
// Ret val defs // Return signature
{"dx: N*T", "d_dim: int32"}, {"dx: N*T", "d_dim: int32"},
// Attr defs // Attr defs
{"T: type", "N: int"}, {"T: type", "N: int"},
// Nodes // Nodes
nodes); nodes,
// Return values
{{"dx", "dx:output"}, {"d_dim", "d_dim:y:0"}});
// clang-format on // clang-format on
} else { } else {
// clang-format off // clang-format off
*g = FDH::Define( *g = FDH::Create(
"_",
// Arg defs // Arg defs
{"dim: int32", "x: N*T", "dy: T"}, {"dim: int32", "x: N*T", "dy: T"},
// Ret val defs // Return signature
{"d_dim: int32", "dx: N*T"}, {"d_dim: int32", "dx: N*T"},
// Attr defs // Attr defs
{"T: type", "N: int"}, {"T: type", "N: int"},
// Nodes // Nodes
nodes); nodes,
// Return values
{{"dx", "dx:output"}, {"d_dim", "d_dim:y:0"}});
// clang-format on // clang-format on
} }
VLOG(1) << "ConcatGrad " << DebugString(*g); VLOG(1) << "ConcatGrad " << DebugString(*g);
@ -399,15 +395,9 @@ Status SliceGrad(const AttrSlice& attrs, FunctionDef* g) {
{{"xs_b"}, "Sub", {"xs", "begin"}, {{"T", DT_INT32}}}, {{"xs_b"}, "Sub", {"xs", "begin"}, {{"T", DT_INT32}}},
{{"xs_b_s"}, "Sub", {"xs_b", "size"}, {{"T", DT_INT32}}}, {{"xs_b_s"}, "Sub", {"xs_b", "size"}, {{"T", DT_INT32}}},
{{"a1"}, "ExpandDims", {"xs_b_s", "one"}, {{"T", DT_INT32}}}, {{"a1"}, "ExpandDims", {"xs_b_s", "one"}, {{"T", DT_INT32}}},
{{"b_and_a"},
"_ListToArray",
{"b1", "a1"},
{{"T", DT_INT32},
{"N", 2},
{"Tin", DataTypeVector{DT_INT32, DT_INT32}}}},
{{"paddings"}, {{"paddings"},
"Concat", "Concat",
{"one", "b_and_a"}, {"one", "b1", "a1"},
{{"N", 2}, {"T", DT_INT32}}}, {{"N", 2}, {"T", DT_INT32}}},
// dx = Pad(dy, paddings) // dx = Pad(dy, paddings)
{{"dx"}, "Pad", {"dy", "paddings"}, {{"T", "$T"}}}, {{"dx"}, "Pad", {"dy", "paddings"}, {{"T", "$T"}}},

View File

@ -314,6 +314,7 @@ Status GradForBinaryCwise(FunctionDef* g, std::vector<FDH::Node> body) {
}; };
nodes.insert(nodes.end(), body.begin(), body.end()); nodes.insert(nodes.end(), body.begin(), body.end());
std::vector<FDH::Node> reshapes = { std::vector<FDH::Node> reshapes = {
{{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "sy"}},
{{"sum_gx"}, "Sum", {"gx", "rx"}}, {{"sum_gx"}, "Sum", {"gx", "rx"}},
{{"dx"}, "Reshape", {"sum_gx", "sx"}}, {{"dx"}, "Reshape", {"sum_gx", "sx"}},
{{"sum_gy"}, "Sum", {"gy", "ry"}}, {{"sum_gy"}, "Sum", {"gy", "ry"}},
@ -323,12 +324,11 @@ Status GradForBinaryCwise(FunctionDef* g, std::vector<FDH::Node> body) {
// clang-format on // clang-format on
for (auto& n : nodes) { for (auto& n : nodes) {
if (n.attr.empty()) { // "BroadcastGradientArgs" doesn't need any attrs.
if (n.attr.empty() && n.op != "BroadcastGradientArgs") {
n.attr = {{"T", "$T"}}; n.attr = {{"T", "$T"}};
} }
} }
// "BroadcastGradientArgs" doesn't need any attrs.
nodes.push_back({{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "sy"}});
*g = FDH::Define( *g = FDH::Define(
// Arg defs // Arg defs
{"x: T", "y: T", "dz: T"}, {"x: T", "y: T", "dz: T"},
@ -518,18 +518,14 @@ Status GradForReductionOp(FunctionDef* g, std::vector<FDH::Node> body) {
FDH::Const("zero", 0), FDH::Const("zero", 0),
FDH::Const("one", 1), FDH::Const("one", 1),
// stitch_idx0 = Range(0, x_rank, 1) // stitch_idx0 = Range(0, x_rank, 1)
{{"stitch_idx1"}, "Identity", {"i"}, {{"T", DT_INT32}}}, {{"stitch_val1"}, "Fill", {"i_shape:output:0", "one:output:0"},
{{"stitch_idx"}, "_ListToArray", {"stitch_idx0", "stitch_idx1"}, {{"T", DT_INT32}}},
{{"Tin", DataTypeSlice{DT_INT32, DT_INT32}}, {{"y_shape"}, "DynamicStitch",
{"T", DT_INT32}, {"N", 2}}}, {"stitch_idx0:output:0", "i",
{{"stitch_val0"}, "Identity", {"x_shape"}, {{"T", DT_INT32}}}, "x_shape:output:0", "stitch_val1:output:0"},
{{"stitch_val1"}, "Fill", {"i_shape", "one"}, {{"T", DT_INT32}}}, {{"N", 2}, {"T", DT_INT32}}},
{{"stitch_val"}, "_ListToArray", {"stitch_val0", "stitch_val1"}, {{"tile_scaling"}, "Div", {"x_shape:output:0", "y_shape:merged:0"},
{{"Tin", DataTypeSlice{DT_INT32, DT_INT32}}, {{"T", DT_INT32}}},
{"T", DT_INT32}, {"N", 2}}},
{{"y_shape"}, "DynamicStitch", {"stitch_idx", "stitch_val"},
{{"N", 2}, {"T", DT_INT32}}},
{{"tile_scaling"}, "Div", {"x_shape", "y_shape"}, {{"T", DT_INT32}}},
{{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}} {{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}}
}; };
// clang-format on // clang-format on
@ -540,41 +536,46 @@ Status GradForReductionOp(FunctionDef* g, std::vector<FDH::Node> body) {
} }
} }
// "Range" doesn't need any attr. // "Range" doesn't need any attr.
nodes.push_back({{"stitch_idx0"}, "Range", {"zero", "x_rank", "one"}, {}}); nodes.push_back({{"stitch_idx0"},
*g = FDH::Define( "Range",
// Arg defs {"zero:output:0", "x_rank:output:0", "one:output:0"},
{"x:T", "i:int32", "dy:T"}, {}});
// Ret val defs *g = FDH::Create("_",
{"dx:T", "di:int32"}, // Input defs
// Attr defs {"x:T", "i:int32", "dy:T"},
{{"T: {half, float, double}"}}, // Ret val defs
// Nodes {"dx:T", "di:int32"},
nodes); // Attr defs
{{"T: {half, float, double}"}},
// Nodes
nodes,
// Return values
{{"dx", "dx:output:0"}, {"di", "di:y:0"}});
return Status::OK(); return Status::OK();
} }
Status SumGrad(const AttrSlice& attrs, FunctionDef* g) { Status SumGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off // clang-format off
return GradForReductionOp(g, { return GradForReductionOp(g, {
{{"dy_reshaped"}, "Reshape", {"dy", "y_shape"}}, {{"dy_reshaped"}, "Reshape", {"dy", "y_shape:merged:0"}},
{{"dx"}, "Tile", {"dy_reshaped", "tile_scaling"}}, {{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
}); });
// clang-format on // clang-format on
return Status::OK();
} }
REGISTER_OP_GRADIENT("Sum", SumGrad); REGISTER_OP_GRADIENT("Sum", SumGrad);
Status MeanGrad(const AttrSlice& attrs, FunctionDef* g) { Status MeanGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off // clang-format off
return GradForReductionOp(g, { return GradForReductionOp(g, {
{{"factor"}, "Prod", {"tile_scaling", "zero"}, {{"T", DT_INT32}}}, {{"factor"}, "Prod", {"tile_scaling:z:0", "zero:output:0"},
{{"factor_T"}, "Cast", {"factor"}, {{"SrcT", DT_INT32}, {"DstT", "$T"}}}, {{"T", DT_INT32}}},
{{"dy_scaled"}, "Div", {"dy", "factor_T"}}, {{"factor_T"}, "Cast", {"factor:output:0"},
{{"dy_reshaped"}, "Reshape", {"dy_scaled", "y_shape"}}, {{"SrcT", DT_INT32}, {"DstT", "$T"}}},
{{"dx"}, "Tile", {"dy_reshaped", "tile_scaling"}}, {{"dy_scaled"}, "Div", {"dy", "factor_T:y:0"}},
{{"dy_reshaped"}, "Reshape", {"dy_scaled:z:0", "y_shape:merged:0"}},
{{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
}); });
// clang-format on // clang-format on
return Status::OK();
} }
REGISTER_OP_GRADIENT("Mean", MeanGrad); REGISTER_OP_GRADIENT("Mean", MeanGrad);

View File

@ -66,7 +66,7 @@ class MathGradTest : public ::testing::Test {
{"Tin", DataTypeSlice{T, T}}, {"Tin", DataTypeSlice{T, T}},
{"Tout", DataTypeSlice{T}}, {"Tout", DataTypeSlice{T}},
}}, }},
{{"dx"}, "Identity", {"grad:0"}, {{"T", T}}}, {{"dx"}, "Identity", {"grad"}, {{"T", T}}},
}); });
// Each test case will feed in "x:0" and expects to get "dx:0". // Each test case will feed in "x:0" and expects to get "dx:0".
auto gdef = test::function::GDef( auto gdef = test::function::GDef(
@ -120,7 +120,7 @@ class MathGradTest : public ::testing::Test {
{ {
FDH::Const("one", 1), FDH::Const("one", 1),
{{"dz"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}}, {{"dz"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}},
{{"grad"}, {{"grad0", "grad1"},
"SymbolicGradient", "SymbolicGradient",
{"x", "y", "dz"}, {"x", "y", "dz"},
{ {
@ -128,8 +128,8 @@ class MathGradTest : public ::testing::Test {
{"Tin", DataTypeSlice{T, T, T}}, {"Tin", DataTypeSlice{T, T, T}},
{"Tout", DataTypeSlice{T, T}}, {"Tout", DataTypeSlice{T, T}},
}}, }},
{{"dx"}, "Identity", {"grad:0"}, {{"T", T}}}, {{"dx"}, "Identity", {"grad0"}, {{"T", T}}},
{{"dy"}, "Identity", {"grad:1"}, {{"T", T}}}, {{"dy"}, "Identity", {"grad1"}, {{"T", T}}},
}); });
// Each test case will feed in "x:0" and "y:0" and expects to get "d0" and // Each test case will feed in "x:0" and "y:0" and expects to get "d0" and
// "d:0". // "d:0".
@ -177,7 +177,7 @@ class MathGradTest : public ::testing::Test {
{ {
FDH::Const("one", 1), FDH::Const("one", 1),
{{"dy"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}}, {{"dy"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}},
{{"grad"}, {{"grad0", "grad1"},
"SymbolicGradient", "SymbolicGradient",
{"x", "i", "dy"}, {"x", "i", "dy"},
{ {
@ -185,8 +185,8 @@ class MathGradTest : public ::testing::Test {
{"Tin", DataTypeSlice{T, DT_INT32, T}}, {"Tin", DataTypeSlice{T, DT_INT32, T}},
{"Tout", DataTypeSlice{T, DT_INT32}}, {"Tout", DataTypeSlice{T, DT_INT32}},
}}, }},
{{"dx"}, "Identity", {"grad:0"}, {{"T", T}}}, {{"dx"}, "Identity", {"grad0"}, {{"T", T}}},
{{"di"}, "Identity", {"grad:1"}, {{"T", DT_INT32}}}, {{"di"}, "Identity", {"grad1"}, {{"T", DT_INT32}}},
}); });
// Each test case will feed in "x:0" and expects to get "dx:0". // Each test case will feed in "x:0" and expects to get "dx:0".
auto gdef = test::function::GDef( auto gdef = test::function::GDef(
@ -267,7 +267,7 @@ class MathGradTest : public ::testing::Test {
{ {
FDH::Const("one", 1), FDH::Const("one", 1),
{{"dz"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}}, {{"dz"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}},
{{"grad"}, {{"grad0", "grad1"},
"SymbolicGradient", "SymbolicGradient",
{"x", "y", "dz"}, {"x", "y", "dz"},
{ {
@ -275,8 +275,8 @@ class MathGradTest : public ::testing::Test {
{"Tin", DataTypeSlice{T, T, T}}, {"Tin", DataTypeSlice{T, T, T}},
{"Tout", DataTypeSlice{T, T}}, {"Tout", DataTypeSlice{T, T}},
}}, }},
{{"dx"}, "Identity", {"grad:0"}, {{"T", T}}}, {{"dx"}, "Identity", {"grad0"}, {{"T", T}}},
{{"dy"}, "Identity", {"grad:1"}, {{"T", T}}}, {{"dy"}, "Identity", {"grad1"}, {{"T", T}}},
}); });
// Each test case will feed in "x:0" and "y:0" and expects to get "d0" and // Each test case will feed in "x:0" and "y:0" and expects to get "d0" and
// "d:0". // "d:0".
@ -331,7 +331,7 @@ class MathGradTest : public ::testing::Test {
auto grad = FDH::Define("TestGrad", {"c:bool", "x:float", "y:float"}, auto grad = FDH::Define("TestGrad", {"c:bool", "x:float", "y:float"},
{"dc:bool", "dx:float", "dy:float"}, {}, {"dc:bool", "dx:float", "dy:float"}, {},
{FDH::Const("dz", 1.f), {FDH::Const("dz", 1.f),
{{"grad"}, {{"grad0", "grad1", "grad2"},
"SymbolicGradient", "SymbolicGradient",
{"c", "x", "y", "dz"}, {"c", "x", "y", "dz"},
{ {
@ -339,9 +339,9 @@ class MathGradTest : public ::testing::Test {
{"Tin", DataTypeSlice{DT_BOOL, T, T, T}}, {"Tin", DataTypeSlice{DT_BOOL, T, T, T}},
{"Tout", DataTypeSlice{DT_BOOL, T, T}}, {"Tout", DataTypeSlice{DT_BOOL, T, T}},
}}, }},
{{"dc"}, "Identity", {"grad:0"}, {{"T", DT_BOOL}}}, {{"dc"}, "Identity", {"grad0"}, {{"T", DT_BOOL}}},
{{"dx"}, "Identity", {"grad:1"}, {{"T", T}}}, {{"dx"}, "Identity", {"grad1"}, {{"T", T}}},
{{"dy"}, "Identity", {"grad:2"}, {{"T", T}}}}); {{"dy"}, "Identity", {"grad2"}, {{"T", T}}}});
// Each test case will feed in "x:0" and expects to get "dx:0". // Each test case will feed in "x:0" and expects to get "dx:0".
auto gdef = test::function::GDef( auto gdef = test::function::GDef(
{ {