Migrate FunctionDefHelper::Define() to NodeDef functions. For callers
of Define() where that doesn't work, switch to Create() instead. Change: 142718606
This commit is contained in:
parent
e14f7841de
commit
a439d9975e
@ -359,38 +359,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
|
|||||||
delete g;
|
delete g;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FunctionLibraryRuntimeTest, ManySwapsOld) {
|
TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) {
|
||||||
auto func = FDH::Define( // Creates a FunctionDef using FunctionDef::Nodes
|
|
||||||
// Name
|
|
||||||
"ManySwapsFirst",
|
|
||||||
// Args
|
|
||||||
{"x: float", "y: float"},
|
|
||||||
// Return values
|
|
||||||
{"o: float"},
|
|
||||||
// attr def
|
|
||||||
{},
|
|
||||||
// Nodes
|
|
||||||
{{{"a0", "b0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}},
|
|
||||||
{{"a1", "b1"}, "Swap", {"a0", "b0"}, {{"T", DT_FLOAT}}},
|
|
||||||
{{"a2", "b2"}, "Swap", {"a1", "b1"}, {{"T", DT_FLOAT}}},
|
|
||||||
{{"a3", "b3"}, "Swap", {"a2", "b2"}, {{"T", DT_FLOAT}}},
|
|
||||||
{{"a4", "b4"}, "Swap", {"a3", "b3"}, {{"T", DT_FLOAT}}},
|
|
||||||
{{"a5", "b5"}, "Swap", {"a4", "b4"}, {{"T", DT_FLOAT}}},
|
|
||||||
{{"o"}, "Identity", {"a5"}, {{"T", DT_FLOAT}}}});
|
|
||||||
Init({test::function::Swap(), func});
|
|
||||||
Graph* g = GetFuncBody("ManySwapsFirst", {});
|
|
||||||
ASSERT_TRUE(g != nullptr);
|
|
||||||
OptimizeGraph(lib_, &g);
|
|
||||||
const char* e0 = R"P(
|
|
||||||
(n3:float, n2:float) -> (n3:float) {
|
|
||||||
}
|
|
||||||
)P";
|
|
||||||
EXPECT_EQ(e0, DebugString(g));
|
|
||||||
delete g;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Like the above test, but using NodeDefs in the FunctionDef.
|
|
||||||
TEST_F(FunctionLibraryRuntimeTest, DISABLED_ManySwapsNodeDef) {
|
|
||||||
auto func = FDH::Create( // Creates a FunctionDef using NodeDefs
|
auto func = FDH::Create( // Creates a FunctionDef using NodeDefs
|
||||||
// Name
|
// Name
|
||||||
"ManySwapsNodeDef",
|
"ManySwapsNodeDef",
|
||||||
@ -423,7 +392,7 @@ TEST_F(FunctionLibraryRuntimeTest, DISABLED_ManySwapsNodeDef) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
|
TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
|
||||||
auto func = FDH::Define(
|
auto func = FDH::Create(
|
||||||
// Name
|
// Name
|
||||||
"ManySwapsFirst",
|
"ManySwapsFirst",
|
||||||
// Args
|
// Args
|
||||||
@ -438,11 +407,12 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
|
|||||||
// y2 depends on the 2nd swap. The 2nd swap has data dependency
|
// y2 depends on the 2nd swap. The 2nd swap has data dependency
|
||||||
// on the 1st swap. The optimization should maintain the control
|
// on the 1st swap. The optimization should maintain the control
|
||||||
// dependencies.
|
// dependencies.
|
||||||
{{{"a0", "b0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}, {"x2"}},
|
{{{"a0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}, {"x2"}},
|
||||||
{{"a1", "b1"}, "Swap", {"a0", "b0"}, {{"T", DT_FLOAT}}},
|
{{"a1"}, "Swap", {"a0:o0:0", "a0:o1:0"}, {{"T", DT_FLOAT}}},
|
||||||
{{"x2"}, "Mul", {"x", "x"}, {{"T", DT_FLOAT}}},
|
{{"x2"}, "Mul", {"x", "x"}, {{"T", DT_FLOAT}}},
|
||||||
{{"y2"}, "Mul", {"y", "y"}, {{"T", DT_FLOAT}}, {"a1"}},
|
{{"y2"}, "Mul", {"y", "y"}, {{"T", DT_FLOAT}}, {"a1"}},
|
||||||
{{"o"}, "Add", {"x2", "y2"}, {{"T", DT_FLOAT}}}});
|
{{"o"}, "Add", {"x2:z:0", "y2:z:0"}, {{"T", DT_FLOAT}}}},
|
||||||
|
{{"o", "o:z:0"}});
|
||||||
Init({test::function::Swap(), func});
|
Init({test::function::Swap(), func});
|
||||||
Graph* g = GetFuncBody("ManySwapsFirst", {});
|
Graph* g = GetFuncBody("ManySwapsFirst", {});
|
||||||
ASSERT_TRUE(g != nullptr);
|
ASSERT_TRUE(g != nullptr);
|
||||||
@ -608,7 +578,7 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
|
|||||||
auto grad =
|
auto grad =
|
||||||
FDH::Define("TestGrad", {"x:float", "y:float"}, {"dx:float", "dy:float"},
|
FDH::Define("TestGrad", {"x:float", "y:float"}, {"dx:float", "dy:float"},
|
||||||
{}, {FDH::Const<float>("dz", 1),
|
{}, {FDH::Const<float>("dz", 1),
|
||||||
{{"grad"},
|
{{"grad0", "grad1"},
|
||||||
"SymbolicGradient",
|
"SymbolicGradient",
|
||||||
{"x", "y", "dz"},
|
{"x", "y", "dz"},
|
||||||
{
|
{
|
||||||
@ -616,8 +586,8 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
|
|||||||
{"Tin", DataTypeSlice{T, T, T}},
|
{"Tin", DataTypeSlice{T, T, T}},
|
||||||
{"Tout", DataTypeSlice{T, T}},
|
{"Tout", DataTypeSlice{T, T}},
|
||||||
}},
|
}},
|
||||||
{{"dx"}, "Identity", {"grad:0"}, {{"T", DT_FLOAT}}},
|
{{"dx"}, "Identity", {"grad0"}, {{"T", DT_FLOAT}}},
|
||||||
{{"dy"}, "Identity", {"grad:1"}, {{"T", DT_FLOAT}}}});
|
{{"dy"}, "Identity", {"grad1"}, {{"T", DT_FLOAT}}}});
|
||||||
|
|
||||||
Init({test, grad});
|
Init({test, grad});
|
||||||
|
|
||||||
@ -660,19 +630,19 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
|
|||||||
OptimizeGraph(lib_, &g);
|
OptimizeGraph(lib_, &g);
|
||||||
const char* e2 = R"P(
|
const char* e2 = R"P(
|
||||||
(n4:float, n3:float) -> (n25:float, n23:float) {
|
(n4:float, n3:float) -> (n25:float, n23:float) {
|
||||||
n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
|
|
||||||
n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]()
|
n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]()
|
||||||
n7 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
|
n8 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
|
||||||
|
n7 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
|
||||||
n19 = Shape[T=float, out_type=int32](n3)
|
n19 = Shape[T=float, out_type=int32](n3)
|
||||||
n8 = Add[T=float](n4, n3)
|
n9 = Add[T=float](n4, n3)
|
||||||
n20 = Shape[T=float, out_type=int32](n4)
|
n20 = Shape[T=float, out_type=int32](n4)
|
||||||
n9 = Rank[T=float](n8)
|
n10 = Rank[T=float](n9)
|
||||||
n14 = Shape[T=float, out_type=int32](n8)
|
n14 = Shape[T=float, out_type=int32](n9)
|
||||||
n21 = BroadcastGradientArgs[T=int32](n20, n19)
|
n21 = BroadcastGradientArgs[T=int32](n20, n19)
|
||||||
n10 = Range[Tidx=int32](n7, n9, n11)
|
n11 = Range[Tidx=int32](n8, n10, n7)
|
||||||
n12 = Shape[T=int32, out_type=int32](n10)
|
n12 = Shape[T=int32, out_type=int32](n11)
|
||||||
n13 = Fill[T=int32](n12, n11)
|
n13 = Fill[T=int32](n12, n7)
|
||||||
n15 = DynamicStitch[N=2, T=int32](n10, n10, n14, n13)
|
n15 = DynamicStitch[N=2, T=int32](n11, n11, n14, n13)
|
||||||
n16 = Reshape[T=float, Tshape=int32](n2, n15)
|
n16 = Reshape[T=float, Tshape=int32](n2, n15)
|
||||||
n17 = Div[T=int32](n14, n15)
|
n17 = Div[T=int32](n14, n15)
|
||||||
n18 = Tile[T=float, Tmultiples=int32](n16, n17)
|
n18 = Tile[T=float, Tmultiples=int32](n16, n17)
|
||||||
@ -842,33 +812,38 @@ TEST(OptimizationTest, RemoveIdentityNodes) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(OptimizationTest, RemoveListArrayConverter) {
|
TEST(OptimizationTest, RemoveListArrayConverter) {
|
||||||
auto func = FDH::Define(
|
auto func = FDH::Create(
|
||||||
// Name
|
// Name
|
||||||
"Test",
|
"Test",
|
||||||
// Args
|
// Args
|
||||||
{"i: float"},
|
{"i: float"},
|
||||||
// Return values
|
// Return signature
|
||||||
{"o: float"},
|
{"o: float"},
|
||||||
// Attrs
|
// Attrs
|
||||||
{},
|
{},
|
||||||
// Nodes
|
// Nodes
|
||||||
{FDH::Const("zero", 0),
|
{FDH::Const("zero", 0),
|
||||||
{{"s"}, "Split", {"zero", "i"}, {{"num_split", 4}, {"T", DT_FLOAT}}},
|
{{"s"},
|
||||||
|
"Split",
|
||||||
|
{"zero:output:0", "i"},
|
||||||
|
{{"num_split", 4}, {"T", DT_FLOAT}}},
|
||||||
{{"a"},
|
{{"a"},
|
||||||
"_ArrayToList",
|
"_ArrayToList",
|
||||||
{"s"},
|
{"s:output"},
|
||||||
{{"N", 4},
|
{{"N", 4},
|
||||||
{"T", DT_FLOAT},
|
{"T", DT_FLOAT},
|
||||||
{"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}},
|
{"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}},
|
||||||
{{"l"}, "Mul", {"a:0", "a:1"}, {{"T", DT_FLOAT}}},
|
{{"l"}, "Mul", {"a:output:0", "a:output:1"}, {{"T", DT_FLOAT}}},
|
||||||
{{"r"}, "Mul", {"a:2", "a:3"}, {{"T", DT_FLOAT}}},
|
{{"r"}, "Mul", {"a:output:2", "a:output:3"}, {{"T", DT_FLOAT}}},
|
||||||
{{"x"},
|
{{"x"},
|
||||||
"_ListToArray",
|
"_ListToArray",
|
||||||
{"l", "r"},
|
{"l:z", "r:z"},
|
||||||
{{"N", 2},
|
{{"N", 2},
|
||||||
{"T", DT_FLOAT},
|
{"T", DT_FLOAT},
|
||||||
{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
|
{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
|
||||||
{{"o"}, "AddN", {"x"}, {{"N", 2}, {"T", DT_FLOAT}}}});
|
{{"o"}, "AddN", {"x:output"}, {{"N", 2}, {"T", DT_FLOAT}}}},
|
||||||
|
// Return values
|
||||||
|
{{"o", "o:sum"}});
|
||||||
|
|
||||||
const char* e0 = R"P(
|
const char* e0 = R"P(
|
||||||
(n0:float) -> (n7:float) {
|
(n0:float) -> (n7:float) {
|
||||||
@ -916,7 +891,7 @@ TEST(OptimizationTest, RemoveListArrayConverter) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
|
TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
|
||||||
auto func = FDH::Define(
|
auto func = FDH::Create(
|
||||||
// Name
|
// Name
|
||||||
"Test",
|
"Test",
|
||||||
// Args
|
// Args
|
||||||
@ -935,10 +910,11 @@ TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
|
|||||||
{"dummy"}},
|
{"dummy"}},
|
||||||
{{"o"},
|
{{"o"},
|
||||||
"AddN",
|
"AddN",
|
||||||
{"x"},
|
{"x:output"},
|
||||||
{{"N", 2}, {"T", DT_FLOAT}},
|
{{"N", 2}, {"T", DT_FLOAT}},
|
||||||
// Control dep
|
// Control dep
|
||||||
{"x"}}});
|
{"x"}}},
|
||||||
|
{{"o", "o:sum"}});
|
||||||
|
|
||||||
const char* e0 = R"P(
|
const char* e0 = R"P(
|
||||||
(n0:float) -> (n3:float) {
|
(n0:float) -> (n3:float) {
|
||||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
|
|
||||||
#include <unordered_set>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/function.pb_text.h"
|
#include "tensorflow/core/framework/function.pb_text.h"
|
||||||
@ -1110,35 +1110,17 @@ FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionDef::Node FunctionDefHelper::Node::ToProto() const {
|
|
||||||
FunctionDef::Node n;
|
|
||||||
for (const string& r : this->ret) {
|
|
||||||
n.add_ret(r);
|
|
||||||
}
|
|
||||||
n.set_op(this->op);
|
|
||||||
for (const string& a : arg) {
|
|
||||||
n.add_arg(a);
|
|
||||||
}
|
|
||||||
for (const auto& a : this->attr) {
|
|
||||||
n.mutable_attr()->insert({a.first, a.second.proto});
|
|
||||||
}
|
|
||||||
for (const string& d : dep) {
|
|
||||||
n.add_dep(d);
|
|
||||||
}
|
|
||||||
return n;
|
|
||||||
}
|
|
||||||
|
|
||||||
NodeDef FunctionDefHelper::Node::ToNodeDef() const {
|
NodeDef FunctionDefHelper::Node::ToNodeDef() const {
|
||||||
NodeDef n;
|
NodeDef n;
|
||||||
n.set_op(this->op);
|
n.set_op(this->op);
|
||||||
n.set_name(this->ret[0]);
|
n.set_name(this->ret[0]);
|
||||||
for (const string& a : arg) {
|
|
||||||
n.add_input(a);
|
|
||||||
}
|
|
||||||
for (const auto& a : this->attr) {
|
for (const auto& a : this->attr) {
|
||||||
n.mutable_attr()->insert({a.first, a.second.proto});
|
n.mutable_attr()->insert({a.first, a.second.proto});
|
||||||
}
|
}
|
||||||
for (const string& d : dep) {
|
for (const string& a : this->arg) {
|
||||||
|
n.add_input(a);
|
||||||
|
}
|
||||||
|
for (const string& d : this->dep) {
|
||||||
n.add_input(strings::StrCat("^", d));
|
n.add_input(strings::StrCat("^", d));
|
||||||
}
|
}
|
||||||
return n;
|
return n;
|
||||||
@ -1189,8 +1171,56 @@ FunctionDef FunctionDefHelper::Define(const string& name,
|
|||||||
OpRegistrationData op_reg_data;
|
OpRegistrationData op_reg_data;
|
||||||
TF_CHECK_OK(b.Finalize(&op_reg_data));
|
TF_CHECK_OK(b.Finalize(&op_reg_data));
|
||||||
fdef.mutable_signature()->Swap(&op_reg_data.op_def);
|
fdef.mutable_signature()->Swap(&op_reg_data.op_def);
|
||||||
for (const auto& n : node_def) {
|
|
||||||
*(fdef.add_node()) = n.ToProto();
|
// Mapping from legacy output names to NodeDef outputs.
|
||||||
|
std::unordered_map<string, string> ret_index;
|
||||||
|
for (const auto& a : fdef.signature().input_arg()) {
|
||||||
|
ret_index[a.name()] = a.name();
|
||||||
|
}
|
||||||
|
|
||||||
|
// For looking up OpDefs
|
||||||
|
auto* op_def_registry = OpRegistry::Global();
|
||||||
|
|
||||||
|
// Function body
|
||||||
|
for (const auto& src : node_def) {
|
||||||
|
NodeDef* n = fdef.add_node_def();
|
||||||
|
n->set_op(src.op);
|
||||||
|
n->set_name(src.ret[0]);
|
||||||
|
for (const auto& a : src.attr) {
|
||||||
|
n->mutable_attr()->insert({a.first, a.second.proto});
|
||||||
|
}
|
||||||
|
for (const string& a : src.arg) {
|
||||||
|
const auto iter = ret_index.find(a);
|
||||||
|
CHECK(iter != ret_index.end()) << "Node input '" << a << "' in '"
|
||||||
|
<< src.ret[0] << "' of " << name;
|
||||||
|
n->add_input(iter->second);
|
||||||
|
}
|
||||||
|
for (const string& d : src.dep) {
|
||||||
|
n->add_input(strings::StrCat("^", d));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the outputs of this node to ret_index.
|
||||||
|
const OpDef* op_def = nullptr;
|
||||||
|
TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
|
||||||
|
CHECK(op_def != nullptr) << n->op();
|
||||||
|
NameRangeMap output_names;
|
||||||
|
TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
|
||||||
|
for (const auto& o : output_names) {
|
||||||
|
CHECK_LE(o.second.second, src.ret.size())
|
||||||
|
<< "Missing ret for output '" << o.first << "' in '" << src.ret[0]
|
||||||
|
<< "' of " << name;
|
||||||
|
for (int i = o.second.first; i < o.second.second; ++i) {
|
||||||
|
ret_index[src.ret[i]] =
|
||||||
|
strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns
|
||||||
|
for (const auto& r : fdef.signature().output_arg()) {
|
||||||
|
const auto iter = ret_index.find(r.name());
|
||||||
|
CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
|
||||||
|
fdef.mutable_ret()->insert({r.name(), iter->second});
|
||||||
}
|
}
|
||||||
return fdef;
|
return fdef;
|
||||||
}
|
}
|
||||||
|
@ -108,7 +108,6 @@ class FunctionDefHelper {
|
|||||||
std::vector<std::pair<string, AttrValueWrapper>> attr;
|
std::vector<std::pair<string, AttrValueWrapper>> attr;
|
||||||
std::vector<string> dep;
|
std::vector<string> dep;
|
||||||
|
|
||||||
FunctionDef::Node ToProto() const;
|
|
||||||
NodeDef ToNodeDef() const;
|
NodeDef ToNodeDef() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -71,7 +71,8 @@ TEST(TFunc, SquarePlusOneOld) {
|
|||||||
SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
||||||
a = Square[T=$T](x)
|
a = Square[T=$T](x)
|
||||||
o = One[T=$T]()
|
o = One[T=$T]()
|
||||||
y = Add[T=$T](a, o)
|
y = Add[T=$T](a:y:0, o:y:0)
|
||||||
|
return y = y:z:0
|
||||||
}
|
}
|
||||||
)P";
|
)P";
|
||||||
EXPECT_EQ(DebugString(fdef), e);
|
EXPECT_EQ(DebugString(fdef), e);
|
||||||
@ -91,7 +92,7 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
|||||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TFunc, DISABLED_SquarePlusOneNodeDef) {
|
TEST(TFunc, SquarePlusOneNodeDef) {
|
||||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||||
// Name
|
// Name
|
||||||
"SquarePlusOne",
|
"SquarePlusOne",
|
||||||
@ -137,7 +138,7 @@ SquarePlusOne[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
|||||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TFunc, DISABLED_ControlDepNodeDef) {
|
TEST(TFunc, ControlDepNodeDef) {
|
||||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||||
// Name
|
// Name
|
||||||
"ControlDep",
|
"ControlDep",
|
||||||
@ -206,6 +207,7 @@ TEST(TFunc, MissingTypeAttrOld) {
|
|||||||
const char* e = R"P(
|
const char* e = R"P(
|
||||||
BackCompat() -> (y:float) {
|
BackCompat() -> (y:float) {
|
||||||
y = HasDefaultType()
|
y = HasDefaultType()
|
||||||
|
return y = y:out:0
|
||||||
}
|
}
|
||||||
)P";
|
)P";
|
||||||
EXPECT_EQ(DebugString(fdef), e);
|
EXPECT_EQ(DebugString(fdef), e);
|
||||||
@ -224,7 +226,7 @@ BackCompat() -> (y:float) {
|
|||||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TFunc, DISABLED_MissingTypeAttrNodeDef) {
|
TEST(TFunc, MissingTypeAttrNodeDef) {
|
||||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||||
// Name
|
// Name
|
||||||
"BackCompat",
|
"BackCompat",
|
||||||
@ -262,7 +264,7 @@ BackCompat() -> (y:float) {
|
|||||||
EXPECT_EQ(DebugString(result.gdef), e2);
|
EXPECT_EQ(DebugString(result.gdef), e2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TFunc, DISABLED_NTimesTNodeDef) {
|
TEST(TFunc, NTimesTNodeDef) {
|
||||||
// Note that the equivalent FunctionDef using FunctionDef::Node requires
|
// Note that the equivalent FunctionDef using FunctionDef::Node requires
|
||||||
// using a _ListToArray to package up the two inputs to AddN as a single
|
// using a _ListToArray to package up the two inputs to AddN as a single
|
||||||
// N*T edge.
|
// N*T edge.
|
||||||
@ -320,7 +322,7 @@ y: N tensors, each of type U;
|
|||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
TEST(TFunc, AddSquared) {
|
TEST(TFunc, AddSquared) {
|
||||||
auto fdef = FDH::Define(
|
auto fdef = FDH::Create(
|
||||||
// Name
|
// Name
|
||||||
"AddSquared",
|
"AddSquared",
|
||||||
// Args
|
// Args
|
||||||
@ -339,12 +341,14 @@ TEST(TFunc, AddSquared) {
|
|||||||
{"U", "$T"},
|
{"U", "$T"},
|
||||||
{"N", "$N"}}},
|
{"N", "$N"}}},
|
||||||
// y = AddN<N=$N,T=$T>(a)
|
// y = AddN<N=$N,T=$T>(a)
|
||||||
{{"y"}, "AddN", {"a"}, {{"N", "$N"}, {"T", "$T"}}}});
|
{{"y"}, "AddN", {"a:y"}, {{"N", "$N"}, {"T", "$T"}}}},
|
||||||
|
{{"y", "y:sum"}});
|
||||||
|
|
||||||
const char* e = R"P(
|
const char* e = R"P(
|
||||||
AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) {
|
AddSquared[N:int, T:{float, double, int32, int64}](x:N*T) -> (y:T) {
|
||||||
a = Map[N=$N, T=$T, U=$T, func=Square[T=$T]](x)
|
a = Map[N=$N, T=$T, U=$T, func=Square[T=$T]](x)
|
||||||
y = AddN[N=$N, T=$T](a)
|
y = AddN[N=$N, T=$T](a:y)
|
||||||
|
return y = y:sum
|
||||||
}
|
}
|
||||||
)P";
|
)P";
|
||||||
EXPECT_EQ(DebugString(fdef), e);
|
EXPECT_EQ(DebugString(fdef), e);
|
||||||
@ -413,8 +417,9 @@ TEST(TFunc, XTimesTwo) {
|
|||||||
auto expect = R"P(
|
auto expect = R"P(
|
||||||
XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
||||||
two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
|
two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
|
||||||
scale = Cast[DstT=$T, SrcT=int64](two)
|
scale = Cast[DstT=$T, SrcT=int64](two:output:0)
|
||||||
y = Mul[T=$T](x, scale)
|
y = Mul[T=$T](x, scale:y:0)
|
||||||
|
return y = y:z:0
|
||||||
}
|
}
|
||||||
)P";
|
)P";
|
||||||
EXPECT_EQ(expect, DebugString(test::function::XTimesTwo()));
|
EXPECT_EQ(expect, DebugString(test::function::XTimesTwo()));
|
||||||
@ -424,7 +429,8 @@ TEST(TFunc, WXPlusB) {
|
|||||||
auto expect = R"P(
|
auto expect = R"P(
|
||||||
WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
|
WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
|
||||||
mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x)
|
mm = MatMul[T=$T, _kernel="eigen", transpose_a=false, transpose_b=false](w, x)
|
||||||
y = Add[T=$T](mm, b)
|
y = Add[T=$T](mm:product:0, b)
|
||||||
|
return y = y:z:0
|
||||||
}
|
}
|
||||||
)P";
|
)P";
|
||||||
EXPECT_EQ(expect, DebugString(test::function::WXPlusB()));
|
EXPECT_EQ(expect, DebugString(test::function::WXPlusB()));
|
||||||
@ -432,7 +438,7 @@ WXPlusB[T:{float, double}](w:T, x:T, b:T) -> (y:T) {
|
|||||||
|
|
||||||
TEST(TFunc, Body_TypeList) {
|
TEST(TFunc, Body_TypeList) {
|
||||||
const Tensor kZero = test::AsScalar<int32>(0);
|
const Tensor kZero = test::AsScalar<int32>(0);
|
||||||
auto fdef = FDH::Define(
|
auto fdef = FDH::Create(
|
||||||
// Name
|
// Name
|
||||||
"Test",
|
"Test",
|
||||||
// Args
|
// Args
|
||||||
@ -443,32 +449,30 @@ TEST(TFunc, Body_TypeList) {
|
|||||||
{},
|
{},
|
||||||
// Nodes
|
// Nodes
|
||||||
{{{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT32}}},
|
{{{"zero"}, "Const", {}, {{"value", kZero}, {"dtype", DT_INT32}}},
|
||||||
{{"s"}, "Split", {"zero", "i"}, {{"num_split", 4}, {"T", DT_FLOAT}}},
|
{{"s"},
|
||||||
{{"lst"},
|
"Split",
|
||||||
"_ArrayToList",
|
{"zero:output:0", "i"},
|
||||||
{"s"},
|
{{"num_split", 4}, {"T", DT_FLOAT}}},
|
||||||
{{"N", 4},
|
{{"l"}, "Mul", {"s:output:0", "s:output:1"}, {{"T", DT_FLOAT}}},
|
||||||
{"T", DT_FLOAT},
|
{{"r"}, "Mul", {"s:output:2", "s:output:3"}, {{"T", DT_FLOAT}}},
|
||||||
{"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}},
|
|
||||||
{{"l"}, "Mul", {"lst:0", "lst:1"}, {{"T", DT_FLOAT}}},
|
|
||||||
{{"r"}, "Mul", {"lst:2", "lst:3"}, {{"T", DT_FLOAT}}},
|
|
||||||
{{"x"},
|
{{"x"},
|
||||||
"_ListToArray",
|
"_ListToArray",
|
||||||
{"l", "r"},
|
{"l:z", "r:z"},
|
||||||
{{"N", 2},
|
{{"N", 2},
|
||||||
{"T", DT_FLOAT},
|
{"T", DT_FLOAT},
|
||||||
{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
|
{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
|
||||||
{{"o"}, "AddN", {"x"}, {{"N", 2}, {"T", DT_FLOAT}}}});
|
{{"o"}, "AddN", {"x:output"}, {{"N", 2}, {"T", DT_FLOAT}}}},
|
||||||
|
{{"o", "o:sum:0"}});
|
||||||
|
|
||||||
const char* e = R"P(
|
const char* e = R"P(
|
||||||
Test(i:float) -> (o:float) {
|
Test(i:float) -> (o:float) {
|
||||||
zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
|
zero = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
|
||||||
s = Split[T=float, num_split=4](zero, i)
|
s = Split[T=float, num_split=4](zero:output:0, i)
|
||||||
lst = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](s)
|
l = Mul[T=float](s:output:0, s:output:1)
|
||||||
l = Mul[T=float](lst:0, lst:1)
|
r = Mul[T=float](s:output:2, s:output:3)
|
||||||
r = Mul[T=float](lst:2, lst:3)
|
x = _ListToArray[N=2, T=float, Tin={float, float}](l:z, r:z)
|
||||||
x = _ListToArray[N=2, T=float, Tin={float, float}](l, r)
|
o = AddN[N=2, T=float](x:output)
|
||||||
o = AddN[N=2, T=float](x)
|
return o = o:sum:0
|
||||||
}
|
}
|
||||||
)P";
|
)P";
|
||||||
EXPECT_EQ(DebugString(fdef), e);
|
EXPECT_EQ(DebugString(fdef), e);
|
||||||
@ -476,14 +480,13 @@ Test(i:float) -> (o:float) {
|
|||||||
InstantiationResult result;
|
InstantiationResult result;
|
||||||
TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
|
TF_ASSERT_OK(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result));
|
||||||
const char* e2 = R"P(
|
const char* e2 = R"P(
|
||||||
(n0:float) -> (n7:float) {
|
(n0:float) -> (n6:float) {
|
||||||
n1 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
|
n1 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
|
||||||
n2 = Split[T=float, num_split=4](n1, n0)
|
n2 = Split[T=float, num_split=4](n1, n0)
|
||||||
n3 = _ArrayToList[N=4, T=float, out_types={float, float, float, float}](n2, n2:1, n2:2, n2:3)
|
n3 = Mul[T=float](n2, n2:1)
|
||||||
n4 = Mul[T=float](n3, n3:1)
|
n4 = Mul[T=float](n2:2, n2:3)
|
||||||
n5 = Mul[T=float](n3:2, n3:3)
|
n5 = _ListToArray[N=2, T=float, Tin={float, float}](n3, n4)
|
||||||
n6 = _ListToArray[N=2, T=float, Tin={float, float}](n4, n5)
|
n6 = AddN[N=2, T=float](n5, n5:1)
|
||||||
n7 = AddN[N=2, T=float](n6, n6:1)
|
|
||||||
}
|
}
|
||||||
)P";
|
)P";
|
||||||
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
|
EXPECT_EQ(result.arg_types, DataTypeVector({DT_FLOAT}));
|
||||||
@ -540,7 +543,8 @@ TEST(TFunc, Body_Array_List_Converter) {
|
|||||||
const char* e = R"P(
|
const char* e = R"P(
|
||||||
MySelect(x:float) -> (z:float) {
|
MySelect(x:float) -> (z:float) {
|
||||||
y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x)
|
y = Cond[Tin={float}, cond=MyCond, else_branch=MyElse, out_types={float}, then_branch=MyThen](x)
|
||||||
z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y, y)
|
z = Cond[Tin={float, float}, cond=MyCond2, else_branch=MyElse2, out_types={float}, then_branch=MyThen2](y:output:0, y:output:0)
|
||||||
|
return z = z:output:0
|
||||||
}
|
}
|
||||||
)P";
|
)P";
|
||||||
EXPECT_EQ(DebugString(fdef), e);
|
EXPECT_EQ(DebugString(fdef), e);
|
||||||
@ -608,16 +612,6 @@ TEST(InstantiateErrors, DupArgs) {
|
|||||||
"Duplicated arg name");
|
"Duplicated arg name");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, Dup_Arg_Node_Name) {
|
|
||||||
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
|
|
||||||
{
|
|
||||||
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
|
|
||||||
});
|
|
||||||
InstantiationResult result;
|
|
||||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
|
||||||
"Duplicated ret name");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(InstantiateErrors, Dup_Node_Names) {
|
TEST(InstantiateErrors, Dup_Node_Names) {
|
||||||
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
|
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
|
||||||
{
|
{
|
||||||
@ -629,44 +623,25 @@ TEST(InstantiateErrors, Dup_Node_Names) {
|
|||||||
"Duplicated ret name");
|
"Duplicated ret name");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, Node_Signature_Mismatch_NoOp) {
|
|
||||||
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
|
|
||||||
{
|
|
||||||
{{"y", "z"}, "NoOp", {}, {{"T", DT_FLOAT}}},
|
|
||||||
});
|
|
||||||
InstantiationResult result;
|
|
||||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
|
||||||
"Expect one ret name");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(InstantiateErrors, Node_Signature_Mismatch) {
|
|
||||||
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
|
|
||||||
{
|
|
||||||
{{"y", "z"}, "One", {}, {{"T", DT_FLOAT}}},
|
|
||||||
});
|
|
||||||
InstantiationResult result;
|
|
||||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
|
||||||
"Malformed function node (#ret)");
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(InstantiateErrors, Node_Arg_Notfound) {
|
TEST(InstantiateErrors, Node_Arg_Notfound) {
|
||||||
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
|
auto fdef = FDH::Create("test", {"x:float"}, {}, {},
|
||||||
{
|
{
|
||||||
{{"y"}, "Add", {"x", "z"}, {{"T", DT_FLOAT}}},
|
{{"y"}, "Add", {"x", "z"}, {{"T", DT_FLOAT}}},
|
||||||
});
|
},
|
||||||
|
{});
|
||||||
InstantiationResult result;
|
InstantiationResult result;
|
||||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||||
"arg[1] is not found");
|
"input z is not found");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, Node_Arg_Mismatch) {
|
TEST(InstantiateErrors, Node_Arg_TypeMismatch) {
|
||||||
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
|
auto fdef = FDH::Define("test", {"x:float"}, {}, {},
|
||||||
{
|
{
|
||||||
{{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}},
|
{{"y"}, "Add", {"x", "x"}, {{"T", DT_INT32}}},
|
||||||
});
|
});
|
||||||
InstantiationResult result;
|
InstantiationResult result;
|
||||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||||
"Invalid arg(0) for function arg");
|
"input x[0] expected type int32 != float, the type of x[0]");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, Node_Arg_ControlMissing) {
|
TEST(InstantiateErrors, Node_Arg_ControlMissing) {
|
||||||
@ -677,20 +652,55 @@ TEST(InstantiateErrors, Node_Arg_ControlMissing) {
|
|||||||
});
|
});
|
||||||
InstantiationResult result;
|
InstantiationResult result;
|
||||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||||
"dep[0] is not found");
|
"input[2] == '^z', is not found.");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, FuncRet_Missing) {
|
TEST(InstantiateErrors, FuncRet_Missing) {
|
||||||
auto fdef = FDH::Define("test", {}, {"y: float"}, {},
|
auto fdef = FDH::Create("test", {}, {"y: float"}, {},
|
||||||
{
|
{
|
||||||
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
|
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
|
||||||
});
|
},
|
||||||
|
{});
|
||||||
InstantiationResult result;
|
InstantiationResult result;
|
||||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||||
"ret is not found");
|
"Return y missing");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, FuncRet_Mismatch) {
|
TEST(InstantiateErrors, FuncRet_NotFound) {
|
||||||
|
auto fdef = FDH::Create("test", {}, {"y: float"}, {},
|
||||||
|
{
|
||||||
|
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
|
||||||
|
},
|
||||||
|
{{"y", "z"}});
|
||||||
|
InstantiationResult result;
|
||||||
|
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||||
|
"Return y -> z is not found");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(InstantiateErrors, FuncRet_NameMismatch) {
|
||||||
|
auto fdef = FDH::Create("test", {}, {"y: float"}, {},
|
||||||
|
{
|
||||||
|
{{"x"}, "One", {}, {{"T", DT_FLOAT}}},
|
||||||
|
},
|
||||||
|
{{"z", "x:y:0"}});
|
||||||
|
InstantiationResult result;
|
||||||
|
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||||
|
"Return y missing");
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(josh11b): Make this an error.
|
||||||
|
// TEST(InstantiateErrors, FuncRet_Extra) {
|
||||||
|
// auto fdef = FDH::Create("test", {}, {"y: float"}, {},
|
||||||
|
// {
|
||||||
|
// {{"x"}, "One", {}, {{"T", DT_FLOAT}}},
|
||||||
|
// },
|
||||||
|
// {{"y", "x:y:0"}, {"z", "x:y:0"}});
|
||||||
|
// InstantiationResult result;
|
||||||
|
// HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||||
|
// "ret is not found");
|
||||||
|
// }
|
||||||
|
|
||||||
|
TEST(InstantiateErrors, FuncRet_TypeMismatch) {
|
||||||
auto fdef = FDH::Define("test", {}, {"y: float"}, {},
|
auto fdef = FDH::Define("test", {}, {"y: float"}, {},
|
||||||
{
|
{
|
||||||
{{"y"}, "One", {}, {{"T", DT_DOUBLE}}},
|
{{"y"}, "One", {}, {{"T", DT_DOUBLE}}},
|
||||||
@ -701,7 +711,7 @@ TEST(InstantiateErrors, FuncRet_Mismatch) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) {
|
TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) {
|
||||||
auto fdef = FDH::Define(
|
auto fdef = FDH::Create(
|
||||||
// Name
|
// Name
|
||||||
"MySelect",
|
"MySelect",
|
||||||
// Args
|
// Args
|
||||||
@ -719,14 +729,15 @@ TEST(InstantiateErrors, TypeList_Missing_Retval_Attr) {
|
|||||||
{"cond", FDH::FunctionRef("MyCond2")},
|
{"cond", FDH::FunctionRef("MyCond2")},
|
||||||
{"then_branch", FDH::FunctionRef("MyThen2")},
|
{"then_branch", FDH::FunctionRef("MyThen2")},
|
||||||
{"else_branch", FDH::FunctionRef("MyElse2")}}},
|
{"else_branch", FDH::FunctionRef("MyElse2")}}},
|
||||||
});
|
},
|
||||||
|
{{"y", "y:output"}});
|
||||||
InstantiationResult result;
|
InstantiationResult result;
|
||||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||||
"type attr not found: out_types");
|
"type attr not found: out_types");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) {
|
TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) {
|
||||||
auto fdef = FDH::Define(
|
auto fdef = FDH::Create(
|
||||||
// Name
|
// Name
|
||||||
"MySelect",
|
"MySelect",
|
||||||
// Args
|
// Args
|
||||||
@ -745,14 +756,15 @@ TEST(InstantiateErrors, TypeList_Num_Retval_Mismatch) {
|
|||||||
{"cond", FDH::FunctionRef("MyCond2")},
|
{"cond", FDH::FunctionRef("MyCond2")},
|
||||||
{"then_branch", FDH::FunctionRef("MyThen2")},
|
{"then_branch", FDH::FunctionRef("MyThen2")},
|
||||||
{"else_branch", FDH::FunctionRef("MyElse2")}}},
|
{"else_branch", FDH::FunctionRef("MyElse2")}}},
|
||||||
});
|
},
|
||||||
|
{{"y", "y:output"}});
|
||||||
InstantiationResult result;
|
InstantiationResult result;
|
||||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||||
"Invalid ret types");
|
"Invalid ret types");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, TypeList_Missing_Arg) {
|
TEST(InstantiateErrors, TypeList_Missing_Arg) {
|
||||||
auto fdef = FDH::Define(
|
auto fdef = FDH::Create(
|
||||||
// Name
|
// Name
|
||||||
"MySelect",
|
"MySelect",
|
||||||
// Args
|
// Args
|
||||||
@ -771,13 +783,14 @@ TEST(InstantiateErrors, TypeList_Missing_Arg) {
|
|||||||
{"cond", FDH::FunctionRef("MyCond2")},
|
{"cond", FDH::FunctionRef("MyCond2")},
|
||||||
{"then_branch", FDH::FunctionRef("MyThen2")},
|
{"then_branch", FDH::FunctionRef("MyThen2")},
|
||||||
{"else_branch", FDH::FunctionRef("MyElse2")}}},
|
{"else_branch", FDH::FunctionRef("MyElse2")}}},
|
||||||
});
|
},
|
||||||
|
{{"y", "y:output"}});
|
||||||
InstantiationResult result;
|
InstantiationResult result;
|
||||||
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
HasError(InstantiateFunction(fdef, kNoAttrs, GetOpSig, &result),
|
||||||
"arg[1] is not found");
|
"input unknown is not found");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputs) {
|
TEST(InstantiateErrors, NodeDef_TooManyInputs) {
|
||||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||||
// Name
|
// Name
|
||||||
"TooManyInputs",
|
"TooManyInputs",
|
||||||
@ -798,7 +811,7 @@ TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputs) {
|
|||||||
"Expected input[2] == 'x' to be a control input.");
|
"Expected input[2] == 'x' to be a control input.");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, DISABLED_NodeDef_TooFewInputs) {
|
TEST(InstantiateErrors, NodeDef_TooFewInputs) {
|
||||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||||
// Name
|
// Name
|
||||||
"TooFewInputs",
|
"TooFewInputs",
|
||||||
@ -819,7 +832,7 @@ TEST(InstantiateErrors, DISABLED_NodeDef_TooFewInputs) {
|
|||||||
"Attempt to access beyond input size: 2 >= 2");
|
"Attempt to access beyond input size: 2 >= 2");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputsFromArray1) {
|
TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray1) {
|
||||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||||
// Name
|
// Name
|
||||||
"TooManyInputsFromArray",
|
"TooManyInputsFromArray",
|
||||||
@ -847,7 +860,7 @@ TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputsFromArray1) {
|
|||||||
"Expected input[1] == 'y' to be a control input.");
|
"Expected input[1] == 'y' to be a control input.");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputsFromArray2) {
|
TEST(InstantiateErrors, NodeDef_TooManyInputsFromArray2) {
|
||||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||||
// Name
|
// Name
|
||||||
"TooManyInputsFromArray",
|
"TooManyInputsFromArray",
|
||||||
@ -875,7 +888,7 @@ TEST(InstantiateErrors, DISABLED_NodeDef_TooManyInputsFromArray2) {
|
|||||||
"Input a:output too long for inputs");
|
"Input a:output too long for inputs");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(InstantiateErrors, DISABLED_NodeDef_TypeMismatch) {
|
TEST(InstantiateErrors, NodeDef_TypeMismatch) {
|
||||||
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
auto fdef = FDH::Create( // Create a FunctionDef using NodeDefs.
|
||||||
// Name
|
// Name
|
||||||
"TypeMismatch",
|
"TypeMismatch",
|
||||||
@ -968,8 +981,9 @@ TEST(FunctionLibraryDefinitionTest, Find) {
|
|||||||
auto expect = R"P(
|
auto expect = R"P(
|
||||||
XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
XTimesTwo[T:{float, double, int32, int64}](x:T) -> (y:T) {
|
||||||
two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
|
two = Const[dtype=int64, value=Tensor<type: int64 shape: [] values: 2>]()
|
||||||
scale = Cast[DstT=$T, SrcT=int64](two)
|
scale = Cast[DstT=$T, SrcT=int64](two:output:0)
|
||||||
y = Mul[T=$T](x, scale)
|
y = Mul[T=$T](x, scale:y:0)
|
||||||
|
return y = y:z:0
|
||||||
}
|
}
|
||||||
)P";
|
)P";
|
||||||
auto found = lib_def.Find("XTimesTwo");
|
auto found = lib_def.Find("XTimesTwo");
|
||||||
|
@ -91,7 +91,7 @@ FunctionDef XTimesTwo() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
FunctionDef XTimesFour() {
|
FunctionDef XTimesFour() {
|
||||||
return FDH::Define(
|
return FDH::Create(
|
||||||
// Name
|
// Name
|
||||||
"XTimesFour",
|
"XTimesFour",
|
||||||
// Args
|
// Args
|
||||||
@ -103,12 +103,13 @@ FunctionDef XTimesFour() {
|
|||||||
// Nodes
|
// Nodes
|
||||||
{
|
{
|
||||||
{{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}},
|
{{"x2"}, "XTimesTwo", {"x"}, {{"T", "$T"}}},
|
||||||
{{"y"}, "XTimesTwo", {"x2"}, {{"T", "$T"}}},
|
{{"y"}, "XTimesTwo", {"x2:y:0"}, {{"T", "$T"}}},
|
||||||
});
|
},
|
||||||
|
{{"y", "y:y:0"}});
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionDef XTimes16() {
|
FunctionDef XTimes16() {
|
||||||
return FDH::Define(
|
return FDH::Create(
|
||||||
// Name
|
// Name
|
||||||
"XTimes16",
|
"XTimes16",
|
||||||
// Args
|
// Args
|
||||||
@ -120,8 +121,9 @@ FunctionDef XTimes16() {
|
|||||||
// Nodes
|
// Nodes
|
||||||
{
|
{
|
||||||
{{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}},
|
{{"x4"}, "XTimesFour", {"x"}, {{"T", "$T"}}},
|
||||||
{{"y"}, "XTimesFour", {"x4"}, {{"T", "$T"}}},
|
{{"y"}, "XTimesFour", {"x4:y:0"}, {{"T", "$T"}}},
|
||||||
});
|
},
|
||||||
|
{{"y", "y:y:0"}});
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionDef WXPlusB() {
|
FunctionDef WXPlusB() {
|
||||||
|
@ -90,7 +90,8 @@ REGISTER_OP_GRADIENT("Identity", IdentityGrad);
|
|||||||
|
|
||||||
Status PackGrad(const AttrSlice& attrs, FunctionDef* g) {
|
Status PackGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
*g = FDH::Define(
|
*g = FDH::Create(
|
||||||
|
"_",
|
||||||
// Arg defs
|
// Arg defs
|
||||||
{"x: N*T", "dy: T"},
|
{"x: N*T", "dy: T"},
|
||||||
// Ret val defs
|
// Ret val defs
|
||||||
@ -105,7 +106,8 @@ Status PackGrad(const AttrSlice& attrs, FunctionDef* g) {
|
|||||||
{"dy"},
|
{"dy"},
|
||||||
{{"T", "$T"}, {"num", "$N"}, {"axis", "$axis"}}
|
{{"T", "$T"}, {"num", "$N"}, {"axis", "$axis"}}
|
||||||
},
|
},
|
||||||
});
|
},
|
||||||
|
{{"dx", "dx:output"}});
|
||||||
// clang-format on
|
// clang-format on
|
||||||
VLOG(1) << "PackGrad " << DebugString(*g);
|
VLOG(1) << "PackGrad " << DebugString(*g);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -147,9 +149,9 @@ Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g,
|
|||||||
std::vector<string> offset_i;
|
std::vector<string> offset_i;
|
||||||
std::vector<string> dx_i;
|
std::vector<string> dx_i;
|
||||||
for (int i = 0; i < N; ++i) {
|
for (int i = 0; i < N; ++i) {
|
||||||
shape_i.push_back(strings::StrCat("shape_lst:", i));
|
shape_i.push_back(strings::StrCat("shapes:output:", i));
|
||||||
offset_i.push_back(strings::StrCat("offset_lst:", i));
|
offset_i.push_back(strings::StrCat("offset:offset:", i));
|
||||||
dx_i.push_back(strings::StrCat("dx_", i));
|
dx_i.push_back(strings::StrCat("dx_", i, ":output:0"));
|
||||||
}
|
}
|
||||||
DataTypeVector dtype_list(N, T);
|
DataTypeVector dtype_list(N, T);
|
||||||
|
|
||||||
@ -160,19 +162,7 @@ Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g,
|
|||||||
// which is the same as dx[i]'s offset within dy.
|
// which is the same as dx[i]'s offset within dy.
|
||||||
std::vector<FDH::Node> nodes{
|
std::vector<FDH::Node> nodes{
|
||||||
{{"shapes"}, "ShapeN", {"x"}, {{"T", "$T"}, {"N", "$N"}}},
|
{{"shapes"}, "ShapeN", {"x"}, {{"T", "$T"}, {"N", "$N"}}},
|
||||||
{{"shape_lst"},
|
{{"offset"}, "ConcatOffset", {"dim", "shapes:output"}, {{"N", "$N"}}},
|
||||||
"_ArrayToList",
|
|
||||||
{"shapes"},
|
|
||||||
{{"T", DT_INT32},
|
|
||||||
{"N", "$N"},
|
|
||||||
{"out_types", DataTypeVector(N, DT_INT32)}}},
|
|
||||||
{{"offset"}, "ConcatOffset", {"dim", "shapes"}, {{"N", "$N"}}},
|
|
||||||
{{"offset_lst"},
|
|
||||||
"_ArrayToList",
|
|
||||||
{"offset"},
|
|
||||||
{{"T", DT_INT32},
|
|
||||||
{"N", "$N"},
|
|
||||||
{"out_types", DataTypeVector(N, DT_INT32)}}},
|
|
||||||
{{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}},
|
{{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}},
|
||||||
{{"dx"},
|
{{"dx"},
|
||||||
"_ListToArray",
|
"_ListToArray",
|
||||||
@ -182,34 +172,40 @@ Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g,
|
|||||||
// For each dx[i], we take a slice of dy. The offset and size of the
|
// For each dx[i], we take a slice of dy. The offset and size of the
|
||||||
// slice is given by offset[i] and shape[i].
|
// slice is given by offset[i] and shape[i].
|
||||||
for (int i = 0; i < N; ++i) {
|
for (int i = 0; i < N; ++i) {
|
||||||
nodes.push_back({{dx_i[i]},
|
nodes.push_back({{strings::StrCat("dx_", i)},
|
||||||
"Slice",
|
"Slice",
|
||||||
{"dy", offset_i[i], shape_i[i]},
|
{"dy", offset_i[i], shape_i[i]},
|
||||||
{{"T", "$T"}, {"Index", DT_INT32}}});
|
{{"T", "$T"}, {"Index", DT_INT32}}});
|
||||||
}
|
}
|
||||||
if (dim_is_last_arg) {
|
if (dim_is_last_arg) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
*g = FDH::Define(
|
*g = FDH::Create(
|
||||||
|
"_",
|
||||||
// Arg defs
|
// Arg defs
|
||||||
{"x: N*T", "dim: int32", "dy: T"},
|
{"x: N*T", "dim: int32", "dy: T"},
|
||||||
// Ret val defs
|
// Return signature
|
||||||
{"dx: N*T", "d_dim: int32"},
|
{"dx: N*T", "d_dim: int32"},
|
||||||
// Attr defs
|
// Attr defs
|
||||||
{"T: type", "N: int"},
|
{"T: type", "N: int"},
|
||||||
// Nodes
|
// Nodes
|
||||||
nodes);
|
nodes,
|
||||||
|
// Return values
|
||||||
|
{{"dx", "dx:output"}, {"d_dim", "d_dim:y:0"}});
|
||||||
// clang-format on
|
// clang-format on
|
||||||
} else {
|
} else {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
*g = FDH::Define(
|
*g = FDH::Create(
|
||||||
|
"_",
|
||||||
// Arg defs
|
// Arg defs
|
||||||
{"dim: int32", "x: N*T", "dy: T"},
|
{"dim: int32", "x: N*T", "dy: T"},
|
||||||
// Ret val defs
|
// Return signature
|
||||||
{"d_dim: int32", "dx: N*T"},
|
{"d_dim: int32", "dx: N*T"},
|
||||||
// Attr defs
|
// Attr defs
|
||||||
{"T: type", "N: int"},
|
{"T: type", "N: int"},
|
||||||
// Nodes
|
// Nodes
|
||||||
nodes);
|
nodes,
|
||||||
|
// Return values
|
||||||
|
{{"dx", "dx:output"}, {"d_dim", "d_dim:y:0"}});
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
VLOG(1) << "ConcatGrad " << DebugString(*g);
|
VLOG(1) << "ConcatGrad " << DebugString(*g);
|
||||||
@ -399,15 +395,9 @@ Status SliceGrad(const AttrSlice& attrs, FunctionDef* g) {
|
|||||||
{{"xs_b"}, "Sub", {"xs", "begin"}, {{"T", DT_INT32}}},
|
{{"xs_b"}, "Sub", {"xs", "begin"}, {{"T", DT_INT32}}},
|
||||||
{{"xs_b_s"}, "Sub", {"xs_b", "size"}, {{"T", DT_INT32}}},
|
{{"xs_b_s"}, "Sub", {"xs_b", "size"}, {{"T", DT_INT32}}},
|
||||||
{{"a1"}, "ExpandDims", {"xs_b_s", "one"}, {{"T", DT_INT32}}},
|
{{"a1"}, "ExpandDims", {"xs_b_s", "one"}, {{"T", DT_INT32}}},
|
||||||
{{"b_and_a"},
|
|
||||||
"_ListToArray",
|
|
||||||
{"b1", "a1"},
|
|
||||||
{{"T", DT_INT32},
|
|
||||||
{"N", 2},
|
|
||||||
{"Tin", DataTypeVector{DT_INT32, DT_INT32}}}},
|
|
||||||
{{"paddings"},
|
{{"paddings"},
|
||||||
"Concat",
|
"Concat",
|
||||||
{"one", "b_and_a"},
|
{"one", "b1", "a1"},
|
||||||
{{"N", 2}, {"T", DT_INT32}}},
|
{{"N", 2}, {"T", DT_INT32}}},
|
||||||
// dx = Pad(dy, paddings)
|
// dx = Pad(dy, paddings)
|
||||||
{{"dx"}, "Pad", {"dy", "paddings"}, {{"T", "$T"}}},
|
{{"dx"}, "Pad", {"dy", "paddings"}, {{"T", "$T"}}},
|
||||||
|
@ -314,6 +314,7 @@ Status GradForBinaryCwise(FunctionDef* g, std::vector<FDH::Node> body) {
|
|||||||
};
|
};
|
||||||
nodes.insert(nodes.end(), body.begin(), body.end());
|
nodes.insert(nodes.end(), body.begin(), body.end());
|
||||||
std::vector<FDH::Node> reshapes = {
|
std::vector<FDH::Node> reshapes = {
|
||||||
|
{{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "sy"}},
|
||||||
{{"sum_gx"}, "Sum", {"gx", "rx"}},
|
{{"sum_gx"}, "Sum", {"gx", "rx"}},
|
||||||
{{"dx"}, "Reshape", {"sum_gx", "sx"}},
|
{{"dx"}, "Reshape", {"sum_gx", "sx"}},
|
||||||
{{"sum_gy"}, "Sum", {"gy", "ry"}},
|
{{"sum_gy"}, "Sum", {"gy", "ry"}},
|
||||||
@ -323,12 +324,11 @@ Status GradForBinaryCwise(FunctionDef* g, std::vector<FDH::Node> body) {
|
|||||||
|
|
||||||
// clang-format on
|
// clang-format on
|
||||||
for (auto& n : nodes) {
|
for (auto& n : nodes) {
|
||||||
if (n.attr.empty()) {
|
// "BroadcastGradientArgs" doesn't need any attrs.
|
||||||
|
if (n.attr.empty() && n.op != "BroadcastGradientArgs") {
|
||||||
n.attr = {{"T", "$T"}};
|
n.attr = {{"T", "$T"}};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// "BroadcastGradientArgs" doesn't need any attrs.
|
|
||||||
nodes.push_back({{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "sy"}});
|
|
||||||
*g = FDH::Define(
|
*g = FDH::Define(
|
||||||
// Arg defs
|
// Arg defs
|
||||||
{"x: T", "y: T", "dz: T"},
|
{"x: T", "y: T", "dz: T"},
|
||||||
@ -518,18 +518,14 @@ Status GradForReductionOp(FunctionDef* g, std::vector<FDH::Node> body) {
|
|||||||
FDH::Const("zero", 0),
|
FDH::Const("zero", 0),
|
||||||
FDH::Const("one", 1),
|
FDH::Const("one", 1),
|
||||||
// stitch_idx0 = Range(0, x_rank, 1)
|
// stitch_idx0 = Range(0, x_rank, 1)
|
||||||
{{"stitch_idx1"}, "Identity", {"i"}, {{"T", DT_INT32}}},
|
{{"stitch_val1"}, "Fill", {"i_shape:output:0", "one:output:0"},
|
||||||
{{"stitch_idx"}, "_ListToArray", {"stitch_idx0", "stitch_idx1"},
|
{{"T", DT_INT32}}},
|
||||||
{{"Tin", DataTypeSlice{DT_INT32, DT_INT32}},
|
{{"y_shape"}, "DynamicStitch",
|
||||||
{"T", DT_INT32}, {"N", 2}}},
|
{"stitch_idx0:output:0", "i",
|
||||||
{{"stitch_val0"}, "Identity", {"x_shape"}, {{"T", DT_INT32}}},
|
"x_shape:output:0", "stitch_val1:output:0"},
|
||||||
{{"stitch_val1"}, "Fill", {"i_shape", "one"}, {{"T", DT_INT32}}},
|
|
||||||
{{"stitch_val"}, "_ListToArray", {"stitch_val0", "stitch_val1"},
|
|
||||||
{{"Tin", DataTypeSlice{DT_INT32, DT_INT32}},
|
|
||||||
{"T", DT_INT32}, {"N", 2}}},
|
|
||||||
{{"y_shape"}, "DynamicStitch", {"stitch_idx", "stitch_val"},
|
|
||||||
{{"N", 2}, {"T", DT_INT32}}},
|
{{"N", 2}, {"T", DT_INT32}}},
|
||||||
{{"tile_scaling"}, "Div", {"x_shape", "y_shape"}, {{"T", DT_INT32}}},
|
{{"tile_scaling"}, "Div", {"x_shape:output:0", "y_shape:merged:0"},
|
||||||
|
{{"T", DT_INT32}}},
|
||||||
{{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}}
|
{{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}}
|
||||||
};
|
};
|
||||||
// clang-format on
|
// clang-format on
|
||||||
@ -540,41 +536,46 @@ Status GradForReductionOp(FunctionDef* g, std::vector<FDH::Node> body) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// "Range" doesn't need any attr.
|
// "Range" doesn't need any attr.
|
||||||
nodes.push_back({{"stitch_idx0"}, "Range", {"zero", "x_rank", "one"}, {}});
|
nodes.push_back({{"stitch_idx0"},
|
||||||
*g = FDH::Define(
|
"Range",
|
||||||
// Arg defs
|
{"zero:output:0", "x_rank:output:0", "one:output:0"},
|
||||||
|
{}});
|
||||||
|
*g = FDH::Create("_",
|
||||||
|
// Input defs
|
||||||
{"x:T", "i:int32", "dy:T"},
|
{"x:T", "i:int32", "dy:T"},
|
||||||
// Ret val defs
|
// Ret val defs
|
||||||
{"dx:T", "di:int32"},
|
{"dx:T", "di:int32"},
|
||||||
// Attr defs
|
// Attr defs
|
||||||
{{"T: {half, float, double}"}},
|
{{"T: {half, float, double}"}},
|
||||||
// Nodes
|
// Nodes
|
||||||
nodes);
|
nodes,
|
||||||
|
// Return values
|
||||||
|
{{"dx", "dx:output:0"}, {"di", "di:y:0"}});
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SumGrad(const AttrSlice& attrs, FunctionDef* g) {
|
Status SumGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
return GradForReductionOp(g, {
|
return GradForReductionOp(g, {
|
||||||
{{"dy_reshaped"}, "Reshape", {"dy", "y_shape"}},
|
{{"dy_reshaped"}, "Reshape", {"dy", "y_shape:merged:0"}},
|
||||||
{{"dx"}, "Tile", {"dy_reshaped", "tile_scaling"}},
|
{{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
|
||||||
});
|
});
|
||||||
// clang-format on
|
// clang-format on
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
REGISTER_OP_GRADIENT("Sum", SumGrad);
|
REGISTER_OP_GRADIENT("Sum", SumGrad);
|
||||||
|
|
||||||
Status MeanGrad(const AttrSlice& attrs, FunctionDef* g) {
|
Status MeanGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
return GradForReductionOp(g, {
|
return GradForReductionOp(g, {
|
||||||
{{"factor"}, "Prod", {"tile_scaling", "zero"}, {{"T", DT_INT32}}},
|
{{"factor"}, "Prod", {"tile_scaling:z:0", "zero:output:0"},
|
||||||
{{"factor_T"}, "Cast", {"factor"}, {{"SrcT", DT_INT32}, {"DstT", "$T"}}},
|
{{"T", DT_INT32}}},
|
||||||
{{"dy_scaled"}, "Div", {"dy", "factor_T"}},
|
{{"factor_T"}, "Cast", {"factor:output:0"},
|
||||||
{{"dy_reshaped"}, "Reshape", {"dy_scaled", "y_shape"}},
|
{{"SrcT", DT_INT32}, {"DstT", "$T"}}},
|
||||||
{{"dx"}, "Tile", {"dy_reshaped", "tile_scaling"}},
|
{{"dy_scaled"}, "Div", {"dy", "factor_T:y:0"}},
|
||||||
|
{{"dy_reshaped"}, "Reshape", {"dy_scaled:z:0", "y_shape:merged:0"}},
|
||||||
|
{{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
|
||||||
});
|
});
|
||||||
// clang-format on
|
// clang-format on
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
REGISTER_OP_GRADIENT("Mean", MeanGrad);
|
REGISTER_OP_GRADIENT("Mean", MeanGrad);
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ class MathGradTest : public ::testing::Test {
|
|||||||
{"Tin", DataTypeSlice{T, T}},
|
{"Tin", DataTypeSlice{T, T}},
|
||||||
{"Tout", DataTypeSlice{T}},
|
{"Tout", DataTypeSlice{T}},
|
||||||
}},
|
}},
|
||||||
{{"dx"}, "Identity", {"grad:0"}, {{"T", T}}},
|
{{"dx"}, "Identity", {"grad"}, {{"T", T}}},
|
||||||
});
|
});
|
||||||
// Each test case will feed in "x:0" and expects to get "dx:0".
|
// Each test case will feed in "x:0" and expects to get "dx:0".
|
||||||
auto gdef = test::function::GDef(
|
auto gdef = test::function::GDef(
|
||||||
@ -120,7 +120,7 @@ class MathGradTest : public ::testing::Test {
|
|||||||
{
|
{
|
||||||
FDH::Const("one", 1),
|
FDH::Const("one", 1),
|
||||||
{{"dz"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}},
|
{{"dz"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}},
|
||||||
{{"grad"},
|
{{"grad0", "grad1"},
|
||||||
"SymbolicGradient",
|
"SymbolicGradient",
|
||||||
{"x", "y", "dz"},
|
{"x", "y", "dz"},
|
||||||
{
|
{
|
||||||
@ -128,8 +128,8 @@ class MathGradTest : public ::testing::Test {
|
|||||||
{"Tin", DataTypeSlice{T, T, T}},
|
{"Tin", DataTypeSlice{T, T, T}},
|
||||||
{"Tout", DataTypeSlice{T, T}},
|
{"Tout", DataTypeSlice{T, T}},
|
||||||
}},
|
}},
|
||||||
{{"dx"}, "Identity", {"grad:0"}, {{"T", T}}},
|
{{"dx"}, "Identity", {"grad0"}, {{"T", T}}},
|
||||||
{{"dy"}, "Identity", {"grad:1"}, {{"T", T}}},
|
{{"dy"}, "Identity", {"grad1"}, {{"T", T}}},
|
||||||
});
|
});
|
||||||
// Each test case will feed in "x:0" and "y:0" and expects to get "d0" and
|
// Each test case will feed in "x:0" and "y:0" and expects to get "d0" and
|
||||||
// "d:0".
|
// "d:0".
|
||||||
@ -177,7 +177,7 @@ class MathGradTest : public ::testing::Test {
|
|||||||
{
|
{
|
||||||
FDH::Const("one", 1),
|
FDH::Const("one", 1),
|
||||||
{{"dy"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}},
|
{{"dy"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}},
|
||||||
{{"grad"},
|
{{"grad0", "grad1"},
|
||||||
"SymbolicGradient",
|
"SymbolicGradient",
|
||||||
{"x", "i", "dy"},
|
{"x", "i", "dy"},
|
||||||
{
|
{
|
||||||
@ -185,8 +185,8 @@ class MathGradTest : public ::testing::Test {
|
|||||||
{"Tin", DataTypeSlice{T, DT_INT32, T}},
|
{"Tin", DataTypeSlice{T, DT_INT32, T}},
|
||||||
{"Tout", DataTypeSlice{T, DT_INT32}},
|
{"Tout", DataTypeSlice{T, DT_INT32}},
|
||||||
}},
|
}},
|
||||||
{{"dx"}, "Identity", {"grad:0"}, {{"T", T}}},
|
{{"dx"}, "Identity", {"grad0"}, {{"T", T}}},
|
||||||
{{"di"}, "Identity", {"grad:1"}, {{"T", DT_INT32}}},
|
{{"di"}, "Identity", {"grad1"}, {{"T", DT_INT32}}},
|
||||||
});
|
});
|
||||||
// Each test case will feed in "x:0" and expects to get "dx:0".
|
// Each test case will feed in "x:0" and expects to get "dx:0".
|
||||||
auto gdef = test::function::GDef(
|
auto gdef = test::function::GDef(
|
||||||
@ -267,7 +267,7 @@ class MathGradTest : public ::testing::Test {
|
|||||||
{
|
{
|
||||||
FDH::Const("one", 1),
|
FDH::Const("one", 1),
|
||||||
{{"dz"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}},
|
{{"dz"}, "Cast", {"one"}, {{"DstT", T}, {"SrcT", DT_INT32}}},
|
||||||
{{"grad"},
|
{{"grad0", "grad1"},
|
||||||
"SymbolicGradient",
|
"SymbolicGradient",
|
||||||
{"x", "y", "dz"},
|
{"x", "y", "dz"},
|
||||||
{
|
{
|
||||||
@ -275,8 +275,8 @@ class MathGradTest : public ::testing::Test {
|
|||||||
{"Tin", DataTypeSlice{T, T, T}},
|
{"Tin", DataTypeSlice{T, T, T}},
|
||||||
{"Tout", DataTypeSlice{T, T}},
|
{"Tout", DataTypeSlice{T, T}},
|
||||||
}},
|
}},
|
||||||
{{"dx"}, "Identity", {"grad:0"}, {{"T", T}}},
|
{{"dx"}, "Identity", {"grad0"}, {{"T", T}}},
|
||||||
{{"dy"}, "Identity", {"grad:1"}, {{"T", T}}},
|
{{"dy"}, "Identity", {"grad1"}, {{"T", T}}},
|
||||||
});
|
});
|
||||||
// Each test case will feed in "x:0" and "y:0" and expects to get "d0" and
|
// Each test case will feed in "x:0" and "y:0" and expects to get "d0" and
|
||||||
// "d:0".
|
// "d:0".
|
||||||
@ -331,7 +331,7 @@ class MathGradTest : public ::testing::Test {
|
|||||||
auto grad = FDH::Define("TestGrad", {"c:bool", "x:float", "y:float"},
|
auto grad = FDH::Define("TestGrad", {"c:bool", "x:float", "y:float"},
|
||||||
{"dc:bool", "dx:float", "dy:float"}, {},
|
{"dc:bool", "dx:float", "dy:float"}, {},
|
||||||
{FDH::Const("dz", 1.f),
|
{FDH::Const("dz", 1.f),
|
||||||
{{"grad"},
|
{{"grad0", "grad1", "grad2"},
|
||||||
"SymbolicGradient",
|
"SymbolicGradient",
|
||||||
{"c", "x", "y", "dz"},
|
{"c", "x", "y", "dz"},
|
||||||
{
|
{
|
||||||
@ -339,9 +339,9 @@ class MathGradTest : public ::testing::Test {
|
|||||||
{"Tin", DataTypeSlice{DT_BOOL, T, T, T}},
|
{"Tin", DataTypeSlice{DT_BOOL, T, T, T}},
|
||||||
{"Tout", DataTypeSlice{DT_BOOL, T, T}},
|
{"Tout", DataTypeSlice{DT_BOOL, T, T}},
|
||||||
}},
|
}},
|
||||||
{{"dc"}, "Identity", {"grad:0"}, {{"T", DT_BOOL}}},
|
{{"dc"}, "Identity", {"grad0"}, {{"T", DT_BOOL}}},
|
||||||
{{"dx"}, "Identity", {"grad:1"}, {{"T", T}}},
|
{{"dx"}, "Identity", {"grad1"}, {{"T", T}}},
|
||||||
{{"dy"}, "Identity", {"grad:2"}, {{"T", T}}}});
|
{{"dy"}, "Identity", {"grad2"}, {{"T", T}}}});
|
||||||
// Each test case will feed in "x:0" and expects to get "dx:0".
|
// Each test case will feed in "x:0" and expects to get "dx:0".
|
||||||
auto gdef = test::function::GDef(
|
auto gdef = test::function::GDef(
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user