Allow parsing HLO constant without literals. This feature is used when parsing HLO for evaluating cost model instead of running the HLO graph.

PiperOrigin-RevId: 281393128
Change-Id: I6e986f9b34304844942731cb364c9479a8edaf00
This commit is contained in:
A. Unique TensorFlower 2019-11-19 15:31:11 -08:00 committed by TensorFlower Gardener
parent bc45d196c3
commit 0231137ed0
11 changed files with 69 additions and 22 deletions

View File

@ -424,6 +424,7 @@ cc_library(
":xla_data_proto_cc", ":xla_data_proto_cc",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@com_google_absl//absl/base", "@com_google_absl//absl/base",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",

View File

@ -72,6 +72,16 @@ T GetRawValue(T val) {
} }
uint16 GetRawValue(Eigen::half val) { return val.x; } uint16 GetRawValue(Eigen::half val) { return val.x; }
bool LiteralProtoHasValues(const LiteralProto& proto) {
return proto.preds_size() || !proto.s8s().empty() || !proto.u8s().empty() ||
proto.s32s_size() || proto.s64s_size() || proto.u32s_size() ||
proto.u64s_size() || proto.f32s_size() || proto.f64s_size() ||
proto.c64s_size() || proto.c128s_size() ||
proto.tuple_literals_size() || !proto.f16s().empty() ||
!proto.bf16s().empty() || !proto.u16s().empty() ||
!proto.s16s().empty() || proto.sparse_indices_size();
}
} // namespace } // namespace
LiteralBase::~LiteralBase() {} LiteralBase::~LiteralBase() {}
@ -288,7 +298,7 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
} }
/* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto( /* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
const LiteralProto& proto) { const LiteralProto& proto, bool prohibit_empty_literal) {
if (!proto.has_shape()) { if (!proto.has_shape()) {
return InvalidArgument("LiteralProto has no shape"); return InvalidArgument("LiteralProto has no shape");
} }
@ -328,7 +338,13 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
} }
CHECK(piece->subshape().IsArray()); CHECK(piece->subshape().IsArray());
// When prohibit_empty_literal is false (allowing literal with no
// values), only copy from proto if the literal proto has values. This
// mode is used for a learned cost model.
if (prohibit_empty_literal || LiteralProtoHasValues(*proto_element)) {
TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
}
return Status::OK(); return Status::OK();
})); }));

View File

@ -732,7 +732,8 @@ class MutableLiteralBase : public LiteralBase {
static Literal MoveIntoTuple(absl::Span<Literal> elements); static Literal MoveIntoTuple(absl::Span<Literal> elements);
// Serialize from a proto. // Serialize from a proto.
static StatusOr<Literal> CreateFromProto(const LiteralProto& proto); static StatusOr<Literal> CreateFromProto(const LiteralProto& proto,
bool prohibit_empty_literal = true);
protected: protected:
// Returns the piece at the given ShapeIndex. // Returns the piece at the given ShapeIndex.

View File

@ -1847,6 +1847,30 @@ TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
HasSubstr("Expected 3 elements in LiteralProto")); HasSubstr("Expected 3 elements in LiteralProto"));
} }
TEST_F(LiteralUtilTest, ValidProtoNoValues) {
// Proto contains a shape, but no values.
LiteralProto proto;
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
Status status =
Literal::CreateFromProto(proto, /*prohibit_empty_literal=*/false)
.status();
EXPECT_TRUE(status.ok());
}
TEST_F(LiteralUtilTest, ValidProtoWithClearedValues) {
auto literal = LiteralUtil::CreateR1<bool>({true, false, true});
LiteralProto proto = literal.ToProto();
EXPECT_EQ(proto.preds_size(), 3);
// Clear values.
proto.clear_preds();
EXPECT_EQ(proto.preds_size(), 0);
Status status =
Literal::CreateFromProto(proto, /*prohibit_empty_literal=*/false)
.status();
EXPECT_TRUE(status.ok());
}
TEST_F(LiteralUtilTest, InvalidProtoNoShape) { TEST_F(LiteralUtilTest, InvalidProtoNoShape) {
// Proto contains values, but no shape. // Proto contains values, but no shape.
LiteralProto proto; LiteralProto proto;

View File

@ -640,16 +640,17 @@ HloComputationProto HloComputation::ToProto() const {
/* static */ StatusOr<std::unique_ptr<HloComputation>> /* static */ StatusOr<std::unique_ptr<HloComputation>>
HloComputation::CreateFromProto( HloComputation::CreateFromProto(
const HloComputationProto& proto, const HloComputationProto& proto,
const absl::flat_hash_map<int64, HloComputation*>& computation_map) { const absl::flat_hash_map<int64, HloComputation*>& computation_map,
bool prohibit_empty_literal) {
absl::flat_hash_map<int64, HloInstruction*> instruction_map; absl::flat_hash_map<int64, HloInstruction*> instruction_map;
absl::flat_hash_map<HloInstruction*, int64> to_proto_id; absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
std::vector<std::unique_ptr<HloInstruction>> instructions; std::vector<std::unique_ptr<HloInstruction>> instructions;
int64 parameter_count = 0; int64 parameter_count = 0;
for (const HloInstructionProto& instruction_proto : proto.instructions()) { for (const HloInstructionProto& instruction_proto : proto.instructions()) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
std::unique_ptr<HloInstruction> instruction, HloInstruction::CreateFromProto(
HloInstruction::CreateFromProto(instruction_proto, instruction_map, instruction_proto, instruction_map, computation_map,
computation_map)); prohibit_empty_literal));
if (instruction->opcode() == HloOpcode::kParameter) { if (instruction->opcode() == HloOpcode::kParameter) {
parameter_count++; parameter_count++;
} }

View File

@ -235,7 +235,8 @@ class HloComputation {
// calls. // calls.
static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto( static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
const HloComputationProto& proto, const HloComputationProto& proto,
const absl::flat_hash_map<int64, HloComputation*>& computation_map); const absl::flat_hash_map<int64, HloComputation*>& computation_map,
bool prohibit_empty_literal = true);
using InstructionSequence = tensorflow::gtl::iterator_range< using InstructionSequence = tensorflow::gtl::iterator_range<
UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>; UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>;

View File

@ -62,7 +62,8 @@ using absl::StrJoin;
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
const HloInstructionProto& proto, const HloInstructionProto& proto,
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map, const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
const absl::flat_hash_map<int64, HloComputation*>& computation_map) { const absl::flat_hash_map<int64, HloComputation*>& computation_map,
bool prohibit_empty_literal) {
TF_RET_CHECK(!proto.opcode().empty()); TF_RET_CHECK(!proto.opcode().empty());
HloOpcode opcode; HloOpcode opcode;
auto opcode_or = StringToHloOpcode(proto.opcode()); auto opcode_or = StringToHloOpcode(proto.opcode());
@ -300,8 +301,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
case HloOpcode::kConstant: { case HloOpcode::kConstant: {
// TODO(b/110214922): Revert this to CHECK(proto.has_literal()). // TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
if (proto.has_literal()) { if (proto.has_literal()) {
TF_ASSIGN_OR_RETURN(auto literal, TF_ASSIGN_OR_RETURN(
Literal::CreateFromProto(proto.literal())); auto literal,
Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
instruction = CreateConstant(std::move(literal)); instruction = CreateConstant(std::move(literal));
// Literal's shape may have no/different tiling info. // Literal's shape may have no/different tiling info.
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()( TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
@ -314,8 +316,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
} }
case HloOpcode::kTrace: { case HloOpcode::kTrace: {
TF_RET_CHECK(proto.has_literal()); TF_RET_CHECK(proto.has_literal());
TF_ASSIGN_OR_RETURN(auto literal, TF_ASSIGN_OR_RETURN(
Literal::CreateFromProto(proto.literal())); auto literal,
Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
instruction = CreateTrace(literal.GetR1U8AsString(), operands(0)); instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
break; break;
} }

View File

@ -475,7 +475,8 @@ class HloInstruction {
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto( static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
const HloInstructionProto& proto, const HloInstructionProto& proto,
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map, const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
const absl::flat_hash_map<int64, HloComputation*>& computation_map); const absl::flat_hash_map<int64, HloComputation*>& computation_map,
bool prohibit_empty_literal = true);
// Creates a parameter-retrieving instruction. // Creates a parameter-retrieving instruction.
static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number, static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,

View File

@ -294,7 +294,8 @@ Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const {
/* static */ /* static */
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
const HloModuleProto& proto, const HloModuleConfig& module_config) { const HloModuleProto& proto, const HloModuleConfig& module_config,
bool prohibit_empty_literal) {
VLOG(2) << "CreateFromProto()"; VLOG(2) << "CreateFromProto()";
XLA_VLOG_LINES(3, proto.DebugString()); XLA_VLOG_LINES(3, proto.DebugString());
@ -332,7 +333,8 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
for (const HloComputationProto& computation_proto : proto.computations()) { for (const HloComputationProto& computation_proto : proto.computations()) {
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloComputation> computation, std::unique_ptr<HloComputation> computation,
HloComputation::CreateFromProto(computation_proto, computation_map)); HloComputation::CreateFromProto(computation_proto, computation_map,
prohibit_empty_literal));
CHECK_NE(computation.get(), nullptr); CHECK_NE(computation.get(), nullptr);
int64 computation_id = computation_proto.id(); int64 computation_id = computation_proto.id();
TF_RET_CHECK(computation_id != -1); TF_RET_CHECK(computation_id != -1);

View File

@ -219,7 +219,8 @@ class HloModule {
// Convert an HloModule to or from a proto. // Convert an HloModule to or from a proto.
HloModuleProto ToProto() const; HloModuleProto ToProto() const;
static StatusOr<std::unique_ptr<HloModule>> CreateFromProto( static StatusOr<std::unique_ptr<HloModule>> CreateFromProto(
const HloModuleProto& proto, const HloModuleConfig& module_config); const HloModuleProto& proto, const HloModuleConfig& module_config,
bool prohibit_empty_literal = true);
// Creates and returns an HloModuleConfig with an appropriate program shape // Creates and returns an HloModuleConfig with an appropriate program shape
// for the HLO module in the given proto. // for the HLO module in the given proto.

View File

@ -30,10 +30,6 @@ tf_upgrade_v2 --intree coolcode --outtree coolcode-upgraded --copyotherfiles Fal
*Note: `tf_upgrade_v2` is installed automatically as a script by the pip install *Note: `tf_upgrade_v2` is installed automatically as a script by the pip install
after TensorFlow 1.12. after TensorFlow 1.12.
You may want to retain revision history, especially when preparing a CL:
```
g4 integrate --retroactive coolcode/... coolcode-upgraded/...
```
## Report ## Report