Allow parsing HLO constant without literals. This feature is used when parsing HLO for evaluating cost model instead of running the HLO graph.

PiperOrigin-RevId: 281393128
Change-Id: I6e986f9b34304844942731cb364c9479a8edaf00
This commit is contained in:
A. Unique TensorFlower 2019-11-19 15:31:11 -08:00 committed by TensorFlower Gardener
parent bc45d196c3
commit 0231137ed0
11 changed files with 69 additions and 22 deletions

View File

@ -424,6 +424,7 @@ cc_library(
":xla_data_proto_cc",
"//tensorflow/core:lib",
"@com_google_absl//absl/base",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",

View File

@ -72,6 +72,16 @@ T GetRawValue(T val) {
}
uint16 GetRawValue(Eigen::half val) { return val.x; }
bool LiteralProtoHasValues(const LiteralProto& proto) {
return proto.preds_size() || !proto.s8s().empty() || !proto.u8s().empty() ||
proto.s32s_size() || proto.s64s_size() || proto.u32s_size() ||
proto.u64s_size() || proto.f32s_size() || proto.f64s_size() ||
proto.c64s_size() || proto.c128s_size() ||
proto.tuple_literals_size() || !proto.f16s().empty() ||
!proto.bf16s().empty() || !proto.u16s().empty() ||
!proto.s16s().empty() || proto.sparse_indices_size();
}
} // namespace
LiteralBase::~LiteralBase() {}
@ -288,7 +298,7 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
}
/* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
const LiteralProto& proto) {
const LiteralProto& proto, bool prohibit_empty_literal) {
if (!proto.has_shape()) {
return InvalidArgument("LiteralProto has no shape");
}
@ -328,7 +338,13 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
}
CHECK(piece->subshape().IsArray());
TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
// When prohibit_empty_literal is false (allowing literal with no
// values), only copy from proto if the literal proto has values. This
// mode is used for a learned cost model.
if (prohibit_empty_literal || LiteralProtoHasValues(*proto_element)) {
TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
}
return Status::OK();
}));

View File

@ -732,7 +732,8 @@ class MutableLiteralBase : public LiteralBase {
static Literal MoveIntoTuple(absl::Span<Literal> elements);
// Serialize from a proto.
static StatusOr<Literal> CreateFromProto(const LiteralProto& proto);
static StatusOr<Literal> CreateFromProto(const LiteralProto& proto,
bool prohibit_empty_literal = true);
protected:
// Returns the piece at the given ShapeIndex.

View File

@ -1847,6 +1847,30 @@ TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
HasSubstr("Expected 3 elements in LiteralProto"));
}
TEST_F(LiteralUtilTest, ValidProtoNoValues) {
// Proto contains a shape, but no values.
LiteralProto proto;
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
Status status =
Literal::CreateFromProto(proto, /*prohibit_empty_literal=*/false)
.status();
EXPECT_TRUE(status.ok());
}
TEST_F(LiteralUtilTest, ValidProtoWithClearedValues) {
auto literal = LiteralUtil::CreateR1<bool>({true, false, true});
LiteralProto proto = literal.ToProto();
EXPECT_EQ(proto.preds_size(), 3);
// Clear values.
proto.clear_preds();
EXPECT_EQ(proto.preds_size(), 0);
Status status =
Literal::CreateFromProto(proto, /*prohibit_empty_literal=*/false)
.status();
EXPECT_TRUE(status.ok());
}
TEST_F(LiteralUtilTest, InvalidProtoNoShape) {
// Proto contains values, but no shape.
LiteralProto proto;

View File

@ -640,16 +640,17 @@ HloComputationProto HloComputation::ToProto() const {
/* static */ StatusOr<std::unique_ptr<HloComputation>>
HloComputation::CreateFromProto(
const HloComputationProto& proto,
const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
const absl::flat_hash_map<int64, HloComputation*>& computation_map,
bool prohibit_empty_literal) {
absl::flat_hash_map<int64, HloInstruction*> instruction_map;
absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
std::vector<std::unique_ptr<HloInstruction>> instructions;
int64 parameter_count = 0;
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloInstruction> instruction,
HloInstruction::CreateFromProto(instruction_proto, instruction_map,
computation_map));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
HloInstruction::CreateFromProto(
instruction_proto, instruction_map, computation_map,
prohibit_empty_literal));
if (instruction->opcode() == HloOpcode::kParameter) {
parameter_count++;
}

View File

@ -235,7 +235,8 @@ class HloComputation {
// calls.
static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
const HloComputationProto& proto,
const absl::flat_hash_map<int64, HloComputation*>& computation_map);
const absl::flat_hash_map<int64, HloComputation*>& computation_map,
bool prohibit_empty_literal = true);
using InstructionSequence = tensorflow::gtl::iterator_range<
UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>;

View File

@ -62,7 +62,8 @@ using absl::StrJoin;
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
const HloInstructionProto& proto,
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
const absl::flat_hash_map<int64, HloComputation*>& computation_map,
bool prohibit_empty_literal) {
TF_RET_CHECK(!proto.opcode().empty());
HloOpcode opcode;
auto opcode_or = StringToHloOpcode(proto.opcode());
@ -300,8 +301,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
case HloOpcode::kConstant: {
// TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
if (proto.has_literal()) {
TF_ASSIGN_OR_RETURN(auto literal,
Literal::CreateFromProto(proto.literal()));
TF_ASSIGN_OR_RETURN(
auto literal,
Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
instruction = CreateConstant(std::move(literal));
// Literal's shape may have no/different tiling info.
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
@ -314,8 +316,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
case HloOpcode::kTrace: {
TF_RET_CHECK(proto.has_literal());
TF_ASSIGN_OR_RETURN(auto literal,
Literal::CreateFromProto(proto.literal()));
TF_ASSIGN_OR_RETURN(
auto literal,
Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
break;
}

View File

@ -475,7 +475,8 @@ class HloInstruction {
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
const HloInstructionProto& proto,
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
const absl::flat_hash_map<int64, HloComputation*>& computation_map);
const absl::flat_hash_map<int64, HloComputation*>& computation_map,
bool prohibit_empty_literal = true);
// Creates a parameter-retrieving instruction.
static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,

View File

@ -294,7 +294,8 @@ Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const {
/* static */
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
const HloModuleProto& proto, const HloModuleConfig& module_config) {
const HloModuleProto& proto, const HloModuleConfig& module_config,
bool prohibit_empty_literal) {
VLOG(2) << "CreateFromProto()";
XLA_VLOG_LINES(3, proto.DebugString());
@ -332,7 +333,8 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
for (const HloComputationProto& computation_proto : proto.computations()) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloComputation> computation,
HloComputation::CreateFromProto(computation_proto, computation_map));
HloComputation::CreateFromProto(computation_proto, computation_map,
prohibit_empty_literal));
CHECK_NE(computation.get(), nullptr);
int64 computation_id = computation_proto.id();
TF_RET_CHECK(computation_id != -1);

View File

@ -219,7 +219,8 @@ class HloModule {
// Convert an HloModule to or from a proto.
HloModuleProto ToProto() const;
static StatusOr<std::unique_ptr<HloModule>> CreateFromProto(
const HloModuleProto& proto, const HloModuleConfig& module_config);
const HloModuleProto& proto, const HloModuleConfig& module_config,
bool prohibit_empty_literal = true);
// Creates and returns an HloModuleConfig with an appropriate program shape
// for the HLO module in the given proto.

View File

@ -30,10 +30,6 @@ tf_upgrade_v2 --intree coolcode --outtree coolcode-upgraded --copyotherfiles Fal
*Note: `tf_upgrade_v2` is installed automatically as a script by the pip install
after TensorFlow 1.12.
You may want to retain revision history, especially when preparing a CL:
```
g4 integrate --retroactive coolcode/... coolcode-upgraded/...
```
## Report