[TF-XLA] Include TF metedata in HLO GraphDef node's attributes.
This includes TF op names and op types. Change: 153708854
This commit is contained in:
parent
3cca863359
commit
3a95e41426
@ -121,6 +121,10 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction,
|
||||
attrs["type"].set_s(
|
||||
xla::PrimitiveType_Name(instruction->shape().element_type()));
|
||||
|
||||
// Set the framework op (e.g. Tensorflow op) that generated this XLA op.
|
||||
attrs["tf_op_type"].set_s(instruction->metadata().op_type());
|
||||
attrs["tf_op_name"].set_s(instruction->metadata().op_name());
|
||||
|
||||
// Set the shape of the output tensor. "_output_shapes" is a special attribute
|
||||
// name used by Tensorboard for shapes of output tensors.
|
||||
tensorflow::AttrValue shapes;
|
||||
|
@ -53,6 +53,13 @@ class HloTfGraphBuilderTest : public HloTestBase {
|
||||
Shape r0f32_ = ShapeUtil::MakeShape(PrimitiveType::F32, {});
|
||||
};
|
||||
|
||||
static const tensorflow::AttrValue &GetNodeAttr(const tensorflow::NodeDef &node,
|
||||
const string &attr_name) {
|
||||
auto attr = node.attr().find(attr_name);
|
||||
CHECK(attr != node.attr().end());
|
||||
return attr->second;
|
||||
}
|
||||
|
||||
TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) {
|
||||
auto builder = HloComputation::Builder("Concatenate");
|
||||
Shape shape = ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2});
|
||||
@ -69,35 +76,34 @@ TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) {
|
||||
EXPECT_EQ(node.name(), "Concatenate/concatenate");
|
||||
|
||||
// Check dimensions.
|
||||
auto dims_value = node.attr().find("dims");
|
||||
CHECK(dims_value != node.attr().end());
|
||||
EXPECT_EQ(dims_value->second.list().i_size(), 1);
|
||||
EXPECT_EQ(dims_value->second.list().i(0), 1);
|
||||
auto dims_value = GetNodeAttr(node, "dims");
|
||||
EXPECT_EQ(dims_value.list().i_size(), 1);
|
||||
EXPECT_EQ(dims_value.list().i(0), 1);
|
||||
|
||||
// Check shapes.
|
||||
auto shape_value = node.attr().find("_output_shapes");
|
||||
CHECK(shape_value != node.attr().end());
|
||||
EXPECT_EQ(shape_value->second.list().shape_size(), 1);
|
||||
EXPECT_EQ(shape_value->second.list().shape(0).dim_size(), 2);
|
||||
EXPECT_EQ(shape_value->second.list().shape(0).dim(0).size(), 2);
|
||||
EXPECT_EQ(shape_value->second.list().shape(0).dim(1).size(), 4);
|
||||
auto shape_value = GetNodeAttr(node, "_output_shapes");
|
||||
EXPECT_EQ(shape_value.list().shape_size(), 1);
|
||||
EXPECT_EQ(shape_value.list().shape(0).dim_size(), 2);
|
||||
EXPECT_EQ(shape_value.list().shape(0).dim(0).size(), 2);
|
||||
EXPECT_EQ(shape_value.list().shape(0).dim(1).size(), 4);
|
||||
}
|
||||
|
||||
TEST_F(HloTfGraphBuilderTest, CheckScalarValue) {
|
||||
auto builder = HloComputation::Builder("Const");
|
||||
builder.AddInstruction(
|
||||
HloInstruction *instruction = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(123)));
|
||||
OpMetadata metadata;
|
||||
metadata.set_op_name("x");
|
||||
metadata.set_op_type("y");
|
||||
instruction->set_metadata(metadata);
|
||||
TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
|
||||
GraphDef graph_def = generator_.GetGraphDef();
|
||||
EXPECT_EQ(graph_def.node_size(), 1);
|
||||
const auto &node = graph_def.node(0);
|
||||
auto value = node.attr().find("value");
|
||||
CHECK(value != node.attr().end());
|
||||
EXPECT_EQ(value->second.s(), "123");
|
||||
|
||||
auto type = node.attr().find("type");
|
||||
CHECK(type != node.attr().end());
|
||||
EXPECT_EQ(type->second.s(), "S32");
|
||||
EXPECT_EQ(GetNodeAttr(node, "value").s(), "123");
|
||||
EXPECT_EQ(GetNodeAttr(node, "type").s(), "S32");
|
||||
EXPECT_EQ(GetNodeAttr(node, "tf_op_name").s(), "x");
|
||||
EXPECT_EQ(GetNodeAttr(node, "tf_op_type").s(), "y");
|
||||
}
|
||||
|
||||
TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) {
|
||||
|
Loading…
Reference in New Issue
Block a user