diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 7f2f5bedee1..fdc1c0ba2d7 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -121,6 +121,10 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, attrs["type"].set_s( xla::PrimitiveType_Name(instruction->shape().element_type())); + // Set the framework op (e.g. Tensorflow op) that generated this XLA op. + attrs["tf_op_type"].set_s(instruction->metadata().op_type()); + attrs["tf_op_name"].set_s(instruction->metadata().op_name()); + // Set the shape of the output tensor. "_output_shapes" is a special attribute // name used by Tensorboard for shapes of output tensors. tensorflow::AttrValue shapes; diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index 3190f2d703a..df664080228 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -53,6 +53,13 @@ class HloTfGraphBuilderTest : public HloTestBase { Shape r0f32_ = ShapeUtil::MakeShape(PrimitiveType::F32, {}); }; +static const tensorflow::AttrValue &GetNodeAttr(const tensorflow::NodeDef &node, + const string &attr_name) { + auto attr = node.attr().find(attr_name); + CHECK(attr != node.attr().end()); + return attr->second; +} + TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) { auto builder = HloComputation::Builder("Concatenate"); Shape shape = ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}); @@ -69,35 +76,34 @@ TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) { EXPECT_EQ(node.name(), "Concatenate/concatenate"); // Check dimensions. - auto dims_value = node.attr().find("dims"); - CHECK(dims_value != node.attr().end()); - EXPECT_EQ(dims_value->second.list().i_size(), 1); - EXPECT_EQ(dims_value->second.list().i(0), 1); + auto dims_value = GetNodeAttr(node, "dims"); + EXPECT_EQ(dims_value.list().i_size(), 1); + EXPECT_EQ(dims_value.list().i(0), 1); // Check shapes. - auto shape_value = node.attr().find("_output_shapes"); - CHECK(shape_value != node.attr().end()); - EXPECT_EQ(shape_value->second.list().shape_size(), 1); - EXPECT_EQ(shape_value->second.list().shape(0).dim_size(), 2); - EXPECT_EQ(shape_value->second.list().shape(0).dim(0).size(), 2); - EXPECT_EQ(shape_value->second.list().shape(0).dim(1).size(), 4); + auto shape_value = GetNodeAttr(node, "_output_shapes"); + EXPECT_EQ(shape_value.list().shape_size(), 1); + EXPECT_EQ(shape_value.list().shape(0).dim_size(), 2); + EXPECT_EQ(shape_value.list().shape(0).dim(0).size(), 2); + EXPECT_EQ(shape_value.list().shape(0).dim(1).size(), 4); } TEST_F(HloTfGraphBuilderTest, CheckScalarValue) { auto builder = HloComputation::Builder("Const"); - builder.AddInstruction( + HloInstruction *instruction = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); + OpMetadata metadata; + metadata.set_op_name("x"); + metadata.set_op_type("y"); + instruction->set_metadata(metadata); TF_CHECK_OK(generator_.AddComputation(*builder.Build())); GraphDef graph_def = generator_.GetGraphDef(); EXPECT_EQ(graph_def.node_size(), 1); const auto &node = graph_def.node(0); - auto value = node.attr().find("value"); - CHECK(value != node.attr().end()); - EXPECT_EQ(value->second.s(), "123"); - - auto type = node.attr().find("type"); - CHECK(type != node.attr().end()); - EXPECT_EQ(type->second.s(), "S32"); + EXPECT_EQ(GetNodeAttr(node, "value").s(), "123"); + EXPECT_EQ(GetNodeAttr(node, "type").s(), "S32"); + EXPECT_EQ(GetNodeAttr(node, "tf_op_name").s(), "x"); + EXPECT_EQ(GetNodeAttr(node, "tf_op_type").s(), "y"); } TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) {