From 12d6b450b2be345b3848efd8d623b1507a2c630f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 8 Nov 2017 15:24:01 -0800 Subject: [PATCH] Hlo parser: support window and convolution. Also, to make the text format easier to write and unambiguous: - Print "window={}" around the window attribute; rename the "window" sub attribute to "size"; - Print the dim_lables in logical order, instead of physical order. PiperOrigin-RevId: 175074526 --- .../compiler/xla/service/hlo_instruction.cc | 10 +- .../compiler/xla/tools/parser/README.md | 16 +- .../compiler/xla/tools/parser/hlo_lexer.cc | 65 +- .../compiler/xla/tools/parser/hlo_lexer.h | 6 +- .../compiler/xla/tools/parser/hlo_parser.cc | 589 ++++++++++++++---- .../xla/tools/parser/hlo_parser_test.cc | 120 ++++ .../compiler/xla/tools/parser/hlo_token.h | 3 + tensorflow/compiler/xla/window_util.cc | 26 +- 8 files changed, 690 insertions(+), 145 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 5107ac782d7..ee98c3fabc5 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1850,7 +1850,7 @@ std::vector HloInstruction::ExtraAttributesToString() const { extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}")); } if (window_ != nullptr) { - extra.push_back(window_util::ToString(*window_)); + extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); } if (padding_config_ != nullptr) { extra.push_back(StrCat("padding=", padding_config_->ShortDebugString())); @@ -2856,13 +2856,7 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { const auto append_dims = [&](const std::vector& dims, const Shape& shape) { CHECK_EQ(dims.size(), ShapeUtil::Rank(shape)); - for (int64 logical = 0; logical < dims.size(); ++logical) { - int64 physical = logical; - if (!shape.layout().minor_to_major().empty()) { - physical = LayoutUtil::Major(shape.layout(), logical); - } - result += dims[physical]; - } + StrAppend(&result, Join(dims, "")); }; // lhs_dims[i] is the symbol of the logical dimension i for the lhs diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md index 2c864d77a20..986041caf61 100644 --- a/tensorflow/compiler/xla/tools/parser/README.md +++ b/tensorflow/compiler/xla/tools/parser/README.md @@ -43,14 +43,22 @@ operand : shape name ; -extra_attributes +attributes : /*empty*/ - | ',' extra_attribute - | ',' extra_attribute extra_attributes + | ',' attribute + | ',' attribute attributes ; -extra_attribute +attribute : attribute_name attribute_value ; +attribute_value + : kInt + | kName + | [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,} /*dim_labels_pattern*/ + | [0-9]+(x[0-9]+)+ /*dxd_pattern*/ + | [0-9]+_[0-9]+(x[0-9]+_[0-9]+)* /*window_pad_pattern*/ + | '{' sub_attributes '}' + ; param_list : '(' param_list1 ')' diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc index d104ff34601..f70386411cf 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -122,7 +122,7 @@ TokKind HloLexer::LexToken() { current_ptr_++; return TokKind::kArrow; } - return LexDigitOrNegative(); + return LexNumberOrPattern(); case '=': return TokKind::kEqual; case ',': @@ -149,12 +149,15 @@ TokKind HloLexer::LexToken() { } } -// Lex a shape, name, keyword, or opcode. +// Lex a shape, name, keyword, opcode, attribute name, or the dim labels +// pattern. +// // shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})? // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*: // keyword ::= HloModule, ENTRY, ... // opcode ::= add, greater-than, ... // attribute_name ::= condition, body, dimensions, ... +// dim_labels_pattern ::= [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,} TokKind HloLexer::LexIdentifier() { { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); @@ -220,6 +223,16 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kOpcode; } + { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + static LazyRE2 dim_labels_pattern = { + R"([0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,})"}; + if (RE2::Consume(&consumable, *dim_labels_pattern)) { + current_ptr_ = consumable.begin(); + str_val_.assign(token_start_, current_ptr_); + return TokKind::kDimLabels; + } + } current_ptr_ = token_start_ + 1; return TokKind::kError; } @@ -240,15 +253,20 @@ TokKind HloLexer::LexPercent() { return TokKind::kError; } -// Lex integer and floating-point values, and -inf. -// int [-]?[0-9]+ -// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+) -// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+) -// negative inf -inf -TokKind HloLexer::LexDigitOrNegative() { +// Lex integer and floating-point values, -inf, and patterns for dim labels, +// dxd (e.g. 1x2x3), and window pad. +// +// fp with exp ::= [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+) +// fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+) +// dim_labels_pattern ::= [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,} +// dxd_pattern ::= [0-9]+(x[0-9]+)+ +// window_pad_pattern ::= [0-9]+_[0-9]+(x[0-9]+_[0-9]+)* +// int ::= [-]?[0-9]+ +// negative inf ::= '-inf' +TokKind HloLexer::LexNumberOrPattern() { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); static LazyRE2 float_pattern = { - R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"}; + R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"}; if (RE2::Consume(&consumable, *float_pattern)) { current_ptr_ = consumable.begin(); tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(), @@ -256,6 +274,29 @@ TokKind HloLexer::LexDigitOrNegative() { return TokKind::kDecimal; } + static LazyRE2 dim_labels_pattern = { + R"([0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,})"}; + static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"}; + static LazyRE2 pad_pattern = {R"([0-9]+_[0-9]+(x[0-9]+_[0-9]+)*)"}; + + if (RE2::Consume(&consumable, *dim_labels_pattern)) { + current_ptr_ = consumable.begin(); + str_val_.assign(token_start_, current_ptr_); + return TokKind::kDimLabels; + } + + if (RE2::Consume(&consumable, *dxd_pattern)) { + current_ptr_ = consumable.begin(); + str_val_.assign(token_start_, current_ptr_); + return TokKind::kDxD; + } + + if (RE2::Consume(&consumable, *pad_pattern)) { + current_ptr_ = consumable.begin(); + str_val_.assign(token_start_, current_ptr_); + return TokKind::kWindowPad; + } + static LazyRE2 int_pattern = {R"([-]?\d+)"}; if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); @@ -350,6 +391,12 @@ string TokKindToString(TokKind kind) { return "kName"; case TokKind::kAttributeName: return "kAttributeName"; + case TokKind::kDimLabels: + return "kDimLabels"; + case TokKind::kDxD: + return "kDxD"; + case TokKind::kWindowPad: + return "kWindowPad"; case TokKind::kShape: return "kShape"; case TokKind::kOpcode: diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h index 3b9efcb92d0..74e6829180a 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h @@ -37,11 +37,15 @@ class HloLexer { } TokKind Lex() { return current_kind_ = LexToken(); } + TokKind GetKind() const { return current_kind_; } string GetStrVal() const { switch (GetKind()) { case TokKind::kName: case TokKind::kAttributeName: + case TokKind::kDimLabels: + case TokKind::kDxD: + case TokKind::kWindowPad: return str_val_; default: LOG(FATAL) << "This token does not have string value"; @@ -92,7 +96,7 @@ class HloLexer { TokKind LexPercent(); TokKind LexShape(); TokKind LexConstant(); - TokKind LexDigitOrNegative(); + TokKind LexNumberOrPattern(); TokKind LexComment(); const tensorflow::StringPiece buf_; diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 6c2e37e3b5c..f1e987cb15c 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -28,6 +28,9 @@ namespace tools { namespace { using tensorflow::StringPiece; +using tensorflow::gtl::optional; +using tensorflow::str_util::Split; +using tensorflow::str_util::SplitAndParseAsInts; using tensorflow::strings::Printf; using tensorflow::strings::StrAppend; using tensorflow::strings::StrCat; @@ -57,8 +60,6 @@ class HloParser { bool ParseInstructionList(HloComputation::Builder* builder, string* root_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); - bool ParseSharding(HloInstruction* instruction); - bool ParseControlPredecessors(HloInstruction* instruction); bool ParseLiteral(std::unique_ptr* literal, const Shape& shape); bool ParseTupleLiteral(std::unique_ptr* literal, const Shape& shape); bool ParseNonTupleLiteral(std::unique_ptr* literal, @@ -78,10 +79,55 @@ class HloParser { bool ParseOperands(std::vector* operands, const int expected_size); - template - bool ParseExtraAttribute(T* value, const string& expected_attribute); - template - bool ParseAttributeValue(T* value); + // Types of attributes. + enum class AttrTy { + kInt64, + kHloComputation, + kWindow, + kConvolutionDimensionNumbers, + kSharding, + kInstructionList, + }; + + struct AttrConfig { + bool required; // whether it's required or optional + AttrTy attr_type; // what type it is + void* result; // where to store the parsed result. + }; + + // Parses attributes given names and configs of the attributes. Each parsed + // result is passed back through the result pointer in corresponding + // AttrConfig. Note that the result pointer must point to a optional typed + // variable which outlives this function. Returns false on error. You should + // not use the any of the results if this function failed. + // + // Example usage: + // + // std::unordered_map attrs; + // optional foo; + // attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo}; + // optional bar; + // attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar}; + // if (!ParseAttribute(attrs)) { + // return false; // Do not use 'foo' 'bar' if failed. + // } + // // Do something with 'bar'. + // if (foo) { // If attr foo is seen, do something with 'foo'. } + // + bool ParseAttributes(const std::unordered_map& attrs); + + // Parses a name and finds the corresponding hlo computation. + bool ParseComputationName(HloComputation** value); + // Parses a list of names and finds the corresponding hlo instructions. + bool ParseInstructionNames(std::vector* instructions); + bool ParseWindow(Window* window); + bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums); + bool ParseSharding(OpSharding* sharding); + + // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. + bool ParseDxD(const string& name, std::vector* result); + // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. + bool ParseWindowPad(std::vector>* pad); bool ParseParamList(); bool ParseName(string* result); @@ -214,7 +260,7 @@ bool HloParser::ParseInstructionList(HloComputation::Builder* builder, "expects '}' at the end of instruction list."); } -// instruction ::= ('ROOT')? name '=' shape opcode operands (extra_attribute)* +// instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)* bool HloParser::ParseInstruction(HloComputation::Builder* builder, string* root_name) { string name; @@ -230,6 +276,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (is_root) { *root_name = name; } + + // Add optional attributes. + std::unordered_map attrs; + optional sharding; + attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; + optional> predecessors; + attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList, + &predecessors}; + HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -237,7 +292,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!ParseToken(TokKind::kLparen, "expects '(' before parameter number") || !ParseInt64(¶meter_number) || - !ParseToken(TokKind::kRparen, "expects ')' after parameter number")) { + !ParseToken(TokKind::kRparen, "expects ')' after parameter number") || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -249,7 +305,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!ParseToken(TokKind::kLparen, "expects '(' before constant literal") || !ParseLiteral(&literal, shape) || - !ParseToken(TokKind::kRparen, "expects ')' after constant literal")) { + !ParseToken(TokKind::kRparen, "expects ')' after constant literal") || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -275,7 +332,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kSin: case HloOpcode::kSort: case HloOpcode::kTanh: { - if (!ParseOperands(&operands, /*expected_size=*/1)) { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -305,7 +363,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: { - if (!ParseOperands(&operands, /*expected_size=*/2)) { + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateBinary( @@ -315,7 +374,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, // Ternary ops. case HloOpcode::kClamp: case HloOpcode::kSelect: { - if (!ParseOperands(&operands, /*expected_size=*/3)) { + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateTernary( @@ -324,7 +384,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } // Other supported ops. case HloOpcode::kConvert: { - if (!ParseOperands(&operands, /*expected_size=*/1)) { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -332,7 +393,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kCrossReplicaSum: { - if (!ParseOperands(&operands, /*expected_size=*/1)) { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -340,7 +402,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReshape: { - if (!ParseOperands(&operands, /*expected_size=*/1)) { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -348,7 +411,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kTuple: { - if (!ParseOperands(&operands)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = @@ -356,70 +419,99 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kWhile: { - HloComputation* condition; - HloComputation* body; + optional condition; + optional body; + attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation, + &condition}; + attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body}; if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseExtraAttribute(&condition, - /*expected_attribute=*/"condition") || - !ParseExtraAttribute(&body, /*expected_attribute=*/"body")) { + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateWhile( - shape, condition, body, /*init=*/operands[0])); + shape, *condition, *body, /*init=*/operands[0])); break; } case HloOpcode::kRecv: { - int64 channel_id; + optional channel_id; + attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/0) || - !ParseExtraAttribute(&channel_id, - /*expected_attribute=*/"channel_id")) { + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateRecv(shape, channel_id)); + HloInstruction::CreateRecv(shape, *channel_id)); break; } case HloOpcode::kSend: { - int64 channel_id; + optional channel_id; + attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseExtraAttribute(&channel_id, - /*expected_attribute=*/"channel_id")) { + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateSend(operands[0], channel_id)); + HloInstruction::CreateSend(operands[0], *channel_id)); break; } case HloOpcode::kGetTupleElement: { - int64 index; + optional index; + attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index}; if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseExtraAttribute(&index, /*expected_attribute=*/"index")) { + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateGetTupleElement(shape, operands[0], index)); + HloInstruction::CreateGetTupleElement(shape, operands[0], *index)); break; } case HloOpcode::kCall: { - HloComputation* to_apply; - if (!ParseOperands(&operands) || - !ParseExtraAttribute(&to_apply, - /*expected_attribute=*/"to_apply")) { + optional to_apply; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &to_apply}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateCall(shape, operands, to_apply)); + HloInstruction::CreateCall(shape, operands, *to_apply)); + break; + } + case HloOpcode::kReduceWindow: { + optional reduce_computation; + optional window; + attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window}; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &reduce_computation}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow( + shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window, + *reduce_computation)); + break; + } + case HloOpcode::kConvolution: { + optional window; + optional dnums; + attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window}; + attrs["dim_labels"] = {/*required=*/true, + AttrTy::kConvolutionDimensionNumbers, &dnums}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateConvolve( + shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums)); break; } case HloOpcode::kBroadcast: case HloOpcode::kCustomCall: case HloOpcode::kConcatenate: case HloOpcode::kReducePrecision: - case HloOpcode::kConvolution: case HloOpcode::kMap: case HloOpcode::kPad: case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: case HloOpcode::kSelectAndScatter: case HloOpcode::kReverse: case HloOpcode::kRng: @@ -438,43 +530,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloOpcodeString(opcode))); } - bool has_sharding = false; - bool has_control = false; - while (EatIfPresent(TokKind::kComma)) { - string attribute_name; - if (!ParseAttributeName(&attribute_name)) { - return TokenError("expects ', sharding=' or ', control-predecessors='"); - } - - if (attribute_name == "sharding") { - // Parse "sharding=". - if (has_sharding) { - return TokenError("expects at most 1 'sharding='"); + // Add common attrs (sharding, control predecessors) to the instruction, if + // they were seen. + if (sharding) { + instruction->set_sharding( + HloSharding::FromProto(sharding.value()).ValueOrDie()); + } + if (predecessors) { + for (auto* pre : *predecessors) { + Status status = pre->AddControlDependencyTo(instruction); + if (!status.ok()) { + return TokenError(StrCat("error adding control dependency for: ", name, + " status: ", status.ToString())); } - has_sharding = true; - if (!ParseSharding(instruction)) { - return false; - } - } else if (attribute_name == "control-predecessors") { - // Parse "control-predecessors" - if (has_control) { - return TokenError("expects at most 1 'control-predecessors='"); - } - has_control = true; - if (!ParseControlPredecessors(instruction)) { - return false; - } - } else { - return TokenError(StrCat("unexpected attribute: ", attribute_name)); } } - return AddInstruction(name, instruction); } // ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('[' // dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list -bool HloParser::ParseSharding(HloInstruction* instruction) { +bool HloParser::ParseSharding(OpSharding* sharding) { if (!ParseToken(TokKind::kLbrace, "expected '{' to start sharding attribute")) { return false; @@ -545,7 +621,6 @@ bool HloParser::ParseSharding(HloInstruction* instruction) { } } - OpSharding sharding; if (replicated) { if (!devices.empty()) { return TokenError( @@ -555,7 +630,7 @@ bool HloParser::ParseSharding(HloInstruction* instruction) { return TokenError( "replicated shardings should not have any tile shape set"); } - sharding.set_type(OpSharding::Type::OpSharding_Type_REPLICATED); + sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED); } else if (maximal) { if (devices.size() != 1) { return TokenError( @@ -564,8 +639,8 @@ bool HloParser::ParseSharding(HloInstruction* instruction) { if (!ShapeUtil::Equal(tile_shape, Shape())) { return TokenError("maximal shardings should not have any tile shape set"); } - sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); - sharding.add_tile_assignment_devices(devices[0]); + sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); + sharding->add_tile_assignment_devices(devices[0]); } else { if (devices.size() <= 1) { return TokenError( @@ -579,47 +654,43 @@ bool HloParser::ParseSharding(HloInstruction* instruction) { "non-maximal shardings must have a tile assignment list including " "dimensions"); } - sharding.set_type(OpSharding::Type::OpSharding_Type_OTHER); - *sharding.mutable_tile_shape() = tile_shape; + sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); + *sharding->mutable_tile_shape() = tile_shape; for (int64 dim : tile_assignment_dimensions) { - sharding.add_tile_assignment_dimensions(dim); + sharding->add_tile_assignment_dimensions(dim); } for (int64 device : devices) { - sharding.add_tile_assignment_devices(device); + sharding->add_tile_assignment_devices(device); } } - instruction->set_sharding(HloSharding::FromProto(sharding).ValueOrDie()); lexer_.Lex(); return true; } // '{' name+ '}' -bool HloParser::ParseControlPredecessors(HloInstruction* instruction) { +bool HloParser::ParseInstructionNames( + std::vector* instructions) { if (!ParseToken(TokKind::kLbrace, - "expects '{' at the beginning of control predecessors")) { + "expects '{' at the beginning of instruction name list")) { return false; } do { string name; if (!ParseName(&name)) { - return TokenError("expects a control predecessor"); + return TokenError("expects a instruction name"); } - HloInstruction* pre = + HloInstruction* instr = tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); - if (!pre) { + if (!instr) { return TokenError( - StrCat("control predecessor ", name, " is not defined: ")); - } - Status status = pre->AddControlDependencyTo(instruction); - if (!status.ok()) { - return TokenError(StrCat("error adding control dependency for: ", name, - " status: ", status.ToString())); + Printf("instruction '%s' is not defined", name.c_str())); } + instructions->push_back(instr); } while (EatIfPresent(TokKind::kComma)); return ParseToken(TokKind::kRbrace, - "expects '}' at the end of control predecessors"); + "expects '}' at the end of control instructions"); } bool HloParser::SetValueInLiteral(int64 value, int64 linear_index, @@ -957,28 +1028,95 @@ bool HloParser::ParseOperands(std::vector* operands, return true; } -// extra_attribute ::= ',' attribute_name value -template -bool HloParser::ParseExtraAttribute(T* value, - const string& expected_attribute) { - if (!ParseToken(TokKind::kComma, - "expects ',' in front of an extra attribute")) { - return false; +bool HloParser::ParseAttributes( + const std::unordered_map& attrs) { + std::unordered_set seen_attrs; + while (EatIfPresent(TokKind::kComma)) { + string name; + if (!ParseAttributeName(&name)) { + return TokenError("error parsing attributes"); + } + VLOG(1) << "Parsing attribute " << name; + if (!seen_attrs.insert(name).second) { + return TokenError(Printf("attribute %s already exists", name.c_str())); + } + auto attr_it = attrs.find(name); + if (attr_it == attrs.end()) { + return TokenError(Printf("unexpected attribute %s", name.c_str())); + } + AttrTy attr_type = attr_it->second.attr_type; + void* attr_out_ptr = attr_it->second.result; + bool success = [&] { + switch (attr_type) { + case AttrTy::kInt64: { + int64 result; + if (!ParseInt64(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kHloComputation: { + HloComputation* result; + if (!ParseComputationName(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } + case AttrTy::kWindow: { + Window result; + if (!ParseWindow(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kConvolutionDimensionNumbers: { + ConvolutionDimensionNumbers result; + if (!ParseConvolutionDimensionNumbers(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } + case AttrTy::kSharding: { + OpSharding sharding; + if (!ParseSharding(&sharding)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(sharding); + return true; + } + case AttrTy::kInstructionList: { + std::vector result; + if (!ParseInstructionNames(&result)) { + return false; + } + static_cast>*>(attr_out_ptr) + ->emplace(result); + return true; + } + } + }(); + if (!success) { + return TokenError(Printf("error parsing attribute %s", name.c_str())); + } } - string attribute_name; - if (!ParseAttributeName(&attribute_name) && - attribute_name != expected_attribute) { - return TokenError(StrCat("expects attribute name: ", expected_attribute)); - } - if (!ParseAttributeValue(value)) { - return TokenError( - StrCat("expects value for attribute: ", expected_attribute)); + // Check that all required attrs were seen. + for (const auto& attr_it : attrs) { + if (attr_it.second.required && + seen_attrs.find(attr_it.first) == seen_attrs.end()) { + return TokenError(Printf("attribute %s is expected but not seen", + attr_it.first.c_str())); + } } return true; } -template <> -bool HloParser::ParseAttributeValue(HloComputation** value) { +bool HloParser::ParseComputationName(HloComputation** value) { string name; if (!ParseName(&name)) { return TokenError("expects computation name"); @@ -990,9 +1128,191 @@ bool HloParser::ParseAttributeValue(HloComputation** value) { return true; } -template <> -bool HloParser::ParseAttributeValue(int64* value) { - return ParseInt64(value); +// ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}' +// The subattributes can appear in any order. 'size=' is required, others are +// optional. +bool HloParser::ParseWindow(Window* window) { + if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { + return false; + } + + std::vector size; + std::vector stride; + std::vector> pad; + std::vector lhs_dilate; + std::vector rhs_dilate; + while (lexer_.GetKind() != TokKind::kRbrace) { + string field_name; + if (!ParseAttributeName(&field_name)) { + return TokenError("expects sub-attributes in window"); + } + bool ok = [&] { + if (field_name == "size") { + return ParseDxD("size", &size); + } + if (field_name == "stride") { + return ParseDxD("stride", &stride); + } + if (field_name == "lhs_dilate") { + return ParseDxD("lhs_dilate", &lhs_dilate); + } + if (field_name == "rhs_dilate") { + return ParseDxD("rls_dilate", &rhs_dilate); + } + if (field_name == "pad") { + return ParseWindowPad(&pad); + } + return TokenError(StrCat("unexpected attribute name: ", field_name)); + }(); + if (!ok) { + return false; + } + } + + if (size.empty()) { + return TokenError( + "sub-attribute 'size=' is required in the window attribute"); + } + if (!stride.empty() && stride.size() != size.size()) { + return TokenError("expects 'stride=' has the same size as 'size='"); + } + if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) { + return TokenError("expects 'lhs_dilate=' has the same size as 'size='"); + } + if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) { + return TokenError("expects 'rhs_dilate=' has the same size as 'size='"); + } + if (!pad.empty() && pad.size() != size.size()) { + return TokenError("expects 'pad=' has the same size as 'size='"); + } + + for (int i = 0; i < size.size(); i++) { + window->add_dimensions()->set_size(size[i]); + if (!pad.empty()) { + window->mutable_dimensions(i)->set_padding_low(pad[i][0]); + window->mutable_dimensions(i)->set_padding_high(pad[i][1]); + } + // If some field is not present, it has the default value. + window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]); + window->mutable_dimensions(i)->set_base_dilation( + lhs_dilate.empty() ? 1 : lhs_dilate[i]); + window->mutable_dimensions(i)->set_window_dilation( + rhs_dilate.empty() ? 1 : rhs_dilate[i]); + } + return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); +} + +// This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString. +// The string looks like "dim_labels=0bf_0io->0bf". +bool HloParser::ParseConvolutionDimensionNumbers( + ConvolutionDimensionNumbers* dnums) { + if (lexer_.GetKind() != TokKind::kDimLabels) { + return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'"); + } + string str = lexer_.GetStrVal(); + + // The str is expected to have 3 items, lhs, rhs, out, and it must looks like + // lhs_rhs->out, that is, the first separator is "_" and the second is "->". + // So we replace the "->" with "_" and then split on "_". + str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->", + /*newsub=*/"_", + /*replace_all=*/false); + std::vector lhs_rhs_out = Split(str, "_"); + if (lhs_rhs_out.size() != 3) { + LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " + << str; + } + + const int64 rank = lhs_rhs_out[0].length(); + if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) { + return TokenError( + "convolution lhs, rhs, and output must have the same rank"); + } + if (rank < 3) { + return TokenError("convolution rank must >=3"); + } + + auto is_unique = [](string str) -> bool { + std::sort(str.begin(), str.end()); + return std::unique(str.begin(), str.end()) == str.end(); + }; + + // lhs + { + const string& lhs = lhs_rhs_out[0]; + if (!is_unique(lhs)) { + return TokenError( + StrCat("expects unique lhs dimension numbers, but sees ", lhs)); + } + for (int i = 0; i < rank - 2; i++) { + dnums->add_spatial_dimensions(-1); + } + for (int i = 0; i < rank; i++) { + char c = lhs[i]; + if (c == 'b') { + dnums->set_input_batch_dimension(i); + } else if (c == 'f') { + dnums->set_input_feature_dimension(i); + } else if (c < '0' + rank && c >= '0') { + dnums->set_spatial_dimensions(c - '0', i); + } else { + return TokenError( + Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1)); + } + } + } + // rhs + { + const string& rhs = lhs_rhs_out[1]; + if (!is_unique(rhs)) { + return TokenError( + StrCat("expects unique rhs dimension numbers, but sees ", rhs)); + } + for (int i = 0; i < rank - 2; i++) { + dnums->add_kernel_spatial_dimensions(-1); + } + for (int i = 0; i < rank; i++) { + char c = rhs[i]; + if (c == 'i') { + dnums->set_kernel_input_feature_dimension(i); + } else if (c == 'o') { + dnums->set_kernel_output_feature_dimension(i); + } else if (c < '0' + rank && c >= '0') { + dnums->set_kernel_spatial_dimensions(c - '0', i); + } else { + return TokenError( + Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1)); + } + } + } + // output + { + const string& out = lhs_rhs_out[2]; + if (!is_unique(out)) { + return TokenError( + StrCat("expects unique output dimension numbers, but sees ", out)); + } + for (int i = 0; i < rank; i++) { + char c = out[i]; + if (c == 'b') { + dnums->set_output_batch_dimension(i); + } else if (c == 'f') { + dnums->set_output_feature_dimension(i); + } else if (c < '0' + rank && c >= '0') { + if (dnums->spatial_dimensions(c - '0') != i) { + return TokenError( + "output spatial dimensions should be the same as input spatial " + "dimensions"); + } + } else { + return TokenError( + Printf("expects [0-%lldbf] in output dimension numbers", rank - 1)); + } + } + } + + lexer_.Lex(); + return true; } // param_list ::= '(' param_list1 ')' @@ -1070,6 +1390,55 @@ bool HloParser::ParseAttributeName(string* result) { return true; } +bool HloParser::ParseDxD(const string& name, std::vector* result) { + if (!result->empty()) { + return TokenError( + Printf("sub-attribute '%s=' already exists", name.c_str())); + } + // 1D + if (lexer_.GetKind() == TokKind::kInt) { + int64 number; + if (!ParseInt64(&number)) { + return TokenError(Printf("expects sub-attribute '%s=i'", name.c_str())); + } + result->push_back(number); + return true; + } + // 2D or higher. + if (lexer_.GetKind() == TokKind::kDxD) { + string str = lexer_.GetStrVal(); + if (!SplitAndParseAsInts(str, 'x', result)) { + return TokenError( + Printf("expects sub-attribute '%s=ixj...'", name.c_str())); + } + lexer_.Lex(); + return true; + } + return TokenError("expects token type kInt or kDxD"); +} + +bool HloParser::ParseWindowPad(std::vector>* pad) { + if (!pad->empty()) { + return TokenError("sub-attribute 'pad=' already exists"); + } + if (lexer_.GetKind() != TokKind::kWindowPad) { + return TokenError("expects window pad pattern, e.g., '0_0x3_3'"); + } + string str = lexer_.GetStrVal(); + std::vector padding_str = Split(str, 'x'); + for (int i = 0; i < padding_str.size(); i++) { + std::vector low_high; + if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || + low_high.size() != 2) { + return TokenError( + "expects padding_low and padding_high separated by '_'"); + } + pad->push_back(low_high); + } + lexer_.Lex(); + return true; +} + bool HloParser::ParseOpcode(HloOpcode* result) { VLOG(1) << "ParseOpcode"; if (lexer_.GetKind() != TokKind::kOpcode) { diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 359256f0646..62b4385e76f 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -25,6 +25,7 @@ namespace tools { namespace { using tensorflow::StringPiece; +using tensorflow::strings::StrCat; struct TestData { string test_name; @@ -247,6 +248,39 @@ ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] { ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1 } +)" +}, +// reduce window +{ +"ReduceWindow", +R"(HloModule R4UnitWindow_module: + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] { + %operand = f32[13,12,8,15]{0,3,2,1} parameter(0) + %constant = f32[] constant(0) + ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3 +} + +)" +}, +// convolution +{ +"Convolution", +R"(HloModule Convolve1D1Window_0_module: + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { + %input = f32[1,2,1]{2,1,0} parameter(0) + %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) + %filter = f32[1,1,1]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f +} + )" } }); @@ -427,6 +461,92 @@ ENTRY %ConstantWithExp.v4 () -> f32[] { // printed as "300". } +TEST_F(HloParserTest, AttibutesAnyOrder) { + const string original = R"(HloModule any_order_module: + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { + %input = f32[1,2,1]{2,1,0} parameter(0) + %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) + %filter = f32[1,1,1]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} +} + +)"; + TF_EXPECT_OK(Parse(original).status()); +} + +TEST_F(HloParserTest, InvalidDimLabels) { + string prefix = R"(HloModule invalid_dim_labels_module: + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { + %input = f32[1,2,1]{2,1,0} parameter(0) + %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) + %filter = f32[1,1,1]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1} )"; + string suffix = R"( +} + +)"; + + ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=00_01_10", suffix)) + .status() + .error_message(), + "expects dim labels pattern"); + + ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=010_1100->010", suffix)) + .status() + .error_message(), + "must have the same rank"); + + ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=0bf_io0->b0f", suffix)) + .status() + .error_message(), + "output spatial dimensions should be the same as input " + "spatial dimensions"); +} + +TEST_F(HloParserTest, UnexpectedAttribute) { + const string original = R"(HloModule unexpected_attr_module: + +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { + %recv = f32[] recv(), channel_id=15 + ROOT %constant = f32[] constant(2.1) + %send = () send(f32[] %constant), channel_id=16, calls=%recv +} + +)"; + ExpectHasSubstr(Parse(original).status().error_message(), + "unexpected attribute calls"); +} + +TEST_F(HloParserTest, MissingAttribute) { + const string original = R"(HloModule missing_attr_module: + +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { + %recv = f32[] recv(), channel_id=15 + ROOT %constant = f32[] constant(-2.1) + %send = () send(f32[] %constant) +} + +)"; + ExpectHasSubstr(Parse(original).status().error_message(), + "attribute channel_id is expected but not seen"); +} + +TEST_F(HloParserTest, PredecessorUndefined) { + const string original = R"(HloModule pre_not_found_module: + +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { + %recv = f32[] recv(), channel_id=15 + ROOT %constant = f32[] constant(2.1) + %send = () send(f32[] %constant), channel_id=16, control-predecessors={%done} +} + +)"; + ExpectHasSubstr(Parse(original).status().error_message(), + "'done' is not defined"); +} + } // namespace } // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h index 9c2069e7568..15ab8b1cccf 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_token.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h @@ -57,6 +57,9 @@ enum class TokKind { // Typed tokens. kName, // %foo kAttributeName, // dimensions= + kDimLabels, // [0-9bf]+_[0-9io]+->[0-9bf]+ + kDxD, // [0-9]+(x[0-9]+)+ + kWindowPad, // [0-9]+_[0-9]+(x[0-9]+_[0-9]+)* kShape, // f32[2,3]{1,0} kOpcode, // add kInt, // 42 diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 23161873a0b..6f7f1479b90 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -26,8 +26,8 @@ namespace xla { namespace window_util { /* static */ string ToString(const WindowDimension& dim) { - using tensorflow::strings::StrCat; using tensorflow::strings::StrAppend; + using tensorflow::strings::StrCat; string str = StrCat("(size=", dim.size()); if (dim.stride() != 1) { StrAppend(&str, ",stride=", dim.stride()); @@ -49,22 +49,22 @@ namespace window_util { } string ToString(const Window& window) { - using tensorflow::strings::StrCat; using tensorflow::strings::StrAppend; + using tensorflow::strings::StrCat; string str; - const auto add_field = [&]( - const char* heading, - std::function format) { - StrAppend(&str, heading, "="); - const char* prefix = ""; - for (const auto& window_dimension : window.dimensions()) { - StrAppend(&str, prefix, format(window_dimension)); - prefix = "x"; - } - }; + const auto add_field = + [&](const char* heading, + std::function format) { + StrAppend(&str, heading, "="); + const char* prefix = ""; + for (const auto& window_dimension : window.dimensions()) { + StrAppend(&str, prefix, format(window_dimension)); + prefix = "x"; + } + }; - add_field("window", + add_field("size", [](const WindowDimension& dim) { return StrCat(dim.size()); }); if (HasStride(window)) { add_field(" stride",