Allow parsing HLO constant without literals. This feature is used when parsing HLO for evaluating cost model instead of running the HLO graph.
PiperOrigin-RevId: 281393128 Change-Id: I6e986f9b34304844942731cb364c9479a8edaf00
This commit is contained in:
parent
bc45d196c3
commit
0231137ed0
@ -424,6 +424,7 @@ cc_library(
|
||||
":xla_data_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
|
@ -72,6 +72,16 @@ T GetRawValue(T val) {
|
||||
}
|
||||
uint16 GetRawValue(Eigen::half val) { return val.x; }
|
||||
|
||||
bool LiteralProtoHasValues(const LiteralProto& proto) {
|
||||
return proto.preds_size() || !proto.s8s().empty() || !proto.u8s().empty() ||
|
||||
proto.s32s_size() || proto.s64s_size() || proto.u32s_size() ||
|
||||
proto.u64s_size() || proto.f32s_size() || proto.f64s_size() ||
|
||||
proto.c64s_size() || proto.c128s_size() ||
|
||||
proto.tuple_literals_size() || !proto.f16s().empty() ||
|
||||
!proto.bf16s().empty() || !proto.u16s().empty() ||
|
||||
!proto.s16s().empty() || proto.sparse_indices_size();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LiteralBase::~LiteralBase() {}
|
||||
@ -288,7 +298,7 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
|
||||
const LiteralProto& proto) {
|
||||
const LiteralProto& proto, bool prohibit_empty_literal) {
|
||||
if (!proto.has_shape()) {
|
||||
return InvalidArgument("LiteralProto has no shape");
|
||||
}
|
||||
@ -328,7 +338,13 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
|
||||
}
|
||||
|
||||
CHECK(piece->subshape().IsArray());
|
||||
TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
|
||||
|
||||
// When prohibit_empty_literal is false (allowing literal with no
|
||||
// values), only copy from proto if the literal proto has values. This
|
||||
// mode is used for a learned cost model.
|
||||
if (prohibit_empty_literal || LiteralProtoHasValues(*proto_element)) {
|
||||
TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}));
|
||||
|
@ -732,7 +732,8 @@ class MutableLiteralBase : public LiteralBase {
|
||||
static Literal MoveIntoTuple(absl::Span<Literal> elements);
|
||||
|
||||
// Serialize from a proto.
|
||||
static StatusOr<Literal> CreateFromProto(const LiteralProto& proto);
|
||||
static StatusOr<Literal> CreateFromProto(const LiteralProto& proto,
|
||||
bool prohibit_empty_literal = true);
|
||||
|
||||
protected:
|
||||
// Returns the piece at the given ShapeIndex.
|
||||
|
@ -1847,6 +1847,30 @@ TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
|
||||
HasSubstr("Expected 3 elements in LiteralProto"));
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, ValidProtoNoValues) {
|
||||
// Proto contains a shape, but no values.
|
||||
LiteralProto proto;
|
||||
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
|
||||
Status status =
|
||||
Literal::CreateFromProto(proto, /*prohibit_empty_literal=*/false)
|
||||
.status();
|
||||
EXPECT_TRUE(status.ok());
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, ValidProtoWithClearedValues) {
|
||||
auto literal = LiteralUtil::CreateR1<bool>({true, false, true});
|
||||
LiteralProto proto = literal.ToProto();
|
||||
EXPECT_EQ(proto.preds_size(), 3);
|
||||
|
||||
// Clear values.
|
||||
proto.clear_preds();
|
||||
EXPECT_EQ(proto.preds_size(), 0);
|
||||
Status status =
|
||||
Literal::CreateFromProto(proto, /*prohibit_empty_literal=*/false)
|
||||
.status();
|
||||
EXPECT_TRUE(status.ok());
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, InvalidProtoNoShape) {
|
||||
// Proto contains values, but no shape.
|
||||
LiteralProto proto;
|
||||
|
@ -640,16 +640,17 @@ HloComputationProto HloComputation::ToProto() const {
|
||||
/* static */ StatusOr<std::unique_ptr<HloComputation>>
|
||||
HloComputation::CreateFromProto(
|
||||
const HloComputationProto& proto,
|
||||
const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
|
||||
const absl::flat_hash_map<int64, HloComputation*>& computation_map,
|
||||
bool prohibit_empty_literal) {
|
||||
absl::flat_hash_map<int64, HloInstruction*> instruction_map;
|
||||
absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
|
||||
std::vector<std::unique_ptr<HloInstruction>> instructions;
|
||||
int64 parameter_count = 0;
|
||||
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloInstruction> instruction,
|
||||
HloInstruction::CreateFromProto(instruction_proto, instruction_map,
|
||||
computation_map));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
|
||||
HloInstruction::CreateFromProto(
|
||||
instruction_proto, instruction_map, computation_map,
|
||||
prohibit_empty_literal));
|
||||
if (instruction->opcode() == HloOpcode::kParameter) {
|
||||
parameter_count++;
|
||||
}
|
||||
|
@ -235,7 +235,8 @@ class HloComputation {
|
||||
// calls.
|
||||
static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
|
||||
const HloComputationProto& proto,
|
||||
const absl::flat_hash_map<int64, HloComputation*>& computation_map);
|
||||
const absl::flat_hash_map<int64, HloComputation*>& computation_map,
|
||||
bool prohibit_empty_literal = true);
|
||||
|
||||
using InstructionSequence = tensorflow::gtl::iterator_range<
|
||||
UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>;
|
||||
|
@ -62,7 +62,8 @@ using absl::StrJoin;
|
||||
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
const HloInstructionProto& proto,
|
||||
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
|
||||
const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
|
||||
const absl::flat_hash_map<int64, HloComputation*>& computation_map,
|
||||
bool prohibit_empty_literal) {
|
||||
TF_RET_CHECK(!proto.opcode().empty());
|
||||
HloOpcode opcode;
|
||||
auto opcode_or = StringToHloOpcode(proto.opcode());
|
||||
@ -300,8 +301,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
case HloOpcode::kConstant: {
|
||||
// TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
|
||||
if (proto.has_literal()) {
|
||||
TF_ASSIGN_OR_RETURN(auto literal,
|
||||
Literal::CreateFromProto(proto.literal()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto literal,
|
||||
Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
|
||||
instruction = CreateConstant(std::move(literal));
|
||||
// Literal's shape may have no/different tiling info.
|
||||
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
|
||||
@ -314,8 +316,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
}
|
||||
case HloOpcode::kTrace: {
|
||||
TF_RET_CHECK(proto.has_literal());
|
||||
TF_ASSIGN_OR_RETURN(auto literal,
|
||||
Literal::CreateFromProto(proto.literal()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto literal,
|
||||
Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
|
||||
instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
|
||||
break;
|
||||
}
|
||||
|
@ -475,7 +475,8 @@ class HloInstruction {
|
||||
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
|
||||
const HloInstructionProto& proto,
|
||||
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
|
||||
const absl::flat_hash_map<int64, HloComputation*>& computation_map);
|
||||
const absl::flat_hash_map<int64, HloComputation*>& computation_map,
|
||||
bool prohibit_empty_literal = true);
|
||||
|
||||
// Creates a parameter-retrieving instruction.
|
||||
static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
|
||||
|
@ -294,7 +294,8 @@ Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const {
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
||||
const HloModuleProto& proto, const HloModuleConfig& module_config) {
|
||||
const HloModuleProto& proto, const HloModuleConfig& module_config,
|
||||
bool prohibit_empty_literal) {
|
||||
VLOG(2) << "CreateFromProto()";
|
||||
XLA_VLOG_LINES(3, proto.DebugString());
|
||||
|
||||
@ -332,7 +333,8 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
||||
for (const HloComputationProto& computation_proto : proto.computations()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloComputation> computation,
|
||||
HloComputation::CreateFromProto(computation_proto, computation_map));
|
||||
HloComputation::CreateFromProto(computation_proto, computation_map,
|
||||
prohibit_empty_literal));
|
||||
CHECK_NE(computation.get(), nullptr);
|
||||
int64 computation_id = computation_proto.id();
|
||||
TF_RET_CHECK(computation_id != -1);
|
||||
|
@ -219,7 +219,8 @@ class HloModule {
|
||||
// Convert an HloModule to or from a proto.
|
||||
HloModuleProto ToProto() const;
|
||||
static StatusOr<std::unique_ptr<HloModule>> CreateFromProto(
|
||||
const HloModuleProto& proto, const HloModuleConfig& module_config);
|
||||
const HloModuleProto& proto, const HloModuleConfig& module_config,
|
||||
bool prohibit_empty_literal = true);
|
||||
|
||||
// Creates and returns an HloModuleConfig with an appropriate program shape
|
||||
// for the HLO module in the given proto.
|
||||
|
@ -30,10 +30,6 @@ tf_upgrade_v2 --intree coolcode --outtree coolcode-upgraded --copyotherfiles Fal
|
||||
*Note: `tf_upgrade_v2` is installed automatically as a script by the pip install
|
||||
after TensorFlow 1.12.
|
||||
|
||||
You may want to retain revision history, especially when preparing a CL:
|
||||
```
|
||||
g4 integrate --retroactive coolcode/... coolcode-upgraded/...
|
||||
```
|
||||
|
||||
## Report
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user