Migrate FunctionDefHelper::Define() to NodeDef functions. For callers
of Define() where that doesn't work, switch to Create() instead. Change: 142718606
This commit is contained in:
parent
e14f7841de
commit
a439d9975e
@ -359,38 +359,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
|
||||
delete g;
|
||||
}
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest, ManySwapsOld) {
|
||||
auto func = FDH::Define( // Creates a FunctionDef using FunctionDef::Nodes
|
||||
// Name
|
||||
"ManySwapsFirst",
|
||||
// Args
|
||||
{"x: float", "y: float"},
|
||||
// Return values
|
||||
{"o: float"},
|
||||
// attr def
|
||||
{},
|
||||
// Nodes
|
||||
{{{"a0", "b0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}},
|
||||
{{"a1", "b1"}, "Swap", {"a0", "b0"}, {{"T", DT_FLOAT}}},
|
||||
{{"a2", "b2"}, "Swap", {"a1", "b1"}, {{"T", DT_FLOAT}}},
|
||||
{{"a3", "b3"}, "Swap", {"a2", "b2"}, {{"T", DT_FLOAT}}},
|
||||
{{"a4", "b4"}, "Swap", {"a3", "b3"}, {{"T", DT_FLOAT}}},
|
||||
{{"a5", "b5"}, "Swap", {"a4", "b4"}, {{"T", DT_FLOAT}}},
|
||||
{{"o"}, "Identity", {"a5"}, {{"T", DT_FLOAT}}}});
|
||||
Init({test::function::Swap(), func});
|
||||
Graph* g = GetFuncBody("ManySwapsFirst", {});
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
OptimizeGraph(lib_, &g);
|
||||
const char* e0 = R"P(
|
||||
(n3:float, n2:float) -> (n3:float) {
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(e0, DebugString(g));
|
||||
delete g;
|
||||
}
|
||||
|
||||
// Like the above test, but using NodeDefs in the FunctionDef.
|
||||
TEST_F(FunctionLibraryRuntimeTest, DISABLED_ManySwapsNodeDef) {
|
||||
TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) {
|
||||
auto func = FDH::Create( // Creates a FunctionDef using NodeDefs
|
||||
// Name
|
||||
"ManySwapsNodeDef",
|
||||
@ -423,7 +392,7 @@ TEST_F(FunctionLibraryRuntimeTest, DISABLED_ManySwapsNodeDef) {
|
||||
}
|
||||
|
||||
TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
|
||||
auto func = FDH::Define(
|
||||
auto func = FDH::Create(
|
||||
// Name
|
||||
"ManySwapsFirst",
|
||||
// Args
|
||||
@ -438,11 +407,12 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
|
||||
// y2 depends on the 2nd swap. The 2nd swap has data dependency
|
||||
// on the 1st swap. The optimization should maintain the control
|
||||
// dependencies.
|
||||
{{{"a0", "b0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}, {"x2"}},
|
||||
{{"a1", "b1"}, "Swap", {"a0", "b0"}, {{"T", DT_FLOAT}}},
|
||||
{{{"a0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}, {"x2"}},
|
||||
{{"a1"}, "Swap", {"a0:o0:0", "a0:o1:0"}, {{"T", DT_FLOAT}}},
|
||||
{{"x2"}, "Mul", {"x", "x"}, {{"T", DT_FLOAT}}},
|
||||
{{"y2"}, "Mul", {"y", "y"}, {{"T", DT_FLOAT}}, {"a1"}},
|
||||
{{"o"}, "Add", {"x2", "y2"}, {{"T", DT_FLOAT}}}});
|
||||
{{"o"}, "Add", {"x2:z:0", "y2:z:0"}, {{"T", DT_FLOAT}}}},
|
||||
{{"o", "o:z:0"}});
|
||||
Init({test::function::Swap(), func});
|
||||
Graph* g = GetFuncBody("ManySwapsFirst", {});
|
||||
ASSERT_TRUE(g != nullptr);
|
||||
@ -608,7 +578,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
|
||||
auto grad =
|
||||
FDH::Define("TestGrad", {"x:float", "y:float"}, {"dx:float", "dy:float"},
|
||||
{}, {FDH::Const<float>("dz", 1),
|
||||
{{"grad"},
|
||||
{{"grad0", "grad1"},
|
||||
"SymbolicGradient",
|
||||
{"x", "y", "dz"},
|
||||
{
|
||||
@ -616,8 +586,8 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
|
||||
{"Tin", DataTypeSlice{T, T, T}},
|
||||
{"Tout", DataTypeSlice{T, T}},
|
||||
}},
|
||||
{{"dx"}, "Identity", {"grad:0"}, {{"T", DT_FLOAT}}},
|
||||
{{"dy"}, "Identity", {"grad:1"}, {{"T", DT_FLOAT}}}});
|
||||
{{"dx"}, "Identity", {"grad0"}, {{"T", DT_FLOAT}}},
|
||||
{{"dy"}, "Identity", {"grad1"}, {{"T", DT_FLOAT}}}});
|
||||
|
||||
Init({test, grad});
|
||||
|
||||
@ -660,19 +630,19 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
|
||||
OptimizeGraph(lib_, &g);
|
||||
const char* e2 = R"P(
|
||||
(n4:float, n3:float) -> (n25:float, n23:float) {
|
||||
n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
|
||||
n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]()
|
||||
n7 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
|
||||
n8 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
|
||||
n7 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
|
||||
n19 = Shape[T=float, out_type=int32](n3)
|
||||
n8 = Add[T=float](n4, n3)
|
||||
n9 = Add[T=float](n4, n3)
|
||||
n20 = Shape[T=float, out_type=int32](n4)
|
||||
n9 = Rank[T=float](n8)
|
||||
n14 = Shape[T=float, out_type=int32](n8)
|
||||
n10 = Rank[T=float](n9)
|
||||
n14 = Shape[T=float, out_type=int32](n9)
|
||||
n21 = BroadcastGradientArgs[T=int32](n20, n19)
|
||||
n10 = Range[Tidx=int32](n7, n9, n11)
|
||||
n12 = Shape[T=int32, out_type=int32](n10)
|
||||
n13 = Fill[T=int32](n12, n11)
|
||||
n15 = DynamicStitch[N=2, T=int32](n10, n10, n14, n13)
|
||||
n11 = Range[Tidx=int32](n8, n10, n7)
|
||||
n12 = Shape[T=int32, out_type=int32](n11)
|
||||
n13 = Fill[T=int32](n12, n7)
|
||||
n15 = DynamicStitch[N=2, T=int32](n11, n11, n14, n13)
|
||||
n16 = Reshape[T=float, Tshape=int32](n2, n15)
|
||||
n17 = Div[T=int32](n14, n15)
|
||||
n18 = Tile[T=float, Tmultiples=int32](n16, n17)
|
||||
@ -842,33 +812,38 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
|
||||
}
|
||||
|
||||
TEST(OptimizationTest, RemoveListArrayConverter) {
|
||||
auto func = FDH::Define(
|
||||
auto func = FDH::Create(
|
||||
// Name
|
||||
"Test",
|
||||
// Args
|
||||
{"i: float"},
|
||||
// Return values
|
||||
// Return signature
|
||||
{"o: float"},
|
||||
// Attrs
|
||||
{},
|
||||
// Nodes
|
||||
{FDH::Const("zero", 0),
|
||||
{{"s"}, "Split", {"zero", "i"}, {{"num_split", 4}, {"T", DT_FLOAT}}},
|
||||
{{"s"},
|
||||
"Split",
|
||||
{"zero:output:0", "i"},
|
||||
{{"num_split", 4}, {"T", DT_FLOAT}}},
|
||||
{{"a"},
|
||||
"_ArrayToList",
|
||||
{"s"},
|
||||
{"s:output"},
|
||||
{{"N", 4},
|
||||
{"T", DT_FLOAT},
|
||||
{"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}},
|
||||
{{"l"}, "Mul", {"a:0", "a:1"}, {{"T", DT_FLOAT}}},
|
||||
{{"r"}, "Mul", {"a:2", "a:3"}, {{"T", DT_FLOAT}}},
|
||||
{{"l"}, "Mul", {"a:output:0", "a:output:1"}, {{"T", DT_FLOAT}}},
|
||||
{{"r"}, "Mul", {"a:output:2", "a:output:3"}, {{"T", DT_FLOAT}}},
|
||||
{{"x"},
|
||||
"_ListToArray",
|
||||
{"l", "r"},
|
||||
{"l:z", "r:z"},
|
||||
{{"N", 2},
|
||||
{"T", DT_FLOAT},
|
||||
{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
|
||||
{{"o"}, "AddN", {"x"}, {{"N", 2}, {"T", DT_FLOAT}}}});
|
||||
{{"o"}, "AddN", {"x:output"}, {{"N", 2}, {"T", DT_FLOAT}}}},
|
||||
// Return values
|
||||
{{"o", "o:sum"}});
|
||||
|
||||
const char* e0 = R"P(
|
||||
(n0:float) -> (n7:float) {
|
||||
@ -916,7 +891,7 @@ TEST(OptimizationTest, RemoveListArrayConverter) {
|
||||
}
|
||||
|
||||
TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
|
||||
auto func = FDH::Define(
|
||||
auto func = FDH::Create(
|
||||
// Name
|
||||
"Test",
|
||||
// Args
|
||||
@ -935,10 +910,11 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
|
||||
{"dummy"}},
|
||||
{{"o"},
|
||||
"AddN",
|
||||
{"x"},
|
||||
{"x:output"},
|
||||
{{"N", 2}, {"T", DT_FLOAT}},
|
||||
// Control dep
|
||||
{"x"}}});
|
||||
{"x"}}},
|
||||
{{"o", "o:sum"}});
|
||||
|
||||
const char* e0 = R"P(
|
||||
(n0:float) -> (n3:float) {
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/function.pb_text.h"
|
||||
@ -1110,35 +1110,17 @@ FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
|
||||
return ret;
|
||||
}
|
||||
|
||||
FunctionDef::Node FunctionDefHelper::Node::ToProto() const {
|
||||
FunctionDef::Node n;
|
||||
for (const string& r : this->ret) {
|
||||
n.add_ret(r);
|
||||
}
|
||||
n.set_op(this->op);
|
||||
for (const string& a : arg) {
|
||||
n.add_arg(a);
|
||||
}
|
||||
for (const auto& a : this->attr) {
|
||||
n.mutable_attr()->insert({a.first, a.second.proto});
|
||||
}
|
||||
for (const string& d : dep) {
|
||||
n.add_dep(d);
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
NodeDef FunctionDefHelper::Node::ToNodeDef() const {
|
||||
NodeDef n;
|
||||
n.set_op(this->op);
|
||||
n.set_name(this->ret[0]);
|
||||
for (const string& a : arg) {
|
||||
n.add_input(a);
|
||||
}
|
||||
for (const auto& a : this->attr) {
|
||||
n.mutable_attr()->insert({a.first, a.second.proto});
|
||||
}
|
||||
for (const string& d : dep) {
|
||||
for (const string& a : this->arg) {
|
||||
n.add_input(a);
|
||||
}
|
||||
for (const string& d : this->dep) {
|
||||
n.add_input(strings::StrCat("^", d));
|
||||
}
|
||||
return n;
|
||||
@ -1189,8 +1171,56 @@ FunctionDef FunctionDefHelper::Define(const string& name,
|
||||
OpRegistrationData op_reg_data;
|
||||
TF_CHECK_OK(b.Finalize(&op_reg_data));
|
||||
fdef.mutable_signature()->Swap(&op_reg_data.op_def);
|
||||
for (const auto& n : node_def) {
|
||||
*(fdef.add_node()) = n.ToProto();
|
||||
|
||||
// Mapping from legacy output names to NodeDef outputs.
|
||||
std::unordered_map<string, string> ret_index;
|
||||
for (const auto& a : fdef.signature().input_arg()) {
|
||||
ret_index[a.name()] = a.name();
|
||||
}
|
||||
|
||||
// For looking up OpDefs
|
||||
auto* op_def_registry = OpRegistry::Global();
|
||||
|
||||
// Function body
|
||||
for (const auto& src : node_def) {
|
||||
NodeDef* n = fdef.add_node_def();
|
||||
n->set_op(src.op);
|
||||
n->set_name(src.ret[0]);
|
||||
for (const auto& a : src.attr) {
|
||||
n->mutable_attr()->insert({a.first, a.second.proto});
|
||||
}
|
||||
for (const string& a : src.arg) {
|
||||
const auto iter = ret_index.find(a);
|
||||
CHECK(iter != ret_index.end()) << "Node input '" << a << "' in '"
|
||||
<< src.ret[0] << "' of " << name;
|
||||
n->add_input(iter->second);
|
||||
}
|
||||
for (const string& d : src.dep) {
|
||||
n->add_input(strings::StrCat("^", d));
|
||||
}
|
||||
|
||||
// Add the outputs of this node to ret_index.
|
||||
const OpDef* op_def = nullptr;
|
||||
TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
|
||||
CHECK(op_def != nullptr) << n->op();
|
||||
NameRangeMap output_names;
|
||||
TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
|
||||
for (const auto& o : output_names) {
|
||||
CHECK_LE(o.second.second, src.ret.size())
|
||||
<< "Missing ret for output '" << o.first << "' in '" << src.ret[0]
|
||||
<< "' of " << name;
|
||||
for (int i = o.second.first; i < o.second.second; ++i) {
|
||||
ret_index[src.ret[i]] =
|
||||
strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns
|
||||
for (const auto& r : fdef.signature().output_arg()) {
|
||||
const auto iter = ret_index.find(r.name());
|
||||
CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
|
||||
fdef.mutable_ret()->insert({r.name(), iter->second});
|
||||
}
|
||||
return fdef;
|
||||
}
|
||||
|
@ -108,7 +108,6 @@ class FunctionDefHelper {
|
||||
std::vector<std::pair<string, AttrValueWrapper>> attr;
|
||||
std::vector<string> dep;
|
||||
|
||||
FunctionDef::Node ToProto() const;
|
||||
NodeDef ToNodeDef() const;
|
||||
};
|
||||
|
||||
|
@ -71,7 +71,8 @@ TEST(TFunc, SquarePlusOneOld) {
|
||||
SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
||||
a = Square[T=$T](x)
|
||||
o = One[T=$T]()
|
||||
y = Add[T=$T](a, o)
|
||||
y = Add[T=$T](a:y:0, o:y:0)
|
||||
return y = y:z:0
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(DebugString(fdef), e);
|
||||
@ -91,7 +92,7 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
}
|
||||
|
||||
TEST(TFunc, DISABLED_SquarePlusOneNodeDef) {
|
||||
TEST(TFunc, SquarePlusOneNodeDef) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
// Name
|
||||
"SquarePlusOne",
|
||||
@ -137,7 +138,7 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
}
|
||||
|
||||
TEST(TFunc, DISABLED_ControlDepNodeDef) {
|
||||
TEST(TFunc, ControlDepNodeDef) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
// Name
|
||||
"ControlDep",
|
||||
@ -206,6 +207,7 @@ TEST(TFunc, MissingTypeAttrOld) {
|
||||
const char* e = R"P(
|
||||
BackCompat() -> (y:float) {
|
||||
y = HasDefaultType()
|
||||
return y = y:out:0
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(DebugString(fdef), e);
|
||||
@ -224,7 +226,7 @@ BackCompat() -> (y:float) {
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
}
|
||||
|
||||
TEST(TFunc, DISABLED_MissingTypeAttrNodeDef) {
|
||||
TEST(TFunc, MissingTypeAttrNodeDef) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
// Name
|
||||
"BackCompat",
|
||||
@ -262,7 +264,7 @@ BackCompat() -> (y:float) {
|
||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||
}
|
||||
|
||||
TEST(TFunc, DISABLED_NTimesTNodeDef) {
|
||||
TEST(TFunc, NTimesTNodeDef) {
|
||||
// Note that the equivalent FunctionDef using FunctionDef::Node requires
|
||||
// using a _ListToArray to package up the two inputs to AddN as a single
|
||||
// N*T edge.
|
||||
@ -320,7 +322,7 @@ y: N tensors, each of type U;
|
||||
)doc");
|
||||
|
||||
TEST(TFunc, AddSquared) {
|
||||
auto fdef = FDH::Define(
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"AddSquared",
|
||||
// Args
|
||||
@ -339,12 +341,14 @@ TEST(TFunc, AddSquared) {
|
||||
{"U", "$T"},
|
||||
{"N", "$N"}}},
|
||||
// y = AddN<N=$N,T=$T>(a)
|
||||
{{"y"}, "AddN", {"a"}, {{"N", "$N"}, {"T", "$T"}}}});
|
||||
{{"y"}, "AddN", {"a:y"}, {{"N", "$N"}, {"T", "$T"}}}},
|
||||
{{"y", "y:sum"}});
|
||||
|
||||
const char* e = R"P(
|
||||
AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) {
|
||||
a = Map[N=$N, T=$T, U=$T, func=Square[T=$T]](x)
|
||||
y = AddN[N=$N, T=$T](a)
|
||||
y = AddN[N=$N, T=$T](a:y)
|
||||
return y = y:sum
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(DebugString(fdef), e);
|
||||
@ -413,8 +417,9 @@ TEST(TFunc, XTimesTwo) {
|
||||
auto expect = R"P(
|
||||
XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
||||
two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
|
||||
scale = Cast[DstT=$T, SrcT=int64](two)
|
||||
y = Mul[T=$T](x, scale)
|
||||
scale = Cast[DstT=$T, SrcT=int64](two:output:0)
|
||||
y = Mul[T=$T](x, scale:y:0)
|
||||
return y = y:z:0
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(expect, DebugString(test::function::XTimesTwo()));
|
||||
@ -424,7 +429,8 @@ TEST(TFunc, WXPlusB) {
|
||||
auto expect = R"P(
|
||||
WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
|
||||
mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x)
|
||||
y = Add[T=$T](mm, b)
|
||||
y = Add[T=$T](mm:product:0, b)
|
||||
return y = y:z:0
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(expect, DebugString(test::function::WXPlusB()));
|
||||
@ -432,7 +438,7 @@ WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
|
||||
|
||||
TEST(TFunc, Body_TypeList) {
|
||||
const Tensor kZero = test::AsScalar<int32>(0);
|
||||
auto fdef = FDH::Define(
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"Test",
|
||||
// Args
|
||||
@ -443,32 +449,30 @@ TEST(TFunc, Body_TypeList) {
|
||||
{},
|
||||
// Nodes
|
||||
{{{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT32}}},
|
||||
{{"s"}, "Split", {"zero", "i"}, {{"num_split", 4}, {"T", DT_FLOAT}}},
|
||||
{{"lst"},
|
||||
"_ArrayToList",
|
||||
{"s"},
|
||||
{{"N", 4},
|
||||
{"T", DT_FLOAT},
|
||||
{"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}},
|
||||
{{"l"}, "Mul", {"lst:0", "lst:1"}, {{"T", DT_FLOAT}}},
|
||||
{{"r"}, "Mul", {"lst:2", "lst:3"}, {{"T", DT_FLOAT}}},
|
||||
{{"s"},
|
||||
"Split",
|
||||
{"zero:output:0", "i"},
|
||||
{{"num_split", 4}, {"T", DT_FLOAT}}},
|
||||
{{"l"}, "Mul", {"s:output:0", "s:output:1"}, {{"T", DT_FLOAT}}},
|
||||
{{"r"}, "Mul", {"s:output:2", "s:output:3"}, {{"T", DT_FLOAT}}},
|
||||
{{"x"},
|
||||
"_ListToArray",
|
||||
{"l", "r"},
|
||||
{"l:z", "r:z"},
|
||||
{{"N", 2},
|
||||
{"T", DT_FLOAT},
|
||||
{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
|
||||
{{"o"}, "AddN", {"x"}, {{"N", 2}, {"T", DT_FLOAT}}}});
|
||||
{{"o"}, "AddN", {"x:output"}, {{"N", 2}, {"T", DT_FLOAT}}}},
|
||||
{{"o", "o:sum:0"}});
|
||||
|
||||
const char* e = R"P(
|
||||
Test(i:float) -> (o:float) {
|
||||
zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
|
||||
s = Split[T=float, num_split=4](zero, i)
|
||||
lst = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](s)
|
||||
l = Mul[T=float](lst:0, lst:1)
|
||||
r = Mul[T=float](lst:2, lst:3)
|
||||
x = _ListToArray[N=2, T=float, Tin={float, float}](l, r)
|
||||
o = AddN[N=2, T=float](x)
|
||||
s = Split[T=float, num_split=4](zero:output:0, i)
|
||||
l = Mul[T=float](s:output:0, s:output:1)
|
||||
r = Mul[T=float](s:output:2, s:output:3)
|
||||
x = _ListToArray[N=2, T=float, Tin={float, float}](l:z, r:z)
|
||||
o = AddN[N=2, T=float](x:output)
|
||||
return o = o:sum:0
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(DebugString(fdef), e);
|
||||
@ -476,14 +480,13 @@ Test(i:float) -> (o:float) {
|
||||
InstantiationResult result;
|
||||
TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
|
||||
const char* e2 = R"P(
|
||||
(n0:float) -> (n7:float) {
|
||||
(n0:float) -> (n6:float) {
|
||||
n1 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
|
||||
n2 = Split[T=float, num_split=4](n1, n0)
|
||||
n3 = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](n2, n2:1, n2:2, n2:3)
|
||||
n4 = Mul[T=float](n3, n3:1)
|
||||
n5 = Mul[T=float](n3:2, n3:3)
|
||||
n6 = _ListToArray[N=2, T=float, Tin={float, float}](n4, n5)
|
||||
n7 = AddN[N=2, T=float](n6, n6:1)
|
||||
n3 = Mul[T=float](n2, n2:1)
|
||||
n4 = Mul[T=float](n2:2, n2:3)
|
||||
n5 = _ListToArray[N=2, T=float, Tin={float, float}](n3, n4)
|
||||
n6 = AddN[N=2, T=float](n5, n5:1)
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
|
||||
@ -540,7 +543,8 @@ TEST(TFunc, Body_Array_List_Converter) {
|
||||
const char* e = R"P(
|
||||
MySelect(x:float) -> (z:float) {
|
||||
y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x)
|
||||
z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y, y)
|
||||
z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y:output:0, y:output:0)
|
||||
return z = z:output:0
|
||||
}
|
||||
)P";
|
||||
EXPECT_EQ(DebugString(fdef), e);
|
||||
@ -608,16 +612,6 @@ TEST(InstantiateErrors, DupArgs) {
|
||||
"Duplicated arg name");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, Dup_Arg_Node_Name) {
|
||||
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
|
||||
{
|
||||
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
|
||||
});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Duplicated ret name");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, Dup_Node_Names) {
|
||||
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
|
||||
{
|
||||
@ -629,44 +623,25 @@ TEST(InstantiateErrors, Dup_Node_Names) {
|
||||
"Duplicated ret name");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, Node_Signature_Mismatch_NoOp) {
|
||||
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
|
||||
{
|
||||
{{"y", "z"}, "NoOp", {}, {{"T", DT_FLOAT}}},
|
||||
});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Expect one ret name");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, Node_Signature_Mismatch) {
|
||||
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
|
||||
{
|
||||
{{"y", "z"}, "One", {}, {{"T", DT_FLOAT}}},
|
||||
});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Malformed function node (#ret)");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, Node_Arg_Notfound) {
|
||||
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
|
||||
auto fdef = FDH::Create("test", {"x:float"}, {}, {},
|
||||
{
|
||||
{{"y"}, "Add", {"x", "z"}, {{"T", DT_FLOAT}}},
|
||||
});
|
||||
},
|
||||
{});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"arg[1] is not found");
|
||||
"input z is not found");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, Node_Arg_Mismatch) {
|
||||
TEST(InstantiateErrors, Node_Arg_TypeMismatch) {
|
||||
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
|
||||
{
|
||||
{{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}},
|
||||
});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Invalid arg(0) for function arg");
|
||||
"input x[0] expected type int32 != float, the type of x[0]");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, Node_Arg_ControlMissing) {
|
||||
@ -677,20 +652,55 @@ TEST(InstantiateErrors, Node_Arg_ControlMissing) {
|
||||
});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"dep[0] is not found");
|
||||
"input[2] == '^z', is not found.");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, FuncRet_Missing) {
|
||||
auto fdef = FDH::Define("test", {}, {"y: float"}, {},
|
||||
auto fdef = FDH::Create("test", {}, {"y: float"}, {},
|
||||
{
|
||||
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
|
||||
});
|
||||
},
|
||||
{});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"ret is not found");
|
||||
"Return y missing");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, FuncRet_Mismatch) {
|
||||
TEST(InstantiateErrors, FuncRet_NotFound) {
|
||||
auto fdef = FDH::Create("test", {}, {"y: float"}, {},
|
||||
{
|
||||
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
|
||||
},
|
||||
{{"y", "z"}});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Return y -> z is not found");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, FuncRet_NameMismatch) {
|
||||
auto fdef = FDH::Create("test", {}, {"y: float"}, {},
|
||||
{
|
||||
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
|
||||
},
|
||||
{{"z", "x:y:0"}});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Return y missing");
|
||||
}
|
||||
|
||||
// TODO(josh11b): Make this an error.
|
||||
// TEST(InstantiateErrors, FuncRet_Extra) {
|
||||
// auto fdef = FDH::Create("test", {}, {"y: float"}, {},
|
||||
// {
|
||||
// {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
|
||||
// },
|
||||
// {{"y", "x:y:0"}, {"z", "x:y:0"}});
|
||||
// InstantiationResult result;
|
||||
// HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
// "ret is not found");
|
||||
// }
|
||||
|
||||
TEST(InstantiateErrors, FuncRet_TypeMismatch) {
|
||||
auto fdef = FDH::Define("test", {}, {"y: float"}, {},
|
||||
{
|
||||
{{"y"}, "One", {}, {{"T", DT_DOUBLE}}},
|
||||
@ -701,7 +711,7 @@ TEST(InstantiateErrors, FuncRet_Mismatch) {
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) {
|
||||
auto fdef = FDH::Define(
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"MySelect",
|
||||
// Args
|
||||
@ -719,14 +729,15 @@ TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) {
|
||||
{"cond", FDH::FunctionRef("MyCond2")},
|
||||
{"then_branch", FDH::FunctionRef("MyThen2")},
|
||||
{"else_branch", FDH::FunctionRef("MyElse2")}}},
|
||||
});
|
||||
},
|
||||
{{"y", "y:output"}});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"type attr not found: out_types");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) {
|
||||
auto fdef = FDH::Define(
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"MySelect",
|
||||
// Args
|
||||
@ -745,14 +756,15 @@ TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) {
|
||||
{"cond", FDH::FunctionRef("MyCond2")},
|
||||
{"then_branch", FDH::FunctionRef("MyThen2")},
|
||||
{"else_branch", FDH::FunctionRef("MyElse2")}}},
|
||||
});
|
||||
},
|
||||
{{"y", "y:output"}});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"Invalid ret types");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, TypeList_Missing_Arg) {
|
||||
auto fdef = FDH::Define(
|
||||
auto fdef = FDH::Create(
|
||||
// Name
|
||||
"MySelect",
|
||||
// Args
|
||||
@ -771,13 +783,14 @@ TEST(InstantiateErrors, TypeList_Missing_Arg) {
|
||||
{"cond", FDH::FunctionRef("MyCond2")},
|
||||
{"then_branch", FDH::FunctionRef("MyThen2")},
|
||||
{"else_branch", FDH::FunctionRef("MyElse2")}}},
|
||||
});
|
||||
},
|
||||
{{"y", "y:output"}});
|
||||
InstantiationResult result;
|
||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||
"arg[1] is not found");
|
||||
"input unknown is not found");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputs) {
|
||||
TEST(InstantiateErrors, NodeDef_TooManyInputs) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
// Name
|
||||
"TooManyInputs",
|
||||
@ -798,7 +811,7 @@ TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputs) {
|
||||
"Expected input[2] == 'x' to be a control input.");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, DISABLED_NodeDef_TooFewInputs) {
|
||||
TEST(InstantiateErrors, NodeDef_TooFewInputs) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
// Name
|
||||
"TooFewInputs",
|
||||
@ -819,7 +832,7 @@ TEST(InstantiateErrors, DISABLED_NodeDef_TooFewInputs) {
|
||||
"Attempt to access beyond input size: 2 >= 2");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputsFromArray1) {
|
||||
TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray1) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
// Name
|
||||
"TooManyInputsFromArray",
|
||||
@ -847,7 +860,7 @@ TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputsFromArray1) {
|
||||
"Expected input[1] == 'y' to be a control input.");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputsFromArray2) {
|
||||
TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray2) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
// Name
|
||||
"TooManyInputsFromArray",
|
||||
@ -875,7 +888,7 @@ TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputsFromArray2) {
|
||||
"Input a:output too long for inputs");
|
||||
}
|
||||
|
||||
TEST(InstantiateErrors, DISABLED_NodeDef_TypeMismatch) {
|
||||
TEST(InstantiateErrors, NodeDef_TypeMismatch) {
|
||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||
// Name
|
||||
"TypeMismatch",
|
||||
@ -968,8 +981,9 @@ TEST(FunctionLibraryDefinitionTest, Find) {
|
||||
auto expect = R"P(
|
||||
XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
||||
two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
|
||||
scale = Cast[DstT=$T, SrcT=int64](two)
|
||||
y = Mul[T=$T](x, scale)
|
||||
scale = Cast[DstT=$T, SrcT=int64](two:output:0)
|
||||
y = Mul[T=$T](x, scale:y:0)
|
||||
return y = y:z:0
|
||||
}
|
||||
)P";
|
||||
auto found = lib_def.Find("XTimesTwo");
|
||||
|
@ -91,7 +91,7 @@ FunctionDef XTimesTwo() {
|
||||
}
|
||||
|
||||
FunctionDef XTimesFour() {
|
||||
return FDH::Define(
|
||||
return FDH::Create(
|
||||
// Name
|
||||
"XTimesFour",
|
||||
// Args
|
||||
@ -103,12 +103,13 @@ FunctionDef XTimesFour() {
|
||||
// Nodes
|
||||
{
|
||||
{{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}},
|
||||
{{"y"}, "XTimesTwo", {"x2"}, {{"T", "$T"}}},
|
||||
});
|
||||
{{"y"}, "XTimesTwo", {"x2:y:0"}, {{"T", "$T"}}},
|
||||
},
|
||||
{{"y", "y:y:0"}});
|
||||
}
|
||||
|
||||
FunctionDef XTimes16() {
|
||||
return FDH::Define(
|
||||
return FDH::Create(
|
||||
// Name
|
||||
"XTimes16",
|
||||
// Args
|
||||
@ -120,8 +121,9 @@ FunctionDef XTimes16() {
|
||||
// Nodes
|
||||
{
|
||||
{{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}},
|
||||
{{"y"}, "XTimesFour", {"x4"}, {{"T", "$T"}}},
|
||||
});
|
||||
{{"y"}, "XTimesFour", {"x4:y:0"}, {{"T", "$T"}}},
|
||||
},
|
||||
{{"y", "y:y:0"}});
|
||||
}
|
||||
|
||||
FunctionDef WXPlusB() {
|
||||
|
@ -90,7 +90,8 @@ REGISTER_OP_GRADIENT("Identity", IdentityGrad);
|
||||
|
||||
Status PackGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
// clang-format off
|
||||
*g = FDH::Define(
|
||||
*g = FDH::Create(
|
||||
"_",
|
||||
// Arg defs
|
||||
{"x: N*T", "dy: T"},
|
||||
// Ret val defs
|
||||
@ -105,7 +106,8 @@ Status PackGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
{"dy"},
|
||||
{{"T", "$T"}, {"num", "$N"}, {"axis", "$axis"}}
|
||||
},
|
||||
});
|
||||
},
|
||||
{{"dx", "dx:output"}});
|
||||
// clang-format on
|
||||
VLOG(1) << "PackGrad " << DebugString(*g);
|
||||
return Status::OK();
|
||||
@ -147,9 +149,9 @@ Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g,
|
||||
std::vector<string> offset_i;
|
||||
std::vector<string> dx_i;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
shape_i.push_back(strings::StrCat("shape_lst:", i));
|
||||
offset_i.push_back(strings::StrCat("offset_lst:", i));
|
||||
dx_i.push_back(strings::StrCat("dx_", i));
|
||||
shape_i.push_back(strings::StrCat("shapes:output:", i));
|
||||
offset_i.push_back(strings::StrCat("offset:offset:", i));
|
||||
dx_i.push_back(strings::StrCat("dx_", i, ":output:0"));
|
||||
}
|
||||
DataTypeVector dtype_list(N, T);
|
||||
|
||||
@ -160,19 +162,7 @@ Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g,
|
||||
// which is the same as dx[i]'s offset within dy.
|
||||
std::vector<FDH::Node> nodes{
|
||||
{{"shapes"}, "ShapeN", {"x"}, {{"T", "$T"}, {"N", "$N"}}},
|
||||
{{"shape_lst"},
|
||||
"_ArrayToList",
|
||||
{"shapes"},
|
||||
{{"T", DT_INT32},
|
||||
{"N", "$N"},
|
||||
{"out_types", DataTypeVector(N, DT_INT32)}}},
|
||||
{{"offset"}, "ConcatOffset", {"dim", "shapes"}, {{"N", "$N"}}},
|
||||
{{"offset_lst"},
|
||||
"_ArrayToList",
|
||||
{"offset"},
|
||||
{{"T", DT_INT32},
|
||||
{"N", "$N"},
|
||||
{"out_types", DataTypeVector(N, DT_INT32)}}},
|
||||
{{"offset"}, "ConcatOffset", {"dim", "shapes:output"}, {{"N", "$N"}}},
|
||||
{{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}},
|
||||
{{"dx"},
|
||||
"_ListToArray",
|
||||
@ -182,34 +172,40 @@ Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g,
|
||||
// For each dx[i], we take a slice of dy. The offset and size of the
|
||||
// slice is given by offset[i] and shape[i].
|
||||
for (int i = 0; i < N; ++i) {
|
||||
nodes.push_back({{dx_i[i]},
|
||||
nodes.push_back({{strings::StrCat("dx_", i)},
|
||||
"Slice",
|
||||
{"dy", offset_i[i], shape_i[i]},
|
||||
{{"T", "$T"}, {"Index", DT_INT32}}});
|
||||
}
|
||||
if (dim_is_last_arg) {
|
||||
// clang-format off
|
||||
*g = FDH::Define(
|
||||
*g = FDH::Create(
|
||||
"_",
|
||||
// Arg defs
|
||||
{"x: N*T", "dim: int32", "dy: T"},
|
||||
// Ret val defs
|
||||
// Return signature
|
||||
{"dx: N*T", "d_dim: int32"},
|
||||
// Attr defs
|
||||
{"T: type", "N: int"},
|
||||
// Nodes
|
||||
nodes);
|
||||
nodes,
|
||||
// Return values
|
||||
{{"dx", "dx:output"}, {"d_dim", "d_dim:y:0"}});
|
||||
// clang-format on
|
||||
} else {
|
||||
// clang-format off
|
||||
*g = FDH::Define(
|
||||
*g = FDH::Create(
|
||||
"_",
|
||||
// Arg defs
|
||||
{"dim: int32", "x: N*T", "dy: T"},
|
||||
// Ret val defs
|
||||
// Return signature
|
||||
{"d_dim: int32", "dx: N*T"},
|
||||
// Attr defs
|
||||
{"T: type", "N: int"},
|
||||
// Nodes
|
||||
nodes);
|
||||
nodes,
|
||||
// Return values
|
||||
{{"dx", "dx:output"}, {"d_dim", "d_dim:y:0"}});
|
||||
// clang-format on
|
||||
}
|
||||
VLOG(1) << "ConcatGrad " << DebugString(*g);
|
||||
@ -399,15 +395,9 @@ Status SliceGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
{{"xs_b"}, "Sub", {"xs", "begin"}, {{"T", DT_INT32}}},
|
||||
{{"xs_b_s"}, "Sub", {"xs_b", "size"}, {{"T", DT_INT32}}},
|
||||
{{"a1"}, "ExpandDims", {"xs_b_s", "one"}, {{"T", DT_INT32}}},
|
||||
{{"b_and_a"},
|
||||
"_ListToArray",
|
||||
{"b1", "a1"},
|
||||
{{"T", DT_INT32},
|
||||
{"N", 2},
|
||||
{"Tin", DataTypeVector{DT_INT32, DT_INT32}}}},
|
||||
{{"paddings"},
|
||||
"Concat",
|
||||
{"one", "b_and_a"},
|
||||
{"one", "b1", "a1"},
|
||||
{{"N", 2}, {"T", DT_INT32}}},
|
||||
// dx = Pad(dy, paddings)
|
||||
{{"dx"}, "Pad", {"dy", "paddings"}, {{"T", "$T"}}},
|
||||
|
@ -314,6 +314,7 @@ Status GradForBinaryCwise(FunctionDef* g, std::vector<FDH::Node> body) {
|
||||
};
|
||||
nodes.insert(nodes.end(), body.begin(), body.end());
|
||||
std::vector<FDH::Node> reshapes = {
|
||||
{{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "sy"}},
|
||||
{{"sum_gx"}, "Sum", {"gx", "rx"}},
|
||||
{{"dx"}, "Reshape", {"sum_gx", "sx"}},
|
||||
{{"sum_gy"}, "Sum", {"gy", "ry"}},
|
||||
@ -323,12 +324,11 @@ Status GradForBinaryCwise(FunctionDef* g, std::vector<FDH::Node> body) {
|
||||
|
||||
// clang-format on
|
||||
for (auto& n : nodes) {
|
||||
if (n.attr.empty()) {
|
||||
// "BroadcastGradientArgs" doesn't need any attrs.
|
||||
if (n.attr.empty() && n.op != "BroadcastGradientArgs") {
|
||||
n.attr = {{"T", "$T"}};
|
||||
}
|
||||
}
|
||||
// "BroadcastGradientArgs" doesn't need any attrs.
|
||||
nodes.push_back({{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "sy"}});
|
||||
*g = FDH::Define(
|
||||
// Arg defs
|
||||
{"x: T", "y: T", "dz: T"},
|
||||
@ -518,18 +518,14 @@ Status GradForReductionOp(FunctionDef* g, std::vector<FDH::Node> body) {
|
||||
FDH::Const("zero", 0),
|
||||
FDH::Const("one", 1),
|
||||
// stitch_idx0 = Range(0, x_rank, 1)
|
||||
{{"stitch_idx1"}, "Identity", {"i"}, {{"T", DT_INT32}}},
|
||||
{{"stitch_idx"}, "_ListToArray", {"stitch_idx0", "stitch_idx1"},
|
||||
{{"Tin", DataTypeSlice{DT_INT32, DT_INT32}},
|
||||
{"T", DT_INT32}, {"N", 2}}},
|
||||
{{"stitch_val0"}, "Identity", {"x_shape"}, {{"T", DT_INT32}}},
|
||||
{{"stitch_val1"}, "Fill", {"i_shape", "one"}, {{"T", DT_INT32}}},
|
||||
{{"stitch_val"}, "_ListToArray", {"stitch_val0", "stitch_val1"},
|
||||
{{"Tin", DataTypeSlice{DT_INT32, DT_INT32}},
|
||||
{"T", DT_INT32}, {"N", 2}}},
|
||||
{{"y_shape"}, "DynamicStitch", {"stitch_idx", "stitch_val"},
|
||||
{{"stitch_val1"}, "Fill", {"i_shape:output:0", "one:output:0"},
|
||||
{{"T", DT_INT32}}},
|
||||
{{"y_shape"}, "DynamicStitch",
|
||||
{"stitch_idx0:output:0", "i",
|
||||
"x_shape:output:0", "stitch_val1:output:0"},
|
||||
{{"N", 2}, {"T", DT_INT32}}},
|
||||
{{"tile_scaling"}, "Div", {"x_shape", "y_shape"}, {{"T", DT_INT32}}},
|
||||
{{"tile_scaling"}, "Div", {"x_shape:output:0", "y_shape:merged:0"},
|
||||
{{"T", DT_INT32}}},
|
||||
{{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}}
|
||||
};
|
||||
// clang-format on
|
||||
@ -540,41 +536,46 @@ Status GradForReductionOp(FunctionDef* g, std::vector<FDH::Node> body) {
|
||||
}
|
||||
}
|
||||
// "Range" doesn't need any attr.
|
||||
nodes.push_back({{"stitch_idx0"}, "Range", {"zero", "x_rank", "one"}, {}});
|
||||
*g = FDH::Define(
|
||||
// Arg defs
|
||||
nodes.push_back({{"stitch_idx0"},
|
||||
"Range",
|
||||
{"zero:output:0", "x_rank:output:0", "one:output:0"},
|
||||
{}});
|
||||
*g = FDH::Create("_",
|
||||
// Input defs
|
||||
{"x:T", "i:int32", "dy:T"},
|
||||
// Ret val defs
|
||||
{"dx:T", "di:int32"},
|
||||
// Attr defs
|
||||
{{"T: {half, float, double}"}},
|
||||
// Nodes
|
||||
nodes);
|
||||
nodes,
|
||||
// Return values
|
||||
{{"dx", "dx:output:0"}, {"di", "di:y:0"}});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SumGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
// clang-format off
|
||||
return GradForReductionOp(g, {
|
||||
{{"dy_reshaped"}, "Reshape", {"dy", "y_shape"}},
|
||||
{{"dx"}, "Tile", {"dy_reshaped", "tile_scaling"}},
|
||||
{{"dy_reshaped"}, "Reshape", {"dy", "y_shape:merged:0"}},
|
||||
{{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
|
||||
});
|
||||
// clang-format on
|
||||
return Status::OK();
|
||||
}
|
||||
REGISTER_OP_GRADIENT("Sum", SumGrad);
|
||||
|
||||
Status MeanGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
// clang-format off
|
||||
return GradForReductionOp(g, {
|
||||
{{"factor"}, "Prod", {"tile_scaling", "zero"}, {{"T", DT_INT32}}},
|
||||
{{"factor_T"}, "Cast", {"factor"}, {{"SrcT", DT_INT32}, {"DstT", "$T"}}},
|
||||
{{"dy_scaled"}, "Div", {"dy", "factor_T"}},
|
||||
{{"dy_reshaped"}, "Reshape", {"dy_scaled", "y_shape"}},
|
||||
{{"dx"}, "Tile", {"dy_reshaped", "tile_scaling"}},
|
||||
{{"factor"}, "Prod", {"tile_scaling:z:0", "zero:output:0"},
|
||||
{{"T", DT_INT32}}},
|
||||
{{"factor_T"}, "Cast", {"factor:output:0"},
|
||||
{{"SrcT", DT_INT32}, {"DstT", "$T"}}},
|
||||
{{"dy_scaled"}, "Div", {"dy", "factor_T:y:0"}},
|
||||
{{"dy_reshaped"}, "Reshape", {"dy_scaled:z:0", "y_shape:merged:0"}},
|
||||
{{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
|
||||
});
|
||||
// clang-format on
|
||||
return Status::OK();
|
||||
}
|
||||
REGISTER_OP_GRADIENT("Mean", MeanGrad);
|
||||
|
||||
|
@ -66,7 +66,7 @@ class MathGradTest : public ::testing::Test {
|
||||
{"Tin", DataTypeSlice{T, T}},
|
||||
{"Tout", DataTypeSlice{T}},
|
||||
}},
|
||||
{{"dx"}, "Identity", {"grad:0"}, {{"T", T}}},
|
||||
{{"dx"}, "Identity", {"grad"}, {{"T", T}}},
|
||||
});
|
||||
// Each test case will feed in "x:0" and expects to get "dx:0".
|
||||
auto gdef = test::function::GDef(
|
||||
@ -120,7 +120,7 @@ class MathGradTest : public ::testing::Test {
|
||||
{
|
||||
FDH::Const("one", 1),
|
||||
{{"dz"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}},
|
||||
{{"grad"},
|
||||
{{"grad0", "grad1"},
|
||||
"SymbolicGradient",
|
||||
{"x", "y", "dz"},
|
||||
{
|
||||
@ -128,8 +128,8 @@ class MathGradTest : public ::testing::Test {
|
||||
{"Tin", DataTypeSlice{T, T, T}},
|
||||
{"Tout", DataTypeSlice{T, T}},
|
||||
}},
|
||||
{{"dx"}, "Identity", {"grad:0"}, {{"T", T}}},
|
||||
{{"dy"}, "Identity", {"grad:1"}, {{"T", T}}},
|
||||
{{"dx"}, "Identity", {"grad0"}, {{"T", T}}},
|
||||
{{"dy"}, "Identity", {"grad1"}, {{"T", T}}},
|
||||
});
|
||||
// Each test case will feed in "x:0" and "y:0" and expects to get "d0" and
|
||||
// "d:0".
|
||||
@ -177,7 +177,7 @@ class MathGradTest : public ::testing::Test {
|
||||
{
|
||||
FDH::Const("one", 1),
|
||||
{{"dy"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}},
|
||||
{{"grad"},
|
||||
{{"grad0", "grad1"},
|
||||
"SymbolicGradient",
|
||||
{"x", "i", "dy"},
|
||||
{
|
||||
@ -185,8 +185,8 @@ class MathGradTest : public ::testing::Test {
|
||||
{"Tin", DataTypeSlice{T, DT_INT32, T}},
|
||||
{"Tout", DataTypeSlice{T, DT_INT32}},
|
||||
}},
|
||||
{{"dx"}, "Identity", {"grad:0"}, {{"T", T}}},
|
||||
{{"di"}, "Identity", {"grad:1"}, {{"T", DT_INT32}}},
|
||||
{{"dx"}, "Identity", {"grad0"}, {{"T", T}}},
|
||||
{{"di"}, "Identity", {"grad1"}, {{"T", DT_INT32}}},
|
||||
});
|
||||
// Each test case will feed in "x:0" and expects to get "dx:0".
|
||||
auto gdef = test::function::GDef(
|
||||
@ -267,7 +267,7 @@ class MathGradTest : public ::testing::Test {
|
||||
{
|
||||
FDH::Const("one", 1),
|
||||
{{"dz"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}},
|
||||
{{"grad"},
|
||||
{{"grad0", "grad1"},
|
||||
"SymbolicGradient",
|
||||
{"x", "y", "dz"},
|
||||
{
|
||||
@ -275,8 +275,8 @@ class MathGradTest : public ::testing::Test {
|
||||
{"Tin", DataTypeSlice{T, T, T}},
|
||||
{"Tout", DataTypeSlice{T, T}},
|
||||
}},
|
||||
{{"dx"}, "Identity", {"grad:0"}, {{"T", T}}},
|
||||
{{"dy"}, "Identity", {"grad:1"}, {{"T", T}}},
|
||||
{{"dx"}, "Identity", {"grad0"}, {{"T", T}}},
|
||||
{{"dy"}, "Identity", {"grad1"}, {{"T", T}}},
|
||||
});
|
||||
// Each test case will feed in "x:0" and "y:0" and expects to get "d0" and
|
||||
// "d:0".
|
||||
@ -331,7 +331,7 @@ class MathGradTest : public ::testing::Test {
|
||||
auto grad = FDH::Define("TestGrad", {"c:bool", "x:float", "y:float"},
|
||||
{"dc:bool", "dx:float", "dy:float"}, {},
|
||||
{FDH::Const("dz", 1.f),
|
||||
{{"grad"},
|
||||
{{"grad0", "grad1", "grad2"},
|
||||
"SymbolicGradient",
|
||||
{"c", "x", "y", "dz"},
|
||||
{
|
||||
@ -339,9 +339,9 @@ class MathGradTest : public ::testing::Test {
|
||||
{"Tin", DataTypeSlice{DT_BOOL, T, T, T}},
|
||||
{"Tout", DataTypeSlice{DT_BOOL, T, T}},
|
||||
}},
|
||||
{{"dc"}, "Identity", {"grad:0"}, {{"T", DT_BOOL}}},
|
||||
{{"dx"}, "Identity", {"grad:1"}, {{"T", T}}},
|
||||
{{"dy"}, "Identity", {"grad:2"}, {{"T", T}}}});
|
||||
{{"dc"}, "Identity", {"grad0"}, {{"T", DT_BOOL}}},
|
||||
{{"dx"}, "Identity", {"grad1"}, {{"T", T}}},
|
||||
{{"dy"}, "Identity", {"grad2"}, {{"T", T}}}});
|
||||
// Each test case will feed in "x:0" and expects to get "dx:0".
|
||||
auto gdef = test::function::GDef(
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user