[XLA:HLO] Also clone metadata when cloning instructions e.g. in fusion.

Without the metadata, it's hard to correlate HLO instructions to TF ops after fusion.
Change: 153709552
This commit is contained in:
A. Unique TensorFlower 2017-04-20 07:02:35 -08:00 committed by TensorFlower Gardener
parent 3a95e41426
commit 88b81ac944
3 changed files with 23 additions and 0 deletions

View File

@ -178,6 +178,7 @@ cc_test(
deps = [
":hlo",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",

View File

@ -435,6 +435,7 @@ HloInstruction::CreateSelectAndScatter(
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
instruction->fusion_kind_ = fusion_kind;
instruction->set_parent(fused_root->parent());
instruction->set_metadata(fused_root->metadata());
instruction->CloneAndFuseInternal(fused_root);
instruction->CheckFusionInstruction();
return instruction;
@ -858,6 +859,7 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix) {
CloneWithNewOperands(shape_, operands_);
clone->name_ = name() + "." + suffix;
clone->set_parent(parent());
clone->set_metadata(metadata_);
return clone;
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
@ -609,6 +610,25 @@ TEST_F(HloInstructionTest, ChainFusionOp) {
UnorderedElementsAre(fusion.get(), exp1.get()));
}
TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
// Create a chain of fused unary ops.
auto constant =
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f));
auto exp1 =
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get());
auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get());
OpMetadata metadata;
metadata.set_op_name("tf_op");
exp1->set_metadata(metadata);
exp2->set_metadata(metadata);
auto fusion = HloInstruction::CreateFusion(
r0f32_, HloInstruction::FusionKind::kLoop, exp2.get());
auto* fused = fusion->FuseInstruction(exp1.get());
EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fused->metadata()));
}
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
// Create a fusion instruction containing a single unary operation.
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});