Allow parsing HLO constant without literals. This feature is used when parsing HLO for evaluating cost model instead of running the HLO graph.
PiperOrigin-RevId: 281393128 Change-Id: I6e986f9b34304844942731cb364c9479a8edaf00
This commit is contained in:
parent
bc45d196c3
commit
0231137ed0
@ -424,6 +424,7 @@ cc_library(
|
|||||||
":xla_data_proto_cc",
|
":xla_data_proto_cc",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/base",
|
"@com_google_absl//absl/base",
|
||||||
|
"@com_google_absl//absl/flags:flag",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
@ -72,6 +72,16 @@ T GetRawValue(T val) {
|
|||||||
}
|
}
|
||||||
uint16 GetRawValue(Eigen::half val) { return val.x; }
|
uint16 GetRawValue(Eigen::half val) { return val.x; }
|
||||||
|
|
||||||
|
bool LiteralProtoHasValues(const LiteralProto& proto) {
|
||||||
|
return proto.preds_size() || !proto.s8s().empty() || !proto.u8s().empty() ||
|
||||||
|
proto.s32s_size() || proto.s64s_size() || proto.u32s_size() ||
|
||||||
|
proto.u64s_size() || proto.f32s_size() || proto.f64s_size() ||
|
||||||
|
proto.c64s_size() || proto.c128s_size() ||
|
||||||
|
proto.tuple_literals_size() || !proto.f16s().empty() ||
|
||||||
|
!proto.bf16s().empty() || !proto.u16s().empty() ||
|
||||||
|
!proto.s16s().empty() || proto.sparse_indices_size();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LiteralBase::~LiteralBase() {}
|
LiteralBase::~LiteralBase() {}
|
||||||
@ -288,7 +298,7 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
|
/* static */ StatusOr<Literal> MutableLiteralBase::CreateFromProto(
|
||||||
const LiteralProto& proto) {
|
const LiteralProto& proto, bool prohibit_empty_literal) {
|
||||||
if (!proto.has_shape()) {
|
if (!proto.has_shape()) {
|
||||||
return InvalidArgument("LiteralProto has no shape");
|
return InvalidArgument("LiteralProto has no shape");
|
||||||
}
|
}
|
||||||
@ -328,7 +338,13 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
|
|||||||
}
|
}
|
||||||
|
|
||||||
CHECK(piece->subshape().IsArray());
|
CHECK(piece->subshape().IsArray());
|
||||||
|
|
||||||
|
// When prohibit_empty_literal is false (allowing literal with no
|
||||||
|
// values), only copy from proto if the literal proto has values. This
|
||||||
|
// mode is used for a learned cost model.
|
||||||
|
if (prohibit_empty_literal || LiteralProtoHasValues(*proto_element)) {
|
||||||
TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
|
TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}));
|
}));
|
||||||
|
@ -732,7 +732,8 @@ class MutableLiteralBase : public LiteralBase {
|
|||||||
static Literal MoveIntoTuple(absl::Span<Literal> elements);
|
static Literal MoveIntoTuple(absl::Span<Literal> elements);
|
||||||
|
|
||||||
// Serialize from a proto.
|
// Serialize from a proto.
|
||||||
static StatusOr<Literal> CreateFromProto(const LiteralProto& proto);
|
static StatusOr<Literal> CreateFromProto(const LiteralProto& proto,
|
||||||
|
bool prohibit_empty_literal = true);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Returns the piece at the given ShapeIndex.
|
// Returns the piece at the given ShapeIndex.
|
||||||
|
@ -1847,6 +1847,30 @@ TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
|
|||||||
HasSubstr("Expected 3 elements in LiteralProto"));
|
HasSubstr("Expected 3 elements in LiteralProto"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(LiteralUtilTest, ValidProtoNoValues) {
|
||||||
|
// Proto contains a shape, but no values.
|
||||||
|
LiteralProto proto;
|
||||||
|
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
|
||||||
|
Status status =
|
||||||
|
Literal::CreateFromProto(proto, /*prohibit_empty_literal=*/false)
|
||||||
|
.status();
|
||||||
|
EXPECT_TRUE(status.ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LiteralUtilTest, ValidProtoWithClearedValues) {
|
||||||
|
auto literal = LiteralUtil::CreateR1<bool>({true, false, true});
|
||||||
|
LiteralProto proto = literal.ToProto();
|
||||||
|
EXPECT_EQ(proto.preds_size(), 3);
|
||||||
|
|
||||||
|
// Clear values.
|
||||||
|
proto.clear_preds();
|
||||||
|
EXPECT_EQ(proto.preds_size(), 0);
|
||||||
|
Status status =
|
||||||
|
Literal::CreateFromProto(proto, /*prohibit_empty_literal=*/false)
|
||||||
|
.status();
|
||||||
|
EXPECT_TRUE(status.ok());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(LiteralUtilTest, InvalidProtoNoShape) {
|
TEST_F(LiteralUtilTest, InvalidProtoNoShape) {
|
||||||
// Proto contains values, but no shape.
|
// Proto contains values, but no shape.
|
||||||
LiteralProto proto;
|
LiteralProto proto;
|
||||||
|
@ -640,16 +640,17 @@ HloComputationProto HloComputation::ToProto() const {
|
|||||||
/* static */ StatusOr<std::unique_ptr<HloComputation>>
|
/* static */ StatusOr<std::unique_ptr<HloComputation>>
|
||||||
HloComputation::CreateFromProto(
|
HloComputation::CreateFromProto(
|
||||||
const HloComputationProto& proto,
|
const HloComputationProto& proto,
|
||||||
const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
|
const absl::flat_hash_map<int64, HloComputation*>& computation_map,
|
||||||
|
bool prohibit_empty_literal) {
|
||||||
absl::flat_hash_map<int64, HloInstruction*> instruction_map;
|
absl::flat_hash_map<int64, HloInstruction*> instruction_map;
|
||||||
absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
|
absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
|
||||||
std::vector<std::unique_ptr<HloInstruction>> instructions;
|
std::vector<std::unique_ptr<HloInstruction>> instructions;
|
||||||
int64 parameter_count = 0;
|
int64 parameter_count = 0;
|
||||||
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
|
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
|
||||||
std::unique_ptr<HloInstruction> instruction,
|
HloInstruction::CreateFromProto(
|
||||||
HloInstruction::CreateFromProto(instruction_proto, instruction_map,
|
instruction_proto, instruction_map, computation_map,
|
||||||
computation_map));
|
prohibit_empty_literal));
|
||||||
if (instruction->opcode() == HloOpcode::kParameter) {
|
if (instruction->opcode() == HloOpcode::kParameter) {
|
||||||
parameter_count++;
|
parameter_count++;
|
||||||
}
|
}
|
||||||
|
@ -235,7 +235,8 @@ class HloComputation {
|
|||||||
// calls.
|
// calls.
|
||||||
static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
|
static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
|
||||||
const HloComputationProto& proto,
|
const HloComputationProto& proto,
|
||||||
const absl::flat_hash_map<int64, HloComputation*>& computation_map);
|
const absl::flat_hash_map<int64, HloComputation*>& computation_map,
|
||||||
|
bool prohibit_empty_literal = true);
|
||||||
|
|
||||||
using InstructionSequence = tensorflow::gtl::iterator_range<
|
using InstructionSequence = tensorflow::gtl::iterator_range<
|
||||||
UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>;
|
UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>;
|
||||||
|
@ -62,7 +62,8 @@ using absl::StrJoin;
|
|||||||
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||||
const HloInstructionProto& proto,
|
const HloInstructionProto& proto,
|
||||||
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
|
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
|
||||||
const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
|
const absl::flat_hash_map<int64, HloComputation*>& computation_map,
|
||||||
|
bool prohibit_empty_literal) {
|
||||||
TF_RET_CHECK(!proto.opcode().empty());
|
TF_RET_CHECK(!proto.opcode().empty());
|
||||||
HloOpcode opcode;
|
HloOpcode opcode;
|
||||||
auto opcode_or = StringToHloOpcode(proto.opcode());
|
auto opcode_or = StringToHloOpcode(proto.opcode());
|
||||||
@ -300,8 +301,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
case HloOpcode::kConstant: {
|
case HloOpcode::kConstant: {
|
||||||
// TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
|
// TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
|
||||||
if (proto.has_literal()) {
|
if (proto.has_literal()) {
|
||||||
TF_ASSIGN_OR_RETURN(auto literal,
|
TF_ASSIGN_OR_RETURN(
|
||||||
Literal::CreateFromProto(proto.literal()));
|
auto literal,
|
||||||
|
Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
|
||||||
instruction = CreateConstant(std::move(literal));
|
instruction = CreateConstant(std::move(literal));
|
||||||
// Literal's shape may have no/different tiling info.
|
// Literal's shape may have no/different tiling info.
|
||||||
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
|
TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
|
||||||
@ -314,8 +316,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
}
|
}
|
||||||
case HloOpcode::kTrace: {
|
case HloOpcode::kTrace: {
|
||||||
TF_RET_CHECK(proto.has_literal());
|
TF_RET_CHECK(proto.has_literal());
|
||||||
TF_ASSIGN_OR_RETURN(auto literal,
|
TF_ASSIGN_OR_RETURN(
|
||||||
Literal::CreateFromProto(proto.literal()));
|
auto literal,
|
||||||
|
Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
|
||||||
instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
|
instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -475,7 +475,8 @@ class HloInstruction {
|
|||||||
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
|
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
|
||||||
const HloInstructionProto& proto,
|
const HloInstructionProto& proto,
|
||||||
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
|
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
|
||||||
const absl::flat_hash_map<int64, HloComputation*>& computation_map);
|
const absl::flat_hash_map<int64, HloComputation*>& computation_map,
|
||||||
|
bool prohibit_empty_literal = true);
|
||||||
|
|
||||||
// Creates a parameter-retrieving instruction.
|
// Creates a parameter-retrieving instruction.
|
||||||
static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
|
static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
|
||||||
|
@ -294,7 +294,8 @@ Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const {
|
|||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
||||||
const HloModuleProto& proto, const HloModuleConfig& module_config) {
|
const HloModuleProto& proto, const HloModuleConfig& module_config,
|
||||||
|
bool prohibit_empty_literal) {
|
||||||
VLOG(2) << "CreateFromProto()";
|
VLOG(2) << "CreateFromProto()";
|
||||||
XLA_VLOG_LINES(3, proto.DebugString());
|
XLA_VLOG_LINES(3, proto.DebugString());
|
||||||
|
|
||||||
@ -332,7 +333,8 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
|||||||
for (const HloComputationProto& computation_proto : proto.computations()) {
|
for (const HloComputationProto& computation_proto : proto.computations()) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<HloComputation> computation,
|
std::unique_ptr<HloComputation> computation,
|
||||||
HloComputation::CreateFromProto(computation_proto, computation_map));
|
HloComputation::CreateFromProto(computation_proto, computation_map,
|
||||||
|
prohibit_empty_literal));
|
||||||
CHECK_NE(computation.get(), nullptr);
|
CHECK_NE(computation.get(), nullptr);
|
||||||
int64 computation_id = computation_proto.id();
|
int64 computation_id = computation_proto.id();
|
||||||
TF_RET_CHECK(computation_id != -1);
|
TF_RET_CHECK(computation_id != -1);
|
||||||
|
@ -219,7 +219,8 @@ class HloModule {
|
|||||||
// Convert an HloModule to or from a proto.
|
// Convert an HloModule to or from a proto.
|
||||||
HloModuleProto ToProto() const;
|
HloModuleProto ToProto() const;
|
||||||
static StatusOr<std::unique_ptr<HloModule>> CreateFromProto(
|
static StatusOr<std::unique_ptr<HloModule>> CreateFromProto(
|
||||||
const HloModuleProto& proto, const HloModuleConfig& module_config);
|
const HloModuleProto& proto, const HloModuleConfig& module_config,
|
||||||
|
bool prohibit_empty_literal = true);
|
||||||
|
|
||||||
// Creates and returns an HloModuleConfig with an appropriate program shape
|
// Creates and returns an HloModuleConfig with an appropriate program shape
|
||||||
// for the HLO module in the given proto.
|
// for the HLO module in the given proto.
|
||||||
|
@ -30,10 +30,6 @@ tf_upgrade_v2 --intree coolcode --outtree coolcode-upgraded --copyotherfiles Fal
|
|||||||
*Note: `tf_upgrade_v2` is installed automatically as a script by the pip install
|
*Note: `tf_upgrade_v2` is installed automatically as a script by the pip install
|
||||||
after TensorFlow 1.12.
|
after TensorFlow 1.12.
|
||||||
|
|
||||||
You may want to retain revision history, especially when preparing a CL:
|
|
||||||
```
|
|
||||||
g4 integrate --retroactive coolcode/... coolcode-upgraded/...
|
|
||||||
```
|
|
||||||
|
|
||||||
## Report
|
## Report
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user