Hlo parser: support window and convolution.

Also, to make the text format easier to write and unambiguous:
- Print "window={}" around the window attribute; rename the "window" sub attribute to "size";
- Print the dim_lables in logical order, instead of physical order.

PiperOrigin-RevId: 175074526
This commit is contained in:
A. Unique TensorFlower 2017-11-08 15:24:01 -08:00 committed by TensorFlower Gardener
parent 8bb665ae1c
commit 12d6b450b2
8 changed files with 690 additions and 145 deletions

View File

@ -1850,7 +1850,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}"));
}
if (window_ != nullptr) {
extra.push_back(window_util::ToString(*window_));
extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
}
if (padding_config_ != nullptr) {
extra.push_back(StrCat("padding=", padding_config_->ShortDebugString()));
@ -2856,13 +2856,7 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const {
const auto append_dims = [&](const std::vector<string>& dims,
const Shape& shape) {
CHECK_EQ(dims.size(), ShapeUtil::Rank(shape));
for (int64 logical = 0; logical < dims.size(); ++logical) {
int64 physical = logical;
if (!shape.layout().minor_to_major().empty()) {
physical = LayoutUtil::Major(shape.layout(), logical);
}
result += dims[physical];
}
StrAppend(&result, Join(dims, ""));
};
// lhs_dims[i] is the symbol of the logical dimension i for the lhs

View File

@ -43,14 +43,22 @@ operand
: shape name
;
extra_attributes
attributes
: /*empty*/
| ',' extra_attribute
| ',' extra_attribute extra_attributes
| ',' attribute
| ',' attribute attributes
;
extra_attribute
attribute
: attribute_name attribute_value
;
attribute_value
: kInt
| kName
| [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,} /*dim_labels_pattern*/
| [0-9]+(x[0-9]+)+ /*dxd_pattern*/
| [0-9]+_[0-9]+(x[0-9]+_[0-9]+)* /*window_pad_pattern*/
| '{' sub_attributes '}'
;
param_list
: '(' param_list1 ')'

View File

@ -122,7 +122,7 @@ TokKind HloLexer::LexToken() {
current_ptr_++;
return TokKind::kArrow;
}
return LexDigitOrNegative();
return LexNumberOrPattern();
case '=':
return TokKind::kEqual;
case ',':
@ -149,12 +149,15 @@ TokKind HloLexer::LexToken() {
}
}
// Lex a shape, name, keyword, or opcode.
// Lex a shape, name, keyword, opcode, attribute name, or the dim labels
// pattern.
//
// shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
// keyword ::= HloModule, ENTRY, ...
// opcode ::= add, greater-than, ...
// attribute_name ::= condition, body, dimensions, ...
// dim_labels_pattern ::= [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,}
TokKind HloLexer::LexIdentifier() {
{
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
@ -220,6 +223,16 @@ TokKind HloLexer::LexIdentifier() {
return TokKind::kOpcode;
}
{
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
static LazyRE2 dim_labels_pattern = {
R"([0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,})"};
if (RE2::Consume(&consumable, *dim_labels_pattern)) {
current_ptr_ = consumable.begin();
str_val_.assign(token_start_, current_ptr_);
return TokKind::kDimLabels;
}
}
current_ptr_ = token_start_ + 1;
return TokKind::kError;
}
@ -240,15 +253,20 @@ TokKind HloLexer::LexPercent() {
return TokKind::kError;
}
// Lex integer and floating-point values, and -inf.
// int [-]?[0-9]+
// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
// negative inf -inf
TokKind HloLexer::LexDigitOrNegative() {
// Lex integer and floating-point values, -inf, and patterns for dim labels,
// dxd (e.g. 1x2x3), and window pad.
//
// fp with exp ::= [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
// fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
// dim_labels_pattern ::= [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,}
// dxd_pattern ::= [0-9]+(x[0-9]+)+
// window_pad_pattern ::= [0-9]+_[0-9]+(x[0-9]+_[0-9]+)*
// int ::= [-]?[0-9]+
// negative inf ::= '-inf'
TokKind HloLexer::LexNumberOrPattern() {
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
static LazyRE2 float_pattern = {
R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"};
R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"};
if (RE2::Consume(&consumable, *float_pattern)) {
current_ptr_ = consumable.begin();
tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(),
@ -256,6 +274,29 @@ TokKind HloLexer::LexDigitOrNegative() {
return TokKind::kDecimal;
}
static LazyRE2 dim_labels_pattern = {
R"([0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,})"};
static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"};
static LazyRE2 pad_pattern = {R"([0-9]+_[0-9]+(x[0-9]+_[0-9]+)*)"};
if (RE2::Consume(&consumable, *dim_labels_pattern)) {
current_ptr_ = consumable.begin();
str_val_.assign(token_start_, current_ptr_);
return TokKind::kDimLabels;
}
if (RE2::Consume(&consumable, *dxd_pattern)) {
current_ptr_ = consumable.begin();
str_val_.assign(token_start_, current_ptr_);
return TokKind::kDxD;
}
if (RE2::Consume(&consumable, *pad_pattern)) {
current_ptr_ = consumable.begin();
str_val_.assign(token_start_, current_ptr_);
return TokKind::kWindowPad;
}
static LazyRE2 int_pattern = {R"([-]?\d+)"};
if (RE2::Consume(&consumable, *int_pattern)) {
current_ptr_ = consumable.begin();
@ -350,6 +391,12 @@ string TokKindToString(TokKind kind) {
return "kName";
case TokKind::kAttributeName:
return "kAttributeName";
case TokKind::kDimLabels:
return "kDimLabels";
case TokKind::kDxD:
return "kDxD";
case TokKind::kWindowPad:
return "kWindowPad";
case TokKind::kShape:
return "kShape";
case TokKind::kOpcode:

View File

@ -37,11 +37,15 @@ class HloLexer {
}
TokKind Lex() { return current_kind_ = LexToken(); }
TokKind GetKind() const { return current_kind_; }
string GetStrVal() const {
switch (GetKind()) {
case TokKind::kName:
case TokKind::kAttributeName:
case TokKind::kDimLabels:
case TokKind::kDxD:
case TokKind::kWindowPad:
return str_val_;
default:
LOG(FATAL) << "This token does not have string value";
@ -92,7 +96,7 @@ class HloLexer {
TokKind LexPercent();
TokKind LexShape();
TokKind LexConstant();
TokKind LexDigitOrNegative();
TokKind LexNumberOrPattern();
TokKind LexComment();
const tensorflow::StringPiece buf_;

View File

@ -28,6 +28,9 @@ namespace tools {
namespace {
using tensorflow::StringPiece;
using tensorflow::gtl::optional;
using tensorflow::str_util::Split;
using tensorflow::str_util::SplitAndParseAsInts;
using tensorflow::strings::Printf;
using tensorflow::strings::StrAppend;
using tensorflow::strings::StrCat;
@ -57,8 +60,6 @@ class HloParser {
bool ParseInstructionList(HloComputation::Builder* builder,
string* root_name);
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
bool ParseSharding(HloInstruction* instruction);
bool ParseControlPredecessors(HloInstruction* instruction);
bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
@ -78,10 +79,55 @@ class HloParser {
bool ParseOperands(std::vector<HloInstruction*>* operands,
const int expected_size);
template <typename T>
bool ParseExtraAttribute(T* value, const string& expected_attribute);
template <typename T>
bool ParseAttributeValue(T* value);
// Types of attributes.
enum class AttrTy {
kInt64,
kHloComputation,
kWindow,
kConvolutionDimensionNumbers,
kSharding,
kInstructionList,
};
struct AttrConfig {
bool required; // whether it's required or optional
AttrTy attr_type; // what type it is
void* result; // where to store the parsed result.
};
// Parses attributes given names and configs of the attributes. Each parsed
// result is passed back through the result pointer in corresponding
// AttrConfig. Note that the result pointer must point to a optional<T> typed
// variable which outlives this function. Returns false on error. You should
// not use the any of the results if this function failed.
//
// Example usage:
//
// std::unordered_map<string, AttrConfig> attrs;
// optional<int64> foo;
// attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo};
// optional<Window> bar;
// attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar};
// if (!ParseAttribute(attrs)) {
// return false; // Do not use 'foo' 'bar' if failed.
// }
// // Do something with 'bar'.
// if (foo) { // If attr foo is seen, do something with 'foo'. }
//
bool ParseAttributes(const std::unordered_map<string, AttrConfig>& attrs);
// Parses a name and finds the corresponding hlo computation.
bool ParseComputationName(HloComputation** value);
// Parses a list of names and finds the corresponding hlo instructions.
bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
bool ParseWindow(Window* window);
bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
bool ParseSharding(OpSharding* sharding);
// Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
bool ParseDxD(const string& name, std::vector<int64>* result);
// Parses window's pad sub-attriute, e.g., pad=0_0x3x3.
bool ParseWindowPad(std::vector<std::vector<int64>>* pad);
bool ParseParamList();
bool ParseName(string* result);
@ -214,7 +260,7 @@ bool HloParser::ParseInstructionList(HloComputation::Builder* builder,
"expects '}' at the end of instruction list.");
}
// instruction ::= ('ROOT')? name '=' shape opcode operands (extra_attribute)*
// instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
bool HloParser::ParseInstruction(HloComputation::Builder* builder,
string* root_name) {
string name;
@ -230,6 +276,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (is_root) {
*root_name = name;
}
// Add optional attributes.
std::unordered_map<string, AttrConfig> attrs;
optional<OpSharding> sharding;
attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
optional<std::vector<HloInstruction*>> predecessors;
attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
&predecessors};
HloInstruction* instruction;
switch (opcode) {
case HloOpcode::kParameter: {
@ -237,7 +292,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!ParseToken(TokKind::kLparen,
"expects '(' before parameter number") ||
!ParseInt64(&parameter_number) ||
!ParseToken(TokKind::kRparen, "expects ')' after parameter number")) {
!ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
@ -249,7 +305,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!ParseToken(TokKind::kLparen,
"expects '(' before constant literal") ||
!ParseLiteral(&literal, shape) ||
!ParseToken(TokKind::kRparen, "expects ')' after constant literal")) {
!ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
@ -275,7 +332,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kSin:
case HloOpcode::kSort:
case HloOpcode::kTanh: {
if (!ParseOperands(&operands, /*expected_size=*/1)) {
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
@ -305,7 +363,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical: {
if (!ParseOperands(&operands, /*expected_size=*/2)) {
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateBinary(
@ -315,7 +374,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
// Ternary ops.
case HloOpcode::kClamp:
case HloOpcode::kSelect: {
if (!ParseOperands(&operands, /*expected_size=*/3)) {
if (!ParseOperands(&operands, /*expected_size=*/3) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateTernary(
@ -324,7 +384,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
// Other supported ops.
case HloOpcode::kConvert: {
if (!ParseOperands(&operands, /*expected_size=*/1)) {
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
@ -332,7 +393,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kCrossReplicaSum: {
if (!ParseOperands(&operands, /*expected_size=*/1)) {
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
@ -340,7 +402,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kReshape: {
if (!ParseOperands(&operands, /*expected_size=*/1)) {
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
@ -348,7 +411,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kTuple: {
if (!ParseOperands(&operands)) {
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
instruction =
@ -356,70 +419,99 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kWhile: {
HloComputation* condition;
HloComputation* body;
optional<HloComputation*> condition;
optional<HloComputation*> body;
attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
&condition};
attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseExtraAttribute(&condition,
/*expected_attribute=*/"condition") ||
!ParseExtraAttribute(&body, /*expected_attribute=*/"body")) {
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateWhile(
shape, condition, body, /*init=*/operands[0]));
shape, *condition, *body, /*init=*/operands[0]));
break;
}
case HloOpcode::kRecv: {
int64 channel_id;
optional<int64> channel_id;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
if (!ParseOperands(&operands, /*expected_size=*/0) ||
!ParseExtraAttribute(&channel_id,
/*expected_attribute=*/"channel_id")) {
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateRecv(shape, channel_id));
HloInstruction::CreateRecv(shape, *channel_id));
break;
}
case HloOpcode::kSend: {
int64 channel_id;
optional<int64> channel_id;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseExtraAttribute(&channel_id,
/*expected_attribute=*/"channel_id")) {
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateSend(operands[0], channel_id));
HloInstruction::CreateSend(operands[0], *channel_id));
break;
}
case HloOpcode::kGetTupleElement: {
int64 index;
optional<int64> index;
attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseExtraAttribute(&index, /*expected_attribute=*/"index")) {
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateGetTupleElement(shape, operands[0], index));
HloInstruction::CreateGetTupleElement(shape, operands[0], *index));
break;
}
case HloOpcode::kCall: {
HloComputation* to_apply;
if (!ParseOperands(&operands) ||
!ParseExtraAttribute(&to_apply,
/*expected_attribute=*/"to_apply")) {
optional<HloComputation*> to_apply;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateCall(shape, operands, to_apply));
HloInstruction::CreateCall(shape, operands, *to_apply));
break;
}
case HloOpcode::kReduceWindow: {
optional<HloComputation*> reduce_computation;
optional<Window> window;
attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window};
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&reduce_computation};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window,
*reduce_computation));
break;
}
case HloOpcode::kConvolution: {
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window};
attrs["dim_labels"] = {/*required=*/true,
AttrTy::kConvolutionDimensionNumbers, &dnums};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums));
break;
}
case HloOpcode::kBroadcast:
case HloOpcode::kCustomCall:
case HloOpcode::kConcatenate:
case HloOpcode::kReducePrecision:
case HloOpcode::kConvolution:
case HloOpcode::kMap:
case HloOpcode::kPad:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kReverse:
case HloOpcode::kRng:
@ -438,43 +530,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
HloOpcodeString(opcode)));
}
bool has_sharding = false;
bool has_control = false;
while (EatIfPresent(TokKind::kComma)) {
string attribute_name;
if (!ParseAttributeName(&attribute_name)) {
return TokenError("expects ', sharding=' or ', control-predecessors='");
}
if (attribute_name == "sharding") {
// Parse "sharding=".
if (has_sharding) {
return TokenError("expects at most 1 'sharding='");
// Add common attrs (sharding, control predecessors) to the instruction, if
// they were seen.
if (sharding) {
instruction->set_sharding(
HloSharding::FromProto(sharding.value()).ValueOrDie());
}
if (predecessors) {
for (auto* pre : *predecessors) {
Status status = pre->AddControlDependencyTo(instruction);
if (!status.ok()) {
return TokenError(StrCat("error adding control dependency for: ", name,
" status: ", status.ToString()));
}
has_sharding = true;
if (!ParseSharding(instruction)) {
return false;
}
} else if (attribute_name == "control-predecessors") {
// Parse "control-predecessors"
if (has_control) {
return TokenError("expects at most 1 'control-predecessors='");
}
has_control = true;
if (!ParseControlPredecessors(instruction)) {
return false;
}
} else {
return TokenError(StrCat("unexpected attribute: ", attribute_name));
}
}
return AddInstruction(name, instruction);
}
// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('['
// dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list
bool HloParser::ParseSharding(HloInstruction* instruction) {
bool HloParser::ParseSharding(OpSharding* sharding) {
if (!ParseToken(TokKind::kLbrace,
"expected '{' to start sharding attribute")) {
return false;
@ -545,7 +621,6 @@ bool HloParser::ParseSharding(HloInstruction* instruction) {
}
}
OpSharding sharding;
if (replicated) {
if (!devices.empty()) {
return TokenError(
@ -555,7 +630,7 @@ bool HloParser::ParseSharding(HloInstruction* instruction) {
return TokenError(
"replicated shardings should not have any tile shape set");
}
sharding.set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
} else if (maximal) {
if (devices.size() != 1) {
return TokenError(
@ -564,8 +639,8 @@ bool HloParser::ParseSharding(HloInstruction* instruction) {
if (!ShapeUtil::Equal(tile_shape, Shape())) {
return TokenError("maximal shardings should not have any tile shape set");
}
sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
sharding.add_tile_assignment_devices(devices[0]);
sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
sharding->add_tile_assignment_devices(devices[0]);
} else {
if (devices.size() <= 1) {
return TokenError(
@ -579,47 +654,43 @@ bool HloParser::ParseSharding(HloInstruction* instruction) {
"non-maximal shardings must have a tile assignment list including "
"dimensions");
}
sharding.set_type(OpSharding::Type::OpSharding_Type_OTHER);
*sharding.mutable_tile_shape() = tile_shape;
sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER);
*sharding->mutable_tile_shape() = tile_shape;
for (int64 dim : tile_assignment_dimensions) {
sharding.add_tile_assignment_dimensions(dim);
sharding->add_tile_assignment_dimensions(dim);
}
for (int64 device : devices) {
sharding.add_tile_assignment_devices(device);
sharding->add_tile_assignment_devices(device);
}
}
instruction->set_sharding(HloSharding::FromProto(sharding).ValueOrDie());
lexer_.Lex();
return true;
}
// '{' name+ '}'
bool HloParser::ParseControlPredecessors(HloInstruction* instruction) {
bool HloParser::ParseInstructionNames(
std::vector<HloInstruction*>* instructions) {
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of control predecessors")) {
"expects '{' at the beginning of instruction name list")) {
return false;
}
do {
string name;
if (!ParseName(&name)) {
return TokenError("expects a control predecessor");
return TokenError("expects a instruction name");
}
HloInstruction* pre =
HloInstruction* instr =
tensorflow::gtl::FindPtrOrNull(instruction_pool_, name);
if (!pre) {
if (!instr) {
return TokenError(
StrCat("control predecessor ", name, " is not defined: "));
}
Status status = pre->AddControlDependencyTo(instruction);
if (!status.ok()) {
return TokenError(StrCat("error adding control dependency for: ", name,
" status: ", status.ToString()));
Printf("instruction '%s' is not defined", name.c_str()));
}
instructions->push_back(instr);
} while (EatIfPresent(TokKind::kComma));
return ParseToken(TokKind::kRbrace,
"expects '}' at the end of control predecessors");
"expects '}' at the end of control instructions");
}
bool HloParser::SetValueInLiteral(int64 value, int64 linear_index,
@ -957,28 +1028,95 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands,
return true;
}
// extra_attribute ::= ',' attribute_name value
template <typename T>
bool HloParser::ParseExtraAttribute(T* value,
const string& expected_attribute) {
if (!ParseToken(TokKind::kComma,
"expects ',' in front of an extra attribute")) {
return false;
bool HloParser::ParseAttributes(
const std::unordered_map<string, AttrConfig>& attrs) {
std::unordered_set<string> seen_attrs;
while (EatIfPresent(TokKind::kComma)) {
string name;
if (!ParseAttributeName(&name)) {
return TokenError("error parsing attributes");
}
VLOG(1) << "Parsing attribute " << name;
if (!seen_attrs.insert(name).second) {
return TokenError(Printf("attribute %s already exists", name.c_str()));
}
auto attr_it = attrs.find(name);
if (attr_it == attrs.end()) {
return TokenError(Printf("unexpected attribute %s", name.c_str()));
}
AttrTy attr_type = attr_it->second.attr_type;
void* attr_out_ptr = attr_it->second.result;
bool success = [&] {
switch (attr_type) {
case AttrTy::kInt64: {
int64 result;
if (!ParseInt64(&result)) {
return false;
}
static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kHloComputation: {
HloComputation* result;
if (!ParseComputationName(&result)) {
return false;
}
static_cast<optional<HloComputation*>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kWindow: {
Window result;
if (!ParseWindow(&result)) {
return false;
}
static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kConvolutionDimensionNumbers: {
ConvolutionDimensionNumbers result;
if (!ParseConvolutionDimensionNumbers(&result)) {
return false;
}
static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kSharding: {
OpSharding sharding;
if (!ParseSharding(&sharding)) {
return false;
}
static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
return true;
}
case AttrTy::kInstructionList: {
std::vector<HloInstruction*> result;
if (!ParseInstructionNames(&result)) {
return false;
}
static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
->emplace(result);
return true;
}
}
}();
if (!success) {
return TokenError(Printf("error parsing attribute %s", name.c_str()));
}
}
string attribute_name;
if (!ParseAttributeName(&attribute_name) &&
attribute_name != expected_attribute) {
return TokenError(StrCat("expects attribute name: ", expected_attribute));
}
if (!ParseAttributeValue(value)) {
return TokenError(
StrCat("expects value for attribute: ", expected_attribute));
// Check that all required attrs were seen.
for (const auto& attr_it : attrs) {
if (attr_it.second.required &&
seen_attrs.find(attr_it.first) == seen_attrs.end()) {
return TokenError(Printf("attribute %s is expected but not seen",
attr_it.first.c_str()));
}
}
return true;
}
template <>
bool HloParser::ParseAttributeValue<HloComputation*>(HloComputation** value) {
bool HloParser::ParseComputationName(HloComputation** value) {
string name;
if (!ParseName(&name)) {
return TokenError("expects computation name");
@ -990,9 +1128,191 @@ bool HloParser::ParseAttributeValue<HloComputation*>(HloComputation** value) {
return true;
}
template <>
bool HloParser::ParseAttributeValue<int64>(int64* value) {
return ParseInt64(value);
// ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
// The subattributes can appear in any order. 'size=' is required, others are
// optional.
bool HloParser::ParseWindow(Window* window) {
if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
return false;
}
std::vector<int64> size;
std::vector<int64> stride;
std::vector<std::vector<int64>> pad;
std::vector<int64> lhs_dilate;
std::vector<int64> rhs_dilate;
while (lexer_.GetKind() != TokKind::kRbrace) {
string field_name;
if (!ParseAttributeName(&field_name)) {
return TokenError("expects sub-attributes in window");
}
bool ok = [&] {
if (field_name == "size") {
return ParseDxD("size", &size);
}
if (field_name == "stride") {
return ParseDxD("stride", &stride);
}
if (field_name == "lhs_dilate") {
return ParseDxD("lhs_dilate", &lhs_dilate);
}
if (field_name == "rhs_dilate") {
return ParseDxD("rls_dilate", &rhs_dilate);
}
if (field_name == "pad") {
return ParseWindowPad(&pad);
}
return TokenError(StrCat("unexpected attribute name: ", field_name));
}();
if (!ok) {
return false;
}
}
if (size.empty()) {
return TokenError(
"sub-attribute 'size=' is required in the window attribute");
}
if (!stride.empty() && stride.size() != size.size()) {
return TokenError("expects 'stride=' has the same size as 'size='");
}
if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) {
return TokenError("expects 'lhs_dilate=' has the same size as 'size='");
}
if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) {
return TokenError("expects 'rhs_dilate=' has the same size as 'size='");
}
if (!pad.empty() && pad.size() != size.size()) {
return TokenError("expects 'pad=' has the same size as 'size='");
}
for (int i = 0; i < size.size(); i++) {
window->add_dimensions()->set_size(size[i]);
if (!pad.empty()) {
window->mutable_dimensions(i)->set_padding_low(pad[i][0]);
window->mutable_dimensions(i)->set_padding_high(pad[i][1]);
}
// If some field is not present, it has the default value.
window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]);
window->mutable_dimensions(i)->set_base_dilation(
lhs_dilate.empty() ? 1 : lhs_dilate[i]);
window->mutable_dimensions(i)->set_window_dilation(
rhs_dilate.empty() ? 1 : rhs_dilate[i]);
}
return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
}
// This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
// The string looks like "dim_labels=0bf_0io->0bf".
bool HloParser::ParseConvolutionDimensionNumbers(
ConvolutionDimensionNumbers* dnums) {
if (lexer_.GetKind() != TokKind::kDimLabels) {
return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'");
}
string str = lexer_.GetStrVal();
// The str is expected to have 3 items, lhs, rhs, out, and it must looks like
// lhs_rhs->out, that is, the first separator is "_" and the second is "->".
// So we replace the "->" with "_" and then split on "_".
str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->",
/*newsub=*/"_",
/*replace_all=*/false);
std::vector<string> lhs_rhs_out = Split(str, "_");
if (lhs_rhs_out.size() != 3) {
LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
<< str;
}
const int64 rank = lhs_rhs_out[0].length();
if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) {
return TokenError(
"convolution lhs, rhs, and output must have the same rank");
}
if (rank < 3) {
return TokenError("convolution rank must >=3");
}
auto is_unique = [](string str) -> bool {
std::sort(str.begin(), str.end());
return std::unique(str.begin(), str.end()) == str.end();
};
// lhs
{
const string& lhs = lhs_rhs_out[0];
if (!is_unique(lhs)) {
return TokenError(
StrCat("expects unique lhs dimension numbers, but sees ", lhs));
}
for (int i = 0; i < rank - 2; i++) {
dnums->add_spatial_dimensions(-1);
}
for (int i = 0; i < rank; i++) {
char c = lhs[i];
if (c == 'b') {
dnums->set_input_batch_dimension(i);
} else if (c == 'f') {
dnums->set_input_feature_dimension(i);
} else if (c < '0' + rank && c >= '0') {
dnums->set_spatial_dimensions(c - '0', i);
} else {
return TokenError(
Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1));
}
}
}
// rhs
{
const string& rhs = lhs_rhs_out[1];
if (!is_unique(rhs)) {
return TokenError(
StrCat("expects unique rhs dimension numbers, but sees ", rhs));
}
for (int i = 0; i < rank - 2; i++) {
dnums->add_kernel_spatial_dimensions(-1);
}
for (int i = 0; i < rank; i++) {
char c = rhs[i];
if (c == 'i') {
dnums->set_kernel_input_feature_dimension(i);
} else if (c == 'o') {
dnums->set_kernel_output_feature_dimension(i);
} else if (c < '0' + rank && c >= '0') {
dnums->set_kernel_spatial_dimensions(c - '0', i);
} else {
return TokenError(
Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1));
}
}
}
// output
{
const string& out = lhs_rhs_out[2];
if (!is_unique(out)) {
return TokenError(
StrCat("expects unique output dimension numbers, but sees ", out));
}
for (int i = 0; i < rank; i++) {
char c = out[i];
if (c == 'b') {
dnums->set_output_batch_dimension(i);
} else if (c == 'f') {
dnums->set_output_feature_dimension(i);
} else if (c < '0' + rank && c >= '0') {
if (dnums->spatial_dimensions(c - '0') != i) {
return TokenError(
"output spatial dimensions should be the same as input spatial "
"dimensions");
}
} else {
return TokenError(
Printf("expects [0-%lldbf] in output dimension numbers", rank - 1));
}
}
}
lexer_.Lex();
return true;
}
// param_list ::= '(' param_list1 ')'
@ -1070,6 +1390,55 @@ bool HloParser::ParseAttributeName(string* result) {
return true;
}
bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) {
if (!result->empty()) {
return TokenError(
Printf("sub-attribute '%s=' already exists", name.c_str()));
}
// 1D
if (lexer_.GetKind() == TokKind::kInt) {
int64 number;
if (!ParseInt64(&number)) {
return TokenError(Printf("expects sub-attribute '%s=i'", name.c_str()));
}
result->push_back(number);
return true;
}
// 2D or higher.
if (lexer_.GetKind() == TokKind::kDxD) {
string str = lexer_.GetStrVal();
if (!SplitAndParseAsInts(str, 'x', result)) {
return TokenError(
Printf("expects sub-attribute '%s=ixj...'", name.c_str()));
}
lexer_.Lex();
return true;
}
return TokenError("expects token type kInt or kDxD");
}
bool HloParser::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
if (!pad->empty()) {
return TokenError("sub-attribute 'pad=' already exists");
}
if (lexer_.GetKind() != TokKind::kWindowPad) {
return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
}
string str = lexer_.GetStrVal();
std::vector<string> padding_str = Split(str, 'x');
for (int i = 0; i < padding_str.size(); i++) {
std::vector<int64> low_high;
if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) ||
low_high.size() != 2) {
return TokenError(
"expects padding_low and padding_high separated by '_'");
}
pad->push_back(low_high);
}
lexer_.Lex();
return true;
}
bool HloParser::ParseOpcode(HloOpcode* result) {
VLOG(1) << "ParseOpcode";
if (lexer_.GetKind() != TokKind::kOpcode) {

View File

@ -25,6 +25,7 @@ namespace tools {
namespace {
using tensorflow::StringPiece;
using tensorflow::strings::StrCat;
struct TestData {
string test_name;
@ -247,6 +248,39 @@ ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] {
ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1
}
)"
},
// reduce window
{
"ReduceWindow",
R"(HloModule R4UnitWindow_module:
%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
}
ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] {
%operand = f32[13,12,8,15]{0,3,2,1} parameter(0)
%constant = f32[] constant(0)
ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3
}
)"
},
// convolution
{
"Convolution",
R"(HloModule Convolve1D1Window_0_module:
ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f
}
)"
}
});
@ -427,6 +461,92 @@ ENTRY %ConstantWithExp.v4 () -> f32[] {
// printed as "300".
}
TEST_F(HloParserTest, AttibutesAnyOrder) {
const string original = R"(HloModule any_order_module:
ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
}
)";
TF_EXPECT_OK(Parse(original).status());
}
TEST_F(HloParserTest, InvalidDimLabels) {
string prefix = R"(HloModule invalid_dim_labels_module:
ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
%input = f32[1,2,1]{2,1,0} parameter(0)
%copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
%filter = f32[1,1,1]{2,1,0} parameter(1)
ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1} )";
string suffix = R"(
}
)";
ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=00_01_10", suffix))
.status()
.error_message(),
"expects dim labels pattern");
ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=010_1100->010", suffix))
.status()
.error_message(),
"must have the same rank");
ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=0bf_io0->b0f", suffix))
.status()
.error_message(),
"output spatial dimensions should be the same as input "
"spatial dimensions");
}
TEST_F(HloParserTest, UnexpectedAttribute) {
const string original = R"(HloModule unexpected_attr_module:
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
%recv = f32[] recv(), channel_id=15
ROOT %constant = f32[] constant(2.1)
%send = () send(f32[] %constant), channel_id=16, calls=%recv
}
)";
ExpectHasSubstr(Parse(original).status().error_message(),
"unexpected attribute calls");
}
TEST_F(HloParserTest, MissingAttribute) {
const string original = R"(HloModule missing_attr_module:
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
%recv = f32[] recv(), channel_id=15
ROOT %constant = f32[] constant(-2.1)
%send = () send(f32[] %constant)
}
)";
ExpectHasSubstr(Parse(original).status().error_message(),
"attribute channel_id is expected but not seen");
}
TEST_F(HloParserTest, PredecessorUndefined) {
const string original = R"(HloModule pre_not_found_module:
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
%recv = f32[] recv(), channel_id=15
ROOT %constant = f32[] constant(2.1)
%send = () send(f32[] %constant), channel_id=16, control-predecessors={%done}
}
)";
ExpectHasSubstr(Parse(original).status().error_message(),
"'done' is not defined");
}
} // namespace
} // namespace tools
} // namespace xla

View File

@ -57,6 +57,9 @@ enum class TokKind {
// Typed tokens.
kName, // %foo
kAttributeName, // dimensions=
kDimLabels, // [0-9bf]+_[0-9io]+->[0-9bf]+
kDxD, // [0-9]+(x[0-9]+)+
kWindowPad, // [0-9]+_[0-9]+(x[0-9]+_[0-9]+)*
kShape, // f32[2,3]{1,0}
kOpcode, // add
kInt, // 42

View File

@ -26,8 +26,8 @@ namespace xla {
namespace window_util {
/* static */ string ToString(const WindowDimension& dim) {
using tensorflow::strings::StrCat;
using tensorflow::strings::StrAppend;
using tensorflow::strings::StrCat;
string str = StrCat("(size=", dim.size());
if (dim.stride() != 1) {
StrAppend(&str, ",stride=", dim.stride());
@ -49,22 +49,22 @@ namespace window_util {
}
string ToString(const Window& window) {
using tensorflow::strings::StrCat;
using tensorflow::strings::StrAppend;
using tensorflow::strings::StrCat;
string str;
const auto add_field = [&](
const char* heading,
std::function<string(const WindowDimension&)> format) {
StrAppend(&str, heading, "=");
const char* prefix = "";
for (const auto& window_dimension : window.dimensions()) {
StrAppend(&str, prefix, format(window_dimension));
prefix = "x";
}
};
const auto add_field =
[&](const char* heading,
std::function<string(const WindowDimension&)> format) {
StrAppend(&str, heading, "=");
const char* prefix = "";
for (const auto& window_dimension : window.dimensions()) {
StrAppend(&str, prefix, format(window_dimension));
prefix = "x";
}
};
add_field("window",
add_field("size",
[](const WindowDimension& dim) { return StrCat(dim.size()); });
if (HasStride(window)) {
add_field(" stride",