Hlo parser: support window and convolution.
Also, to make the text format easier to write and unambiguous: - Print "window={}" around the window attribute; rename the "window" sub attribute to "size"; - Print the dim_lables in logical order, instead of physical order. PiperOrigin-RevId: 175074526
This commit is contained in:
parent
8bb665ae1c
commit
12d6b450b2
@ -1850,7 +1850,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
|
||||
extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}"));
|
||||
}
|
||||
if (window_ != nullptr) {
|
||||
extra.push_back(window_util::ToString(*window_));
|
||||
extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
|
||||
}
|
||||
if (padding_config_ != nullptr) {
|
||||
extra.push_back(StrCat("padding=", padding_config_->ShortDebugString()));
|
||||
@ -2856,13 +2856,7 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const {
|
||||
const auto append_dims = [&](const std::vector<string>& dims,
|
||||
const Shape& shape) {
|
||||
CHECK_EQ(dims.size(), ShapeUtil::Rank(shape));
|
||||
for (int64 logical = 0; logical < dims.size(); ++logical) {
|
||||
int64 physical = logical;
|
||||
if (!shape.layout().minor_to_major().empty()) {
|
||||
physical = LayoutUtil::Major(shape.layout(), logical);
|
||||
}
|
||||
result += dims[physical];
|
||||
}
|
||||
StrAppend(&result, Join(dims, ""));
|
||||
};
|
||||
|
||||
// lhs_dims[i] is the symbol of the logical dimension i for the lhs
|
||||
|
@ -43,14 +43,22 @@ operand
|
||||
: shape name
|
||||
;
|
||||
|
||||
extra_attributes
|
||||
attributes
|
||||
: /*empty*/
|
||||
| ',' extra_attribute
|
||||
| ',' extra_attribute extra_attributes
|
||||
| ',' attribute
|
||||
| ',' attribute attributes
|
||||
;
|
||||
extra_attribute
|
||||
attribute
|
||||
: attribute_name attribute_value
|
||||
;
|
||||
attribute_value
|
||||
: kInt
|
||||
| kName
|
||||
| [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,} /*dim_labels_pattern*/
|
||||
| [0-9]+(x[0-9]+)+ /*dxd_pattern*/
|
||||
| [0-9]+_[0-9]+(x[0-9]+_[0-9]+)* /*window_pad_pattern*/
|
||||
| '{' sub_attributes '}'
|
||||
;
|
||||
|
||||
param_list
|
||||
: '(' param_list1 ')'
|
||||
|
@ -122,7 +122,7 @@ TokKind HloLexer::LexToken() {
|
||||
current_ptr_++;
|
||||
return TokKind::kArrow;
|
||||
}
|
||||
return LexDigitOrNegative();
|
||||
return LexNumberOrPattern();
|
||||
case '=':
|
||||
return TokKind::kEqual;
|
||||
case ',':
|
||||
@ -149,12 +149,15 @@ TokKind HloLexer::LexToken() {
|
||||
}
|
||||
}
|
||||
|
||||
// Lex a shape, name, keyword, or opcode.
|
||||
// Lex a shape, name, keyword, opcode, attribute name, or the dim labels
|
||||
// pattern.
|
||||
//
|
||||
// shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
|
||||
// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
|
||||
// keyword ::= HloModule, ENTRY, ...
|
||||
// opcode ::= add, greater-than, ...
|
||||
// attribute_name ::= condition, body, dimensions, ...
|
||||
// dim_labels_pattern ::= [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,}
|
||||
TokKind HloLexer::LexIdentifier() {
|
||||
{
|
||||
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
|
||||
@ -220,6 +223,16 @@ TokKind HloLexer::LexIdentifier() {
|
||||
return TokKind::kOpcode;
|
||||
}
|
||||
|
||||
{
|
||||
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
|
||||
static LazyRE2 dim_labels_pattern = {
|
||||
R"([0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,})"};
|
||||
if (RE2::Consume(&consumable, *dim_labels_pattern)) {
|
||||
current_ptr_ = consumable.begin();
|
||||
str_val_.assign(token_start_, current_ptr_);
|
||||
return TokKind::kDimLabels;
|
||||
}
|
||||
}
|
||||
current_ptr_ = token_start_ + 1;
|
||||
return TokKind::kError;
|
||||
}
|
||||
@ -240,15 +253,20 @@ TokKind HloLexer::LexPercent() {
|
||||
return TokKind::kError;
|
||||
}
|
||||
|
||||
// Lex integer and floating-point values, and -inf.
|
||||
// int [-]?[0-9]+
|
||||
// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
|
||||
// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
|
||||
// negative inf -inf
|
||||
TokKind HloLexer::LexDigitOrNegative() {
|
||||
// Lex integer and floating-point values, -inf, and patterns for dim labels,
|
||||
// dxd (e.g. 1x2x3), and window pad.
|
||||
//
|
||||
// fp with exp ::= [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
|
||||
// fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
|
||||
// dim_labels_pattern ::= [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,}
|
||||
// dxd_pattern ::= [0-9]+(x[0-9]+)+
|
||||
// window_pad_pattern ::= [0-9]+_[0-9]+(x[0-9]+_[0-9]+)*
|
||||
// int ::= [-]?[0-9]+
|
||||
// negative inf ::= '-inf'
|
||||
TokKind HloLexer::LexNumberOrPattern() {
|
||||
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
|
||||
static LazyRE2 float_pattern = {
|
||||
R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"};
|
||||
R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"};
|
||||
if (RE2::Consume(&consumable, *float_pattern)) {
|
||||
current_ptr_ = consumable.begin();
|
||||
tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(),
|
||||
@ -256,6 +274,29 @@ TokKind HloLexer::LexDigitOrNegative() {
|
||||
return TokKind::kDecimal;
|
||||
}
|
||||
|
||||
static LazyRE2 dim_labels_pattern = {
|
||||
R"([0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,})"};
|
||||
static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"};
|
||||
static LazyRE2 pad_pattern = {R"([0-9]+_[0-9]+(x[0-9]+_[0-9]+)*)"};
|
||||
|
||||
if (RE2::Consume(&consumable, *dim_labels_pattern)) {
|
||||
current_ptr_ = consumable.begin();
|
||||
str_val_.assign(token_start_, current_ptr_);
|
||||
return TokKind::kDimLabels;
|
||||
}
|
||||
|
||||
if (RE2::Consume(&consumable, *dxd_pattern)) {
|
||||
current_ptr_ = consumable.begin();
|
||||
str_val_.assign(token_start_, current_ptr_);
|
||||
return TokKind::kDxD;
|
||||
}
|
||||
|
||||
if (RE2::Consume(&consumable, *pad_pattern)) {
|
||||
current_ptr_ = consumable.begin();
|
||||
str_val_.assign(token_start_, current_ptr_);
|
||||
return TokKind::kWindowPad;
|
||||
}
|
||||
|
||||
static LazyRE2 int_pattern = {R"([-]?\d+)"};
|
||||
if (RE2::Consume(&consumable, *int_pattern)) {
|
||||
current_ptr_ = consumable.begin();
|
||||
@ -350,6 +391,12 @@ string TokKindToString(TokKind kind) {
|
||||
return "kName";
|
||||
case TokKind::kAttributeName:
|
||||
return "kAttributeName";
|
||||
case TokKind::kDimLabels:
|
||||
return "kDimLabels";
|
||||
case TokKind::kDxD:
|
||||
return "kDxD";
|
||||
case TokKind::kWindowPad:
|
||||
return "kWindowPad";
|
||||
case TokKind::kShape:
|
||||
return "kShape";
|
||||
case TokKind::kOpcode:
|
||||
|
@ -37,11 +37,15 @@ class HloLexer {
|
||||
}
|
||||
|
||||
TokKind Lex() { return current_kind_ = LexToken(); }
|
||||
|
||||
TokKind GetKind() const { return current_kind_; }
|
||||
string GetStrVal() const {
|
||||
switch (GetKind()) {
|
||||
case TokKind::kName:
|
||||
case TokKind::kAttributeName:
|
||||
case TokKind::kDimLabels:
|
||||
case TokKind::kDxD:
|
||||
case TokKind::kWindowPad:
|
||||
return str_val_;
|
||||
default:
|
||||
LOG(FATAL) << "This token does not have string value";
|
||||
@ -92,7 +96,7 @@ class HloLexer {
|
||||
TokKind LexPercent();
|
||||
TokKind LexShape();
|
||||
TokKind LexConstant();
|
||||
TokKind LexDigitOrNegative();
|
||||
TokKind LexNumberOrPattern();
|
||||
TokKind LexComment();
|
||||
|
||||
const tensorflow::StringPiece buf_;
|
||||
|
@ -28,6 +28,9 @@ namespace tools {
|
||||
namespace {
|
||||
|
||||
using tensorflow::StringPiece;
|
||||
using tensorflow::gtl::optional;
|
||||
using tensorflow::str_util::Split;
|
||||
using tensorflow::str_util::SplitAndParseAsInts;
|
||||
using tensorflow::strings::Printf;
|
||||
using tensorflow::strings::StrAppend;
|
||||
using tensorflow::strings::StrCat;
|
||||
@ -57,8 +60,6 @@ class HloParser {
|
||||
bool ParseInstructionList(HloComputation::Builder* builder,
|
||||
string* root_name);
|
||||
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
|
||||
bool ParseSharding(HloInstruction* instruction);
|
||||
bool ParseControlPredecessors(HloInstruction* instruction);
|
||||
bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
|
||||
bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
|
||||
bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
|
||||
@ -78,10 +79,55 @@ class HloParser {
|
||||
bool ParseOperands(std::vector<HloInstruction*>* operands,
|
||||
const int expected_size);
|
||||
|
||||
template <typename T>
|
||||
bool ParseExtraAttribute(T* value, const string& expected_attribute);
|
||||
template <typename T>
|
||||
bool ParseAttributeValue(T* value);
|
||||
// Types of attributes.
|
||||
enum class AttrTy {
|
||||
kInt64,
|
||||
kHloComputation,
|
||||
kWindow,
|
||||
kConvolutionDimensionNumbers,
|
||||
kSharding,
|
||||
kInstructionList,
|
||||
};
|
||||
|
||||
struct AttrConfig {
|
||||
bool required; // whether it's required or optional
|
||||
AttrTy attr_type; // what type it is
|
||||
void* result; // where to store the parsed result.
|
||||
};
|
||||
|
||||
// Parses attributes given names and configs of the attributes. Each parsed
|
||||
// result is passed back through the result pointer in corresponding
|
||||
// AttrConfig. Note that the result pointer must point to a optional<T> typed
|
||||
// variable which outlives this function. Returns false on error. You should
|
||||
// not use the any of the results if this function failed.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// std::unordered_map<string, AttrConfig> attrs;
|
||||
// optional<int64> foo;
|
||||
// attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo};
|
||||
// optional<Window> bar;
|
||||
// attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar};
|
||||
// if (!ParseAttribute(attrs)) {
|
||||
// return false; // Do not use 'foo' 'bar' if failed.
|
||||
// }
|
||||
// // Do something with 'bar'.
|
||||
// if (foo) { // If attr foo is seen, do something with 'foo'. }
|
||||
//
|
||||
bool ParseAttributes(const std::unordered_map<string, AttrConfig>& attrs);
|
||||
|
||||
// Parses a name and finds the corresponding hlo computation.
|
||||
bool ParseComputationName(HloComputation** value);
|
||||
// Parses a list of names and finds the corresponding hlo instructions.
|
||||
bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
|
||||
bool ParseWindow(Window* window);
|
||||
bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
|
||||
bool ParseSharding(OpSharding* sharding);
|
||||
|
||||
// Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
|
||||
bool ParseDxD(const string& name, std::vector<int64>* result);
|
||||
// Parses window's pad sub-attriute, e.g., pad=0_0x3x3.
|
||||
bool ParseWindowPad(std::vector<std::vector<int64>>* pad);
|
||||
|
||||
bool ParseParamList();
|
||||
bool ParseName(string* result);
|
||||
@ -214,7 +260,7 @@ bool HloParser::ParseInstructionList(HloComputation::Builder* builder,
|
||||
"expects '}' at the end of instruction list.");
|
||||
}
|
||||
|
||||
// instruction ::= ('ROOT')? name '=' shape opcode operands (extra_attribute)*
|
||||
// instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
|
||||
bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
string* root_name) {
|
||||
string name;
|
||||
@ -230,6 +276,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
if (is_root) {
|
||||
*root_name = name;
|
||||
}
|
||||
|
||||
// Add optional attributes.
|
||||
std::unordered_map<string, AttrConfig> attrs;
|
||||
optional<OpSharding> sharding;
|
||||
attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
|
||||
optional<std::vector<HloInstruction*>> predecessors;
|
||||
attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
|
||||
&predecessors};
|
||||
|
||||
HloInstruction* instruction;
|
||||
switch (opcode) {
|
||||
case HloOpcode::kParameter: {
|
||||
@ -237,7 +292,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
if (!ParseToken(TokKind::kLparen,
|
||||
"expects '(' before parameter number") ||
|
||||
!ParseInt64(¶meter_number) ||
|
||||
!ParseToken(TokKind::kRparen, "expects ')' after parameter number")) {
|
||||
!ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(
|
||||
@ -249,7 +305,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
if (!ParseToken(TokKind::kLparen,
|
||||
"expects '(' before constant literal") ||
|
||||
!ParseLiteral(&literal, shape) ||
|
||||
!ParseToken(TokKind::kRparen, "expects ')' after constant literal")) {
|
||||
!ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(
|
||||
@ -275,7 +332,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kSort:
|
||||
case HloOpcode::kTanh: {
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1)) {
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(
|
||||
@ -305,7 +363,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
case HloOpcode::kShiftLeft:
|
||||
case HloOpcode::kShiftRightArithmetic:
|
||||
case HloOpcode::kShiftRightLogical: {
|
||||
if (!ParseOperands(&operands, /*expected_size=*/2)) {
|
||||
if (!ParseOperands(&operands, /*expected_size=*/2) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(HloInstruction::CreateBinary(
|
||||
@ -315,7 +374,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
// Ternary ops.
|
||||
case HloOpcode::kClamp:
|
||||
case HloOpcode::kSelect: {
|
||||
if (!ParseOperands(&operands, /*expected_size=*/3)) {
|
||||
if (!ParseOperands(&operands, /*expected_size=*/3) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(HloInstruction::CreateTernary(
|
||||
@ -324,7 +384,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
}
|
||||
// Other supported ops.
|
||||
case HloOpcode::kConvert: {
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1)) {
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(
|
||||
@ -332,7 +393,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCrossReplicaSum: {
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1)) {
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(
|
||||
@ -340,7 +402,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kReshape: {
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1)) {
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(
|
||||
@ -348,7 +411,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kTuple: {
|
||||
if (!ParseOperands(&operands)) {
|
||||
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction =
|
||||
@ -356,70 +419,99 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kWhile: {
|
||||
HloComputation* condition;
|
||||
HloComputation* body;
|
||||
optional<HloComputation*> condition;
|
||||
optional<HloComputation*> body;
|
||||
attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
|
||||
&condition};
|
||||
attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1) ||
|
||||
!ParseExtraAttribute(&condition,
|
||||
/*expected_attribute=*/"condition") ||
|
||||
!ParseExtraAttribute(&body, /*expected_attribute=*/"body")) {
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(HloInstruction::CreateWhile(
|
||||
shape, condition, body, /*init=*/operands[0]));
|
||||
shape, *condition, *body, /*init=*/operands[0]));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kRecv: {
|
||||
int64 channel_id;
|
||||
optional<int64> channel_id;
|
||||
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
|
||||
if (!ParseOperands(&operands, /*expected_size=*/0) ||
|
||||
!ParseExtraAttribute(&channel_id,
|
||||
/*expected_attribute=*/"channel_id")) {
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(
|
||||
HloInstruction::CreateRecv(shape, channel_id));
|
||||
HloInstruction::CreateRecv(shape, *channel_id));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kSend: {
|
||||
int64 channel_id;
|
||||
optional<int64> channel_id;
|
||||
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1) ||
|
||||
!ParseExtraAttribute(&channel_id,
|
||||
/*expected_attribute=*/"channel_id")) {
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(
|
||||
HloInstruction::CreateSend(operands[0], channel_id));
|
||||
HloInstruction::CreateSend(operands[0], *channel_id));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kGetTupleElement: {
|
||||
int64 index;
|
||||
optional<int64> index;
|
||||
attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1) ||
|
||||
!ParseExtraAttribute(&index, /*expected_attribute=*/"index")) {
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(shape, operands[0], index));
|
||||
HloInstruction::CreateGetTupleElement(shape, operands[0], *index));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCall: {
|
||||
HloComputation* to_apply;
|
||||
if (!ParseOperands(&operands) ||
|
||||
!ParseExtraAttribute(&to_apply,
|
||||
/*expected_attribute=*/"to_apply")) {
|
||||
optional<HloComputation*> to_apply;
|
||||
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
|
||||
&to_apply};
|
||||
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(
|
||||
HloInstruction::CreateCall(shape, operands, to_apply));
|
||||
HloInstruction::CreateCall(shape, operands, *to_apply));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kReduceWindow: {
|
||||
optional<HloComputation*> reduce_computation;
|
||||
optional<Window> window;
|
||||
attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window};
|
||||
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
|
||||
&reduce_computation};
|
||||
if (!ParseOperands(&operands, /*expected_size=*/2) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
|
||||
shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window,
|
||||
*reduce_computation));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kConvolution: {
|
||||
optional<Window> window;
|
||||
optional<ConvolutionDimensionNumbers> dnums;
|
||||
attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window};
|
||||
attrs["dim_labels"] = {/*required=*/true,
|
||||
AttrTy::kConvolutionDimensionNumbers, &dnums};
|
||||
if (!ParseOperands(&operands, /*expected_size=*/2) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
|
||||
shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kCustomCall:
|
||||
case HloOpcode::kConcatenate:
|
||||
case HloOpcode::kReducePrecision:
|
||||
case HloOpcode::kConvolution:
|
||||
case HloOpcode::kMap:
|
||||
case HloOpcode::kPad:
|
||||
case HloOpcode::kReduce:
|
||||
case HloOpcode::kReduceWindow:
|
||||
case HloOpcode::kSelectAndScatter:
|
||||
case HloOpcode::kReverse:
|
||||
case HloOpcode::kRng:
|
||||
@ -438,43 +530,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
|
||||
HloOpcodeString(opcode)));
|
||||
}
|
||||
|
||||
bool has_sharding = false;
|
||||
bool has_control = false;
|
||||
while (EatIfPresent(TokKind::kComma)) {
|
||||
string attribute_name;
|
||||
if (!ParseAttributeName(&attribute_name)) {
|
||||
return TokenError("expects ', sharding=' or ', control-predecessors='");
|
||||
}
|
||||
|
||||
if (attribute_name == "sharding") {
|
||||
// Parse "sharding=".
|
||||
if (has_sharding) {
|
||||
return TokenError("expects at most 1 'sharding='");
|
||||
// Add common attrs (sharding, control predecessors) to the instruction, if
|
||||
// they were seen.
|
||||
if (sharding) {
|
||||
instruction->set_sharding(
|
||||
HloSharding::FromProto(sharding.value()).ValueOrDie());
|
||||
}
|
||||
if (predecessors) {
|
||||
for (auto* pre : *predecessors) {
|
||||
Status status = pre->AddControlDependencyTo(instruction);
|
||||
if (!status.ok()) {
|
||||
return TokenError(StrCat("error adding control dependency for: ", name,
|
||||
" status: ", status.ToString()));
|
||||
}
|
||||
has_sharding = true;
|
||||
if (!ParseSharding(instruction)) {
|
||||
return false;
|
||||
}
|
||||
} else if (attribute_name == "control-predecessors") {
|
||||
// Parse "control-predecessors"
|
||||
if (has_control) {
|
||||
return TokenError("expects at most 1 'control-predecessors='");
|
||||
}
|
||||
has_control = true;
|
||||
if (!ParseControlPredecessors(instruction)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return TokenError(StrCat("unexpected attribute: ", attribute_name));
|
||||
}
|
||||
}
|
||||
|
||||
return AddInstruction(name, instruction);
|
||||
}
|
||||
|
||||
// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('['
|
||||
// dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list
|
||||
bool HloParser::ParseSharding(HloInstruction* instruction) {
|
||||
bool HloParser::ParseSharding(OpSharding* sharding) {
|
||||
if (!ParseToken(TokKind::kLbrace,
|
||||
"expected '{' to start sharding attribute")) {
|
||||
return false;
|
||||
@ -545,7 +621,6 @@ bool HloParser::ParseSharding(HloInstruction* instruction) {
|
||||
}
|
||||
}
|
||||
|
||||
OpSharding sharding;
|
||||
if (replicated) {
|
||||
if (!devices.empty()) {
|
||||
return TokenError(
|
||||
@ -555,7 +630,7 @@ bool HloParser::ParseSharding(HloInstruction* instruction) {
|
||||
return TokenError(
|
||||
"replicated shardings should not have any tile shape set");
|
||||
}
|
||||
sharding.set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
|
||||
sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
|
||||
} else if (maximal) {
|
||||
if (devices.size() != 1) {
|
||||
return TokenError(
|
||||
@ -564,8 +639,8 @@ bool HloParser::ParseSharding(HloInstruction* instruction) {
|
||||
if (!ShapeUtil::Equal(tile_shape, Shape())) {
|
||||
return TokenError("maximal shardings should not have any tile shape set");
|
||||
}
|
||||
sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
|
||||
sharding.add_tile_assignment_devices(devices[0]);
|
||||
sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
|
||||
sharding->add_tile_assignment_devices(devices[0]);
|
||||
} else {
|
||||
if (devices.size() <= 1) {
|
||||
return TokenError(
|
||||
@ -579,47 +654,43 @@ bool HloParser::ParseSharding(HloInstruction* instruction) {
|
||||
"non-maximal shardings must have a tile assignment list including "
|
||||
"dimensions");
|
||||
}
|
||||
sharding.set_type(OpSharding::Type::OpSharding_Type_OTHER);
|
||||
*sharding.mutable_tile_shape() = tile_shape;
|
||||
sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER);
|
||||
*sharding->mutable_tile_shape() = tile_shape;
|
||||
for (int64 dim : tile_assignment_dimensions) {
|
||||
sharding.add_tile_assignment_dimensions(dim);
|
||||
sharding->add_tile_assignment_dimensions(dim);
|
||||
}
|
||||
for (int64 device : devices) {
|
||||
sharding.add_tile_assignment_devices(device);
|
||||
sharding->add_tile_assignment_devices(device);
|
||||
}
|
||||
}
|
||||
|
||||
instruction->set_sharding(HloSharding::FromProto(sharding).ValueOrDie());
|
||||
lexer_.Lex();
|
||||
return true;
|
||||
}
|
||||
|
||||
// '{' name+ '}'
|
||||
bool HloParser::ParseControlPredecessors(HloInstruction* instruction) {
|
||||
bool HloParser::ParseInstructionNames(
|
||||
std::vector<HloInstruction*>* instructions) {
|
||||
if (!ParseToken(TokKind::kLbrace,
|
||||
"expects '{' at the beginning of control predecessors")) {
|
||||
"expects '{' at the beginning of instruction name list")) {
|
||||
return false;
|
||||
}
|
||||
do {
|
||||
string name;
|
||||
if (!ParseName(&name)) {
|
||||
return TokenError("expects a control predecessor");
|
||||
return TokenError("expects a instruction name");
|
||||
}
|
||||
HloInstruction* pre =
|
||||
HloInstruction* instr =
|
||||
tensorflow::gtl::FindPtrOrNull(instruction_pool_, name);
|
||||
if (!pre) {
|
||||
if (!instr) {
|
||||
return TokenError(
|
||||
StrCat("control predecessor ", name, " is not defined: "));
|
||||
}
|
||||
Status status = pre->AddControlDependencyTo(instruction);
|
||||
if (!status.ok()) {
|
||||
return TokenError(StrCat("error adding control dependency for: ", name,
|
||||
" status: ", status.ToString()));
|
||||
Printf("instruction '%s' is not defined", name.c_str()));
|
||||
}
|
||||
instructions->push_back(instr);
|
||||
} while (EatIfPresent(TokKind::kComma));
|
||||
|
||||
return ParseToken(TokKind::kRbrace,
|
||||
"expects '}' at the end of control predecessors");
|
||||
"expects '}' at the end of control instructions");
|
||||
}
|
||||
|
||||
bool HloParser::SetValueInLiteral(int64 value, int64 linear_index,
|
||||
@ -957,28 +1028,95 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands,
|
||||
return true;
|
||||
}
|
||||
|
||||
// extra_attribute ::= ',' attribute_name value
|
||||
template <typename T>
|
||||
bool HloParser::ParseExtraAttribute(T* value,
|
||||
const string& expected_attribute) {
|
||||
if (!ParseToken(TokKind::kComma,
|
||||
"expects ',' in front of an extra attribute")) {
|
||||
return false;
|
||||
bool HloParser::ParseAttributes(
|
||||
const std::unordered_map<string, AttrConfig>& attrs) {
|
||||
std::unordered_set<string> seen_attrs;
|
||||
while (EatIfPresent(TokKind::kComma)) {
|
||||
string name;
|
||||
if (!ParseAttributeName(&name)) {
|
||||
return TokenError("error parsing attributes");
|
||||
}
|
||||
VLOG(1) << "Parsing attribute " << name;
|
||||
if (!seen_attrs.insert(name).second) {
|
||||
return TokenError(Printf("attribute %s already exists", name.c_str()));
|
||||
}
|
||||
auto attr_it = attrs.find(name);
|
||||
if (attr_it == attrs.end()) {
|
||||
return TokenError(Printf("unexpected attribute %s", name.c_str()));
|
||||
}
|
||||
AttrTy attr_type = attr_it->second.attr_type;
|
||||
void* attr_out_ptr = attr_it->second.result;
|
||||
bool success = [&] {
|
||||
switch (attr_type) {
|
||||
case AttrTy::kInt64: {
|
||||
int64 result;
|
||||
if (!ParseInt64(&result)) {
|
||||
return false;
|
||||
}
|
||||
static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
|
||||
return true;
|
||||
}
|
||||
case AttrTy::kHloComputation: {
|
||||
HloComputation* result;
|
||||
if (!ParseComputationName(&result)) {
|
||||
return false;
|
||||
}
|
||||
static_cast<optional<HloComputation*>*>(attr_out_ptr)
|
||||
->emplace(result);
|
||||
return true;
|
||||
}
|
||||
case AttrTy::kWindow: {
|
||||
Window result;
|
||||
if (!ParseWindow(&result)) {
|
||||
return false;
|
||||
}
|
||||
static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
|
||||
return true;
|
||||
}
|
||||
case AttrTy::kConvolutionDimensionNumbers: {
|
||||
ConvolutionDimensionNumbers result;
|
||||
if (!ParseConvolutionDimensionNumbers(&result)) {
|
||||
return false;
|
||||
}
|
||||
static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
|
||||
->emplace(result);
|
||||
return true;
|
||||
}
|
||||
case AttrTy::kSharding: {
|
||||
OpSharding sharding;
|
||||
if (!ParseSharding(&sharding)) {
|
||||
return false;
|
||||
}
|
||||
static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
|
||||
return true;
|
||||
}
|
||||
case AttrTy::kInstructionList: {
|
||||
std::vector<HloInstruction*> result;
|
||||
if (!ParseInstructionNames(&result)) {
|
||||
return false;
|
||||
}
|
||||
static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
|
||||
->emplace(result);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}();
|
||||
if (!success) {
|
||||
return TokenError(Printf("error parsing attribute %s", name.c_str()));
|
||||
}
|
||||
}
|
||||
string attribute_name;
|
||||
if (!ParseAttributeName(&attribute_name) &&
|
||||
attribute_name != expected_attribute) {
|
||||
return TokenError(StrCat("expects attribute name: ", expected_attribute));
|
||||
}
|
||||
if (!ParseAttributeValue(value)) {
|
||||
return TokenError(
|
||||
StrCat("expects value for attribute: ", expected_attribute));
|
||||
// Check that all required attrs were seen.
|
||||
for (const auto& attr_it : attrs) {
|
||||
if (attr_it.second.required &&
|
||||
seen_attrs.find(attr_it.first) == seen_attrs.end()) {
|
||||
return TokenError(Printf("attribute %s is expected but not seen",
|
||||
attr_it.first.c_str()));
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool HloParser::ParseAttributeValue<HloComputation*>(HloComputation** value) {
|
||||
bool HloParser::ParseComputationName(HloComputation** value) {
|
||||
string name;
|
||||
if (!ParseName(&name)) {
|
||||
return TokenError("expects computation name");
|
||||
@ -990,9 +1128,191 @@ bool HloParser::ParseAttributeValue<HloComputation*>(HloComputation** value) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool HloParser::ParseAttributeValue<int64>(int64* value) {
|
||||
return ParseInt64(value);
|
||||
// ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
|
||||
// The subattributes can appear in any order. 'size=' is required, others are
|
||||
// optional.
|
||||
bool HloParser::ParseWindow(Window* window) {
|
||||
if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int64> size;
|
||||
std::vector<int64> stride;
|
||||
std::vector<std::vector<int64>> pad;
|
||||
std::vector<int64> lhs_dilate;
|
||||
std::vector<int64> rhs_dilate;
|
||||
while (lexer_.GetKind() != TokKind::kRbrace) {
|
||||
string field_name;
|
||||
if (!ParseAttributeName(&field_name)) {
|
||||
return TokenError("expects sub-attributes in window");
|
||||
}
|
||||
bool ok = [&] {
|
||||
if (field_name == "size") {
|
||||
return ParseDxD("size", &size);
|
||||
}
|
||||
if (field_name == "stride") {
|
||||
return ParseDxD("stride", &stride);
|
||||
}
|
||||
if (field_name == "lhs_dilate") {
|
||||
return ParseDxD("lhs_dilate", &lhs_dilate);
|
||||
}
|
||||
if (field_name == "rhs_dilate") {
|
||||
return ParseDxD("rls_dilate", &rhs_dilate);
|
||||
}
|
||||
if (field_name == "pad") {
|
||||
return ParseWindowPad(&pad);
|
||||
}
|
||||
return TokenError(StrCat("unexpected attribute name: ", field_name));
|
||||
}();
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (size.empty()) {
|
||||
return TokenError(
|
||||
"sub-attribute 'size=' is required in the window attribute");
|
||||
}
|
||||
if (!stride.empty() && stride.size() != size.size()) {
|
||||
return TokenError("expects 'stride=' has the same size as 'size='");
|
||||
}
|
||||
if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) {
|
||||
return TokenError("expects 'lhs_dilate=' has the same size as 'size='");
|
||||
}
|
||||
if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) {
|
||||
return TokenError("expects 'rhs_dilate=' has the same size as 'size='");
|
||||
}
|
||||
if (!pad.empty() && pad.size() != size.size()) {
|
||||
return TokenError("expects 'pad=' has the same size as 'size='");
|
||||
}
|
||||
|
||||
for (int i = 0; i < size.size(); i++) {
|
||||
window->add_dimensions()->set_size(size[i]);
|
||||
if (!pad.empty()) {
|
||||
window->mutable_dimensions(i)->set_padding_low(pad[i][0]);
|
||||
window->mutable_dimensions(i)->set_padding_high(pad[i][1]);
|
||||
}
|
||||
// If some field is not present, it has the default value.
|
||||
window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]);
|
||||
window->mutable_dimensions(i)->set_base_dilation(
|
||||
lhs_dilate.empty() ? 1 : lhs_dilate[i]);
|
||||
window->mutable_dimensions(i)->set_window_dilation(
|
||||
rhs_dilate.empty() ? 1 : rhs_dilate[i]);
|
||||
}
|
||||
return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
|
||||
}
|
||||
|
||||
// This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
|
||||
// The string looks like "dim_labels=0bf_0io->0bf".
|
||||
bool HloParser::ParseConvolutionDimensionNumbers(
|
||||
ConvolutionDimensionNumbers* dnums) {
|
||||
if (lexer_.GetKind() != TokKind::kDimLabels) {
|
||||
return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'");
|
||||
}
|
||||
string str = lexer_.GetStrVal();
|
||||
|
||||
// The str is expected to have 3 items, lhs, rhs, out, and it must looks like
|
||||
// lhs_rhs->out, that is, the first separator is "_" and the second is "->".
|
||||
// So we replace the "->" with "_" and then split on "_".
|
||||
str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->",
|
||||
/*newsub=*/"_",
|
||||
/*replace_all=*/false);
|
||||
std::vector<string> lhs_rhs_out = Split(str, "_");
|
||||
if (lhs_rhs_out.size() != 3) {
|
||||
LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
|
||||
<< str;
|
||||
}
|
||||
|
||||
const int64 rank = lhs_rhs_out[0].length();
|
||||
if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) {
|
||||
return TokenError(
|
||||
"convolution lhs, rhs, and output must have the same rank");
|
||||
}
|
||||
if (rank < 3) {
|
||||
return TokenError("convolution rank must >=3");
|
||||
}
|
||||
|
||||
auto is_unique = [](string str) -> bool {
|
||||
std::sort(str.begin(), str.end());
|
||||
return std::unique(str.begin(), str.end()) == str.end();
|
||||
};
|
||||
|
||||
// lhs
|
||||
{
|
||||
const string& lhs = lhs_rhs_out[0];
|
||||
if (!is_unique(lhs)) {
|
||||
return TokenError(
|
||||
StrCat("expects unique lhs dimension numbers, but sees ", lhs));
|
||||
}
|
||||
for (int i = 0; i < rank - 2; i++) {
|
||||
dnums->add_spatial_dimensions(-1);
|
||||
}
|
||||
for (int i = 0; i < rank; i++) {
|
||||
char c = lhs[i];
|
||||
if (c == 'b') {
|
||||
dnums->set_input_batch_dimension(i);
|
||||
} else if (c == 'f') {
|
||||
dnums->set_input_feature_dimension(i);
|
||||
} else if (c < '0' + rank && c >= '0') {
|
||||
dnums->set_spatial_dimensions(c - '0', i);
|
||||
} else {
|
||||
return TokenError(
|
||||
Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
// rhs
|
||||
{
|
||||
const string& rhs = lhs_rhs_out[1];
|
||||
if (!is_unique(rhs)) {
|
||||
return TokenError(
|
||||
StrCat("expects unique rhs dimension numbers, but sees ", rhs));
|
||||
}
|
||||
for (int i = 0; i < rank - 2; i++) {
|
||||
dnums->add_kernel_spatial_dimensions(-1);
|
||||
}
|
||||
for (int i = 0; i < rank; i++) {
|
||||
char c = rhs[i];
|
||||
if (c == 'i') {
|
||||
dnums->set_kernel_input_feature_dimension(i);
|
||||
} else if (c == 'o') {
|
||||
dnums->set_kernel_output_feature_dimension(i);
|
||||
} else if (c < '0' + rank && c >= '0') {
|
||||
dnums->set_kernel_spatial_dimensions(c - '0', i);
|
||||
} else {
|
||||
return TokenError(
|
||||
Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
// output
|
||||
{
|
||||
const string& out = lhs_rhs_out[2];
|
||||
if (!is_unique(out)) {
|
||||
return TokenError(
|
||||
StrCat("expects unique output dimension numbers, but sees ", out));
|
||||
}
|
||||
for (int i = 0; i < rank; i++) {
|
||||
char c = out[i];
|
||||
if (c == 'b') {
|
||||
dnums->set_output_batch_dimension(i);
|
||||
} else if (c == 'f') {
|
||||
dnums->set_output_feature_dimension(i);
|
||||
} else if (c < '0' + rank && c >= '0') {
|
||||
if (dnums->spatial_dimensions(c - '0') != i) {
|
||||
return TokenError(
|
||||
"output spatial dimensions should be the same as input spatial "
|
||||
"dimensions");
|
||||
}
|
||||
} else {
|
||||
return TokenError(
|
||||
Printf("expects [0-%lldbf] in output dimension numbers", rank - 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lexer_.Lex();
|
||||
return true;
|
||||
}
|
||||
|
||||
// param_list ::= '(' param_list1 ')'
|
||||
@ -1070,6 +1390,55 @@ bool HloParser::ParseAttributeName(string* result) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) {
|
||||
if (!result->empty()) {
|
||||
return TokenError(
|
||||
Printf("sub-attribute '%s=' already exists", name.c_str()));
|
||||
}
|
||||
// 1D
|
||||
if (lexer_.GetKind() == TokKind::kInt) {
|
||||
int64 number;
|
||||
if (!ParseInt64(&number)) {
|
||||
return TokenError(Printf("expects sub-attribute '%s=i'", name.c_str()));
|
||||
}
|
||||
result->push_back(number);
|
||||
return true;
|
||||
}
|
||||
// 2D or higher.
|
||||
if (lexer_.GetKind() == TokKind::kDxD) {
|
||||
string str = lexer_.GetStrVal();
|
||||
if (!SplitAndParseAsInts(str, 'x', result)) {
|
||||
return TokenError(
|
||||
Printf("expects sub-attribute '%s=ixj...'", name.c_str()));
|
||||
}
|
||||
lexer_.Lex();
|
||||
return true;
|
||||
}
|
||||
return TokenError("expects token type kInt or kDxD");
|
||||
}
|
||||
|
||||
bool HloParser::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
|
||||
if (!pad->empty()) {
|
||||
return TokenError("sub-attribute 'pad=' already exists");
|
||||
}
|
||||
if (lexer_.GetKind() != TokKind::kWindowPad) {
|
||||
return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
|
||||
}
|
||||
string str = lexer_.GetStrVal();
|
||||
std::vector<string> padding_str = Split(str, 'x');
|
||||
for (int i = 0; i < padding_str.size(); i++) {
|
||||
std::vector<int64> low_high;
|
||||
if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) ||
|
||||
low_high.size() != 2) {
|
||||
return TokenError(
|
||||
"expects padding_low and padding_high separated by '_'");
|
||||
}
|
||||
pad->push_back(low_high);
|
||||
}
|
||||
lexer_.Lex();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParser::ParseOpcode(HloOpcode* result) {
|
||||
VLOG(1) << "ParseOpcode";
|
||||
if (lexer_.GetKind() != TokKind::kOpcode) {
|
||||
|
@ -25,6 +25,7 @@ namespace tools {
|
||||
namespace {
|
||||
|
||||
using tensorflow::StringPiece;
|
||||
using tensorflow::strings::StrCat;
|
||||
|
||||
struct TestData {
|
||||
string test_name;
|
||||
@ -247,6 +248,39 @@ ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] {
|
||||
ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// reduce window
|
||||
{
|
||||
"ReduceWindow",
|
||||
R"(HloModule R4UnitWindow_module:
|
||||
|
||||
%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
|
||||
%lhs = f32[] parameter(0)
|
||||
%rhs = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
|
||||
}
|
||||
|
||||
ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] {
|
||||
%operand = f32[13,12,8,15]{0,3,2,1} parameter(0)
|
||||
%constant = f32[] constant(0)
|
||||
ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// convolution
|
||||
{
|
||||
"Convolution",
|
||||
R"(HloModule Convolve1D1Window_0_module:
|
||||
|
||||
ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
|
||||
%input = f32[1,2,1]{2,1,0} parameter(0)
|
||||
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
|
||||
%filter = f32[1,1,1]{2,1,0} parameter(1)
|
||||
ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f
|
||||
}
|
||||
|
||||
)"
|
||||
}
|
||||
});
|
||||
@ -427,6 +461,92 @@ ENTRY %ConstantWithExp.v4 () -> f32[] {
|
||||
// printed as "300".
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, AttibutesAnyOrder) {
|
||||
const string original = R"(HloModule any_order_module:
|
||||
|
||||
ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
|
||||
%input = f32[1,2,1]{2,1,0} parameter(0)
|
||||
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
|
||||
%filter = f32[1,1,1]{2,1,0} parameter(1)
|
||||
ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
|
||||
}
|
||||
|
||||
)";
|
||||
TF_EXPECT_OK(Parse(original).status());
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, InvalidDimLabels) {
|
||||
string prefix = R"(HloModule invalid_dim_labels_module:
|
||||
|
||||
ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
|
||||
%input = f32[1,2,1]{2,1,0} parameter(0)
|
||||
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
|
||||
%filter = f32[1,1,1]{2,1,0} parameter(1)
|
||||
ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1} )";
|
||||
string suffix = R"(
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=00_01_10", suffix))
|
||||
.status()
|
||||
.error_message(),
|
||||
"expects dim labels pattern");
|
||||
|
||||
ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=010_1100->010", suffix))
|
||||
.status()
|
||||
.error_message(),
|
||||
"must have the same rank");
|
||||
|
||||
ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=0bf_io0->b0f", suffix))
|
||||
.status()
|
||||
.error_message(),
|
||||
"output spatial dimensions should be the same as input "
|
||||
"spatial dimensions");
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, UnexpectedAttribute) {
|
||||
const string original = R"(HloModule unexpected_attr_module:
|
||||
|
||||
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
|
||||
%recv = f32[] recv(), channel_id=15
|
||||
ROOT %constant = f32[] constant(2.1)
|
||||
%send = () send(f32[] %constant), channel_id=16, calls=%recv
|
||||
}
|
||||
|
||||
)";
|
||||
ExpectHasSubstr(Parse(original).status().error_message(),
|
||||
"unexpected attribute calls");
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, MissingAttribute) {
|
||||
const string original = R"(HloModule missing_attr_module:
|
||||
|
||||
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
|
||||
%recv = f32[] recv(), channel_id=15
|
||||
ROOT %constant = f32[] constant(-2.1)
|
||||
%send = () send(f32[] %constant)
|
||||
}
|
||||
|
||||
)";
|
||||
ExpectHasSubstr(Parse(original).status().error_message(),
|
||||
"attribute channel_id is expected but not seen");
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, PredecessorUndefined) {
|
||||
const string original = R"(HloModule pre_not_found_module:
|
||||
|
||||
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
|
||||
%recv = f32[] recv(), channel_id=15
|
||||
ROOT %constant = f32[] constant(2.1)
|
||||
%send = () send(f32[] %constant), channel_id=16, control-predecessors={%done}
|
||||
}
|
||||
|
||||
)";
|
||||
ExpectHasSubstr(Parse(original).status().error_message(),
|
||||
"'done' is not defined");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tools
|
||||
} // namespace xla
|
||||
|
@ -57,6 +57,9 @@ enum class TokKind {
|
||||
// Typed tokens.
|
||||
kName, // %foo
|
||||
kAttributeName, // dimensions=
|
||||
kDimLabels, // [0-9bf]+_[0-9io]+->[0-9bf]+
|
||||
kDxD, // [0-9]+(x[0-9]+)+
|
||||
kWindowPad, // [0-9]+_[0-9]+(x[0-9]+_[0-9]+)*
|
||||
kShape, // f32[2,3]{1,0}
|
||||
kOpcode, // add
|
||||
kInt, // 42
|
||||
|
@ -26,8 +26,8 @@ namespace xla {
|
||||
namespace window_util {
|
||||
|
||||
/* static */ string ToString(const WindowDimension& dim) {
|
||||
using tensorflow::strings::StrCat;
|
||||
using tensorflow::strings::StrAppend;
|
||||
using tensorflow::strings::StrCat;
|
||||
string str = StrCat("(size=", dim.size());
|
||||
if (dim.stride() != 1) {
|
||||
StrAppend(&str, ",stride=", dim.stride());
|
||||
@ -49,22 +49,22 @@ namespace window_util {
|
||||
}
|
||||
|
||||
string ToString(const Window& window) {
|
||||
using tensorflow::strings::StrCat;
|
||||
using tensorflow::strings::StrAppend;
|
||||
using tensorflow::strings::StrCat;
|
||||
|
||||
string str;
|
||||
const auto add_field = [&](
|
||||
const char* heading,
|
||||
std::function<string(const WindowDimension&)> format) {
|
||||
StrAppend(&str, heading, "=");
|
||||
const char* prefix = "";
|
||||
for (const auto& window_dimension : window.dimensions()) {
|
||||
StrAppend(&str, prefix, format(window_dimension));
|
||||
prefix = "x";
|
||||
}
|
||||
};
|
||||
const auto add_field =
|
||||
[&](const char* heading,
|
||||
std::function<string(const WindowDimension&)> format) {
|
||||
StrAppend(&str, heading, "=");
|
||||
const char* prefix = "";
|
||||
for (const auto& window_dimension : window.dimensions()) {
|
||||
StrAppend(&str, prefix, format(window_dimension));
|
||||
prefix = "x";
|
||||
}
|
||||
};
|
||||
|
||||
add_field("window",
|
||||
add_field("size",
|
||||
[](const WindowDimension& dim) { return StrCat(dim.size()); });
|
||||
if (HasStride(window)) {
|
||||
add_field(" stride",
|
||||
|
Loading…
Reference in New Issue
Block a user