Merge pull request #40540 from Agoniii:cathyx/amp

PiperOrigin-RevId: 317163221
Change-Id: Ic126a1f533d9dd95098425b1eeee3e03acfdd17d
This commit is contained in:
TensorFlower Gardener 2020-06-18 13:06:36 -07:00
commit 543d7c47a0

View File

@ -287,10 +287,10 @@ TEST_F(AutoMixedPrecisionTest, Simple) {
Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1);
Output wht1 = ops::MatMul(s.WithOpName("wht1"), clr2, clr2);
Output clr3 = ops::Relu(s.WithOpName("clr3"), wht1);
Output blk2 = ops::Log(s.WithOpName("blk2"), clr3);
Output clr4 = ops::Relu(s.WithOpName("clr4"), blk2);
Output blk3 = ops::SparseMatMul(s.WithOpName("blk3"), clr4, clr4);
Output clr5 = ops::Relu(s.WithOpName("clr5"), blk3);
Output gry2 = ops::Log(s.WithOpName("gry2"), clr3);
Output clr4 = ops::Relu(s.WithOpName("clr4"), gry2);
Output blk2 = ops::SparseMatMul(s.WithOpName("blk2"), clr4, clr4);
Output clr5 = ops::Relu(s.WithOpName("clr5"), blk2);
Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
GrapplerItem item;
@ -313,10 +313,10 @@ TEST_F(AutoMixedPrecisionTest, Simple) {
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("gry2")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("blk3")->attr().at("Ta").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("blk3")->attr().at("Tb").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("Ta").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("Tb").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
auto tensors = EvaluateNodes(output, item.fetch);