[TF-XLA] Include TF metedata in HLO GraphDef node's attributes.
This includes TF op names and op types. Change: 153708854
This commit is contained in:
		
							parent
							
								
									3cca863359
								
							
						
					
					
						commit
						3a95e41426
					
				@ -121,6 +121,10 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction,
 | 
			
		||||
  attrs["type"].set_s(
 | 
			
		||||
      xla::PrimitiveType_Name(instruction->shape().element_type()));
 | 
			
		||||
 | 
			
		||||
  // Set the framework op (e.g. Tensorflow op) that generated this XLA op.
 | 
			
		||||
  attrs["tf_op_type"].set_s(instruction->metadata().op_type());
 | 
			
		||||
  attrs["tf_op_name"].set_s(instruction->metadata().op_name());
 | 
			
		||||
 | 
			
		||||
  // Set the shape of the output tensor. "_output_shapes" is a special attribute
 | 
			
		||||
  // name used by Tensorboard for shapes of output tensors.
 | 
			
		||||
  tensorflow::AttrValue shapes;
 | 
			
		||||
 | 
			
		||||
@ -53,6 +53,13 @@ class HloTfGraphBuilderTest : public HloTestBase {
 | 
			
		||||
  Shape r0f32_ = ShapeUtil::MakeShape(PrimitiveType::F32, {});
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
static const tensorflow::AttrValue &GetNodeAttr(const tensorflow::NodeDef &node,
 | 
			
		||||
                                                const string &attr_name) {
 | 
			
		||||
  auto attr = node.attr().find(attr_name);
 | 
			
		||||
  CHECK(attr != node.attr().end());
 | 
			
		||||
  return attr->second;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) {
 | 
			
		||||
  auto builder = HloComputation::Builder("Concatenate");
 | 
			
		||||
  Shape shape = ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2});
 | 
			
		||||
@ -69,35 +76,34 @@ TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) {
 | 
			
		||||
  EXPECT_EQ(node.name(), "Concatenate/concatenate");
 | 
			
		||||
 | 
			
		||||
  // Check dimensions.
 | 
			
		||||
  auto dims_value = node.attr().find("dims");
 | 
			
		||||
  CHECK(dims_value != node.attr().end());
 | 
			
		||||
  EXPECT_EQ(dims_value->second.list().i_size(), 1);
 | 
			
		||||
  EXPECT_EQ(dims_value->second.list().i(0), 1);
 | 
			
		||||
  auto dims_value = GetNodeAttr(node, "dims");
 | 
			
		||||
  EXPECT_EQ(dims_value.list().i_size(), 1);
 | 
			
		||||
  EXPECT_EQ(dims_value.list().i(0), 1);
 | 
			
		||||
 | 
			
		||||
  // Check shapes.
 | 
			
		||||
  auto shape_value = node.attr().find("_output_shapes");
 | 
			
		||||
  CHECK(shape_value != node.attr().end());
 | 
			
		||||
  EXPECT_EQ(shape_value->second.list().shape_size(), 1);
 | 
			
		||||
  EXPECT_EQ(shape_value->second.list().shape(0).dim_size(), 2);
 | 
			
		||||
  EXPECT_EQ(shape_value->second.list().shape(0).dim(0).size(), 2);
 | 
			
		||||
  EXPECT_EQ(shape_value->second.list().shape(0).dim(1).size(), 4);
 | 
			
		||||
  auto shape_value = GetNodeAttr(node, "_output_shapes");
 | 
			
		||||
  EXPECT_EQ(shape_value.list().shape_size(), 1);
 | 
			
		||||
  EXPECT_EQ(shape_value.list().shape(0).dim_size(), 2);
 | 
			
		||||
  EXPECT_EQ(shape_value.list().shape(0).dim(0).size(), 2);
 | 
			
		||||
  EXPECT_EQ(shape_value.list().shape(0).dim(1).size(), 4);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HloTfGraphBuilderTest, CheckScalarValue) {
 | 
			
		||||
  auto builder = HloComputation::Builder("Const");
 | 
			
		||||
  builder.AddInstruction(
 | 
			
		||||
  HloInstruction *instruction = builder.AddInstruction(
 | 
			
		||||
      HloInstruction::CreateConstant(LiteralUtil::CreateR0(123)));
 | 
			
		||||
  OpMetadata metadata;
 | 
			
		||||
  metadata.set_op_name("x");
 | 
			
		||||
  metadata.set_op_type("y");
 | 
			
		||||
  instruction->set_metadata(metadata);
 | 
			
		||||
  TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
 | 
			
		||||
  GraphDef graph_def = generator_.GetGraphDef();
 | 
			
		||||
  EXPECT_EQ(graph_def.node_size(), 1);
 | 
			
		||||
  const auto &node = graph_def.node(0);
 | 
			
		||||
  auto value = node.attr().find("value");
 | 
			
		||||
  CHECK(value != node.attr().end());
 | 
			
		||||
  EXPECT_EQ(value->second.s(), "123");
 | 
			
		||||
 | 
			
		||||
  auto type = node.attr().find("type");
 | 
			
		||||
  CHECK(type != node.attr().end());
 | 
			
		||||
  EXPECT_EQ(type->second.s(), "S32");
 | 
			
		||||
  EXPECT_EQ(GetNodeAttr(node, "value").s(), "123");
 | 
			
		||||
  EXPECT_EQ(GetNodeAttr(node, "type").s(), "S32");
 | 
			
		||||
  EXPECT_EQ(GetNodeAttr(node, "tf_op_name").s(), "x");
 | 
			
		||||
  EXPECT_EQ(GetNodeAttr(node, "tf_op_type").s(), "y");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user