Added a test to make sure that graph properties for variables are properly
reported PiperOrigin-RevId: 158053084
This commit is contained in:
parent
2ccfe8e764
commit
0df6760fe9
@ -149,6 +149,54 @@ TEST_F(GraphPropertiesTest, DynamicProperties) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, Variables) {
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(NodeDefBuilder("Var", "Variable")
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("shape", TensorShape({3, 7}))
|
||||
.Finalize(item.graph.add_node()));
|
||||
item.fetch.push_back("Var");
|
||||
|
||||
Tensor initial_val(DT_FLOAT, TensorShape({3, 7}));
|
||||
TF_CHECK_OK(NodeDefBuilder("InitialVal", "Const")
|
||||
.Attr("dtype", DT_FLOAT)
|
||||
.Attr("value", initial_val)
|
||||
.Finalize(item.graph.add_node()));
|
||||
TF_CHECK_OK(NodeDefBuilder("InitVar", "Assign")
|
||||
.Input("Var", 0, DT_FLOAT_REF)
|
||||
.Input("InitialVal", 0, DT_FLOAT)
|
||||
.Finalize(item.graph.add_node()));
|
||||
item.init_ops.push_back("InitVar");
|
||||
|
||||
{
|
||||
GraphProperties static_properties(item);
|
||||
TF_CHECK_OK(static_properties.InferStatically());
|
||||
|
||||
const auto props = static_properties.GetOutputProperties("Var");
|
||||
EXPECT_EQ(1, props.size());
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
|
||||
EXPECT_FALSE(prop.shape().unknown_rank());
|
||||
EXPECT_EQ(2, prop.shape().dim_size());
|
||||
EXPECT_EQ(3, prop.shape().dim(0).size());
|
||||
EXPECT_EQ(7, prop.shape().dim(1).size());
|
||||
}
|
||||
{
|
||||
TF_CHECK_OK(cluster_->Initialize(item));
|
||||
GraphProperties dynamic_properties(item);
|
||||
TF_CHECK_OK(dynamic_properties.InferDynamically(cluster_.get()));
|
||||
|
||||
const auto props = dynamic_properties.GetOutputProperties("Var");
|
||||
EXPECT_EQ(1, props.size());
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
EXPECT_EQ(DT_FLOAT_REF, prop.dtype());
|
||||
EXPECT_FALSE(prop.shape().unknown_rank());
|
||||
EXPECT_EQ(2, prop.shape().dim_size());
|
||||
EXPECT_EQ(3, prop.shape().dim(0).size());
|
||||
EXPECT_EQ(7, prop.shape().dim(1).size());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphPropertiesTest, VarHandles) {
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(NodeDefBuilder("Var", "VarHandleOp")
|
||||
|
Loading…
Reference in New Issue
Block a user