From 07aae20777c8da8d1dfd33b76afa8609cb364743 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 22 Oct 2019 06:50:02 -0700 Subject: [PATCH] Move ParseHloString to VerifiedHloModule. To make this possible, add HloParser interface which only allows VerifiedHloModule to instantiate a parser and make the private HloParser class a HloParserImpl child class. Also use std::string instead of string in hlo_parser.cc. Finally add the verification at the end of the parsing instead of calling Verify() at various call sites. PiperOrigin-RevId: 276055152 Change-Id: I647e7a14e1ff9ae0aa1ba764af8718c753226e6b --- tensorflow/compiler/xla/service/hlo_parser.cc | 476 +++++++++--------- tensorflow/compiler/xla/service/hlo_parser.h | 21 +- .../compiler/xla/service/hlo_parser_test.cc | 6 +- .../conv_emitter/conv_emitter_test.cc | 3 +- .../compiler/xla/service/tuple_util_test.cc | 3 +- tensorflow/compiler/xla/tests/BUILD | 2 + .../compiler/xla/tests/hlo_test_base.cc | 3 +- .../xla/tests/local_client_test_base.cc | 3 +- .../compiler/xla/tests/verified_hlo_module.cc | 27 +- .../compiler/xla/tests/verified_hlo_module.h | 14 +- 10 files changed, 299 insertions(+), 259 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 1a701e343d7..da02b6b405d 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -15,7 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include +#include #include +#include +#include +#include #include "absl/algorithm/container.h" #include "absl/memory/memory.h" @@ -72,18 +77,18 @@ HloSchedule ScheduleFromInstructionOrder(HloModule* module) { using LinearOrMultiIndex = absl::variant>; // Parser for the HloModule::ToString() format text. -class HloParser { +class HloParserImpl : public HloParser { public: using LocTy = HloLexer::LocTy; - explicit HloParser(absl::string_view str) : lexer_(str) {} + explicit HloParserImpl(absl::string_view str) : lexer_(str) {} // Runs the parser and constructs the resulting HLO in the given (empty) - // HloModule. Returns false if an error occurred. - Status Run(HloModule* module); + // HloModule. Returns the error status in case an error occurred. + Status Run(HloModule* module) override; // Returns the error information. - string GetError() const { return StrJoin(error_, "\n"); } + std::string GetError() const { return StrJoin(error_, "\n"); } // Stand alone parsing utils for various aggregate data types. StatusOr ParseShapeOnly(); @@ -97,7 +102,7 @@ class HloParser { private: using InstrNameTable = - std::unordered_map>; + std::unordered_map>; // Returns the map from the instruction name to the instruction itself and its // location in the current scope. @@ -111,7 +116,7 @@ class HloParser { // create an instruction. This is useful when we reify parameters as they're // resolved; i.e. for ParseSingleInstruction. std::pair* FindInstruction( - const string& name, const optional& shape = nullopt); + const std::string& name, const optional& shape = nullopt); // Parse a single instruction worth of text. bool ParseSingleInstruction(HloModule* module); @@ -122,10 +127,11 @@ class HloParser { bool ParseComputations(HloModule* module); bool ParseComputation(HloComputation** entry_computation); bool ParseInstructionList(HloComputation** computation, - const string& computation_name); - bool ParseInstruction(HloComputation::Builder* builder, string* root_name); - bool ParseInstructionRhs(HloComputation::Builder* builder, const string& name, - LocTy name_loc); + const std::string& computation_name); + bool ParseInstruction(HloComputation::Builder* builder, + std::string* root_name); + bool ParseInstructionRhs(HloComputation::Builder* builder, + const std::string& name, LocTy name_loc); bool ParseControlPredecessors(HloInstruction* instruction); bool ParseLiteral(Literal* literal, const Shape& shape); bool ParseTupleLiteral(Literal* literal, const Shape& shape); @@ -222,7 +228,7 @@ class HloParser { // // Example usage: // - // std::unordered_map attrs; + // std::unordered_map attrs; // optional foo; // attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo}; // optional bar; @@ -233,25 +239,28 @@ class HloParser { // // Do something with 'bar'. // if (foo) { // If attr foo is seen, do something with 'foo'. } // - bool ParseAttributes(const std::unordered_map& attrs); + bool ParseAttributes( + const std::unordered_map& attrs); // sub_attributes ::= '{' (','? attribute)* '}' // // Usage is the same as ParseAttributes. See immediately above. - bool ParseSubAttributes(const std::unordered_map& attrs); + bool ParseSubAttributes( + const std::unordered_map& attrs); // Parses one attribute. If it has already been seen, return error. Returns // true and adds to seen_attrs on success. // // Do not call this except in ParseAttributes or ParseSubAttributes. - bool ParseAttributeHelper(const std::unordered_map& attrs, - std::unordered_set* seen_attrs); + bool ParseAttributeHelper( + const std::unordered_map& attrs, + std::unordered_set* seen_attrs); // Parses an attribute string into a protocol buffer `message`. // Since proto3 has no notion of mandatory fields, `required_attrs` gives the // set of mandatory attributes. bool ParseAttributesAsProtoMessage( - const std::unordered_set& required_attrs, + const std::unordered_set& required_attrs, tensorflow::protobuf::Message* message); // Parses one attribute. If it has already been seen, return error. Returns @@ -260,7 +269,7 @@ class HloParser { // Do not call this except in ParseAttributesAsProtoMessage. bool ParseAttributeAsProtoMessageHelper( tensorflow::protobuf::Message* message, - std::unordered_set* seen_attrs); + std::unordered_set* seen_attrs); // Parses a name and finds the corresponding hlo computation. bool ParseComputationName(HloComputation** value); @@ -282,7 +291,7 @@ class HloParser { bool ParseDomain(DomainData* domain); // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. - bool ParseDxD(const string& name, std::vector* result); + bool ParseDxD(const std::string& name, std::vector* result); // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. bool ParseWindowPad(std::vector>* pad); @@ -303,9 +312,9 @@ class HloParser { bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); bool ParseParamList(); - bool ParseName(string* result); - bool ParseAttributeName(string* result); - bool ParseString(string* result); + bool ParseName(std::string* result); + bool ParseAttributeName(std::string* result); + bool ParseString(std::string* result); bool ParseDimensionSizes(std::vector* dimension_sizes, std::vector* dynamic_dimensions); bool ParseShape(Shape* result); @@ -323,7 +332,7 @@ class HloParser { bool ParseDouble(double* result); bool ParseComplex(std::complex* result); bool ParseBool(bool* result); - bool ParseToken(TokKind kind, const string& msg); + bool ParseToken(TokKind kind, const std::string& msg); // Returns true if the current token is the beginning of a shape. bool CanBeShape(); @@ -341,11 +350,11 @@ class HloParser { // Adds the instruction to the pool. Returns false and emits an error if the // instruction already exists. - bool AddInstruction(const string& name, HloInstruction* instruction, + bool AddInstruction(const std::string& name, HloInstruction* instruction, LocTy name_loc); // Adds the computation to the pool. Returns false and emits an error if the // computation already exists. - bool AddComputation(const string& name, HloComputation* computation, + bool AddComputation(const std::string& name, HloComputation* computation, LocTy name_loc); HloLexer lexer_; @@ -374,11 +383,11 @@ class HloParser { }; // Map from the computation name to the computation itself and its location. - std::unordered_map> + std::unordered_map> computation_pool_; std::vector> computations_; - std::vector error_; + std::vector error_; // When an operand name cannot be resolved, this function is called to create // a parameter instruction with the given name and shape. It registers the @@ -386,7 +395,7 @@ class HloParser { // the newly-created instruction and the placeholder location. If `name` is // empty, this should create the parameter with a generated name. This is // supposed to be set and used only in ParseSingleInstruction. - std::function*(const string& name, + std::function*(const std::string& name, const Shape& shape)> create_missing_instruction_; }; @@ -416,26 +425,26 @@ std::vector CreateReplicaGroups( return replica_groups; } -bool HloParser::Error(LocTy loc, absl::string_view msg) { +bool HloParserImpl::Error(LocTy loc, absl::string_view msg) { auto line_col = lexer_.GetLineAndColumn(loc); const unsigned line = line_col.first; const unsigned col = line_col.second; - std::vector error_lines; + std::vector error_lines; error_lines.push_back( StrCat("was parsing ", line, ":", col, ": error: ", msg)); error_lines.emplace_back(lexer_.GetLine(loc)); - error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); + error_lines.push_back(col == 0 ? "" : StrCat(std::string(col - 1, ' '), "^")); error_.push_back(StrJoin(error_lines, "\n")); VLOG(1) << "Error: " << error_.back(); return false; } -bool HloParser::TokenError(absl::string_view msg) { +bool HloParserImpl::TokenError(absl::string_view msg) { return Error(lexer_.GetLoc(), msg); } -Status HloParser::Run(HloModule* module) { +Status HloParserImpl::Run(HloModule* module) { lexer_.Lex(); if (lexer_.GetKind() == TokKind::kw_HloModule) { // This means that the text contains a full HLO module. @@ -456,8 +465,9 @@ Status HloParser::Run(HloModule* module) { return Status::OK(); } -std::pair* HloParser::FindInstruction( - const string& name, const optional& shape) { +std::pair* +HloParserImpl::FindInstruction(const std::string& name, + const optional& shape) { std::pair* instr = nullptr; if (!name.empty()) { instr = tensorflow::gtl::FindOrNull(current_name_table(), name); @@ -490,20 +500,20 @@ std::pair* HloParser::FindInstruction( } // ::= 'HloModule' name computations -bool HloParser::ParseHloModule(HloModule* module) { +bool HloParserImpl::ParseHloModule(HloModule* module) { if (lexer_.GetKind() != TokKind::kw_HloModule) { return TokenError("expects HloModule"); } // Eat 'HloModule' lexer_.Lex(); - string name; + std::string name; if (!ParseName(&name)) { return false; } absl::optional is_scheduled; - std::unordered_map attrs; + std::unordered_map attrs; attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled}; if (!ParseAttributes(attrs)) { return false; @@ -522,7 +532,7 @@ bool HloParser::ParseHloModule(HloModule* module) { } // computations ::= (computation)+ -bool HloParser::ParseComputations(HloModule* module) { +bool HloParserImpl::ParseComputations(HloModule* module) { HloComputation* entry_computation = nullptr; do { if (!ParseComputation(&entry_computation)) { @@ -559,11 +569,11 @@ bool HloParser::ParseComputations(HloModule* module) { } // computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list -bool HloParser::ParseComputation(HloComputation** entry_computation) { +bool HloParserImpl::ParseComputation(HloComputation** entry_computation) { LocTy maybe_entry_loc = lexer_.GetLoc(); const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY); - string name; + std::string name; LocTy name_loc = lexer_.GetLoc(); if (!ParseName(&name)) { return false; @@ -604,15 +614,15 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { // instruction_list ::= '{' instruction_list1 '}' // instruction_list1 ::= (instruction)+ -bool HloParser::ParseInstructionList(HloComputation** computation, - const string& computation_name) { +bool HloParserImpl::ParseInstructionList(HloComputation** computation, + const std::string& computation_name) { Scope scope(&scoped_name_tables_); HloComputation::Builder builder(computation_name); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of instruction list.")) { return false; } - string root_name; + std::string root_name; do { if (!ParseInstruction(&builder, &root_name)) { return false; @@ -645,9 +655,9 @@ bool HloParser::ParseInstructionList(HloComputation** computation, } // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)* -bool HloParser::ParseInstruction(HloComputation::Builder* builder, - string* root_name) { - string name; +bool HloParserImpl::ParseInstruction(HloComputation::Builder* builder, + std::string* root_name) { + std::string name; LocTy maybe_root_loc = lexer_.GetLoc(); bool is_root = EatIfPresent(TokKind::kw_ROOT); @@ -667,8 +677,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, return ParseInstructionRhs(builder, name, name_loc); } -bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, - const string& name, LocTy name_loc) { +bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, + const std::string& name, + LocTy name_loc) { Shape shape; HloOpcode opcode; std::vector operands; @@ -678,7 +689,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, } // Add optional attributes. - std::unordered_map attrs; + std::unordered_map attrs; optional sharding; optional frontend_attributes; attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; @@ -694,7 +705,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional metadata; attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; - optional backend_config; + optional backend_config; attrs["backend_config"] = {/*required=*/false, AttrTy::kString, &backend_config}; optional> outer_dimension_partitions; @@ -1155,7 +1166,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, TriangularSolveOptions options; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributesAsProtoMessage( - /*required_attrs=*/std::unordered_set(), &options)) { + /*required_attrs=*/std::unordered_set(), &options)) { return false; } instruction = @@ -1179,7 +1190,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, CholeskyOptions options; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributesAsProtoMessage( - /*required_attrs=*/std::unordered_set(), &options)) { + /*required_attrs=*/std::unordered_set(), &options)) { return false; } instruction = builder->AddInstruction( @@ -1419,7 +1430,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kInfeed: { - optional config; + optional config; attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -1440,7 +1451,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kOutfeed: { - optional config; + optional config; attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { @@ -1533,7 +1544,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, break; } case HloOpcode::kCustomCall: { - optional custom_call_target; + optional custom_call_target; optional window; optional dnums; optional feature_group_count; @@ -1840,7 +1851,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, // ::= '{' (single_sharding | tuple_sharding) '}' // // tuple_sharding ::= single_sharding* (',' single_sharding)* -bool HloParser::ParseSharding(OpSharding* sharding) { +bool HloParserImpl::ParseSharding(OpSharding* sharding) { // A single sharding starts with '{' and is not followed by '{'. // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for // an empty tuple. @@ -1873,7 +1884,7 @@ bool HloParser::ParseSharding(OpSharding* sharding) { // attributes // ::= /*empty*/ // ::= attribute '=' value (',' attribute '=' value)* -bool HloParser::ParseFrontendAttributes( +bool HloParserImpl::ParseFrontendAttributes( FrontendAttributes* frontend_attributes) { CHECK(frontend_attributes != nullptr); if (!ParseToken(TokKind::kLbrace, @@ -1884,7 +1895,7 @@ bool HloParser::ParseFrontendAttributes( // empty } else { do { - string attribute; + std::string attribute; if (!ParseAttributeName(&attribute)) { return false; } @@ -1902,8 +1913,8 @@ bool HloParser::ParseFrontendAttributes( // ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? // ('devices=' ('[' dims ']')* device_list)? '}' // dims ::= int_list device_list ::= int_list -bool HloParser::ParseSingleSharding(OpSharding* sharding, - bool lbrace_pre_lexed) { +bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, + bool lbrace_pre_lexed) { if (!lbrace_pre_lexed && !ParseToken(TokKind::kLbrace, "expected '{' to start sharding attribute")) { @@ -2010,7 +2021,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, // parameter_replication ::= // '{' ('true' | 'false')* (',' ('true' | 'false'))* '}' -bool HloParser::ParseParameterReplication( +bool HloParserImpl::ParseParameterReplication( ParameterReplication* parameter_replication) { if (!ParseToken(TokKind::kLbrace, "expected '{' to start parameter_replication attribute")) { @@ -2042,7 +2053,7 @@ bool HloParser::ParseParameterReplication( // int64_elements // ::= /*empty*/ // ::= int64_val (',' int64_val)* -bool HloParser::ParseReplicaGroupsOnly( +bool HloParserImpl::ParseReplicaGroupsOnly( std::vector* replica_groups) { std::vector> result; if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, @@ -2055,9 +2066,9 @@ bool HloParser::ParseReplicaGroupsOnly( // domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ',' // 'exit=' exit_sharding '}' -bool HloParser::ParseDomain(DomainData* domain) { - std::unordered_map attrs; - optional kind; +bool HloParserImpl::ParseDomain(DomainData* domain) { + std::unordered_map attrs; + optional kind; optional entry_sharding; optional exit_sharding; attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind}; @@ -2082,7 +2093,7 @@ bool HloParser::ParseDomain(DomainData* domain) { } // '{' name+ '}' -bool HloParser::ParseInstructionNames( +bool HloParserImpl::ParseInstructionNames( std::vector* instructions) { if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of instruction name list")) { @@ -2090,7 +2101,7 @@ bool HloParser::ParseInstructionNames( } LocTy loc = lexer_.GetLoc(); do { - string name; + std::string name; if (!ParseName(&name)) { return Error(loc, "expects a instruction name"); } @@ -2105,8 +2116,9 @@ bool HloParser::ParseInstructionNames( "expects '}' at the end of instruction name list"); } -bool HloParser::SetValueInLiteral(LocTy loc, int64 value, - LinearOrMultiIndex index, Literal* literal) { +bool HloParserImpl::SetValueInLiteral(LocTy loc, int64 value, + LinearOrMultiIndex index, + Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case S8: @@ -2139,8 +2151,9 @@ bool HloParser::SetValueInLiteral(LocTy loc, int64 value, } } -bool HloParser::SetValueInLiteral(LocTy loc, double value, - LinearOrMultiIndex index, Literal* literal) { +bool HloParserImpl::SetValueInLiteral(LocTy loc, double value, + LinearOrMultiIndex index, + Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case F16: @@ -2158,8 +2171,9 @@ bool HloParser::SetValueInLiteral(LocTy loc, double value, } } -bool HloParser::SetValueInLiteral(LocTy loc, bool value, - LinearOrMultiIndex index, Literal* literal) { +bool HloParserImpl::SetValueInLiteral(LocTy loc, bool value, + LinearOrMultiIndex index, + Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case PRED: @@ -2170,8 +2184,9 @@ bool HloParser::SetValueInLiteral(LocTy loc, bool value, } } -bool HloParser::SetValueInLiteral(LocTy loc, std::complex value, - LinearOrMultiIndex index, Literal* literal) { +bool HloParserImpl::SetValueInLiteral(LocTy loc, std::complex value, + LinearOrMultiIndex index, + Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case C64: @@ -2187,18 +2202,18 @@ bool HloParser::SetValueInLiteral(LocTy loc, std::complex value, } template -string StringifyValue(T val) { +std::string StringifyValue(T val) { return StrCat(val); } template <> -string StringifyValue(std::complex val) { +std::string StringifyValue(std::complex val) { return StrFormat("(%f, %f)", std::real(val), std::imag(val)); } template -bool HloParser::SetValueInLiteralHelper(LocTy loc, ParsedElemT value, - LinearOrMultiIndex index, - Literal* literal) { +bool HloParserImpl::SetValueInLiteralHelper(LocTy loc, ParsedElemT value, + LinearOrMultiIndex index, + Literal* literal) { if (!CheckParsedValueIsInRange(loc, value)) { return false; } @@ -2218,7 +2233,7 @@ bool HloParser::SetValueInLiteralHelper(LocTy loc, ParsedElemT value, auto* multi_index = absl::get_if>(&index); CHECK(multi_index != nullptr); - auto invalid_idx = [&](string msg) { + auto invalid_idx = [&](std::string msg) { return Error(loc, StrFormat("Invalid sparse index [%s]. %s", absl::StrJoin(*multi_index, ", "), msg)); }; @@ -2251,7 +2266,7 @@ bool HloParser::SetValueInLiteralHelper(LocTy loc, ParsedElemT value, // literal // ::= tuple // ::= non_tuple -bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { +bool HloParserImpl::ParseLiteral(Literal* literal, const Shape& shape) { return shape.IsTuple() ? ParseTupleLiteral(literal, shape) : ParseNonTupleLiteral(literal, shape); } @@ -2261,7 +2276,7 @@ bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) { // literal_list // ::= /*empty*/ // ::= literal (',' literal)* -bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) { +bool HloParserImpl::ParseTupleLiteral(Literal* literal, const Shape& shape) { if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) { return false; } @@ -2291,7 +2306,7 @@ bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) { // ::= rank01 // ::= rank2345 // rank2345 ::= shape sparse_or_nested_array -bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { +bool HloParserImpl::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { if (LayoutUtil::IsSparseArray(shape)) { return ParseSparseLiteral(literal, shape); } @@ -2300,7 +2315,7 @@ bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) { return ParseDenseLiteral(literal, shape); } -bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { +bool HloParserImpl::ParseDenseLiteral(Literal* literal, const Shape& shape) { // Cast `rank` to int because we call shape.dimensions(int rank) below, and if // `rank` is an int64, that's an implicit narrowing conversion, which is // implementation-defined behavior. @@ -2319,12 +2334,12 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { // sub-array is supposed to contain exactly 3 elements, so check if // elems_seen_per_dim[1] is 3. std::vector elems_seen_per_dim(rank); - auto get_index_str = [&elems_seen_per_dim](int dim) -> string { + auto get_index_str = [&elems_seen_per_dim](int dim) -> std::string { std::vector elems_seen_until_dim(elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", StrJoin(elems_seen_until_dim, ",", - [](string* out, const int64& num_elems) { + [](std::string* out, const int64 num_elems) { StrAppend(out, num_elems - 1); }), "]"); @@ -2476,7 +2491,7 @@ bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) { return true; } -bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { +bool HloParserImpl::ParseSparseLiteral(Literal* literal, const Shape& shape) { *literal = Literal(shape); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of a sparse literal")) { @@ -2569,7 +2584,7 @@ bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) { } // MaxFiniteValue is a type-traits helper used by -// HloParser::CheckParsedValueIsInRange. +// HloParserImpl::CheckParsedValueIsInRange. template struct MinMaxFiniteValue { static T max() { return std::numeric_limits::max(); } @@ -2592,7 +2607,7 @@ struct MinMaxFiniteValue { }; template -bool HloParser::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) { +bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) { PrimitiveType literal_ty = primitive_util::NativeToPrimitiveType(); if (std::isnan(value) || @@ -2632,8 +2647,8 @@ bool HloParser::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) { } template -bool HloParser::CheckParsedValueIsInRange(LocTy loc, - std::complex value) { +bool HloParserImpl::CheckParsedValueIsInRange(LocTy loc, + std::complex value) { // e.g. `float` for std::complex using LiteralComplexComponentT = decltype(std::real(std::declval())); @@ -2675,7 +2690,7 @@ bool HloParser::CheckParsedValueIsInRange(LocTy loc, // ::= /*empty*/ // ::= operand (, operand)* // operand ::= (shape)? name -bool HloParser::ParseOperands(std::vector* operands) { +bool HloParserImpl::ParseOperands(std::vector* operands) { CHECK(operands != nullptr); if (!ParseToken(TokKind::kLparen, "expects '(' at the beginning of operands")) { @@ -2686,7 +2701,7 @@ bool HloParser::ParseOperands(std::vector* operands) { } else { do { LocTy loc = lexer_.GetLoc(); - string name; + std::string name; optional shape; if (CanBeShape()) { shape.emplace(); @@ -2721,8 +2736,8 @@ bool HloParser::ParseOperands(std::vector* operands) { return ParseToken(TokKind::kRparen, "expects ')' at the end of operands"); } -bool HloParser::ParseOperands(std::vector* operands, - const int expected_size) { +bool HloParserImpl::ParseOperands(std::vector* operands, + const int expected_size) { CHECK(operands != nullptr); LocTy loc = lexer_.GetLoc(); if (!ParseOperands(operands)) { @@ -2736,13 +2751,13 @@ bool HloParser::ParseOperands(std::vector* operands, } // sub_attributes ::= '{' (','? attribute)* '}' -bool HloParser::ParseSubAttributes( - const std::unordered_map& attrs) { +bool HloParserImpl::ParseSubAttributes( + const std::unordered_map& attrs) { LocTy loc = lexer_.GetLoc(); if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) { return false; } - std::unordered_set seen_attrs; + std::unordered_set seen_attrs; if (lexer_.GetKind() == TokKind::kRbrace) { // empty } else { @@ -2765,10 +2780,10 @@ bool HloParser::ParseSubAttributes( } // attributes ::= (',' attribute)* -bool HloParser::ParseAttributes( - const std::unordered_map& attrs) { +bool HloParserImpl::ParseAttributes( + const std::unordered_map& attrs) { LocTy loc = lexer_.GetLoc(); - std::unordered_set seen_attrs; + std::unordered_set seen_attrs; while (EatIfPresent(TokKind::kComma)) { if (!ParseAttributeHelper(attrs, &seen_attrs)) { return false; @@ -2785,11 +2800,11 @@ bool HloParser::ParseAttributes( return true; } -bool HloParser::ParseAttributeHelper( - const std::unordered_map& attrs, - std::unordered_set* seen_attrs) { +bool HloParserImpl::ParseAttributeHelper( + const std::unordered_map& attrs, + std::unordered_set* seen_attrs) { LocTy loc = lexer_.GetLoc(); - string name; + std::string name; if (!ParseAttributeName(&name)) { return Error(loc, "error parsing attributes"); } @@ -2799,16 +2814,17 @@ bool HloParser::ParseAttributeHelper( } auto attr_it = attrs.find(name); if (attr_it == attrs.end()) { - string allowed_attrs; + std::string allowed_attrs; if (attrs.empty()) { allowed_attrs = "No attributes are allowed here."; } else { - allowed_attrs = StrCat( - "Allowed attributes: ", - StrJoin(attrs, ", ", - [&](string* out, const std::pair& kv) { - StrAppend(out, kv.first); - })); + allowed_attrs = + StrCat("Allowed attributes: ", + StrJoin(attrs, ", ", + [&](std::string* out, + const std::pair& kv) { + StrAppend(out, kv.first); + })); } return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name, allowed_attrs)); @@ -2991,11 +3007,11 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kString: { - string result; + std::string result; if (!ParseString(&result)) { return false; } - static_cast*>(attr_out_ptr)->emplace(result); + static_cast*>(attr_out_ptr)->emplace(result); return true; } case AttrTy::kMetadata: { @@ -3046,18 +3062,18 @@ bool HloParser::ParseAttributeHelper( } // attributes ::= (',' attribute)* -bool HloParser::ParseAttributesAsProtoMessage( - const std::unordered_set& required_attrs, +bool HloParserImpl::ParseAttributesAsProtoMessage( + const std::unordered_set& required_attrs, tensorflow::protobuf::Message* message) { LocTy loc = lexer_.GetLoc(); - std::unordered_set seen_attrs; + std::unordered_set seen_attrs; while (EatIfPresent(TokKind::kComma)) { if (!ParseAttributeAsProtoMessageHelper(message, &seen_attrs)) { return false; } } // Check that all required attrs were seen. - for (const string& attr : required_attrs) { + for (const std::string& attr : required_attrs) { if (seen_attrs.find(attr) == seen_attrs.end()) { return Error(loc, StrFormat("attribute %s is expected but not seen", attr)); @@ -3066,11 +3082,11 @@ bool HloParser::ParseAttributesAsProtoMessage( return true; } -bool HloParser::ParseAttributeAsProtoMessageHelper( +bool HloParserImpl::ParseAttributeAsProtoMessageHelper( tensorflow::protobuf::Message* message, - std::unordered_set* seen_attrs) { + std::unordered_set* seen_attrs) { LocTy loc = lexer_.GetLoc(); - string name; + std::string name; if (!ParseAttributeName(&name)) { return Error(loc, "error parsing attributes"); } @@ -3082,7 +3098,7 @@ bool HloParser::ParseAttributeAsProtoMessageHelper( const tensorflow::protobuf::FieldDescriptor* fd = descriptor->FindFieldByName(name); if (!fd) { - string allowed_attrs = "Allowed attributes: "; + std::string allowed_attrs = "Allowed attributes: "; for (int i = 0; i < descriptor->field_count(); ++i) { if (i == 0) { @@ -3111,7 +3127,7 @@ bool HloParser::ParseAttributeAsProtoMessageHelper( return TokenError( StrFormat("expects %s type", fd->enum_type()->name())); } - string val = lexer_.GetStrVal(); + std::string val = lexer_.GetStrVal(); const tensorflow::protobuf::EnumValueDescriptor* evd = fd->enum_type()->FindValueByName(val); if (evd == nullptr) { @@ -3134,8 +3150,8 @@ bool HloParser::ParseAttributeAsProtoMessageHelper( return true; } -bool HloParser::ParseComputationName(HloComputation** value) { - string name; +bool HloParserImpl::ParseComputationName(HloComputation** value) { + std::string name; LocTy loc = lexer_.GetLoc(); if (!ParseName(&name)) { return Error(loc, "expects computation name"); @@ -3152,7 +3168,7 @@ bool HloParser::ParseComputationName(HloComputation** value) { // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}' // The subattributes can appear in any order. 'size=' is required, others are // optional. -bool HloParser::ParseWindow(Window* window, bool expect_outer_curlies) { +bool HloParserImpl::ParseWindow(Window* window, bool expect_outer_curlies) { LocTy loc = lexer_.GetLoc(); if (expect_outer_curlies && !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { @@ -3169,7 +3185,7 @@ bool HloParser::ParseWindow(Window* window, bool expect_outer_curlies) { expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof; while (lexer_.GetKind() != end_token) { LocTy attr_loc = lexer_.GetLoc(); - string field_name; + std::string field_name; if (!ParseAttributeName(&field_name)) { return Error(attr_loc, "expects sub-attributes in window"); } @@ -3236,22 +3252,22 @@ bool HloParser::ParseWindow(Window* window, bool expect_outer_curlies) { } // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString. -// The string looks like "dim_labels=0bf_0io->0bf". -bool HloParser::ParseConvolutionDimensionNumbers( +// Thestring looks like "dim_labels=0bf_0io->0bf". +bool HloParserImpl::ParseConvolutionDimensionNumbers( ConvolutionDimensionNumbers* dnums) { if (lexer_.GetKind() != TokKind::kDimLabels) { return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'"); } - string str = lexer_.GetStrVal(); + std::string str = lexer_.GetStrVal(); // The str is expected to have 3 items, lhs, rhs, out, and it must look like // lhs_rhs->out, that is, the first separator is "_" and the second is "->". - std::vector split1 = absl::StrSplit(str, '_'); + std::vector split1 = absl::StrSplit(str, '_'); if (split1.size() != 2) { LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " << str; } - std::vector split2 = absl::StrSplit(split1[1], "->"); + std::vector split2 = absl::StrSplit(split1[1], "->"); if (split2.size() != 2) { LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " << str; @@ -3269,14 +3285,14 @@ bool HloParser::ParseConvolutionDimensionNumbers( return TokenError("convolution rank must >=2"); } - auto is_unique = [](string str) -> bool { + auto is_unique = [](std::string str) -> bool { absl::c_sort(str); return std::unique(str.begin(), str.end()) == str.end(); }; // lhs { - if (!is_unique(string(lhs))) { + if (!is_unique(std::string(lhs))) { return TokenError( StrCat("expects unique lhs dimension numbers, but sees ", lhs)); } @@ -3299,7 +3315,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( } // rhs { - if (!is_unique(string(rhs))) { + if (!is_unique(std::string(rhs))) { return TokenError( StrCat("expects unique rhs dimension numbers, but sees ", rhs)); } @@ -3322,7 +3338,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( } // output { - if (!is_unique(string(out))) { + if (!is_unique(std::string(out))) { return TokenError( StrCat("expects unique output dimension numbers, but sees ", out)); } @@ -3367,7 +3383,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( // // {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}} // -bool HloParser::ParseSliceRanges(SliceRanges* result) { +bool HloParserImpl::ParseSliceRanges(SliceRanges* result) { if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) { return false; } @@ -3404,7 +3420,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { // precision_elements // ::= /*empty*/ // ::= precision_val (delim precision_val)* -bool HloParser::ParsePrecisionList( +bool HloParserImpl::ParsePrecisionList( std::vector* result) { auto parse_and_add_item = [&]() { PrecisionConfig::Precision item; @@ -3418,7 +3434,7 @@ bool HloParser::ParsePrecisionList( parse_and_add_item); } -bool HloParser::ParseHloComputation(HloComputation** result) { +bool HloParserImpl::ParseHloComputation(HloComputation** result) { if (lexer_.GetKind() == TokKind::kLbrace) { // This means it is a nested computation. return ParseInstructionList(result, /*computation_name=*/"_"); @@ -3427,7 +3443,8 @@ bool HloParser::ParseHloComputation(HloComputation** result) { return ParseComputationName(result); } -bool HloParser::ParseHloComputationList(std::vector* result) { +bool HloParserImpl::ParseHloComputationList( + std::vector* result) { auto parse_and_add_item = [&]() { HloComputation* computation; if (!ParseHloComputation(&computation)) { @@ -3445,7 +3462,7 @@ bool HloParser::ParseHloComputationList(std::vector* result) { // precision_elements // ::= /*empty*/ // ::= shape (',' shape)* -bool HloParser::ParseShapeList(std::vector* result) { +bool HloParserImpl::ParseShapeList(std::vector* result) { auto parse_and_add_item = [&]() { Shape shape; if (!ParseShape(&shape)) { @@ -3462,9 +3479,9 @@ bool HloParser::ParseShapeList(std::vector* result) { // int64_elements // ::= /*empty*/ // ::= int64_val (delim int64_val)* -bool HloParser::ParseInt64List(const TokKind start, const TokKind end, - const TokKind delim, - std::vector* result) { +bool HloParserImpl::ParseInt64List(const TokKind start, const TokKind end, + const TokKind delim, + std::vector* result) { auto parse_and_add_item = [&]() { int64 i; if (!ParseInt64(&i)) { @@ -3484,9 +3501,9 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, // int64_elements // ::= /*empty*/ // ::= int64_val (delim int64_val)* -bool HloParser::ParseInt64ListList(const TokKind start, const TokKind end, - const TokKind delim, - std::vector>* result) { +bool HloParserImpl::ParseInt64ListList( + const TokKind start, const TokKind end, const TokKind delim, + std::vector>* result) { auto parse_and_add_item = [&]() { std::vector item; if (!ParseInt64List(start, end, delim, &item)) { @@ -3498,9 +3515,9 @@ bool HloParser::ParseInt64ListList(const TokKind start, const TokKind end, return ParseList(start, end, delim, parse_and_add_item); } -bool HloParser::ParseList(const TokKind start, const TokKind end, - const TokKind delim, - const std::function& parse_and_add_item) { +bool HloParserImpl::ParseList(const TokKind start, const TokKind end, + const TokKind delim, + const std::function& parse_and_add_item) { if (!ParseToken(start, StrCat("expects a list starting with ", TokKindToString(start)))) { return false; @@ -3519,7 +3536,7 @@ bool HloParser::ParseList(const TokKind start, const TokKind end, } // param_list_to_shape ::= param_list '->' shape -bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) { +bool HloParserImpl::ParseParamListToShape(Shape* shape, LocTy* shape_loc) { if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) { return false; } @@ -3527,7 +3544,7 @@ bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) { return ParseShape(shape); } -bool HloParser::CanBeParamListToShape() { +bool HloParserImpl::CanBeParamListToShape() { return lexer_.GetKind() == TokKind::kLparen; } @@ -3536,7 +3553,7 @@ bool HloParser::CanBeParamListToShape() { // ::= /*empty*/ // ::= param (',' param)* // param ::= name shape -bool HloParser::ParseParamList() { +bool HloParserImpl::ParseParamList() { if (!ParseToken(TokKind::kLparen, "expects '(' at the beginning of param list")) { return false; @@ -3547,7 +3564,7 @@ bool HloParser::ParseParamList() { } else { do { Shape shape; - string name; + std::string name; if (!ParseName(&name) || !ParseShape(&shape)) { return false; } @@ -3561,8 +3578,8 @@ bool HloParser::ParseParamList() { // ::= /*empty*/ // ::= <=? int64 (',' param)* // param ::= name shape -bool HloParser::ParseDimensionSizes(std::vector* dimension_sizes, - std::vector* dynamic_dimensions) { +bool HloParserImpl::ParseDimensionSizes(std::vector* dimension_sizes, + std::vector* dynamic_dimensions) { auto parse_and_add_item = [&]() { int64 i; bool is_dynamic = false; @@ -3587,7 +3604,7 @@ bool HloParser::ParseDimensionSizes(std::vector* dimension_sizes, // dim_list // ::= /*empty*/ // ::= (int64 | '*') (',' (int64 | '*'))* -bool HloParser::ParseTiles(std::vector* tiles) { +bool HloParserImpl::ParseTiles(std::vector* tiles) { auto parse_and_add_tile_dimension = [&]() { tensorflow::int64 i; if (ParseInt64(&i)) { @@ -3619,8 +3636,8 @@ bool HloParser::ParseTiles(std::vector* tiles) { // ::= 'E' | 'S' // attr_value // ::= int64 -bool HloParser::ParseLayoutIntAttribute(int64* attr_value, - absl::string_view attr_description) { +bool HloParserImpl::ParseLayoutIntAttribute( + int64* attr_value, absl::string_view attr_description) { if (!ParseToken(TokKind::kLparen, StrCat("expects ", attr_description, " to start with ", TokKindToString(TokKind::kLparen)))) { @@ -3644,7 +3661,7 @@ bool HloParser::ParseLayoutIntAttribute(int64* attr_value, // memory_space // ::= /*empty*/ // ::= 'S' '(' int64 ')' -bool HloParser::ParseLayout(Layout* layout) { +bool HloParserImpl::ParseLayout(Layout* layout) { std::vector minor_to_major; std::vector tiles; tensorflow::int64 element_size_in_bits = 0; @@ -3712,7 +3729,7 @@ bool HloParser::ParseLayout(Layout* layout) { // tuple_elements // ::= /*empty*/ // ::= shape (',' shape)* -bool HloParser::ParseShape(Shape* result) { +bool HloParserImpl::ParseShape(Shape* result) { if (EatIfPresent(TokKind::kLparen)) { // Tuple std::vector shapes; if (lexer_.GetKind() == TokKind::kRparen) { @@ -3753,7 +3770,7 @@ bool HloParser::ParseShape(Shape* result) { if (lexer_.GetKind() == TokKind::kw_sparse) { lexer_.Lex(); - const string message = + const std::string message = "expects a brace-bracketed integer for sparse layout"; int64 max_sparse_elements; if (!ParseToken(TokKind::kLbrace, message) || @@ -3794,14 +3811,14 @@ bool HloParser::ParseShape(Shape* result) { return true; } -bool HloParser::CanBeShape() { +bool HloParserImpl::CanBeShape() { // A non-tuple shape starts with a kPrimitiveType token; a tuple shape starts // with '('. return lexer_.GetKind() == TokKind::kPrimitiveType || lexer_.GetKind() == TokKind::kLparen; } -bool HloParser::ParseName(string* result) { +bool HloParserImpl::ParseName(std::string* result) { VLOG(3) << "ParseName"; if (lexer_.GetKind() != TokKind::kIdent && lexer_.GetKind() != TokKind::kName) { @@ -3812,7 +3829,7 @@ bool HloParser::ParseName(string* result) { return true; } -bool HloParser::ParseAttributeName(string* result) { +bool HloParserImpl::ParseAttributeName(std::string* result) { if (lexer_.GetKind() != TokKind::kAttributeName) { return TokenError("expects attribute name"); } @@ -3821,7 +3838,7 @@ bool HloParser::ParseAttributeName(string* result) { return true; } -bool HloParser::ParseString(string* result) { +bool HloParserImpl::ParseString(std::string* result) { VLOG(3) << "ParseString"; if (lexer_.GetKind() != TokKind::kString) { return TokenError("expects string"); @@ -3831,7 +3848,8 @@ bool HloParser::ParseString(string* result) { return true; } -bool HloParser::ParseDxD(const string& name, std::vector* result) { +bool HloParserImpl::ParseDxD(const std::string& name, + std::vector* result) { LocTy loc = lexer_.GetLoc(); if (!result->empty()) { return Error(loc, StrFormat("sub-attribute '%s=' already exists", name)); @@ -3847,7 +3865,7 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { } // 2D or higher. if (lexer_.GetKind() == TokKind::kDxD) { - string str = lexer_.GetStrVal(); + std::string str = lexer_.GetStrVal(); if (!SplitToInt64s(str, 'x', result)) { return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name)); } @@ -3857,7 +3875,7 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { return TokenError("expects token type kInt or kDxD"); } -bool HloParser::ParseWindowPad(std::vector>* pad) { +bool HloParserImpl::ParseWindowPad(std::vector>* pad) { LocTy loc = lexer_.GetLoc(); if (!pad->empty()) { return Error(loc, "sub-attribute 'pad=' already exists"); @@ -3865,7 +3883,7 @@ bool HloParser::ParseWindowPad(std::vector>* pad) { if (lexer_.GetKind() != TokKind::kPad) { return TokenError("expects window pad pattern, e.g., '0_0x3_3'"); } - string str = lexer_.GetStrVal(); + std::string str = lexer_.GetStrVal(); for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { std::vector low_high; if (!SplitToInt64s(padding_dim_str, '_', &low_high) || @@ -3883,12 +3901,12 @@ bool HloParser::ParseWindowPad(std::vector>* pad) { // looks like "0_0_0x3_3_1". The string is first separated by 'x', each // substring represents one PaddingConfigDimension. The substring is 3 (or 2) // numbers joined by '_'. -bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { +bool HloParserImpl::ParsePaddingConfig(PaddingConfig* padding) { if (lexer_.GetKind() != TokKind::kPad) { return TokenError("expects padding config, e.g., '0_0_0x3_3_1'"); } LocTy loc = lexer_.GetLoc(); - string str = lexer_.GetStrVal(); + std::string str = lexer_.GetStrVal(); for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { std::vector padding_dim; if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) || @@ -3907,11 +3925,11 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { } // '{' metadata_string '}' -bool HloParser::ParseMetadata(OpMetadata* metadata) { - std::unordered_map attrs; - optional op_type; - optional op_name; - optional source_file; +bool HloParserImpl::ParseMetadata(OpMetadata* metadata) { + std::unordered_map attrs; + optional op_type; + optional op_name; + optional source_file; optional source_line; attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type}; attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name}; @@ -3935,12 +3953,12 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) { return true; } -bool HloParser::ParseOpcode(HloOpcode* result) { +bool HloParserImpl::ParseOpcode(HloOpcode* result) { VLOG(3) << "ParseOpcode"; if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects opcode"); } - string val = lexer_.GetStrVal(); + std::string val = lexer_.GetStrVal(); auto status_or_result = StringToHloOpcode(val); if (!status_or_result.ok()) { return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val, @@ -3951,12 +3969,12 @@ bool HloParser::ParseOpcode(HloOpcode* result) { return true; } -bool HloParser::ParseFftType(FftType* result) { +bool HloParserImpl::ParseFftType(FftType* result) { VLOG(3) << "ParseFftType"; if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects fft type"); } - string val = lexer_.GetStrVal(); + std::string val = lexer_.GetStrVal(); if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) { return TokenError(StrFormat("expects fft type but sees: %s", val)); } @@ -3964,12 +3982,12 @@ bool HloParser::ParseFftType(FftType* result) { return true; } -bool HloParser::ParseComparisonDirection(ComparisonDirection* result) { +bool HloParserImpl::ParseComparisonDirection(ComparisonDirection* result) { VLOG(1) << "ParseComparisonDirection"; if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects comparison direction"); } - string val = lexer_.GetStrVal(); + std::string val = lexer_.GetStrVal(); auto status_or_result = StringToComparisonDirection(val); if (!status_or_result.ok()) { return TokenError( @@ -3980,12 +3998,12 @@ bool HloParser::ParseComparisonDirection(ComparisonDirection* result) { return true; } -bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { +bool HloParserImpl::ParseFusionKind(HloInstruction::FusionKind* result) { VLOG(3) << "ParseFusionKind"; if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects fusion kind"); } - string val = lexer_.GetStrVal(); + std::string val = lexer_.GetStrVal(); auto status_or_result = StringToFusionKind(val); if (!status_or_result.ok()) { return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s", @@ -3997,12 +4015,12 @@ bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { return true; } -bool HloParser::ParseRandomDistribution(RandomDistribution* result) { +bool HloParserImpl::ParseRandomDistribution(RandomDistribution* result) { VLOG(3) << "ParseRandomDistribution"; if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects random distribution"); } - string val = lexer_.GetStrVal(); + std::string val = lexer_.GetStrVal(); auto status_or_result = StringToRandomDistribution(val); if (!status_or_result.ok()) { return TokenError( @@ -4014,12 +4032,12 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { return true; } -bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) { +bool HloParserImpl::ParsePrecision(PrecisionConfig::Precision* result) { VLOG(3) << "ParsePrecision"; if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects random distribution"); } - string val = lexer_.GetStrVal(); + std::string val = lexer_.GetStrVal(); auto status_or_result = StringToPrecision(val); if (!status_or_result.ok()) { return TokenError(StrFormat("expects precision but sees: %s, error: %s", @@ -4031,7 +4049,7 @@ bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) { return true; } -bool HloParser::ParseInt64(int64* result) { +bool HloParserImpl::ParseInt64(int64* result) { VLOG(3) << "ParseInt64"; if (lexer_.GetKind() != TokKind::kInt) { return TokenError("expects integer"); @@ -4041,7 +4059,7 @@ bool HloParser::ParseInt64(int64* result) { return true; } -bool HloParser::ParseDouble(double* result) { +bool HloParserImpl::ParseDouble(double* result) { switch (lexer_.GetKind()) { case TokKind::kDecimal: { double val = lexer_.GetDecimalVal(); @@ -4074,7 +4092,7 @@ bool HloParser::ParseDouble(double* result) { return true; } -bool HloParser::ParseComplex(std::complex* result) { +bool HloParserImpl::ParseComplex(std::complex* result) { if (lexer_.GetKind() != TokKind::kLparen) { return TokenError("expects '(' before complex number"); } @@ -4110,7 +4128,7 @@ bool HloParser::ParseComplex(std::complex* result) { return true; } -bool HloParser::ParseBool(bool* result) { +bool HloParserImpl::ParseBool(bool* result) { if (lexer_.GetKind() != TokKind::kw_true && lexer_.GetKind() != TokKind::kw_false) { return TokenError("expects true or false"); @@ -4120,7 +4138,7 @@ bool HloParser::ParseBool(bool* result) { return true; } -bool HloParser::ParseToken(TokKind kind, const string& msg) { +bool HloParserImpl::ParseToken(TokKind kind, const std::string& msg) { VLOG(3) << "ParseToken " << TokKindToString(kind) << " " << msg; if (lexer_.GetKind() != kind) { return TokenError(msg); @@ -4129,7 +4147,7 @@ bool HloParser::ParseToken(TokKind kind, const string& msg) { return true; } -bool HloParser::EatIfPresent(TokKind kind) { +bool HloParserImpl::EatIfPresent(TokKind kind) { if (lexer_.GetKind() != kind) { return false; } @@ -4137,8 +4155,9 @@ bool HloParser::EatIfPresent(TokKind kind) { return true; } -bool HloParser::AddInstruction(const string& name, HloInstruction* instruction, - LocTy name_loc) { +bool HloParserImpl::AddInstruction(const std::string& name, + HloInstruction* instruction, + LocTy name_loc) { auto result = current_name_table().insert({name, {instruction, name_loc}}); if (!result.second) { Error(name_loc, StrCat("instruction already exists: ", name)); @@ -4148,8 +4167,9 @@ bool HloParser::AddInstruction(const string& name, HloInstruction* instruction, return true; } -bool HloParser::AddComputation(const string& name, HloComputation* computation, - LocTy name_loc) { +bool HloParserImpl::AddComputation(const std::string& name, + HloComputation* computation, + LocTy name_loc) { auto result = computation_pool_.insert({name, {computation, name_loc}}); if (!result.second) { Error(name_loc, StrCat("computation already exists: ", name)); @@ -4159,7 +4179,7 @@ bool HloParser::AddComputation(const string& name, HloComputation* computation, return true; } -StatusOr HloParser::ParseShapeOnly() { +StatusOr HloParserImpl::ParseShapeOnly() { lexer_.Lex(); Shape shape; if (!ParseShape(&shape)) { @@ -4171,7 +4191,7 @@ StatusOr HloParser::ParseShapeOnly() { return shape; } -StatusOr HloParser::ParseShardingOnly() { +StatusOr HloParserImpl::ParseShardingOnly() { lexer_.Lex(); OpSharding op_sharding; if (!ParseSharding(&op_sharding)) { @@ -4183,7 +4203,7 @@ StatusOr HloParser::ParseShardingOnly() { return HloSharding::FromProto(op_sharding); } -StatusOr HloParser::ParseFrontendAttributesOnly() { +StatusOr HloParserImpl::ParseFrontendAttributesOnly() { lexer_.Lex(); FrontendAttributes attributes; if (!ParseFrontendAttributes(&attributes)) { @@ -4196,7 +4216,7 @@ StatusOr HloParser::ParseFrontendAttributesOnly() { return attributes; } -StatusOr> HloParser::ParseParameterReplicationOnly() { +StatusOr> HloParserImpl::ParseParameterReplicationOnly() { lexer_.Lex(); ParameterReplication parameter_replication; if (!ParseParameterReplication(¶meter_replication)) { @@ -4211,7 +4231,7 @@ StatusOr> HloParser::ParseParameterReplicationOnly() { parameter_replication.replicated_at_leaf_buffers().end()); } -StatusOr> HloParser::ParseReplicaGroupsOnly() { +StatusOr> HloParserImpl::ParseReplicaGroupsOnly() { lexer_.Lex(); std::vector replica_groups; if (!ParseReplicaGroupsOnly(&replica_groups)) { @@ -4223,7 +4243,7 @@ StatusOr> HloParser::ParseReplicaGroupsOnly() { return replica_groups; } -StatusOr HloParser::ParseWindowOnly() { +StatusOr HloParserImpl::ParseWindowOnly() { lexer_.Lex(); Window window; if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) { @@ -4236,7 +4256,7 @@ StatusOr HloParser::ParseWindowOnly() { } StatusOr -HloParser::ParseConvolutionDimensionNumbersOnly() { +HloParserImpl::ParseConvolutionDimensionNumbersOnly() { lexer_.Lex(); ConvolutionDimensionNumbers dnums; if (!ParseConvolutionDimensionNumbers(&dnums)) { @@ -4249,7 +4269,7 @@ HloParser::ParseConvolutionDimensionNumbersOnly() { return dnums; } -StatusOr HloParser::ParsePaddingConfigOnly() { +StatusOr HloParserImpl::ParsePaddingConfigOnly() { lexer_.Lex(); PaddingConfig padding_config; if (!ParsePaddingConfig(&padding_config)) { @@ -4261,7 +4281,7 @@ StatusOr HloParser::ParsePaddingConfigOnly() { return padding_config; } -bool HloParser::ParseSingleInstruction(HloModule* module) { +bool HloParserImpl::ParseSingleInstruction(HloModule* module) { if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) { LOG(FATAL) << "Parser state is not clean. Please do not call any other " "methods before calling ParseSingleInstruction."; @@ -4273,9 +4293,9 @@ bool HloParser::ParseSingleInstruction(HloModule* module) { int64 parameter_count = 0; create_missing_instruction_ = [this, &builder, ¶meter_count]( - const string& name, + const std::string& name, const Shape& shape) -> std::pair* { - string new_name = name.empty() ? StrCat("_", parameter_count) : name; + std::string new_name = name.empty() ? StrCat("_", parameter_count) : name; HloInstruction* parameter = builder.AddInstruction( HloInstruction::CreateParameter(parameter_count++, shape, new_name)); current_name_table()[new_name] = {parameter, lexer_.GetLoc()}; @@ -4296,7 +4316,7 @@ bool HloParser::ParseSingleInstruction(HloModule* module) { // This means that the instruction's left-hand side might exist, e.g. // // foo = f32[10] fusion(...), calls={...} - string root_name; + std::string root_name; if (!ParseInstruction(&builder, &root_name)) { return false; } @@ -4322,7 +4342,7 @@ bool HloParser::ParseSingleInstruction(HloModule* module) { StatusOr> ParseAndReturnUnverifiedModule( absl::string_view str, const HloModuleConfig& config) { auto module = absl::make_unique(/*name=*/"_", config); - HloParser parser(str); + HloParserImpl parser(str); TF_RETURN_IF_ERROR(parser.Run(module.get())); return std::move(module); } @@ -4332,53 +4352,51 @@ StatusOr> ParseAndReturnUnverifiedModule( return ParseAndReturnUnverifiedModule(str, HloModuleConfig()); } -Status ParseHloString(absl::string_view str, HloModule* module) { - TF_RET_CHECK(module->computation_count() == 0); - HloParser parser(str); - TF_RETURN_IF_ERROR(parser.Run(module)); - return Status::OK(); -} - StatusOr ParseSharding(absl::string_view str) { - HloParser parser(str); + HloParserImpl parser(str); return parser.ParseShardingOnly(); } StatusOr ParseFrontendAttributes(absl::string_view str) { - HloParser parser(str); + HloParserImpl parser(str); return parser.ParseFrontendAttributesOnly(); } StatusOr> ParseParameterReplication(absl::string_view str) { - HloParser parser(str); + HloParserImpl parser(str); return parser.ParseParameterReplicationOnly(); } StatusOr> ParseReplicaGroupsOnly( absl::string_view str) { - HloParser parser(str); + HloParserImpl parser(str); return parser.ParseReplicaGroupsOnly(); } StatusOr ParseWindow(absl::string_view str) { - HloParser parser(str); + HloParserImpl parser(str); return parser.ParseWindowOnly(); } StatusOr ParseConvolutionDimensionNumbers( absl::string_view str) { - HloParser parser(str); + HloParserImpl parser(str); return parser.ParseConvolutionDimensionNumbersOnly(); } StatusOr ParsePaddingConfig(absl::string_view str) { - HloParser parser(str); + HloParserImpl parser(str); return parser.ParsePaddingConfigOnly(); } StatusOr ParseShape(absl::string_view str) { - HloParser parser(str); + HloParserImpl parser(str); return parser.ParseShapeOnly(); } +std::unique_ptr HloParser::CreateHloParserForTests( + absl::string_view str) { + return absl::make_unique(str); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 91ce79ec982..5cc865faab3 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ +#include +#include + #include "absl/memory/memory.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -44,11 +47,6 @@ StatusOr> ParseAndReturnUnverifiedModule( StatusOr> ParseAndReturnUnverifiedModule( absl::string_view str); -// Given a string in the HloModule::ToString() format, parses the string and -// builds the HloModule in place at the given module pointer. 'module' must -// point to an empty module (no computations). -Status ParseHloString(absl::string_view str, HloModule* module); - // Parses sharding from str. str is supposed to contain the body of the // sharding, i.e. just the rhs of the "sharding={...}" attribute string, e.g., // "{replicated}". @@ -85,6 +83,19 @@ StatusOr ParseShape(absl::string_view str); StatusOr> ParseReplicaGroupsOnly( absl::string_view str); +class HloParser { + public: + // Runs the parser and constructs the resulting HLO in the given (empty) + // HloModule. Returns the error status in case an error occurred. + virtual Status Run(HloModule* module) = 0; + virtual ~HloParser() {} + + private: + static std::unique_ptr CreateHloParserForTests( + absl::string_view str); + friend class VerifiedHloModule; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 5eac250b9ea..0c06986151c 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1707,8 +1707,7 @@ class HloParameterizedParserTest /*verifier_layout_sensitive=*/false, /*allow_mixed_precision_in_hlo_verifier=*/true, ShapeUtil::ByteSizeOfElements); - TF_ASSERT_OK(ParseHloString(original, verified_module.get())); - TF_ASSERT_OK(verified_module->Verify()); + TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(original)); module = std::move(verified_module); } else { TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnUnverifiedModule(original)); @@ -1768,8 +1767,7 @@ class HloParserTest : public ::testing::Test { /*verifier_layout_sensitive=*/false, /*allow_mixed_precision_in_hlo_verifier=*/true, ShapeUtil::ByteSizeOfElements); - TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); - TF_RETURN_IF_ERROR(module->Verify()); + TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text)); return std::move(module); } }; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc index 2fd51fa9349..fc5f94ca790 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc @@ -67,8 +67,7 @@ std::string CompileHloConvAndGetMlir(absl::string_view hlo_text) { "Conv", hlo_config, /*verifier_layout_sensitive=*/false, /*allow_mixed_precision_in_hlo_verifier=*/true, /*shape_size_function=*/ShapeUtil::ByteSizeOfElements); - TF_CHECK_OK(xla::ParseHloString(hlo_text, &hlo_module)); - TF_CHECK_OK(hlo_module.Verify()); + TF_CHECK_OK(hlo_module.ParseHloStringAndVerifyModule(hlo_text)); xla::HloInstruction* conv = hlo_module.entry_computation()->root_instruction(); diff --git a/tensorflow/compiler/xla/service/tuple_util_test.cc b/tensorflow/compiler/xla/service/tuple_util_test.cc index 85d78af0b09..c5308be227c 100644 --- a/tensorflow/compiler/xla/service/tuple_util_test.cc +++ b/tensorflow/compiler/xla/service/tuple_util_test.cc @@ -46,8 +46,7 @@ ENTRY entry { "TupleUtilTest", HloModuleConfig(), /*verifier_layout_sensitive=*/true, /*allow_mixed_precision_in_hlo_verifier=*/false, ShapeUtil::ByteSizeOfElements); - TF_RETURN_IF_ERROR(ParseHloString(hlo_string, module.get())); - TF_RETURN_IF_ERROR(module->Verify()); + TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_string)); *entry_computation = module->entry_computation(); *param0 = (*entry_computation)->parameter_instruction(0); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index a2dadcd8d39..463e87967f6 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -114,10 +114,12 @@ cc_library( hdrs = ["verified_hlo_module.h"], deps = [ "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/core:lib", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 7cc957be0af..17e37607be1 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -136,8 +136,7 @@ HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, TestName(), config, verifier_layout_sensitive_, allow_mixed_precision_in_hlo_verifier_, backend().compiler()->ShapeSizeBytesFunction()); - TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); - TF_RETURN_IF_ERROR(module->Verify()); + TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text)); return std::move(module); } diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 1532f1b5d8d..fdb3489f450 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -223,8 +223,7 @@ LocalClientTestBase::ParseAndReturnVerifiedModule( TestName(), config, /*verifier_layout_sensitive=*/false, /*allow_mixed_precision_in_hlo_verifier=*/true, local_client_->backend().compiler()->ShapeSizeBytesFunction()); - TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); - TF_RETURN_IF_ERROR(module->Verify()); + TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text)); return std::move(module); } diff --git a/tensorflow/compiler/xla/tests/verified_hlo_module.cc b/tensorflow/compiler/xla/tests/verified_hlo_module.cc index cd0c4073a26..e3aeeacc303 100644 --- a/tensorflow/compiler/xla/tests/verified_hlo_module.cc +++ b/tensorflow/compiler/xla/tests/verified_hlo_module.cc @@ -14,25 +14,26 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" -#include - #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" namespace xla { -Status VerifiedHloModule::Verify() { - if (computation_count() == 0) { - // The computation was never built. Nothing to verify. - return Status::OK(); - } - return verifier_.Run(this).status(); +Status VerifiedHloModule::ParseHloStringAndVerifyModule(absl::string_view str) { + TF_RET_CHECK(computation_count() == 0); + auto parser = HloParser::CreateHloParserForTests(str); + TF_RETURN_IF_ERROR(parser->Run(this)); + return Verify(); } -void VerifiedHloModule::VerifyOrAddFailure(const string& message) { +void VerifiedHloModule::VerifyOrAddFailure(absl::string_view message) { Status status = Verify(); if (!status.ok()) { ADD_FAILURE() << "HloVerifier failed on module " << name() @@ -43,4 +44,12 @@ void VerifiedHloModule::VerifyOrAddFailure(const string& message) { } } +Status VerifiedHloModule::Verify() { + if (computation_count() == 0) { + // The computation was never built. Nothing to verify. + return Status::OK(); + } + return verifier_.Run(this).status(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/verified_hlo_module.h b/tensorflow/compiler/xla/tests/verified_hlo_module.h index 1c13773acd4..ac4a103ad7d 100644 --- a/tensorflow/compiler/xla/tests/verified_hlo_module.h +++ b/tensorflow/compiler/xla/tests/verified_hlo_module.h @@ -16,8 +16,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_TESTS_VERIFIED_HLO_MODULE_H_ #include -#include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" @@ -43,14 +43,20 @@ class VerifiedHloModule : public HloModule { ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } - // Verifies the module using HloVerifier and returns the status. - Status Verify(); + // Given a string in the HloModule::ToString() format, parses the string and + // builds the VerifiedHloModule in place. Before calling this method, the + // module must be empty (no computations). Finally verifies the module using + // HloVerifier and returns the status. + Status ParseHloStringAndVerifyModule(absl::string_view str); // Verifies the module and flags any error with ADD_FAILURE. 'message' is // included in the failure message. - void VerifyOrAddFailure(const string& message); + void VerifyOrAddFailure(absl::string_view message); private: + // Verifies the module using HloVerifier and returns the status. + Status Verify(); + HloVerifier verifier_; };