From 0231137ed043ebc01517856f5472d26bc71a6fa8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Tue, 19 Nov 2019 15:31:11 -0800 Subject: [PATCH] Allow parsing HLO constant without literals. This feature is used when parsing HLO for evaluating cost model instead of running the HLO graph. PiperOrigin-RevId: 281393128 Change-Id: I6e986f9b34304844942731cb364c9479a8edaf00 --- tensorflow/compiler/xla/BUILD | 1 + tensorflow/compiler/xla/literal.cc | 20 ++++++++++++++-- tensorflow/compiler/xla/literal.h | 3 ++- tensorflow/compiler/xla/literal_test.cc | 24 +++++++++++++++++++ .../compiler/xla/service/hlo_computation.cc | 11 +++++---- .../compiler/xla/service/hlo_computation.h | 3 ++- .../compiler/xla/service/hlo_instruction.cc | 13 ++++++---- .../compiler/xla/service/hlo_instruction.h | 3 ++- tensorflow/compiler/xla/service/hlo_module.cc | 6 +++-- tensorflow/compiler/xla/service/hlo_module.h | 3 ++- tensorflow/tools/compatibility/README.md | 4 ---- 11 files changed, 69 insertions(+), 22 deletions(-) diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 24be75c3d62..90461005fac 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -424,6 +424,7 @@ cc_library( ":xla_data_proto_cc", "//tensorflow/core:lib", "@com_google_absl//absl/base", + "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index bbea6081975..3a219673304 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -72,6 +72,16 @@ T GetRawValue(T val) { } uint16 GetRawValue(Eigen::half val) { return val.x; } +bool LiteralProtoHasValues(const LiteralProto& proto) { + return proto.preds_size() || !proto.s8s().empty() || !proto.u8s().empty() || + proto.s32s_size() || proto.s64s_size() || proto.u32s_size() || + proto.u64s_size() || proto.f32s_size() || proto.f64s_size() || + proto.c64s_size() || proto.c128s_size() || + proto.tuple_literals_size() || !proto.f16s().empty() || + !proto.bf16s().empty() || !proto.u16s().empty() || + !proto.s16s().empty() || proto.sparse_indices_size(); +} + } // namespace LiteralBase::~LiteralBase() {} @@ -288,7 +298,7 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, } /* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto( - const LiteralProto& proto) { + const LiteralProto& proto, bool prohibit_empty_literal) { if (!proto.has_shape()) { return InvalidArgument("LiteralProto has no shape"); } @@ -328,7 +338,13 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, } CHECK(piece->subshape().IsArray()); - TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); + + // When prohibit_empty_literal is false (allowing literal with no + // values), only copy from proto if the literal proto has values. This + // mode is used for a learned cost model. + if (prohibit_empty_literal || LiteralProtoHasValues(*proto_element)) { + TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); + } return Status::OK(); })); diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index af15cab4a94..227717188ab 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -732,7 +732,8 @@ class MutableLiteralBase : public LiteralBase { static Literal MoveIntoTuple(absl::Span<Literal> elements); // Serialize from a proto. - static StatusOr<Literal> CreateFromProto(const LiteralProto& proto); + static StatusOr<Literal> CreateFromProto(const LiteralProto& proto, + bool prohibit_empty_literal = true); protected: // Returns the piece at the given ShapeIndex. diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 43863db5b3f..d1dd6b8fd77 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -1847,6 +1847,30 @@ TEST_F(LiteralUtilTest, InvalidProtoNoValues) { HasSubstr("Expected 3 elements in LiteralProto")); } +TEST_F(LiteralUtilTest, ValidProtoNoValues) { + // Proto contains a shape, but no values. + LiteralProto proto; + *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto(); + Status status = + Literal::CreateFromProto(proto, /*prohibit_empty_literal=*/false) + .status(); + EXPECT_TRUE(status.ok()); +} + +TEST_F(LiteralUtilTest, ValidProtoWithClearedValues) { + auto literal = LiteralUtil::CreateR1<bool>({true, false, true}); + LiteralProto proto = literal.ToProto(); + EXPECT_EQ(proto.preds_size(), 3); + + // Clear values. + proto.clear_preds(); + EXPECT_EQ(proto.preds_size(), 0); + Status status = + Literal::CreateFromProto(proto, /*prohibit_empty_literal=*/false) + .status(); + EXPECT_TRUE(status.ok()); +} + TEST_F(LiteralUtilTest, InvalidProtoNoShape) { // Proto contains values, but no shape. LiteralProto proto; diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 2e316b0f2d3..8f9ac5e2db7 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -640,16 +640,17 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr<std::unique_ptr<HloComputation>> HloComputation::CreateFromProto( const HloComputationProto& proto, - const absl::flat_hash_map<int64, HloComputation*>& computation_map) { + const absl::flat_hash_map<int64, HloComputation*>& computation_map, + bool prohibit_empty_literal) { absl::flat_hash_map<int64, HloInstruction*> instruction_map; absl::flat_hash_map<HloInstruction*, int64> to_proto_id; std::vector<std::unique_ptr<HloInstruction>> instructions; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr<HloInstruction> instruction, - HloInstruction::CreateFromProto(instruction_proto, instruction_map, - computation_map)); + TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction, + HloInstruction::CreateFromProto( + instruction_proto, instruction_map, computation_map, + prohibit_empty_literal)); if (instruction->opcode() == HloOpcode::kParameter) { parameter_count++; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index d66e0b63d64..f96c9046c19 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -235,7 +235,8 @@ class HloComputation { // calls. static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto( const HloComputationProto& proto, - const absl::flat_hash_map<int64, HloComputation*>& computation_map); + const absl::flat_hash_map<int64, HloComputation*>& computation_map, + bool prohibit_empty_literal = true); using InstructionSequence = tensorflow::gtl::iterator_range< UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index e66707eb7fd..368a3876f8c 100755 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -62,7 +62,8 @@ using absl::StrJoin; StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( const HloInstructionProto& proto, const absl::flat_hash_map<int64, HloInstruction*>& instruction_map, - const absl::flat_hash_map<int64, HloComputation*>& computation_map) { + const absl::flat_hash_map<int64, HloComputation*>& computation_map, + bool prohibit_empty_literal) { TF_RET_CHECK(!proto.opcode().empty()); HloOpcode opcode; auto opcode_or = StringToHloOpcode(proto.opcode()); @@ -300,8 +301,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( case HloOpcode::kConstant: { // TODO(b/110214922): Revert this to CHECK(proto.has_literal()). if (proto.has_literal()) { - TF_ASSIGN_OR_RETURN(auto literal, - Literal::CreateFromProto(proto.literal())); + TF_ASSIGN_OR_RETURN( + auto literal, + Literal::CreateFromProto(proto.literal(), prohibit_empty_literal)); instruction = CreateConstant(std::move(literal)); // Literal's shape may have no/different tiling info. TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( @@ -314,8 +316,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( } case HloOpcode::kTrace: { TF_RET_CHECK(proto.has_literal()); - TF_ASSIGN_OR_RETURN(auto literal, - Literal::CreateFromProto(proto.literal())); + TF_ASSIGN_OR_RETURN( + auto literal, + Literal::CreateFromProto(proto.literal(), prohibit_empty_literal)); instruction = CreateTrace(literal.GetR1U8AsString(), operands(0)); break; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index b35d9d07dcf..5e2e53ea6db 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -475,7 +475,8 @@ class HloInstruction { static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto( const HloInstructionProto& proto, const absl::flat_hash_map<int64, HloInstruction*>& instruction_map, - const absl::flat_hash_map<int64, HloComputation*>& computation_map); + const absl::flat_hash_map<int64, HloComputation*>& computation_map, + bool prohibit_empty_literal = true); // Creates a parameter-retrieving instruction. static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number, diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 27b79049688..74ef9a1ec4c 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -294,7 +294,8 @@ Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const { /* static */ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( - const HloModuleProto& proto, const HloModuleConfig& module_config) { + const HloModuleProto& proto, const HloModuleConfig& module_config, + bool prohibit_empty_literal) { VLOG(2) << "CreateFromProto()"; XLA_VLOG_LINES(3, proto.DebugString()); @@ -332,7 +333,8 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( for (const HloComputationProto& computation_proto : proto.computations()) { TF_ASSIGN_OR_RETURN( std::unique_ptr<HloComputation> computation, - HloComputation::CreateFromProto(computation_proto, computation_map)); + HloComputation::CreateFromProto(computation_proto, computation_map, + prohibit_empty_literal)); CHECK_NE(computation.get(), nullptr); int64 computation_id = computation_proto.id(); TF_RET_CHECK(computation_id != -1); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 40f3c972f63..9a96b787bb8 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -219,7 +219,8 @@ class HloModule { // Convert an HloModule to or from a proto. HloModuleProto ToProto() const; static StatusOr<std::unique_ptr<HloModule>> CreateFromProto( - const HloModuleProto& proto, const HloModuleConfig& module_config); + const HloModuleProto& proto, const HloModuleConfig& module_config, + bool prohibit_empty_literal = true); // Creates and returns an HloModuleConfig with an appropriate program shape // for the HLO module in the given proto. diff --git a/tensorflow/tools/compatibility/README.md b/tensorflow/tools/compatibility/README.md index 3611e85d28d..12d3a0faf44 100644 --- a/tensorflow/tools/compatibility/README.md +++ b/tensorflow/tools/compatibility/README.md @@ -30,10 +30,6 @@ tf_upgrade_v2 --intree coolcode --outtree coolcode-upgraded --copyotherfiles Fal *Note: `tf_upgrade_v2` is installed automatically as a script by the pip install after TensorFlow 1.12. -You may want to retain revision history, especially when preparing a CL: -``` -g4 integrate --retroactive coolcode/... coolcode-upgraded/... -``` ## Report