[TF-XLA] Include TF metedata in HLO GraphDef node's attributes.
This includes TF op names and op types. Change: 153708854
This commit is contained in:
parent
3cca863359
commit
3a95e41426
@ -121,6 +121,10 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction,
|
|||||||
attrs["type"].set_s(
|
attrs["type"].set_s(
|
||||||
xla::PrimitiveType_Name(instruction->shape().element_type()));
|
xla::PrimitiveType_Name(instruction->shape().element_type()));
|
||||||
|
|
||||||
|
// Set the framework op (e.g. Tensorflow op) that generated this XLA op.
|
||||||
|
attrs["tf_op_type"].set_s(instruction->metadata().op_type());
|
||||||
|
attrs["tf_op_name"].set_s(instruction->metadata().op_name());
|
||||||
|
|
||||||
// Set the shape of the output tensor. "_output_shapes" is a special attribute
|
// Set the shape of the output tensor. "_output_shapes" is a special attribute
|
||||||
// name used by Tensorboard for shapes of output tensors.
|
// name used by Tensorboard for shapes of output tensors.
|
||||||
tensorflow::AttrValue shapes;
|
tensorflow::AttrValue shapes;
|
||||||
|
|||||||
@ -53,6 +53,13 @@ class HloTfGraphBuilderTest : public HloTestBase {
|
|||||||
Shape r0f32_ = ShapeUtil::MakeShape(PrimitiveType::F32, {});
|
Shape r0f32_ = ShapeUtil::MakeShape(PrimitiveType::F32, {});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static const tensorflow::AttrValue &GetNodeAttr(const tensorflow::NodeDef &node,
|
||||||
|
const string &attr_name) {
|
||||||
|
auto attr = node.attr().find(attr_name);
|
||||||
|
CHECK(attr != node.attr().end());
|
||||||
|
return attr->second;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) {
|
TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) {
|
||||||
auto builder = HloComputation::Builder("Concatenate");
|
auto builder = HloComputation::Builder("Concatenate");
|
||||||
Shape shape = ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2});
|
Shape shape = ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2});
|
||||||
@ -69,35 +76,34 @@ TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) {
|
|||||||
EXPECT_EQ(node.name(), "Concatenate/concatenate");
|
EXPECT_EQ(node.name(), "Concatenate/concatenate");
|
||||||
|
|
||||||
// Check dimensions.
|
// Check dimensions.
|
||||||
auto dims_value = node.attr().find("dims");
|
auto dims_value = GetNodeAttr(node, "dims");
|
||||||
CHECK(dims_value != node.attr().end());
|
EXPECT_EQ(dims_value.list().i_size(), 1);
|
||||||
EXPECT_EQ(dims_value->second.list().i_size(), 1);
|
EXPECT_EQ(dims_value.list().i(0), 1);
|
||||||
EXPECT_EQ(dims_value->second.list().i(0), 1);
|
|
||||||
|
|
||||||
// Check shapes.
|
// Check shapes.
|
||||||
auto shape_value = node.attr().find("_output_shapes");
|
auto shape_value = GetNodeAttr(node, "_output_shapes");
|
||||||
CHECK(shape_value != node.attr().end());
|
EXPECT_EQ(shape_value.list().shape_size(), 1);
|
||||||
EXPECT_EQ(shape_value->second.list().shape_size(), 1);
|
EXPECT_EQ(shape_value.list().shape(0).dim_size(), 2);
|
||||||
EXPECT_EQ(shape_value->second.list().shape(0).dim_size(), 2);
|
EXPECT_EQ(shape_value.list().shape(0).dim(0).size(), 2);
|
||||||
EXPECT_EQ(shape_value->second.list().shape(0).dim(0).size(), 2);
|
EXPECT_EQ(shape_value.list().shape(0).dim(1).size(), 4);
|
||||||
EXPECT_EQ(shape_value->second.list().shape(0).dim(1).size(), 4);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloTfGraphBuilderTest, CheckScalarValue) {
|
TEST_F(HloTfGraphBuilderTest, CheckScalarValue) {
|
||||||
auto builder = HloComputation::Builder("Const");
|
auto builder = HloComputation::Builder("Const");
|
||||||
builder.AddInstruction(
|
HloInstruction *instruction = builder.AddInstruction(
|
||||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(123)));
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0(123)));
|
||||||
|
OpMetadata metadata;
|
||||||
|
metadata.set_op_name("x");
|
||||||
|
metadata.set_op_type("y");
|
||||||
|
instruction->set_metadata(metadata);
|
||||||
TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
|
TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
|
||||||
GraphDef graph_def = generator_.GetGraphDef();
|
GraphDef graph_def = generator_.GetGraphDef();
|
||||||
EXPECT_EQ(graph_def.node_size(), 1);
|
EXPECT_EQ(graph_def.node_size(), 1);
|
||||||
const auto &node = graph_def.node(0);
|
const auto &node = graph_def.node(0);
|
||||||
auto value = node.attr().find("value");
|
EXPECT_EQ(GetNodeAttr(node, "value").s(), "123");
|
||||||
CHECK(value != node.attr().end());
|
EXPECT_EQ(GetNodeAttr(node, "type").s(), "S32");
|
||||||
EXPECT_EQ(value->second.s(), "123");
|
EXPECT_EQ(GetNodeAttr(node, "tf_op_name").s(), "x");
|
||||||
|
EXPECT_EQ(GetNodeAttr(node, "tf_op_type").s(), "y");
|
||||||
auto type = node.attr().find("type");
|
|
||||||
CHECK(type != node.attr().end());
|
|
||||||
EXPECT_EQ(type->second.s(), "S32");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) {
|
TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user