[XLA:HLO] Also clone metadata when cloning instructions e.g. in fusion.
Without the metadata, it's hard to correlate HLO instructions to TF ops after fusion. Change: 153709552
This commit is contained in:
parent
3a95e41426
commit
88b81ac944
@ -178,6 +178,7 @@ cc_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
|
"//tensorflow/compiler/xla:protobuf_util",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
"//tensorflow/compiler/xla:test_helpers",
|
"//tensorflow/compiler/xla:test_helpers",
|
||||||
|
@ -435,6 +435,7 @@ HloInstruction::CreateSelectAndScatter(
|
|||||||
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
|
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
|
||||||
instruction->fusion_kind_ = fusion_kind;
|
instruction->fusion_kind_ = fusion_kind;
|
||||||
instruction->set_parent(fused_root->parent());
|
instruction->set_parent(fused_root->parent());
|
||||||
|
instruction->set_metadata(fused_root->metadata());
|
||||||
instruction->CloneAndFuseInternal(fused_root);
|
instruction->CloneAndFuseInternal(fused_root);
|
||||||
instruction->CheckFusionInstruction();
|
instruction->CheckFusionInstruction();
|
||||||
return instruction;
|
return instruction;
|
||||||
@ -858,6 +859,7 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix) {
|
|||||||
CloneWithNewOperands(shape_, operands_);
|
CloneWithNewOperands(shape_, operands_);
|
||||||
clone->name_ = name() + "." + suffix;
|
clone->name_ = name() + "." + suffix;
|
||||||
clone->set_parent(parent());
|
clone->set_parent(parent());
|
||||||
|
clone->set_metadata(metadata_);
|
||||||
return clone;
|
return clone;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
@ -609,6 +610,25 @@ TEST_F(HloInstructionTest, ChainFusionOp) {
|
|||||||
UnorderedElementsAre(fusion.get(), exp1.get()));
|
UnorderedElementsAre(fusion.get(), exp1.get()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
|
||||||
|
// Create a chain of fused unary ops.
|
||||||
|
auto constant =
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f));
|
||||||
|
auto exp1 =
|
||||||
|
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get());
|
||||||
|
auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get());
|
||||||
|
OpMetadata metadata;
|
||||||
|
metadata.set_op_name("tf_op");
|
||||||
|
exp1->set_metadata(metadata);
|
||||||
|
exp2->set_metadata(metadata);
|
||||||
|
|
||||||
|
auto fusion = HloInstruction::CreateFusion(
|
||||||
|
r0f32_, HloInstruction::FusionKind::kLoop, exp2.get());
|
||||||
|
auto* fused = fusion->FuseInstruction(exp1.get());
|
||||||
|
EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
|
||||||
|
EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fused->metadata()));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
|
TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
|
||||||
// Create a fusion instruction containing a single unary operation.
|
// Create a fusion instruction containing a single unary operation.
|
||||||
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
|
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
|
||||||
|
Loading…
Reference in New Issue
Block a user