[TF-XLA] Include TF metedata in HLO GraphDef node's attributes.

This includes TF op names and op types.
Change: 153708854
This commit is contained in:
A. Unique TensorFlower 2017-04-20 06:55:26 -08:00 committed by TensorFlower Gardener
parent 3cca863359
commit 3a95e41426
2 changed files with 28 additions and 18 deletions

View File

@ -121,6 +121,10 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction,
attrs["type"].set_s(
xla::PrimitiveType_Name(instruction->shape().element_type()));
// Set the framework op (e.g. Tensorflow op) that generated this XLA op.
attrs["tf_op_type"].set_s(instruction->metadata().op_type());
attrs["tf_op_name"].set_s(instruction->metadata().op_name());
// Set the shape of the output tensor. "_output_shapes" is a special attribute
// name used by Tensorboard for shapes of output tensors.
tensorflow::AttrValue shapes;

View File

@ -53,6 +53,13 @@ class HloTfGraphBuilderTest : public HloTestBase {
Shape r0f32_ = ShapeUtil::MakeShape(PrimitiveType::F32, {});
};
static const tensorflow::AttrValue &GetNodeAttr(const tensorflow::NodeDef &node,
const string &attr_name) {
auto attr = node.attr().find(attr_name);
CHECK(attr != node.attr().end());
return attr->second;
}
TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) {
auto builder = HloComputation::Builder("Concatenate");
Shape shape = ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2});
@ -69,35 +76,34 @@ TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) {
EXPECT_EQ(node.name(), "Concatenate/concatenate");
// Check dimensions.
auto dims_value = node.attr().find("dims");
CHECK(dims_value != node.attr().end());
EXPECT_EQ(dims_value->second.list().i_size(), 1);
EXPECT_EQ(dims_value->second.list().i(0), 1);
auto dims_value = GetNodeAttr(node, "dims");
EXPECT_EQ(dims_value.list().i_size(), 1);
EXPECT_EQ(dims_value.list().i(0), 1);
// Check shapes.
auto shape_value = node.attr().find("_output_shapes");
CHECK(shape_value != node.attr().end());
EXPECT_EQ(shape_value->second.list().shape_size(), 1);
EXPECT_EQ(shape_value->second.list().shape(0).dim_size(), 2);
EXPECT_EQ(shape_value->second.list().shape(0).dim(0).size(), 2);
EXPECT_EQ(shape_value->second.list().shape(0).dim(1).size(), 4);
auto shape_value = GetNodeAttr(node, "_output_shapes");
EXPECT_EQ(shape_value.list().shape_size(), 1);
EXPECT_EQ(shape_value.list().shape(0).dim_size(), 2);
EXPECT_EQ(shape_value.list().shape(0).dim(0).size(), 2);
EXPECT_EQ(shape_value.list().shape(0).dim(1).size(), 4);
}
TEST_F(HloTfGraphBuilderTest, CheckScalarValue) {
auto builder = HloComputation::Builder("Const");
builder.AddInstruction(
HloInstruction *instruction = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(123)));
OpMetadata metadata;
metadata.set_op_name("x");
metadata.set_op_type("y");
instruction->set_metadata(metadata);
TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
GraphDef graph_def = generator_.GetGraphDef();
EXPECT_EQ(graph_def.node_size(), 1);
const auto &node = graph_def.node(0);
auto value = node.attr().find("value");
CHECK(value != node.attr().end());
EXPECT_EQ(value->second.s(), "123");
auto type = node.attr().find("type");
CHECK(type != node.attr().end());
EXPECT_EQ(type->second.s(), "S32");
EXPECT_EQ(GetNodeAttr(node, "value").s(), "123");
EXPECT_EQ(GetNodeAttr(node, "type").s(), "S32");
EXPECT_EQ(GetNodeAttr(node, "tf_op_name").s(), "x");
EXPECT_EQ(GetNodeAttr(node, "tf_op_type").s(), "y");
}
TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) {