[XLA:HLO] Also clone metadata when cloning instructions e.g. in fusion.
Without the metadata, it's hard to correlate HLO instructions to TF ops after fusion. Change: 153709552
This commit is contained in:
parent
3a95e41426
commit
88b81ac944
@ -178,6 +178,7 @@ cc_test(
|
||||
deps = [
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
|
@ -435,6 +435,7 @@ HloInstruction::CreateSelectAndScatter(
|
||||
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
|
||||
instruction->fusion_kind_ = fusion_kind;
|
||||
instruction->set_parent(fused_root->parent());
|
||||
instruction->set_metadata(fused_root->metadata());
|
||||
instruction->CloneAndFuseInternal(fused_root);
|
||||
instruction->CheckFusionInstruction();
|
||||
return instruction;
|
||||
@ -858,6 +859,7 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix) {
|
||||
CloneWithNewOperands(shape_, operands_);
|
||||
clone->name_ = name() + "." + suffix;
|
||||
clone->set_parent(parent());
|
||||
clone->set_metadata(metadata_);
|
||||
return clone;
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -609,6 +610,25 @@ TEST_F(HloInstructionTest, ChainFusionOp) {
|
||||
UnorderedElementsAre(fusion.get(), exp1.get()));
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
|
||||
// Create a chain of fused unary ops.
|
||||
auto constant =
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f));
|
||||
auto exp1 =
|
||||
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get());
|
||||
auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get());
|
||||
OpMetadata metadata;
|
||||
metadata.set_op_name("tf_op");
|
||||
exp1->set_metadata(metadata);
|
||||
exp2->set_metadata(metadata);
|
||||
|
||||
auto fusion = HloInstruction::CreateFusion(
|
||||
r0f32_, HloInstruction::FusionKind::kLoop, exp2.get());
|
||||
auto* fused = fusion->FuseInstruction(exp1.get());
|
||||
EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
|
||||
EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fused->metadata()));
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
|
||||
// Create a fusion instruction containing a single unary operation.
|
||||
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
|
||||
|
Loading…
Reference in New Issue
Block a user