STT-tensorflow/tensorflow/compiler/xla/service/hlo_parser.cc
TensorFlower Gardener 47612c19ea Merge pull request from AnthonyBarbier:frontend_attributes
PiperOrigin-RevId: 265793937
2019-08-27 16:48:01 -07:00

4367 lines
154 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include <type_traits>
#include "absl/algorithm/container.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "absl/types/variant.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_lexer.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace xla {
namespace {
using absl::nullopt;
using absl::optional;
using absl::StrAppend;
using absl::StrCat;
using absl::StrFormat;
using absl::StrJoin;
// Creates and returns a schedule created using the order of the instructions in
// the HloComputation::instructions() vectors in the module.
HloSchedule ScheduleFromInstructionOrder(HloModule* module) {
HloSchedule schedule(module);
for (HloComputation* computation : module->computations()) {
if (!computation->IsFusionComputation()) {
for (HloInstruction* instruction : computation->instructions()) {
schedule.GetOrCreateSequence(computation).push_back(instruction);
}
}
}
return schedule;
}
// Some functions accept either a linear index or a multi-dimensional index
// (used for indexing into sparse literals).
using LinearOrMultiIndex = absl::variant<int64, absl::Span<const int64>>;
// Parser for the HloModule::ToString() format text.
class HloParser {
public:
using LocTy = HloLexer::LocTy;
explicit HloParser(absl::string_view str) : lexer_(str) {}
// Runs the parser and constructs the resulting HLO in the given (empty)
// HloModule. Returns false if an error occurred.
Status Run(HloModule* module);
// Returns the error information.
string GetError() const { return StrJoin(error_, "\n"); }
// Stand alone parsing utils for various aggregate data types.
StatusOr<Shape> ParseShapeOnly();
StatusOr<HloSharding> ParseShardingOnly();
StatusOr<FrontendAttributes> ParseFrontendAttributesOnly();
StatusOr<std::vector<bool>> ParseParameterReplicationOnly();
StatusOr<Window> ParseWindowOnly();
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
StatusOr<PaddingConfig> ParsePaddingConfigOnly();
StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly();
private:
using InstrNameTable =
std::unordered_map<string, std::pair<HloInstruction*, LocTy>>;
// Returns the map from the instruction name to the instruction itself and its
// location in the current scope.
InstrNameTable& current_name_table() { return scoped_name_tables_.back(); }
// Locates an instruction with the given name in the current_name_table() or
// returns nullptr.
//
// When the name is not found or name is empty, if create_missing_instruction_
// hook is registered and a "shape" is provided, the hook will be called to
// create an instruction. This is useful when we reify parameters as they're
// resolved; i.e. for ParseSingleInstruction.
std::pair<HloInstruction*, LocTy>* FindInstruction(
const string& name, const optional<Shape>& shape = nullopt);
// Parse a single instruction worth of text.
bool ParseSingleInstruction(HloModule* module);
// Parses a module, returning false if an error occurred.
bool ParseHloModule(HloModule* module);
bool ParseComputations(HloModule* module);
bool ParseComputation(HloComputation** entry_computation);
bool ParseInstructionList(HloComputation** computation,
const string& computation_name);
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
bool ParseInstructionRhs(HloComputation::Builder* builder, const string& name,
LocTy name_loc);
bool ParseControlPredecessors(HloInstruction* instruction);
bool ParseLiteral(Literal* literal, const Shape& shape);
bool ParseTupleLiteral(Literal* literal, const Shape& shape);
bool ParseNonTupleLiteral(Literal* literal, const Shape& shape);
bool ParseDenseLiteral(Literal* literal, const Shape& shape);
bool ParseSparseLiteral(Literal* literal, const Shape& shape);
// Sets the sub-value of literal at the given linear or sparse index to the
// given value. If the literal is dense, it myst have the default layout.
//
// `loc` should be the source location of the value.
bool SetValueInLiteral(LocTy loc, int64 value, LinearOrMultiIndex index,
Literal* literal);
bool SetValueInLiteral(LocTy loc, double value, LinearOrMultiIndex index,
Literal* literal);
bool SetValueInLiteral(LocTy loc, bool value, LinearOrMultiIndex index,
Literal* literal);
bool SetValueInLiteral(LocTy loc, std::complex<double> value,
LinearOrMultiIndex index, Literal* literal);
// `loc` should be the source location of the value.
template <typename LiteralNativeT, typename ParsedElemT>
bool SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
LinearOrMultiIndex index, Literal* literal);
// Checks whether the given value is within the range of LiteralNativeT.
// `loc` should be the source location of the value.
template <typename LiteralNativeT, typename ParsedElemT>
bool CheckParsedValueIsInRange(LocTy loc, ParsedElemT value);
template <typename LiteralNativeT>
bool CheckParsedValueIsInRange(LocTy loc, std::complex<double> value);
bool ParseOperands(std::vector<HloInstruction*>* operands);
// Fills parsed operands into 'operands' and expects a certain number of
// operands.
bool ParseOperands(std::vector<HloInstruction*>* operands,
const int expected_size);
// Describes the start, limit, and stride on every dimension of the operand
// being sliced.
struct SliceRanges {
std::vector<int64> starts;
std::vector<int64> limits;
std::vector<int64> strides;
};
// The data parsed for the kDomain instruction.
struct DomainData {
std::unique_ptr<DomainMetadata> entry_metadata;
std::unique_ptr<DomainMetadata> exit_metadata;
};
// Types of attributes.
enum class AttrTy {
kBool,
kInt64,
kInt32,
kFloat,
kString,
kBracedInt64List,
kBracedInt64ListList,
kHloComputation,
kBracedHloComputationList,
kFftType,
kComparisonDirection,
kWindow,
kConvolutionDimensionNumbers,
kSharding,
kFrontendAttributes,
kParameterReplication,
kInstructionList,
kSliceRanges,
kPaddingConfig,
kMetadata,
kFusionKind,
kDistribution,
kDomain,
kPrecisionList,
kShapeList
};
struct AttrConfig {
bool required; // whether it's required or optional
AttrTy attr_type; // what type it is
void* result; // where to store the parsed result.
};
// attributes ::= (',' attribute)*
//
// Parses attributes given names and configs of the attributes. Each parsed
// result is passed back through the result pointer in corresponding
// AttrConfig. Note that the result pointer must point to a optional<T> typed
// variable which outlives this function. Returns false on error. You should
// not use the any of the results if this function failed.
//
// Example usage:
//
// std::unordered_map<string, AttrConfig> attrs;
// optional<int64> foo;
// attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo};
// optional<Window> bar;
// attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar};
// if (!ParseAttributes(attrs)) {
// return false; // Do not use 'foo' 'bar' if failed.
// }
// // Do something with 'bar'.
// if (foo) { // If attr foo is seen, do something with 'foo'. }
//
bool ParseAttributes(const std::unordered_map<string, AttrConfig>& attrs);
// sub_attributes ::= '{' (','? attribute)* '}'
//
// Usage is the same as ParseAttributes. See immediately above.
bool ParseSubAttributes(const std::unordered_map<string, AttrConfig>& attrs);
// Parses one attribute. If it has already been seen, return error. Returns
// true and adds to seen_attrs on success.
//
// Do not call this except in ParseAttributes or ParseSubAttributes.
bool ParseAttributeHelper(const std::unordered_map<string, AttrConfig>& attrs,
std::unordered_set<string>* seen_attrs);
// Parses an attribute string into a protocol buffer `message`.
// Since proto3 has no notion of mandatory fields, `required_attrs` gives the
// set of mandatory attributes.
bool ParseAttributesAsProtoMessage(
const std::unordered_set<string>& required_attrs,
tensorflow::protobuf::Message* message);
// Parses one attribute. If it has already been seen, return error. Returns
// true and adds to seen_attrs on success.
//
// Do not call this except in ParseAttributesAsProtoMessage.
bool ParseAttributeAsProtoMessageHelper(
tensorflow::protobuf::Message* message,
std::unordered_set<string>* seen_attrs);
// Parses a name and finds the corresponding hlo computation.
bool ParseComputationName(HloComputation** value);
// Parses a list of names and finds the corresponding hlo instructions.
bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
// Pass expect_outer_curlies == true when parsing a Window in the context of a
// larger computation. Pass false when parsing a stand-alone Window string.
bool ParseWindow(Window* window, bool expect_outer_curlies);
bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
bool ParsePaddingConfig(PaddingConfig* padding);
bool ParseMetadata(OpMetadata* metadata);
bool ParseSharding(OpSharding* sharding);
bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes);
bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
bool ParseParameterReplication(ParameterReplication* parameter_replication);
bool ParseReplicaGroupsOnly(std::vector<ReplicaGroup>* replica_groups);
// Parses the metadata behind a kDOmain instruction.
bool ParseDomain(DomainData* domain);
// Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
bool ParseDxD(const string& name, std::vector<int64>* result);
// Parses window's pad sub-attriute, e.g., pad=0_0x3x3.
bool ParseWindowPad(std::vector<std::vector<int64>>* pad);
bool ParseSliceRanges(SliceRanges* result);
bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result);
bool ParseHloComputation(HloComputation** result);
bool ParseHloComputationList(std::vector<HloComputation*>* result);
bool ParseShapeList(std::vector<Shape>* result);
bool ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim, std::vector<int64>* result);
bool ParseInt64ListList(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<std::vector<int64>>* result);
// 'parse_and_add_item' is an lambda to parse an element in the list and add
// the parsed element to the result. It's supposed to capture the result.
bool ParseList(const TokKind start, const TokKind end, const TokKind delim,
const std::function<bool()>& parse_and_add_item);
bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
bool ParseParamList();
bool ParseName(string* result);
bool ParseAttributeName(string* result);
bool ParseString(string* result);
bool ParseDimensionSizes(std::vector<int64>* dimension_sizes,
std::vector<bool>* dynamic_dimensions);
bool ParseShape(Shape* result);
bool ParseLayout(Layout* layout);
bool ParseLayoutIntAttribute(int64* attr_value,
absl::string_view attr_description);
bool ParseTiles(std::vector<Tile>* tiles);
bool ParseOpcode(HloOpcode* result);
bool ParseFftType(FftType* result);
bool ParseComparisonDirection(ComparisonDirection* result);
bool ParseFusionKind(HloInstruction::FusionKind* result);
bool ParseRandomDistribution(RandomDistribution* result);
bool ParsePrecision(PrecisionConfig::Precision* result);
bool ParseInt64(int64* result);
bool ParseDouble(double* result);
bool ParseComplex(std::complex<double>* result);
bool ParseBool(bool* result);
bool ParseToken(TokKind kind, const string& msg);
// Returns true if the current token is the beginning of a shape.
bool CanBeShape();
// Returns true if the current token is the beginning of a
// param_list_to_shape.
bool CanBeParamListToShape();
// Logs the current parsing line and the given message. Always returns false.
bool TokenError(absl::string_view msg);
bool Error(LocTy loc, absl::string_view msg);
// If the current token is 'kind', eats it (i.e. lexes the next token) and
// returns true.
bool EatIfPresent(TokKind kind);
// Adds the instruction to the pool. Returns false and emits an error if the
// instruction already exists.
bool AddInstruction(const string& name, HloInstruction* instruction,
LocTy name_loc);
// Adds the computation to the pool. Returns false and emits an error if the
// computation already exists.
bool AddComputation(const string& name, HloComputation* computation,
LocTy name_loc);
HloLexer lexer_;
// A stack for the instruction names. The top of the stack stores the
// instruction name table for the current scope.
//
// A instruction's name is unique among its scope (i.e. its parent
// computation), but it's not necessarily unique among all computations in the
// module. When there are multiple levels of nested computations, the same
// name could appear in both an outer computation and an inner computation. So
// we need a stack to make sure a name is only visible within its scope,
std::vector<InstrNameTable> scoped_name_tables_;
// A helper class which pushes and pops to an InstrNameTable stack via RAII.
class Scope {
public:
explicit Scope(std::vector<InstrNameTable>* scoped_name_tables)
: scoped_name_tables_(scoped_name_tables) {
scoped_name_tables_->emplace_back();
}
~Scope() { scoped_name_tables_->pop_back(); }
private:
std::vector<InstrNameTable>* scoped_name_tables_;
};
// Map from the computation name to the computation itself and its location.
std::unordered_map<string, std::pair<HloComputation*, LocTy>>
computation_pool_;
std::vector<std::unique_ptr<HloComputation>> computations_;
std::vector<string> error_;
// When an operand name cannot be resolved, this function is called to create
// a parameter instruction with the given name and shape. It registers the
// name, instruction, and a placeholder location in the name table. It returns
// the newly-created instruction and the placeholder location. If `name` is
// empty, this should create the parameter with a generated name. This is
// supposed to be set and used only in ParseSingleInstruction.
std::function<std::pair<HloInstruction*, LocTy>*(const string& name,
const Shape& shape)>
create_missing_instruction_;
};
bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) {
for (const auto& split : absl::StrSplit(s, delim)) {
int64 val;
if (!absl::SimpleAtoi(split, &val)) {
return false;
}
out->push_back(val);
}
return true;
}
// Creates replica groups from the provided nested array. groups[i] represents
// the replica ids for group 'i'.
std::vector<ReplicaGroup> CreateReplicaGroups(
absl::Span<const std::vector<int64>> groups) {
std::vector<ReplicaGroup> replica_groups;
absl::c_transform(groups, std::back_inserter(replica_groups),
[](const std::vector<int64>& ids) {
ReplicaGroup group;
*group.mutable_replica_ids() = {ids.begin(), ids.end()};
return group;
});
return replica_groups;
}
bool HloParser::Error(LocTy loc, absl::string_view msg) {
auto line_col = lexer_.GetLineAndColumn(loc);
const unsigned line = line_col.first;
const unsigned col = line_col.second;
std::vector<string> error_lines;
error_lines.push_back(
StrCat("was parsing ", line, ":", col, ": error: ", msg));
error_lines.emplace_back(lexer_.GetLine(loc));
error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^"));
error_.push_back(StrJoin(error_lines, "\n"));
VLOG(1) << "Error: " << error_.back();
return false;
}
bool HloParser::TokenError(absl::string_view msg) {
return Error(lexer_.GetLoc(), msg);
}
Status HloParser::Run(HloModule* module) {
lexer_.Lex();
if (lexer_.GetKind() == TokKind::kw_HloModule) {
// This means that the text contains a full HLO module.
if (!ParseHloModule(module)) {
return InvalidArgument(
"Syntax error when trying to parse the text as a HloModule:\n%s",
GetError());
}
return Status::OK();
}
// This means that the text is a single HLO instruction.
if (!ParseSingleInstruction(module)) {
return InvalidArgument(
"Syntax error when trying to parse the text as a single "
"HloInstruction:\n%s",
GetError());
}
return Status::OK();
}
std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
const string& name, const optional<Shape>& shape) {
std::pair<HloInstruction*, LocTy>* instr = nullptr;
if (!name.empty()) {
instr = tensorflow::gtl::FindOrNull(current_name_table(), name);
}
// Potentially call the missing instruction hook.
if (instr == nullptr && create_missing_instruction_ != nullptr &&
scoped_name_tables_.size() == 1) {
if (!shape.has_value()) {
Error(lexer_.GetLoc(),
"Operand had no shape in HLO text; cannot create parameter for "
"single-instruction module.");
return nullptr;
}
return create_missing_instruction_(name, *shape);
}
if (instr != nullptr && shape.has_value() &&
!ShapeUtil::Compatible(instr->first->shape(), shape.value())) {
Error(
lexer_.GetLoc(),
StrCat("The declared operand shape ",
ShapeUtil::HumanStringWithLayout(shape.value()),
" is not compatible with the shape of the operand instruction ",
ShapeUtil::HumanStringWithLayout(instr->first->shape()), "."));
return nullptr;
}
return instr;
}
// ::= 'HloModule' name computations
bool HloParser::ParseHloModule(HloModule* module) {
if (lexer_.GetKind() != TokKind::kw_HloModule) {
return TokenError("expects HloModule");
}
// Eat 'HloModule'
lexer_.Lex();
string name;
if (!ParseName(&name)) {
return false;
}
absl::optional<bool> is_scheduled;
std::unordered_map<string, AttrConfig> attrs;
attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled};
if (!ParseAttributes(attrs)) {
return false;
}
module->set_name(name);
if (!ParseComputations(module)) {
return false;
}
if (is_scheduled.has_value() && *is_scheduled) {
TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
}
return true;
}
// computations ::= (computation)+
bool HloParser::ParseComputations(HloModule* module) {
HloComputation* entry_computation = nullptr;
do {
if (!ParseComputation(&entry_computation)) {
return false;
}
} while (lexer_.GetKind() != TokKind::kEof);
for (int i = 0; i < computations_.size(); i++) {
// If entry_computation is not nullptr, it means the computation it pointed
// to is marked with "ENTRY"; otherwise, no computation is marked with
// "ENTRY", and we use the last computation as the entry computation. We
// add the non-entry computations as embedded computations to the module.
if ((entry_computation != nullptr &&
computations_[i].get() != entry_computation) ||
(entry_computation == nullptr && i != computations_.size() - 1)) {
module->AddEmbeddedComputation(std::move(computations_[i]));
continue;
}
auto computation = module->AddEntryComputation(std::move(computations_[i]));
// The parameters and result layouts were set to default layout. Here we
// set the layouts to what the hlo text says.
for (int p = 0; p < computation->num_parameters(); p++) {
const Shape& param_shape = computation->parameter_instruction(p)->shape();
TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_parameter_layout(p)
->CopyLayoutFromShape(param_shape));
}
const Shape& result_shape = computation->root_instruction()->shape();
TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_result_layout()
->CopyLayoutFromShape(result_shape));
}
return true;
}
// computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list
bool HloParser::ParseComputation(HloComputation** entry_computation) {
LocTy maybe_entry_loc = lexer_.GetLoc();
const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY);
string name;
LocTy name_loc = lexer_.GetLoc();
if (!ParseName(&name)) {
return false;
}
LocTy shape_loc = nullptr;
Shape shape;
if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) {
return false;
}
HloComputation* computation = nullptr;
if (!ParseInstructionList(&computation, name)) {
return false;
}
// If param_list_to_shape was present, check compatibility.
if (shape_loc != nullptr &&
!ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) {
return Error(
shape_loc,
StrCat(
"Shape of computation ", name, ", ", ShapeUtil::HumanString(shape),
", is not compatible with that of its root instruction ",
computation->root_instruction()->name(), ", ",
ShapeUtil::HumanString(computation->root_instruction()->shape())));
}
if (is_entry_computation) {
if (*entry_computation != nullptr) {
return Error(maybe_entry_loc, "expects only one ENTRY");
}
*entry_computation = computation;
}
return AddComputation(name, computation, name_loc);
}
// instruction_list ::= '{' instruction_list1 '}'
// instruction_list1 ::= (instruction)+
bool HloParser::ParseInstructionList(HloComputation** computation,
const string& computation_name) {
Scope scope(&scoped_name_tables_);
HloComputation::Builder builder(computation_name);
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of instruction list.")) {
return false;
}
string root_name;
do {
if (!ParseInstruction(&builder, &root_name)) {
return false;
}
} while (lexer_.GetKind() != TokKind::kRbrace);
if (!ParseToken(TokKind::kRbrace,
"expects '}' at the end of instruction list.")) {
return false;
}
HloInstruction* root = nullptr;
if (!root_name.empty()) {
std::pair<HloInstruction*, LocTy>* root_node =
tensorflow::gtl::FindOrNull(current_name_table(), root_name);
// This means some instruction was marked as ROOT but we didn't find it in
// the pool, which should not happen.
if (root_node == nullptr) {
LOG(FATAL) << "instruction " << root_name
<< " was marked as ROOT but the parser has not seen it before";
}
root = root_node->first;
}
// Now root can be either an existing instruction or a nullptr. If it's a
// nullptr, the implementation of Builder will set the last instruction as
// the root instruction.
computations_.emplace_back(builder.Build(root));
*computation = computations_.back().get();
return true;
}
// instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
bool HloParser::ParseInstruction(HloComputation::Builder* builder,
string* root_name) {
string name;
LocTy maybe_root_loc = lexer_.GetLoc();
bool is_root = EatIfPresent(TokKind::kw_ROOT);
const LocTy name_loc = lexer_.GetLoc();
if (!ParseName(&name) ||
!ParseToken(TokKind::kEqual, "expects '=' in instruction")) {
return false;
}
if (is_root) {
if (!root_name->empty()) {
return Error(maybe_root_loc, "one computation should have only one ROOT");
}
*root_name = name;
}
return ParseInstructionRhs(builder, name, name_loc);
}
bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
const string& name, LocTy name_loc) {
Shape shape;
HloOpcode opcode;
std::vector<HloInstruction*> operands;
if (!ParseShape(&shape) || !ParseOpcode(&opcode)) {
return false;
}
// Add optional attributes.
std::unordered_map<string, AttrConfig> attrs;
optional<OpSharding> sharding;
optional<FrontendAttributes> frontend_attributes;
attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
attrs["frontend_attributes"] = {
/*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes};
optional<ParameterReplication> parameter_replication;
attrs["parameter_replication"] = {/*required=*/false,
AttrTy::kParameterReplication,
&parameter_replication};
optional<std::vector<HloInstruction*>> predecessors;
attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
&predecessors};
optional<OpMetadata> metadata;
attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};
optional<string> backend_config;
attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
&backend_config};
optional<std::vector<int64>> outer_dimension_partitions;
attrs["outer_dimension_partitions"] = {/*required=*/false,
AttrTy::kBracedInt64List,
&outer_dimension_partitions};
HloInstruction* instruction;
switch (opcode) {
case HloOpcode::kParameter: {
int64 parameter_number;
if (!ParseToken(TokKind::kLparen,
"expects '(' before parameter number") ||
!ParseInt64(&parameter_number)) {
return false;
}
if (parameter_number < 0) {
Error(lexer_.GetLoc(), "parameter number must be >= 0");
return false;
}
if (!ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateParameter(parameter_number, shape, name));
break;
}
case HloOpcode::kConstant: {
Literal literal;
if (!ParseToken(TokKind::kLparen,
"expects '(' before constant literal") ||
!ParseLiteral(&literal, shape) ||
!ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
break;
}
case HloOpcode::kIota: {
optional<int64> iota_dimension;
attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64,
&iota_dimension};
if (!ParseOperands(&operands, /*expected_size=*/0) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateIota(shape, *iota_dimension));
break;
}
// Unary ops.
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kBitcast:
case HloOpcode::kCeil:
case HloOpcode::kClz:
case HloOpcode::kCopy:
case HloOpcode::kCopyStart:
case HloOpcode::kCopyDone:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kImag:
case HloOpcode::kIsFinite:
case HloOpcode::kFloor:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kRsqrt:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
case HloOpcode::kTanh: {
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateUnary(shape, opcode, operands[0]));
break;
}
// Binary ops.
case HloOpcode::kAdd:
case HloOpcode::kDivide:
case HloOpcode::kMultiply:
case HloOpcode::kSubtract:
case HloOpcode::kAtan2:
case HloOpcode::kComplex:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical: {
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateBinary(
shape, opcode, operands[0], operands[1]));
break;
}
// Ternary ops.
case HloOpcode::kClamp:
case HloOpcode::kSelect:
case HloOpcode::kTupleSelect: {
if (!ParseOperands(&operands, /*expected_size=*/3) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateTernary(
shape, opcode, operands[0], operands[1], operands[2]));
break;
}
// Other supported ops.
case HloOpcode::kConvert: {
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateConvert(shape, operands[0]));
break;
}
case HloOpcode::kBitcastConvert: {
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateBitcastConvert(shape, operands[0]));
break;
}
case HloOpcode::kAllReduce: {
optional<std::vector<std::vector<int64>>> tmp_groups;
optional<HloComputation*> to_apply;
optional<std::vector<int64>> replica_group_ids;
optional<int64> channel_id;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
attrs["replica_groups"] = {/*required=*/false,
AttrTy::kBracedInt64ListList, &tmp_groups};
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
std::vector<ReplicaGroup> replica_groups;
if (tmp_groups) {
replica_groups = CreateReplicaGroups(*tmp_groups);
}
instruction = builder->AddInstruction(HloInstruction::CreateAllReduce(
shape, operands, *to_apply, replica_groups, channel_id));
break;
}
case HloOpcode::kAllToAll: {
optional<std::vector<std::vector<int64>>> tmp_groups;
attrs["replica_groups"] = {/*required=*/false,
AttrTy::kBracedInt64ListList, &tmp_groups};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
std::vector<ReplicaGroup> replica_groups;
if (tmp_groups) {
replica_groups = CreateReplicaGroups(*tmp_groups);
}
instruction = builder->AddInstruction(
HloInstruction::CreateAllToAll(shape, operands, replica_groups));
break;
}
case HloOpcode::kCollectivePermute: {
optional<std::vector<std::vector<int64>>> source_targets;
attrs["source_target_pairs"] = {
/*required=*/true, AttrTy::kBracedInt64ListList, &source_targets};
optional<int64> channel_id;
attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
std::vector<std::pair<int64, int64>> pairs(source_targets->size());
for (int i = 0; i < pairs.size(); i++) {
if ((*source_targets)[i].size() != 2) {
return TokenError(
"expects 'source_target_pairs=' to be a list of pairs");
}
pairs[i].first = (*source_targets)[i][0];
pairs[i].second = (*source_targets)[i][1];
}
instruction =
builder->AddInstruction(HloInstruction::CreateCollectivePermute(
shape, operands[0], pairs, channel_id));
break;
}
case HloOpcode::kReplicaId: {
if (!ParseOperands(&operands, /*expected_size=*/0) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateReplicaId());
break;
}
case HloOpcode::kPartitionId: {
if (!ParseOperands(&operands, /*expected_size=*/0) ||
!ParseAttributes(attrs)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreatePartitionId());
break;
}
case HloOpcode::kReshape: {
optional<int64> inferred_dimension;
attrs["inferred_dimension"] = {/*required=*/false, AttrTy::kInt64,
&inferred_dimension};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateReshape(
shape, operands[0], inferred_dimension.value_or(-1)));
break;
}
case HloOpcode::kAfterAll: {
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
if (operands.empty()) {
instruction = builder->AddInstruction(HloInstruction::CreateToken());
} else {
instruction =
builder->AddInstruction(HloInstruction::CreateAfterAll(operands));
}
break;
}
case HloOpcode::kAddDependency: {
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateAddDependency(operands[0], operands[1]));
break;
}
case HloOpcode::kSort: {
optional<std::vector<int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
optional<bool> is_stable = false;
attrs["is_stable"] = {/*required=*/false, AttrTy::kBool, &is_stable};
optional<HloComputation*> to_apply;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
dimensions->size() != 1) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateSort(shape, dimensions->at(0), operands,
to_apply.value(), is_stable.value()));
break;
}
case HloOpcode::kTuple: {
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateTuple(operands));
break;
}
case HloOpcode::kWhile: {
optional<HloComputation*> condition;
optional<HloComputation*> body;
attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
&condition};
attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateWhile(
shape, *condition, *body, /*init=*/operands[0]));
break;
}
case HloOpcode::kRecv: {
optional<int64> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
// If the is_host_transfer attribute is not present then default to false.
instruction = builder->AddInstruction(HloInstruction::CreateRecv(
shape.tuple_shapes(0), operands[0], *channel_id, *is_host_transfer));
break;
}
case HloOpcode::kRecvDone: {
optional<int64> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
if (channel_id != operands[0]->channel_id()) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateRecvDone(operands[0], *is_host_transfer));
break;
}
case HloOpcode::kSend: {
optional<int64> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateSend(
operands[0], operands[1], *channel_id, *is_host_transfer));
break;
}
case HloOpcode::kSendDone: {
optional<int64> channel_id;
// If the is_host_transfer attribute is not present then default to false.
optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
&is_host_transfer};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
if (channel_id != operands[0]->channel_id()) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateSendDone(operands[0], *is_host_transfer));
break;
}
case HloOpcode::kGetTupleElement: {
optional<int64> index;
attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateGetTupleElement(shape, operands[0], *index));
break;
}
case HloOpcode::kCall: {
optional<HloComputation*> to_apply;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateCall(shape, operands, *to_apply));
break;
}
case HloOpcode::kReduceWindow: {
optional<HloComputation*> reduce_computation;
optional<Window> window;
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&reduce_computation};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
if (!window) {
window.emplace();
}
instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window,
*reduce_computation));
break;
}
case HloOpcode::kConvolution: {
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
optional<int64> feature_group_count;
optional<int64> batch_group_count;
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["dim_labels"] = {/*required=*/true,
AttrTy::kConvolutionDimensionNumbers, &dnums};
attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
&feature_group_count};
attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
&batch_group_count};
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
&operand_precision};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
if (!window) {
window.emplace();
}
if (!feature_group_count) {
feature_group_count = 1;
}
if (!batch_group_count) {
batch_group_count = 1;
}
PrecisionConfig precision_config;
if (operand_precision) {
*precision_config.mutable_operand_precision() = {
operand_precision->begin(), operand_precision->end()};
} else {
precision_config.mutable_operand_precision()->Resize(
operands.size(), PrecisionConfig::DEFAULT);
}
instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
shape, /*lhs=*/operands[0], /*rhs=*/operands[1],
feature_group_count.value(), batch_group_count.value(), *window,
*dnums, precision_config));
break;
}
case HloOpcode::kFft: {
optional<FftType> fft_type;
optional<std::vector<int64>> fft_length;
attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type};
attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List,
&fft_length};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateFft(
shape, operands[0], *fft_type, *fft_length));
break;
}
case HloOpcode::kTriangularSolve: {
TriangularSolveOptions options;
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributesAsProtoMessage(
/*required_attrs=*/std::unordered_set<string>(), &options)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateTriangularSolve(
shape, operands[0], operands[1], options));
break;
}
case HloOpcode::kCompare: {
optional<ComparisonDirection> direction;
attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection,
&direction};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateCompare(
shape, operands[0], operands[1], *direction));
break;
}
case HloOpcode::kCholesky: {
CholeskyOptions options;
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributesAsProtoMessage(
/*required_attrs=*/std::unordered_set<string>(), &options)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateCholesky(shape, operands[0], options));
break;
}
case HloOpcode::kBroadcast: {
optional<std::vector<int64>> broadcast_dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&broadcast_dimensions};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateBroadcast(
shape, operands[0], *broadcast_dimensions));
break;
}
case HloOpcode::kConcatenate: {
optional<std::vector<int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
dimensions->size() != 1) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateConcatenate(
shape, operands, dimensions->at(0)));
break;
}
case HloOpcode::kMap: {
optional<HloComputation*> to_apply;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
optional<std::vector<int64>> dimensions;
attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateMap(shape, operands, *to_apply));
break;
}
case HloOpcode::kReduce: {
auto loc = lexer_.GetLoc();
optional<HloComputation*> reduce_computation;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&reduce_computation};
optional<std::vector<int64>> dimensions_to_reduce;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions_to_reduce};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
if (operands.size() % 2) {
return Error(loc, StrCat("expects an even number of operands, but has ",
operands.size(), " operands"));
}
instruction = builder->AddInstruction(HloInstruction::CreateReduce(
shape, /*operands=*/
absl::Span<HloInstruction* const>(operands).subspan(
0, operands.size() / 2),
/*init_values=*/
absl::Span<HloInstruction* const>(operands).subspan(operands.size() /
2),
*dimensions_to_reduce, *reduce_computation));
break;
}
case HloOpcode::kReverse: {
optional<std::vector<int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateReverse(shape, operands[0], *dimensions));
break;
}
case HloOpcode::kSelectAndScatter: {
optional<HloComputation*> select;
attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select};
optional<HloComputation*> scatter;
attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter};
optional<Window> window;
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
if (!ParseOperands(&operands, /*expected_size=*/3) ||
!ParseAttributes(attrs)) {
return false;
}
if (!window) {
window.emplace();
}
instruction =
builder->AddInstruction(HloInstruction::CreateSelectAndScatter(
shape, /*operand=*/operands[0], *select, *window,
/*source=*/operands[1], /*init_value=*/operands[2], *scatter));
break;
}
case HloOpcode::kSlice: {
optional<SliceRanges> slice_ranges;
attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateSlice(
shape, operands[0], slice_ranges->starts, slice_ranges->limits,
slice_ranges->strides));
break;
}
case HloOpcode::kDynamicSlice: {
optional<std::vector<int64>> dynamic_slice_sizes;
attrs["dynamic_slice_sizes"] = {
/*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
LocTy loc = lexer_.GetLoc();
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
if (operands.empty()) {
return Error(loc, "Expected at least one operand.");
}
if (!(operands.size() == 2 && operands[1]->shape().rank() == 1) &&
operands.size() != 1 + operands[0]->shape().rank()) {
return Error(loc, "Wrong number of operands.");
}
instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice(
shape, /*operand=*/operands[0],
/*start_indices=*/absl::MakeSpan(operands).subspan(1),
*dynamic_slice_sizes));
break;
}
case HloOpcode::kDynamicUpdateSlice: {
LocTy loc = lexer_.GetLoc();
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
if (operands.size() < 2) {
return Error(loc, "Expected at least two operands.");
}
if (!(operands.size() == 3 && operands[2]->shape().rank() == 1) &&
operands.size() != 2 + operands[0]->shape().rank()) {
return Error(loc, "Wrong number of operands.");
}
instruction =
builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
shape, /*operand=*/operands[0], /*update=*/operands[1],
/*start_indices=*/absl::MakeSpan(operands).subspan(2)));
break;
}
case HloOpcode::kTranspose: {
optional<std::vector<int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateTranspose(shape, operands[0], *dimensions));
break;
}
case HloOpcode::kBatchNormTraining: {
optional<float> epsilon;
attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
optional<int64> feature_index;
attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
&feature_index};
if (!ParseOperands(&operands, /*expected_size=*/3) ||
!ParseAttributes(attrs)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateBatchNormTraining(
shape, /*operand=*/operands[0], /*scale=*/operands[1],
/*offset=*/operands[2], *epsilon, *feature_index));
break;
}
case HloOpcode::kBatchNormInference: {
optional<float> epsilon;
attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
optional<int64> feature_index;
attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
&feature_index};
if (!ParseOperands(&operands, /*expected_size=*/5) ||
!ParseAttributes(attrs)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateBatchNormInference(
shape, /*operand=*/operands[0], /*scale=*/operands[1],
/*offset=*/operands[2], /*mean=*/operands[3],
/*variance=*/operands[4], *epsilon, *feature_index));
break;
}
case HloOpcode::kBatchNormGrad: {
optional<float> epsilon;
attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
optional<int64> feature_index;
attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
&feature_index};
if (!ParseOperands(&operands, /*expected_size=*/5) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad(
shape, /*operand=*/operands[0], /*scale=*/operands[1],
/*mean=*/operands[2], /*variance=*/operands[3],
/*grad_output=*/operands[4], *epsilon, *feature_index));
break;
}
case HloOpcode::kPad: {
optional<PaddingConfig> padding;
attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreatePad(
shape, operands[0], /*padding_value=*/operands[1], *padding));
break;
}
case HloOpcode::kFusion: {
optional<HloComputation*> fusion_computation;
attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation,
&fusion_computation};
optional<HloInstruction::FusionKind> fusion_kind;
attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateFusion(
shape, *fusion_kind, operands, *fusion_computation));
break;
}
case HloOpcode::kInfeed: {
optional<string> config;
attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
// We need to know the infeed data shape to construct the infeed
// instruction. This is the zero-th element of the tuple-shaped output of
// the infeed instruction. ShapeUtil::GetTupleElementShape will check fail
// if the shape is not a non-empty tuple, so add guard so an error message
// can be emitted instead of a check fail
if (!shape.IsTuple() && !ShapeUtil::IsEmptyTuple(shape)) {
return Error(lexer_.GetLoc(),
"infeed must have a non-empty tuple shape");
}
instruction = builder->AddInstruction(HloInstruction::CreateInfeed(
ShapeUtil::GetTupleElementShape(shape, 0), operands[0],
config ? *config : ""));
break;
}
case HloOpcode::kOutfeed: {
optional<string> config;
attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0],
operands[1], config ? *config : ""));
break;
}
case HloOpcode::kRng: {
optional<RandomDistribution> distribution;
attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution,
&distribution};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateRng(shape, *distribution, operands));
break;
}
case HloOpcode::kRngGetAndUpdateState: {
optional<int64> delta;
attrs["delta"] = {/*required=*/true, AttrTy::kInt64, &delta};
if (!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateRngGetAndUpdateState(shape, *delta));
break;
}
case HloOpcode::kReducePrecision: {
optional<int64> exponent_bits;
optional<int64> mantissa_bits;
attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64,
&exponent_bits};
attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64,
&mantissa_bits};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateReducePrecision(
shape, operands[0], static_cast<int>(*exponent_bits),
static_cast<int>(*mantissa_bits)));
break;
}
case HloOpcode::kConditional: {
optional<HloComputation*> true_computation;
optional<HloComputation*> false_computation;
optional<std::vector<HloComputation*>> branch_computations;
if (!ParseOperands(&operands)) {
return false;
}
if (!ShapeUtil::IsScalar(operands[0]->shape())) {
return Error(lexer_.GetLoc(), "The first operand must be a scalar");
}
const bool branch_index_is_bool =
operands[0]->shape().element_type() == PRED;
if (branch_index_is_bool) {
attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation,
&true_computation};
attrs["false_computation"] = {
/*required=*/true, AttrTy::kHloComputation, &false_computation};
} else {
if (operands[0]->shape().element_type() != S32) {
return Error(lexer_.GetLoc(),
"The first operand must be a scalar of PRED or S32");
}
attrs["branch_computations"] = {/*required=*/true,
AttrTy::kBracedHloComputationList,
&branch_computations};
}
if (!ParseAttributes(attrs)) {
return false;
}
if (branch_index_is_bool) {
branch_computations.emplace({*true_computation, *false_computation});
}
if (branch_computations->empty() ||
operands.size() != branch_computations->size() + 1) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateConditional(
shape, /*branch_index=*/operands[0],
absl::MakeSpan(*branch_computations),
absl::MakeSpan(operands).subspan(1)));
break;
}
case HloOpcode::kCustomCall: {
optional<string> custom_call_target;
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
optional<int64> feature_group_count;
optional<int64> batch_group_count;
optional<std::vector<Shape>> operand_layout_constraints;
optional<bool> custom_call_has_side_effect;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["dim_labels"] = {/*required=*/false,
AttrTy::kConvolutionDimensionNumbers, &dnums};
attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
&feature_group_count};
attrs["batch_group_count"] = {/*required=*/false, AttrTy::kInt64,
&batch_group_count};
attrs["operand_layout_constraints"] = {
/*required=*/false, AttrTy::kShapeList, &operand_layout_constraints};
attrs["custom_call_has_side_effect"] = {/*required=*/false, AttrTy::kBool,
&custom_call_has_side_effect};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
if (operand_layout_constraints.has_value()) {
if (!LayoutUtil::HasLayout(shape)) {
return Error(lexer_.GetLoc(),
"Layout must be set on layout-constrained custom call");
}
if (operands.size() != operand_layout_constraints->size()) {
return Error(lexer_.GetLoc(),
StrCat("Expected ", operands.size(),
" operand layout constraints, ",
operand_layout_constraints->size(), " given"));
}
for (int64 i = 0; i < operands.size(); ++i) {
const Shape& operand_shape_with_layout =
(*operand_layout_constraints)[i];
if (!LayoutUtil::HasLayout(operand_shape_with_layout)) {
return Error(lexer_.GetLoc(),
StrCat("Operand layout constraint shape ",
ShapeUtil::HumanStringWithLayout(
operand_shape_with_layout),
" for operand ", i, " does not have a layout"));
}
if (!ShapeUtil::Compatible(operand_shape_with_layout,
operands[i]->shape())) {
return Error(
lexer_.GetLoc(),
StrCat(
"Operand layout constraint shape ",
ShapeUtil::HumanStringWithLayout(operand_shape_with_layout),
" for operand ", i,
" is not compatible with operand shape ",
ShapeUtil::HumanStringWithLayout(operands[i]->shape())));
}
}
instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
shape, operands, *custom_call_target, *operand_layout_constraints,
backend_config ? *backend_config : ""));
} else {
instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
shape, operands, *custom_call_target,
backend_config ? *backend_config : ""));
}
auto custom_call_instr = Cast<HloCustomCallInstruction>(instruction);
if (window.has_value()) {
custom_call_instr->set_window(*window);
}
if (dnums.has_value()) {
custom_call_instr->set_convolution_dimension_numbers(*dnums);
}
if (feature_group_count.has_value()) {
custom_call_instr->set_feature_group_count(*feature_group_count);
}
if (batch_group_count.has_value()) {
custom_call_instr->set_batch_group_count(*batch_group_count);
}
if (custom_call_has_side_effect.has_value()) {
custom_call_instr->set_custom_call_has_side_effect(
*custom_call_has_side_effect);
}
break;
}
case HloOpcode::kDot: {
optional<std::vector<int64>> lhs_contracting_dims;
attrs["lhs_contracting_dims"] = {
/*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims};
optional<std::vector<int64>> rhs_contracting_dims;
attrs["rhs_contracting_dims"] = {
/*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims};
optional<std::vector<int64>> lhs_batch_dims;
attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
&lhs_batch_dims};
optional<std::vector<int64>> rhs_batch_dims;
attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
&rhs_batch_dims};
optional<std::vector<PrecisionConfig::Precision>> operand_precision;
attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
&operand_precision};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
DotDimensionNumbers dnum;
if (lhs_contracting_dims) {
*dnum.mutable_lhs_contracting_dimensions() = {
lhs_contracting_dims->begin(), lhs_contracting_dims->end()};
}
if (rhs_contracting_dims) {
*dnum.mutable_rhs_contracting_dimensions() = {
rhs_contracting_dims->begin(), rhs_contracting_dims->end()};
}
if (lhs_batch_dims) {
*dnum.mutable_lhs_batch_dimensions() = {lhs_batch_dims->begin(),
lhs_batch_dims->end()};
}
if (rhs_batch_dims) {
*dnum.mutable_rhs_batch_dimensions() = {rhs_batch_dims->begin(),
rhs_batch_dims->end()};
}
PrecisionConfig precision_config;
if (operand_precision) {
*precision_config.mutable_operand_precision() = {
operand_precision->begin(), operand_precision->end()};
} else {
precision_config.mutable_operand_precision()->Resize(
operands.size(), PrecisionConfig::DEFAULT);
}
instruction = builder->AddInstruction(HloInstruction::CreateDot(
shape, operands[0], operands[1], dnum, precision_config));
break;
}
case HloOpcode::kGather: {
optional<std::vector<int64>> offset_dims;
attrs["offset_dims"] = {/*required=*/true, AttrTy::kBracedInt64List,
&offset_dims};
optional<std::vector<int64>> collapsed_slice_dims;
attrs["collapsed_slice_dims"] = {
/*required=*/true, AttrTy::kBracedInt64List, &collapsed_slice_dims};
optional<std::vector<int64>> start_index_map;
attrs["start_index_map"] = {/*required=*/true, AttrTy::kBracedInt64List,
&start_index_map};
optional<int64> index_vector_dim;
attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
&index_vector_dim};
optional<std::vector<int64>> slice_sizes;
attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List,
&slice_sizes};
optional<bool> indices_are_sorted = false;
attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
&indices_are_sorted};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
GatherDimensionNumbers dim_numbers =
HloGatherInstruction::MakeGatherDimNumbers(
/*offset_dims=*/*offset_dims,
/*collapsed_slice_dims=*/*collapsed_slice_dims,
/*start_index_map=*/*start_index_map,
/*index_vector_dim=*/*index_vector_dim);
instruction = builder->AddInstruction(HloInstruction::CreateGather(
shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
dim_numbers, *slice_sizes, indices_are_sorted.value()));
break;
}
case HloOpcode::kScatter: {
optional<std::vector<int64>> update_window_dims;
attrs["update_window_dims"] = {
/*required=*/true, AttrTy::kBracedInt64List, &update_window_dims};
optional<std::vector<int64>> inserted_window_dims;
attrs["inserted_window_dims"] = {
/*required=*/true, AttrTy::kBracedInt64List, &inserted_window_dims};
optional<std::vector<int64>> scatter_dims_to_operand_dims;
attrs["scatter_dims_to_operand_dims"] = {/*required=*/true,
AttrTy::kBracedInt64List,
&scatter_dims_to_operand_dims};
optional<int64> index_vector_dim;
attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
&index_vector_dim};
optional<HloComputation*> update_computation;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&update_computation};
optional<bool> indices_are_sorted = false;
attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool,
&indices_are_sorted};
if (!ParseOperands(&operands, /*expected_size=*/3) ||
!ParseAttributes(attrs)) {
return false;
}
ScatterDimensionNumbers dim_numbers =
HloScatterInstruction::MakeScatterDimNumbers(
/*update_window_dims=*/*update_window_dims,
/*inserted_window_dims=*/*inserted_window_dims,
/*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims,
/*index_vector_dim=*/*index_vector_dim);
instruction = builder->AddInstruction(HloInstruction::CreateScatter(
shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1],
/*updates=*/operands[2], *update_computation, dim_numbers,
indices_are_sorted.value()));
break;
}
case HloOpcode::kDomain: {
DomainData domain;
attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateDomain(
shape, operands[0], std::move(domain.exit_metadata),
std::move(domain.entry_metadata)));
break;
}
case HloOpcode::kTrace:
return TokenError(StrCat("parsing not yet implemented for op: ",
HloOpcodeString(opcode)));
case HloOpcode::kGetDimensionSize:
optional<std::vector<int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateGetDimensionSize(
shape, operands[0], (*dimensions)[0]));
break;
}
instruction->SetAndSanitizeName(name);
if (instruction->name() != name) {
return Error(name_loc,
StrCat("illegal instruction name: ", name,
"; suggest renaming to: ", instruction->name()));
}
// Add shared attributes like metadata to the instruction, if they were seen.
if (sharding) {
instruction->set_sharding(
HloSharding::FromProto(sharding.value()).ValueOrDie());
}
if (parameter_replication) {
int leaf_count = ShapeUtil::GetLeafCount(instruction->shape());
const auto& replicated =
parameter_replication->replicated_at_leaf_buffers();
if (leaf_count != replicated.size()) {
return Error(lexer_.GetLoc(),
StrCat("parameter has ", leaf_count,
" leaf buffers, but parameter_replication has ",
replicated.size(), " elements."));
}
instruction->set_parameter_replicated_at_leaf_buffers(replicated);
}
if (predecessors) {
for (auto* pre : *predecessors) {
Status status = pre->AddControlDependencyTo(instruction);
if (!status.ok()) {
return Error(name_loc, StrCat("error adding control dependency for: ",
name, " status: ", status.ToString()));
}
}
}
if (metadata) {
instruction->set_metadata(*metadata);
}
if (backend_config) {
instruction->set_raw_backend_config_string(std::move(*backend_config));
}
if (outer_dimension_partitions) {
instruction->set_outer_dimension_partitions(*outer_dimension_partitions);
}
return AddInstruction(name, instruction, name_loc);
} // NOLINT(readability/fn_size)
// ::= '{' (single_sharding | tuple_sharding) '}'
//
// tuple_sharding ::= single_sharding* (',' single_sharding)*
bool HloParser::ParseSharding(OpSharding* sharding) {
// A single sharding starts with '{' and is not followed by '{'.
// A tuple sharding starts with '{' and is followed by '{', or is '{''}' for
// an empty tuple.
if (!ParseToken(TokKind::kLbrace,
"expected '{' to start sharding attribute")) {
return false;
}
if (lexer_.GetKind() != TokKind::kLbrace &&
lexer_.GetKind() != TokKind::kRbrace) {
return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true);
}
// Tuple sharding.
// Allow empty tuple shardings.
if (lexer_.GetKind() != TokKind::kRbrace) {
do {
if (!ParseSingleSharding(sharding->add_tuple_shardings(),
/*lbrace_pre_lexed=*/false)) {
return false;
}
} while (EatIfPresent(TokKind::kComma));
}
sharding->set_type(OpSharding::TUPLE);
return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
}
// frontend_attributes ::= '{' attributes '}'
// attributes
// ::= /*empty*/
// ::= attribute '=' value (',' attribute '=' value)*
bool HloParser::ParseFrontendAttributes(
FrontendAttributes* frontend_attributes) {
CHECK(frontend_attributes != nullptr);
if (!ParseToken(TokKind::kLbrace,
"expected '{' to start frontend attributes")) {
return false;
}
if (lexer_.GetKind() == TokKind::kRbrace) {
// empty
} else {
do {
string attribute;
if (!ParseAttributeName(&attribute)) {
return false;
}
if (lexer_.GetKind() != TokKind::kIdent) {
return false;
}
(*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal();
lexer_.Lex();
} while (EatIfPresent(TokKind::kComma));
}
return ParseToken(TokKind::kRbrace,
"expects '}' at the end of frontend attributes");
}
// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape?
// ('devices=' ('[' dims ']')* device_list)? '}'
// dims ::= int_list device_list ::= int_list
bool HloParser::ParseSingleSharding(OpSharding* sharding,
bool lbrace_pre_lexed) {
if (!lbrace_pre_lexed &&
!ParseToken(TokKind::kLbrace,
"expected '{' to start sharding attribute")) {
return false;
}
LocTy loc = lexer_.GetLoc();
bool maximal = false;
bool replicated = false;
std::vector<int64> devices;
std::vector<int64> tile_assignment_dimensions;
while (lexer_.GetKind() != TokKind::kRbrace) {
switch (lexer_.GetKind()) {
case TokKind::kw_maximal:
maximal = true;
lexer_.Lex();
break;
case TokKind::kw_replicated:
replicated = true;
lexer_.Lex();
break;
case TokKind::kAttributeName: {
if (lexer_.GetStrVal() == "device") {
if (lexer_.Lex() != TokKind::kInt) {
return TokenError("device= attribute must be an integer");
}
devices = {lexer_.GetInt64Val()};
lexer_.Lex();
} else if (lexer_.GetStrVal() == "devices") {
lexer_.Lex();
if (!ParseToken(TokKind::kLsquare,
"expected '[' to start sharding devices shape")) {
return false;
}
do {
int64 dim;
if (!ParseInt64(&dim)) {
return false;
}
tile_assignment_dimensions.push_back(dim);
} while (EatIfPresent(TokKind::kComma));
if (!ParseToken(TokKind::kRsquare,
"expected ']' to start sharding devices shape")) {
return false;
}
do {
int64 device;
if (!ParseInt64(&device)) {
return false;
}
devices.push_back(device);
} while (EatIfPresent(TokKind::kComma));
} else {
return TokenError(
"unknown attribute in sharding: expected device= or devices=");
}
break;
}
case TokKind::kRbrace:
break;
default:
return TokenError("unexpected token");
}
}
if (replicated) {
if (!devices.empty()) {
return Error(loc,
"replicated shardings should not have any devices assigned");
}
sharding->set_type(OpSharding::REPLICATED);
} else if (maximal) {
if (devices.size() != 1) {
return Error(loc,
"maximal shardings should have exactly one device assigned");
}
sharding->set_type(OpSharding::MAXIMAL);
sharding->add_tile_assignment_devices(devices[0]);
} else {
if (devices.size() <= 1) {
return Error(
loc, "non-maximal shardings must have more than one device assigned");
}
if (tile_assignment_dimensions.empty()) {
return Error(
loc,
"non-maximal shardings must have a tile assignment list including "
"dimensions");
}
sharding->set_type(OpSharding::OTHER);
for (int64 dim : tile_assignment_dimensions) {
sharding->add_tile_assignment_dimensions(dim);
}
for (int64 device : devices) {
sharding->add_tile_assignment_devices(device);
}
}
lexer_.Lex();
return true;
}
// parameter_replication ::=
// '{' ('true' | 'false')* (',' ('true' | 'false'))* '}'
bool HloParser::ParseParameterReplication(
ParameterReplication* parameter_replication) {
if (!ParseToken(TokKind::kLbrace,
"expected '{' to start parameter_replication attribute")) {
return false;
}
if (lexer_.GetKind() != TokKind::kRbrace) {
do {
if (lexer_.GetKind() == TokKind::kw_true) {
parameter_replication->add_replicated_at_leaf_buffers(true);
} else if (lexer_.GetKind() == TokKind::kw_false) {
parameter_replication->add_replicated_at_leaf_buffers(false);
} else {
return false;
}
lexer_.Lex();
} while (EatIfPresent(TokKind::kComma));
}
return ParseToken(TokKind::kRbrace,
"expected '}' to end parameter_replication attribute");
}
// replica_groups ::='{' int64list_elements '}'
// int64list_elements
// ::= /*empty*/
// ::= int64list (',' int64list)*
// int64list ::= '{' int64_elements '}'
// int64_elements
// ::= /*empty*/
// ::= int64_val (',' int64_val)*
bool HloParser::ParseReplicaGroupsOnly(
std::vector<ReplicaGroup>* replica_groups) {
std::vector<std::vector<int64>> result;
if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
&result)) {
return false;
}
*replica_groups = CreateReplicaGroups(result);
return true;
}
// domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ','
// 'exit=' exit_sharding '}'
bool HloParser::ParseDomain(DomainData* domain) {
std::unordered_map<string, AttrConfig> attrs;
optional<string> kind;
optional<OpSharding> entry_sharding;
optional<OpSharding> exit_sharding;
attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind};
attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding};
attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding};
if (!ParseSubAttributes(attrs)) {
return false;
}
if (*kind == ShardingMetadata::KindName()) {
auto entry_sharding_ptr = absl::make_unique<HloSharding>(
HloSharding::FromProto(*entry_sharding).ValueOrDie());
auto exit_sharding_ptr = absl::make_unique<HloSharding>(
HloSharding::FromProto(*exit_sharding).ValueOrDie());
domain->entry_metadata =
absl::make_unique<ShardingMetadata>(std::move(entry_sharding_ptr));
domain->exit_metadata =
absl::make_unique<ShardingMetadata>(std::move(exit_sharding_ptr));
} else {
return TokenError(StrCat("unsupported domain kind: ", *kind));
}
return true;
}
// '{' name+ '}'
bool HloParser::ParseInstructionNames(
std::vector<HloInstruction*>* instructions) {
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of instruction name list")) {
return false;
}
LocTy loc = lexer_.GetLoc();
do {
string name;
if (!ParseName(&name)) {
return Error(loc, "expects a instruction name");
}
std::pair<HloInstruction*, LocTy>* instr = FindInstruction(name);
if (!instr) {
return TokenError(StrFormat("instruction '%s' is not defined", name));
}
instructions->push_back(instr->first);
} while (EatIfPresent(TokKind::kComma));
return ParseToken(TokKind::kRbrace,
"expects '}' at the end of instruction name list");
}
bool HloParser::SetValueInLiteral(LocTy loc, int64 value,
LinearOrMultiIndex index, Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
case S8:
return SetValueInLiteralHelper<int8>(loc, value, index, literal);
case S16:
return SetValueInLiteralHelper<int16>(loc, value, index, literal);
case S32:
return SetValueInLiteralHelper<int32>(loc, value, index, literal);
case S64:
return SetValueInLiteralHelper<int64>(loc, value, index, literal);
case U8:
return SetValueInLiteralHelper<tensorflow::uint8>(loc, value, index,
literal);
case U16:
return SetValueInLiteralHelper<tensorflow::uint16>(loc, value, index,
literal);
case U32:
return SetValueInLiteralHelper<tensorflow::uint32>(loc, value, index,
literal);
case U64:
return SetValueInLiteralHelper<tensorflow::uint64>(loc, value, index,
literal);
case PRED:
// Bool type literals with rank >= 1 are printed in 0s and 1s.
return SetValueInLiteralHelper<bool>(loc, static_cast<bool>(value), index,
literal);
default:
LOG(FATAL) << "unknown integral primitive type "
<< PrimitiveType_Name(shape.element_type());
}
}
bool HloParser::SetValueInLiteral(LocTy loc, double value,
LinearOrMultiIndex index, Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
case F16:
return SetValueInLiteralHelper<Eigen::half>(loc, value, index, literal);
case BF16:
return SetValueInLiteralHelper<tensorflow::bfloat16>(loc, value, index,
literal);
case F32:
return SetValueInLiteralHelper<float>(loc, value, index, literal);
case F64:
return SetValueInLiteralHelper<double>(loc, value, index, literal);
default:
LOG(FATAL) << "unknown floating point primitive type "
<< PrimitiveType_Name(shape.element_type());
}
}
bool HloParser::SetValueInLiteral(LocTy loc, bool value,
LinearOrMultiIndex index, Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
case PRED:
return SetValueInLiteralHelper<bool>(loc, value, index, literal);
default:
LOG(FATAL) << PrimitiveType_Name(shape.element_type())
<< " is not PRED type";
}
}
bool HloParser::SetValueInLiteral(LocTy loc, std::complex<double> value,
LinearOrMultiIndex index, Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
case C64:
return SetValueInLiteralHelper<std::complex<float>>(loc, value, index,
literal);
case C128:
return SetValueInLiteralHelper<std::complex<double>>(loc, value, index,
literal);
default:
LOG(FATAL) << PrimitiveType_Name(shape.element_type())
<< " is not a complex type type";
}
}
template <typename T>
string StringifyValue(T val) {
return StrCat(val);
}
template <>
string StringifyValue(std::complex<double> val) {
return StrFormat("(%f, %f)", std::real(val), std::imag(val));
}
template <typename LiteralNativeT, typename ParsedElemT>
bool HloParser::SetValueInLiteralHelper(LocTy loc, ParsedElemT value,
LinearOrMultiIndex index,
Literal* literal) {
if (!CheckParsedValueIsInRange<LiteralNativeT>(loc, value)) {
return false;
}
// Check that the index is in range and assign into the literal
if (auto* linear_index = absl::get_if<int64>(&index)) {
if (*linear_index >= ShapeUtil::ElementsIn(literal->shape())) {
return Error(loc, StrCat("trys to set value ", StringifyValue(value),
" to a literal in shape ",
ShapeUtil::HumanString(literal->shape()),
" at linear index ", *linear_index,
", but the index is out of range"));
}
literal->data<LiteralNativeT>().at(*linear_index) =
static_cast<LiteralNativeT>(value);
} else {
auto* multi_index = absl::get_if<absl::Span<const int64>>(&index);
CHECK(multi_index != nullptr);
auto invalid_idx = [&](string msg) {
return Error(loc, StrFormat("Invalid sparse index [%s]. %s",
absl::StrJoin(*multi_index, ", "), msg));
};
const auto& shape = literal->shape();
if (shape.rank() != multi_index->size()) {
return invalid_idx(
StrFormat("Has rank %d, but constant has shape %s, which has rank %d",
multi_index->size(), shape.ToString(), shape.rank()));
}
for (int64 i = 0; i < shape.rank(); ++i) {
auto idx = (*multi_index)[i];
if (idx < 0) {
return invalid_idx(StrFormat(
"Sub-index value at %d, namely %d, cannot be negative.", i, idx));
}
if (idx >= shape.dimensions(i)) {
return invalid_idx(
StrFormat("Sub-index at %d, namely %d, doesn't fit within shape "
"dimension %d in %s",
i, idx, shape.dimensions(i), shape.ToString()));
}
}
literal->AppendSparseElement(*multi_index,
static_cast<LiteralNativeT>(value));
}
return true;
}
// literal
// ::= tuple
// ::= non_tuple
bool HloParser::ParseLiteral(Literal* literal, const Shape& shape) {
return shape.IsTuple() ? ParseTupleLiteral(literal, shape)
: ParseNonTupleLiteral(literal, shape);
}
// tuple
// ::= shape '(' literal_list ')'
// literal_list
// ::= /*empty*/
// ::= literal (',' literal)*
bool HloParser::ParseTupleLiteral(Literal* literal, const Shape& shape) {
if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
return false;
}
std::vector<Literal> elements(ShapeUtil::TupleElementCount(shape));
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
} else {
// literal, (',' literal)*
for (int i = 0; i < elements.size(); i++) {
if (i > 0) {
ParseToken(TokKind::kComma, "exepcts ',' to separate tuple elements");
}
if (!ParseLiteral(&elements[i],
ShapeUtil::GetTupleElementShape(shape, i))) {
return TokenError(StrCat("expects the ", i, "th element"));
}
}
}
*literal = LiteralUtil::MakeTupleOwned(std::move(elements));
return ParseToken(TokKind::kRparen,
StrCat("expects ')' at the end of the tuple with ",
ShapeUtil::TupleElementCount(shape), "elements"));
}
// non_tuple
// ::= rank01
// ::= rank2345
// rank2345 ::= shape sparse_or_nested_array
bool HloParser::ParseNonTupleLiteral(Literal* literal, const Shape& shape) {
if (LayoutUtil::IsSparseArray(shape)) {
return ParseSparseLiteral(literal, shape);
}
CHECK(LayoutUtil::IsDenseArray(shape)) << shape.ToString(true);
return ParseDenseLiteral(literal, shape);
}
bool HloParser::ParseDenseLiteral(Literal* literal, const Shape& shape) {
// Cast `rank` to int because we call shape.dimensions(int rank) below, and if
// `rank` is an int64, that's an implicit narrowing conversion, which is
// implementation-defined behavior.
const int rank = static_cast<int>(shape.rank());
// Create a literal with the given shape in default layout.
*literal = LiteralUtil::CreateFromDimensions(
shape.element_type(), AsInt64Slice(shape.dimensions()));
int64 nest_level = 0;
int64 linear_index = 0;
// elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for
// the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}},
// when we are parsing the 2nd '{' (right before '1'), we are seeing a
// sub-array of the dimension 0, so elems_seen_per_dim[0]++. When we are at
// the first '}' (right after '3'), it means the sub-array ends, and the
// sub-array is supposed to contain exactly 3 elements, so check if
// elems_seen_per_dim[1] is 3.
std::vector<int64> elems_seen_per_dim(rank);
auto get_index_str = [&elems_seen_per_dim](int dim) -> string {
std::vector<int64> elems_seen_until_dim(elems_seen_per_dim.begin(),
elems_seen_per_dim.begin() + dim);
return StrCat("[",
StrJoin(elems_seen_until_dim, ",",
[](string* out, const int64& num_elems) {
StrAppend(out, num_elems - 1);
}),
"]");
};
auto add_one_elem_seen = [&] {
if (rank > 0) {
if (nest_level != rank) {
return TokenError(absl::StrFormat(
"expects nested array in rank %d, but sees %d", rank, nest_level));
}
elems_seen_per_dim[rank - 1]++;
if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) {
return TokenError(absl::StrFormat(
"expects %d elements on the minor-most dimension, but "
"sees more",
shape.dimensions(rank - 1)));
}
}
return true;
};
do {
switch (lexer_.GetKind()) {
default:
return TokenError("unexpected token type in a literal");
case TokKind::kLbrace: {
nest_level++;
if (nest_level > rank) {
return TokenError(absl::StrFormat(
"expects nested array in rank %d, but sees larger", rank));
}
if (nest_level > 1) {
elems_seen_per_dim[nest_level - 2]++;
if (elems_seen_per_dim[nest_level - 2] >
shape.dimensions(nest_level - 2)) {
return TokenError(absl::StrFormat(
"expects %d elements in the %sth element, but sees more",
shape.dimensions(nest_level - 2),
get_index_str(nest_level - 2)));
}
}
lexer_.Lex();
break;
}
case TokKind::kRbrace: {
nest_level--;
if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) {
return TokenError(absl::StrFormat(
"expects %d elements in the %sth element, but sees %d",
shape.dimensions(nest_level), get_index_str(nest_level),
elems_seen_per_dim[nest_level]));
}
elems_seen_per_dim[nest_level] = 0;
lexer_.Lex();
break;
}
case TokKind::kLparen: {
if (!primitive_util::IsComplexType(shape.element_type())) {
return TokenError(
absl::StrFormat("unexpected '(' in literal. Parens are only "
"valid for complex literals"));
}
std::complex<double> value;
LocTy loc = lexer_.GetLoc();
if (!add_one_elem_seen() || !ParseComplex(&value) ||
!SetValueInLiteral(loc, value, linear_index++, literal)) {
return false;
}
break;
}
case TokKind::kDots: {
if (nest_level != 1) {
return TokenError(absl::StrFormat(
"expects `...` at nest level 1, but sees it at nest level %d",
nest_level));
}
elems_seen_per_dim[0] = shape.dimensions(0);
lexer_.Lex();
// Fill data with deterministic (garbage) values. Use static to avoid
// creating identical constants which could potentially got CSE'ed
// away. This is a best-effort approach to make sure replaying a HLO
// gives us same optimized HLO graph.
static uint32 data = 0;
uint32* raw_data = static_cast<uint32*>(literal->untyped_data());
for (int64 i = 0; i < literal->size_bytes() / 4; ++i) {
raw_data[i] = data++;
}
uint8* raw_data_int8 = static_cast<uint8*>(literal->untyped_data());
static uint8 data_int8 = 0;
for (int64 i = 0; i < literal->size_bytes() % 4; ++i) {
raw_data_int8[literal->size_bytes() / 4 + i] = data_int8++;
}
break;
}
case TokKind::kComma:
// Skip.
lexer_.Lex();
break;
case TokKind::kw_true:
case TokKind::kw_false:
case TokKind::kInt:
case TokKind::kDecimal:
case TokKind::kw_nan:
case TokKind::kw_inf:
case TokKind::kNegInf: {
add_one_elem_seen();
if (lexer_.GetKind() == TokKind::kw_true ||
lexer_.GetKind() == TokKind::kw_false) {
if (!SetValueInLiteral(lexer_.GetLoc(),
lexer_.GetKind() == TokKind::kw_true,
linear_index++, literal)) {
return false;
}
lexer_.Lex();
} else if (primitive_util::IsIntegralType(shape.element_type()) ||
shape.element_type() == PRED) {
LocTy loc = lexer_.GetLoc();
int64 value;
if (!ParseInt64(&value)) {
return Error(loc, StrCat("expects integer for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
return false;
}
} else if (primitive_util::IsFloatingPointType(shape.element_type())) {
LocTy loc = lexer_.GetLoc();
double value;
if (!ParseDouble(&value)) {
return Error(
loc, StrCat("expect floating point value for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
if (!SetValueInLiteral(loc, value, linear_index++, literal)) {
return false;
}
} else {
return TokenError(StrCat("unsupported primitive type ",
PrimitiveType_Name(shape.element_type())));
}
break;
}
} // end of switch
} while (nest_level > 0);
*literal = literal->Relayout(shape.layout());
return true;
}
bool HloParser::ParseSparseLiteral(Literal* literal, const Shape& shape) {
*literal = Literal(shape);
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of a sparse literal")) {
return false;
}
for (;;) {
if (lexer_.GetKind() == TokKind::kRbrace) {
lexer_.Lex();
break;
}
std::vector<int64> index;
if (lexer_.GetKind() == TokKind::kInt) {
int64 single_index = lexer_.GetInt64Val();
lexer_.Lex();
index.push_back(single_index);
} else {
if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
&index)) {
return false;
}
}
if (!ParseToken(TokKind::kColon,
"expects ':' after after the sparse array index and before "
"the sparse array value")) {
return false;
}
LocTy value_loc = lexer_.GetLoc();
if (lexer_.GetKind() == TokKind::kw_true ||
lexer_.GetKind() == TokKind::kw_false) {
bool value = lexer_.GetKind() == TokKind::kw_true;
if (!SetValueInLiteral(lexer_.GetLoc(), value, index, literal)) {
return false;
}
lexer_.Lex();
} else if (primitive_util::IsIntegralType(shape.element_type())) {
int64 value;
if (!ParseInt64(&value)) {
return Error(value_loc,
StrCat("expects integer for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
if (!SetValueInLiteral(value_loc, value, index, literal)) {
return false;
}
} else if (primitive_util::IsFloatingPointType(shape.element_type())) {
double value;
if (!ParseDouble(&value)) {
return Error(value_loc,
StrCat("expects floating point value for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
if (!SetValueInLiteral(value_loc, value, index, literal)) {
return false;
}
} else if (primitive_util::IsComplexType(shape.element_type())) {
std::complex<double> value;
if (!ParseComplex(&value)) {
return Error(value_loc,
StrCat("expects complex value for primitive type: ",
PrimitiveType_Name(shape.element_type())));
}
if (!SetValueInLiteral(value_loc, value, index, literal)) {
return false;
}
} else {
LOG(FATAL) << "Unexpected element type: "
<< PrimitiveType_Name(shape.element_type());
}
if (lexer_.GetKind() != TokKind::kRbrace &&
!ParseToken(TokKind::kComma,
"expects ',' separator between sparse array elements")) {
return false;
}
if (literal->sparse_element_count() + 1 ==
LayoutUtil::MaxSparseElements(shape.layout())) {
return Error(
lexer_.GetLoc(),
StrCat("number of sparse elements exceeds maximum for layout: ",
ShapeUtil::HumanStringWithLayout(shape)));
}
}
literal->SortSparseElements();
return true;
}
// MaxFiniteValue is a type-traits helper used by
// HloParser::CheckParsedValueIsInRange.
template <typename T>
struct MinMaxFiniteValue {
static T max() { return std::numeric_limits<T>::max(); }
static T min() { return std::numeric_limits<T>::lowest(); }
};
template <>
struct MinMaxFiniteValue<Eigen::half> {
static double max() {
// Sadly this is not constexpr, so this forces `value` to be a method.
return static_cast<double>(Eigen::NumTraits<Eigen::half>::highest());
}
static double min() { return -max(); }
};
template <>
struct MinMaxFiniteValue<bfloat16> {
static double max() { return static_cast<double>(bfloat16::highest()); }
static double min() { return -max(); }
};
template <typename LiteralNativeT, typename ParsedElemT>
bool HloParser::CheckParsedValueIsInRange(LocTy loc, ParsedElemT value) {
PrimitiveType literal_ty =
primitive_util::NativeToPrimitiveType<LiteralNativeT>();
if (std::isnan(value) ||
(std::numeric_limits<ParsedElemT>::has_infinity &&
(std::numeric_limits<ParsedElemT>::infinity() == value ||
-std::numeric_limits<ParsedElemT>::infinity() == value))) {
// Skip range checking for non-finite value.
} else if (std::is_unsigned<LiteralNativeT>::value) {
CHECK((std::is_same<ParsedElemT, int64>::value ||
std::is_same<ParsedElemT, bool>::value))
<< "Unimplemented checking for ParsedElemT";
ParsedElemT upper_bound;
if (sizeof(LiteralNativeT) >= sizeof(ParsedElemT)) {
upper_bound = std::numeric_limits<ParsedElemT>::max();
} else {
upper_bound =
static_cast<ParsedElemT>(std::numeric_limits<LiteralNativeT>::max());
}
if (value > upper_bound || value < 0) {
// Value is out of range for LiteralNativeT.
return Error(loc, StrCat("value ", value,
" is out of range for literal's primitive type ",
PrimitiveType_Name(literal_ty), " namely [0, ",
upper_bound, "]."));
}
} else if (value > MinMaxFiniteValue<LiteralNativeT>::max() ||
value < MinMaxFiniteValue<LiteralNativeT>::min()) {
// Value is out of range for LiteralNativeT.
return Error(loc, StrCat("value ", value,
" is out of range for literal's primitive type ",
PrimitiveType_Name(literal_ty), " namely [",
MinMaxFiniteValue<LiteralNativeT>::min(), ", ",
MinMaxFiniteValue<LiteralNativeT>::max(), "]."));
}
return true;
}
template <typename LiteralNativeT>
bool HloParser::CheckParsedValueIsInRange(LocTy loc,
std::complex<double> value) {
// e.g. `float` for std::complex<float>
using LiteralComplexComponentT =
decltype(std::real(std::declval<LiteralNativeT>()));
// We could do simply
//
// return CheckParsedValueIsInRange<LiteralNativeT>(std::real(value)) &&
// CheckParsedValueIsInRange<LiteralNativeT>(std::imag(value));
//
// but this would give bad error messages on failure.
auto check_component = [&](absl::string_view name, double v) {
if (std::isnan(v) || v == std::numeric_limits<double>::infinity() ||
v == -std::numeric_limits<double>::infinity()) {
// Skip range-checking for non-finite values.
return true;
}
double min = MinMaxFiniteValue<LiteralComplexComponentT>::min();
double max = MinMaxFiniteValue<LiteralComplexComponentT>::max();
if (v < min || v > max) {
// Value is out of range for LitearlComplexComponentT.
return Error(
loc,
StrCat(name, " part ", v,
" is out of range for literal's primitive type ",
PrimitiveType_Name(
primitive_util::NativeToPrimitiveType<LiteralNativeT>()),
", namely [", min, ", ", max, "]."));
}
return true;
};
return check_component("real", std::real(value)) &&
check_component("imaginary", std::imag(value));
}
// operands ::= '(' operands1 ')'
// operands1
// ::= /*empty*/
// ::= operand (, operand)*
// operand ::= (shape)? name
bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
CHECK(operands != nullptr);
if (!ParseToken(TokKind::kLparen,
"expects '(' at the beginning of operands")) {
return false;
}
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
} else {
do {
LocTy loc = lexer_.GetLoc();
string name;
optional<Shape> shape;
if (CanBeShape()) {
shape.emplace();
if (!ParseShape(&shape.value())) {
return false;
}
}
if (!ParseName(&name)) {
// When parsing a single instruction (as opposed to a whole module), an
// HLO may have one or more operands with a shape but no name:
//
// foo = add(f32[10], f32[10])
//
// create_missing_instruction_ is always non-null when parsing a single
// instruction, and is responsible for creating kParameter instructions
// for these operands.
if (shape.has_value() && create_missing_instruction_ != nullptr &&
scoped_name_tables_.size() == 1) {
name = "";
} else {
return false;
}
}
std::pair<HloInstruction*, LocTy>* instruction =
FindInstruction(name, shape);
if (instruction == nullptr) {
return Error(loc, StrCat("instruction does not exist: ", name));
}
operands->push_back(instruction->first);
} while (EatIfPresent(TokKind::kComma));
}
return ParseToken(TokKind::kRparen, "expects ')' at the end of operands");
}
bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands,
const int expected_size) {
CHECK(operands != nullptr);
LocTy loc = lexer_.GetLoc();
if (!ParseOperands(operands)) {
return false;
}
if (expected_size != operands->size()) {
return Error(loc, StrCat("expects ", expected_size, " operands, but has ",
operands->size(), " operands"));
}
return true;
}
// sub_attributes ::= '{' (','? attribute)* '}'
bool HloParser::ParseSubAttributes(
const std::unordered_map<string, AttrConfig>& attrs) {
LocTy loc = lexer_.GetLoc();
if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) {
return false;
}
std::unordered_set<string> seen_attrs;
if (lexer_.GetKind() == TokKind::kRbrace) {
// empty
} else {
do {
EatIfPresent(TokKind::kComma);
if (!ParseAttributeHelper(attrs, &seen_attrs)) {
return false;
}
} while (lexer_.GetKind() != TokKind::kRbrace);
}
// Check that all required attrs were seen.
for (const auto& attr_it : attrs) {
if (attr_it.second.required &&
seen_attrs.find(attr_it.first) == seen_attrs.end()) {
return Error(loc, StrFormat("sub-attribute %s is expected but not seen",
attr_it.first));
}
}
return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes");
}
// attributes ::= (',' attribute)*
bool HloParser::ParseAttributes(
const std::unordered_map<string, AttrConfig>& attrs) {
LocTy loc = lexer_.GetLoc();
std::unordered_set<string> seen_attrs;
while (EatIfPresent(TokKind::kComma)) {
if (!ParseAttributeHelper(attrs, &seen_attrs)) {
return false;
}
}
// Check that all required attrs were seen.
for (const auto& attr_it : attrs) {
if (attr_it.second.required &&
seen_attrs.find(attr_it.first) == seen_attrs.end()) {
return Error(loc, StrFormat("attribute %s is expected but not seen",
attr_it.first));
}
}
return true;
}
bool HloParser::ParseAttributeHelper(
const std::unordered_map<string, AttrConfig>& attrs,
std::unordered_set<string>* seen_attrs) {
LocTy loc = lexer_.GetLoc();
string name;
if (!ParseAttributeName(&name)) {
return Error(loc, "error parsing attributes");
}
VLOG(3) << "Parsing attribute " << name;
if (!seen_attrs->insert(name).second) {
return Error(loc, StrFormat("attribute %s already exists", name));
}
auto attr_it = attrs.find(name);
if (attr_it == attrs.end()) {
string allowed_attrs;
if (attrs.empty()) {
allowed_attrs = "No attributes are allowed here.";
} else {
allowed_attrs = StrCat(
"Allowed attributes: ",
StrJoin(attrs, ", ",
[&](string* out, const std::pair<string, AttrConfig>& kv) {
StrAppend(out, kv.first);
}));
}
return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name,
allowed_attrs));
}
AttrTy attr_type = attr_it->second.attr_type;
void* attr_out_ptr = attr_it->second.result;
bool success = [&] {
LocTy attr_loc = lexer_.GetLoc();
switch (attr_type) {
case AttrTy::kBool: {
bool result;
if (!ParseBool(&result)) {
return false;
}
static_cast<optional<bool>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kInt64: {
int64 result;
if (!ParseInt64(&result)) {
return false;
}
static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kInt32: {
int64 result;
if (!ParseInt64(&result)) {
return false;
}
if (result != static_cast<int32>(result)) {
return Error(attr_loc, "value out of range for int32");
}
static_cast<optional<int32>*>(attr_out_ptr)
->emplace(static_cast<int32>(result));
return true;
}
case AttrTy::kFloat: {
double result;
if (!ParseDouble(&result)) {
return false;
}
if (result > std::numeric_limits<float>::max() ||
result < std::numeric_limits<float>::lowest()) {
return Error(attr_loc, "value out of range for float");
}
static_cast<optional<float>*>(attr_out_ptr)
->emplace(static_cast<float>(result));
return true;
}
case AttrTy::kHloComputation: {
HloComputation* result = nullptr;
if (!ParseHloComputation(&result)) {
return false;
}
static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kBracedHloComputationList: {
std::vector<HloComputation*> result;
if (!ParseHloComputationList(&result)) {
return false;
}
static_cast<optional<std::vector<HloComputation*>>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kFftType: {
FftType result;
if (!ParseFftType(&result)) {
return false;
}
static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kComparisonDirection: {
ComparisonDirection result;
if (!ParseComparisonDirection(&result)) {
return false;
}
static_cast<optional<ComparisonDirection>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kWindow: {
Window result;
if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
return false;
}
static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kConvolutionDimensionNumbers: {
ConvolutionDimensionNumbers result;
if (!ParseConvolutionDimensionNumbers(&result)) {
return false;
}
static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kSharding: {
OpSharding sharding;
if (!ParseSharding(&sharding)) {
return false;
}
static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
return true;
}
case AttrTy::kFrontendAttributes: {
FrontendAttributes frontend_attributes;
if (!ParseFrontendAttributes(&frontend_attributes)) {
return false;
}
static_cast<optional<FrontendAttributes>*>(attr_out_ptr)
->emplace(frontend_attributes);
return true;
}
case AttrTy::kParameterReplication: {
ParameterReplication parameter_replication;
if (!ParseParameterReplication(&parameter_replication)) {
return false;
}
static_cast<optional<ParameterReplication>*>(attr_out_ptr)
->emplace(parameter_replication);
return true;
}
case AttrTy::kInstructionList: {
std::vector<HloInstruction*> result;
if (!ParseInstructionNames(&result)) {
return false;
}
static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kFusionKind: {
HloInstruction::FusionKind result;
if (!ParseFusionKind(&result)) {
return false;
}
static_cast<optional<HloInstruction::FusionKind>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kBracedInt64List: {
std::vector<int64> result;
if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
&result)) {
return false;
}
static_cast<optional<std::vector<int64>>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kBracedInt64ListList: {
std::vector<std::vector<int64>> result;
if (!ParseInt64ListList(TokKind::kLbrace, TokKind::kRbrace,
TokKind::kComma, &result)) {
return false;
}
static_cast<optional<std::vector<std::vector<int64>>>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kSliceRanges: {
SliceRanges result;
if (!ParseSliceRanges(&result)) {
return false;
}
static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kPaddingConfig: {
PaddingConfig result;
if (!ParsePaddingConfig(&result)) {
return false;
}
static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kString: {
string result;
if (!ParseString(&result)) {
return false;
}
static_cast<optional<string>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kMetadata: {
OpMetadata result;
if (!ParseMetadata(&result)) {
return false;
}
static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result);
return true;
}
case AttrTy::kDistribution: {
RandomDistribution result;
if (!ParseRandomDistribution(&result)) {
return false;
}
static_cast<optional<RandomDistribution>*>(attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kDomain: {
return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
}
case AttrTy::kPrecisionList: {
std::vector<PrecisionConfig::Precision> result;
if (!ParsePrecisionList(&result)) {
return false;
}
static_cast<optional<std::vector<PrecisionConfig::Precision>>*>(
attr_out_ptr)
->emplace(result);
return true;
}
case AttrTy::kShapeList: {
std::vector<Shape> result;
if (!ParseShapeList(&result)) {
return false;
}
static_cast<optional<std::vector<Shape>>*>(attr_out_ptr)
->emplace(result);
return true;
}
}
}();
if (!success) {
return Error(loc, StrFormat("error parsing attribute %s", name));
}
return true;
}
// attributes ::= (',' attribute)*
bool HloParser::ParseAttributesAsProtoMessage(
const std::unordered_set<string>& required_attrs,
tensorflow::protobuf::Message* message) {
LocTy loc = lexer_.GetLoc();
std::unordered_set<string> seen_attrs;
while (EatIfPresent(TokKind::kComma)) {
if (!ParseAttributeAsProtoMessageHelper(message, &seen_attrs)) {
return false;
}
}
// Check that all required attrs were seen.
for (const string& attr : required_attrs) {
if (seen_attrs.find(attr) == seen_attrs.end()) {
return Error(loc,
StrFormat("attribute %s is expected but not seen", attr));
}
}
return true;
}
bool HloParser::ParseAttributeAsProtoMessageHelper(
tensorflow::protobuf::Message* message,
std::unordered_set<string>* seen_attrs) {
LocTy loc = lexer_.GetLoc();
string name;
if (!ParseAttributeName(&name)) {
return Error(loc, "error parsing attributes");
}
VLOG(3) << "Parsing attribute " << name;
if (!seen_attrs->insert(name).second) {
return Error(loc, StrFormat("attribute %s already exists", name));
}
const tensorflow::protobuf::Descriptor* descriptor = message->GetDescriptor();
const tensorflow::protobuf::FieldDescriptor* fd =
descriptor->FindFieldByName(name);
if (!fd) {
string allowed_attrs = "Allowed attributes: ";
for (int i = 0; i < descriptor->field_count(); ++i) {
if (i == 0) {
absl::StrAppend(&allowed_attrs, descriptor->field(i)->name());
} else {
absl::StrAppend(&allowed_attrs, ", ", descriptor->field(i)->name());
}
}
return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name,
allowed_attrs));
}
const tensorflow::protobuf::Reflection* reflection = message->GetReflection();
CHECK(!fd->is_repeated()); // Repeated fields not implemented.
bool success = [&] {
switch (fd->type()) {
case tensorflow::protobuf::FieldDescriptor::TYPE_BOOL: {
bool result;
if (!ParseBool(&result)) {
return false;
}
reflection->SetBool(message, fd, result);
return true;
}
case tensorflow::protobuf::FieldDescriptor::TYPE_ENUM: {
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError(
StrFormat("expects %s type", fd->enum_type()->name()));
}
string val = lexer_.GetStrVal();
const tensorflow::protobuf::EnumValueDescriptor* evd =
fd->enum_type()->FindValueByName(val);
if (evd == nullptr) {
return TokenError(StrFormat("expects %s type but sees: %s",
fd->enum_type()->name(), val));
}
reflection->SetEnum(message, fd, evd);
lexer_.Lex();
return true;
}
default:
LOG(ERROR) << "Unimplemented protocol buffer type "
<< fd->DebugString();
return false;
}
}();
if (!success) {
return Error(loc, StrFormat("error parsing attribute %s", name));
}
return true;
}
bool HloParser::ParseComputationName(HloComputation** value) {
string name;
LocTy loc = lexer_.GetLoc();
if (!ParseName(&name)) {
return Error(loc, "expects computation name");
}
std::pair<HloComputation*, LocTy>* computation =
tensorflow::gtl::FindOrNull(computation_pool_, name);
if (computation == nullptr) {
return Error(loc, StrCat("computation does not exist: ", name));
}
*value = computation->first;
return true;
}
// ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
// The subattributes can appear in any order. 'size=' is required, others are
// optional.
bool HloParser::ParseWindow(Window* window, bool expect_outer_curlies) {
LocTy loc = lexer_.GetLoc();
if (expect_outer_curlies &&
!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
return false;
}
std::vector<int64> size;
std::vector<int64> stride;
std::vector<std::vector<int64>> pad;
std::vector<int64> lhs_dilate;
std::vector<int64> rhs_dilate;
std::vector<int64> rhs_reversal;
const auto end_token =
expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof;
while (lexer_.GetKind() != end_token) {
LocTy attr_loc = lexer_.GetLoc();
string field_name;
if (!ParseAttributeName(&field_name)) {
return Error(attr_loc, "expects sub-attributes in window");
}
bool ok = [&] {
if (field_name == "size") {
return ParseDxD("size", &size);
}
if (field_name == "stride") {
return ParseDxD("stride", &stride);
}
if (field_name == "lhs_dilate") {
return ParseDxD("lhs_dilate", &lhs_dilate);
}
if (field_name == "rhs_dilate") {
return ParseDxD("rls_dilate", &rhs_dilate);
}
if (field_name == "pad") {
return ParseWindowPad(&pad);
}
if (field_name == "rhs_reversal") {
return ParseDxD("rhs_reversal", &rhs_reversal);
}
return Error(attr_loc, StrCat("unexpected attribute name: ", field_name));
}();
if (!ok) {
return false;
}
}
if (size.empty()) {
return Error(loc,
"sub-attribute 'size=' is required in the window attribute");
}
if (!stride.empty() && stride.size() != size.size()) {
return Error(loc, "expects 'stride=' has the same size as 'size='");
}
if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) {
return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='");
}
if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) {
return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='");
}
if (!pad.empty() && pad.size() != size.size()) {
return Error(loc, "expects 'pad=' has the same size as 'size='");
}
for (int i = 0; i < size.size(); i++) {
window->add_dimensions()->set_size(size[i]);
if (!pad.empty()) {
window->mutable_dimensions(i)->set_padding_low(pad[i][0]);
window->mutable_dimensions(i)->set_padding_high(pad[i][1]);
}
// If some field is not present, it has the default value.
window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]);
window->mutable_dimensions(i)->set_base_dilation(
lhs_dilate.empty() ? 1 : lhs_dilate[i]);
window->mutable_dimensions(i)->set_window_dilation(
rhs_dilate.empty() ? 1 : rhs_dilate[i]);
window->mutable_dimensions(i)->set_window_reversal(
rhs_reversal.empty() ? false : (rhs_reversal[i] == 1));
}
return !expect_outer_curlies ||
ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
}
// This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
// The string looks like "dim_labels=0bf_0io->0bf".
bool HloParser::ParseConvolutionDimensionNumbers(
ConvolutionDimensionNumbers* dnums) {
if (lexer_.GetKind() != TokKind::kDimLabels) {
return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'");
}
string str = lexer_.GetStrVal();
// The str is expected to have 3 items, lhs, rhs, out, and it must look like
// lhs_rhs->out, that is, the first separator is "_" and the second is "->".
std::vector<string> split1 = absl::StrSplit(str, '_');
if (split1.size() != 2) {
LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
<< str;
}
std::vector<string> split2 = absl::StrSplit(split1[1], "->");
if (split2.size() != 2) {
LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
<< str;
}
absl::string_view lhs = split1[0];
absl::string_view rhs = split2[0];
absl::string_view out = split2[1];
const int64 rank = lhs.length();
if (rank != rhs.length() || rank != out.length()) {
return TokenError(
"convolution lhs, rhs, and output must have the same rank");
}
if (rank < 2) {
return TokenError("convolution rank must >=2");
}
auto is_unique = [](string str) -> bool {
absl::c_sort(str);
return std::unique(str.begin(), str.end()) == str.end();
};
// lhs
{
if (!is_unique(string(lhs))) {
return TokenError(
StrCat("expects unique lhs dimension numbers, but sees ", lhs));
}
for (int i = 0; i < rank - 2; i++) {
dnums->add_input_spatial_dimensions(-1);
}
for (int i = 0; i < rank; i++) {
char c = lhs[i];
if (c == 'b') {
dnums->set_input_batch_dimension(i);
} else if (c == 'f') {
dnums->set_input_feature_dimension(i);
} else if (c < '0' + rank && c >= '0') {
dnums->set_input_spatial_dimensions(c - '0', i);
} else {
return TokenError(
StrFormat("expects [0-%dbf] in lhs dimension numbers", rank - 1));
}
}
}
// rhs
{
if (!is_unique(string(rhs))) {
return TokenError(
StrCat("expects unique rhs dimension numbers, but sees ", rhs));
}
for (int i = 0; i < rank - 2; i++) {
dnums->add_kernel_spatial_dimensions(-1);
}
for (int i = 0; i < rank; i++) {
char c = rhs[i];
if (c == 'i') {
dnums->set_kernel_input_feature_dimension(i);
} else if (c == 'o') {
dnums->set_kernel_output_feature_dimension(i);
} else if (c < '0' + rank && c >= '0') {
dnums->set_kernel_spatial_dimensions(c - '0', i);
} else {
return TokenError(
StrFormat("expects [0-%dio] in rhs dimension numbers", rank - 1));
}
}
}
// output
{
if (!is_unique(string(out))) {
return TokenError(
StrCat("expects unique output dimension numbers, but sees ", out));
}
for (int i = 0; i < rank - 2; i++) {
dnums->add_output_spatial_dimensions(-1);
}
for (int i = 0; i < rank; i++) {
char c = out[i];
if (c == 'b') {
dnums->set_output_batch_dimension(i);
} else if (c == 'f') {
dnums->set_output_feature_dimension(i);
} else if (c < '0' + rank && c >= '0') {
dnums->set_output_spatial_dimensions(c - '0', i);
} else {
return TokenError(StrFormat(
"expects [0-%dbf] in output dimension numbers", rank - 1));
}
}
}
lexer_.Lex();
return true;
}
// ::= '{' ranges '}'
// ::= /*empty*/
// ::= range (',' range)*
// range ::= '[' start ':' limit (':' stride)? ']'
//
// The slice ranges are printed as:
//
// {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...}
//
// This function extracts the starts, limits, and strides as 3 vectors to the
// result. If stride is not present, stride is 1. For example, if the slice
// ranges is printed as:
//
// {[2:3:4], [5:6:7], [8:9]}
//
// The parsed result will be:
//
// {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}}
//
bool HloParser::ParseSliceRanges(SliceRanges* result) {
if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) {
return false;
}
std::vector<std::vector<int64>> ranges;
if (lexer_.GetKind() == TokKind::kRbrace) {
// empty
return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
}
do {
LocTy loc = lexer_.GetLoc();
ranges.emplace_back();
if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon,
&ranges.back())) {
return false;
}
const auto& range = ranges.back();
if (range.size() != 2 && range.size() != 3) {
return Error(loc,
StrFormat("expects [start:limit:step] or [start:limit], "
"but sees %d elements.",
range.size()));
}
} while (EatIfPresent(TokKind::kComma));
for (const auto& range : ranges) {
result->starts.push_back(range[0]);
result->limits.push_back(range[1]);
result->strides.push_back(range.size() == 3 ? range[2] : 1);
}
return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
}
// precisionlist ::= start precision_elements end
// precision_elements
// ::= /*empty*/
// ::= precision_val (delim precision_val)*
bool HloParser::ParsePrecisionList(
std::vector<PrecisionConfig::Precision>* result) {
auto parse_and_add_item = [&]() {
PrecisionConfig::Precision item;
if (!ParsePrecision(&item)) {
return false;
}
result->push_back(item);
return true;
};
return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
parse_and_add_item);
}
bool HloParser::ParseHloComputation(HloComputation** result) {
if (lexer_.GetKind() == TokKind::kLbrace) {
// This means it is a nested computation.
return ParseInstructionList(result, /*computation_name=*/"_");
}
// This means it is a computation name.
return ParseComputationName(result);
}
bool HloParser::ParseHloComputationList(std::vector<HloComputation*>* result) {
auto parse_and_add_item = [&]() {
HloComputation* computation;
if (!ParseHloComputation(&computation)) {
return false;
}
LOG(INFO) << "parsed computation " << computation->name();
result->push_back(computation);
return true;
};
return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
parse_and_add_item);
}
// shapelist ::= '{' shapes '}'
// precision_elements
// ::= /*empty*/
// ::= shape (',' shape)*
bool HloParser::ParseShapeList(std::vector<Shape>* result) {
auto parse_and_add_item = [&]() {
Shape shape;
if (!ParseShape(&shape)) {
return false;
}
result->push_back(std::move(shape));
return true;
};
return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
parse_and_add_item);
}
// int64list ::= start int64_elements end
// int64_elements
// ::= /*empty*/
// ::= int64_val (delim int64_val)*
bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<int64>* result) {
auto parse_and_add_item = [&]() {
int64 i;
if (!ParseInt64(&i)) {
return false;
}
result->push_back(i);
return true;
};
return ParseList(start, end, delim, parse_and_add_item);
}
// int64listlist ::= start int64list_elements end
// int64list_elements
// ::= /*empty*/
// ::= int64list (delim int64list)*
// int64list ::= start int64_elements end
// int64_elements
// ::= /*empty*/
// ::= int64_val (delim int64_val)*
bool HloParser::ParseInt64ListList(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<std::vector<int64>>* result) {
auto parse_and_add_item = [&]() {
std::vector<int64> item;
if (!ParseInt64List(start, end, delim, &item)) {
return false;
}
result->push_back(item);
return true;
};
return ParseList(start, end, delim, parse_and_add_item);
}
bool HloParser::ParseList(const TokKind start, const TokKind end,
const TokKind delim,
const std::function<bool()>& parse_and_add_item) {
if (!ParseToken(start, StrCat("expects a list starting with ",
TokKindToString(start)))) {
return false;
}
if (lexer_.GetKind() == end) {
// empty
} else {
do {
if (!parse_and_add_item()) {
return false;
}
} while (EatIfPresent(delim));
}
return ParseToken(
end, StrCat("expects a list to end with ", TokKindToString(end)));
}
// param_list_to_shape ::= param_list '->' shape
bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) {
if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
return false;
}
*shape_loc = lexer_.GetLoc();
return ParseShape(shape);
}
bool HloParser::CanBeParamListToShape() {
return lexer_.GetKind() == TokKind::kLparen;
}
// param_list ::= '(' param_list1 ')'
// param_list1
// ::= /*empty*/
// ::= param (',' param)*
// param ::= name shape
bool HloParser::ParseParamList() {
if (!ParseToken(TokKind::kLparen,
"expects '(' at the beginning of param list")) {
return false;
}
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
} else {
do {
Shape shape;
string name;
if (!ParseName(&name) || !ParseShape(&shape)) {
return false;
}
} while (EatIfPresent(TokKind::kComma));
}
return ParseToken(TokKind::kRparen, "expects ')' at the end of param list");
}
// dimension_sizes ::= '[' dimension_list ']'
// dimension_list
// ::= /*empty*/
// ::= <=? int64 (',' param)*
// param ::= name shape
bool HloParser::ParseDimensionSizes(std::vector<int64>* dimension_sizes,
std::vector<bool>* dynamic_dimensions) {
auto parse_and_add_item = [&]() {
int64 i;
bool is_dynamic = false;
if (lexer_.GetKind() == TokKind::kLeq) {
is_dynamic = true;
lexer_.Lex();
}
if (!ParseInt64(&i)) {
return false;
}
dimension_sizes->push_back(i);
dynamic_dimensions->push_back(is_dynamic);
return true;
};
return ParseList(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
parse_and_add_item);
}
// tiles
// ::= /*empty*/
// ::= 'T' '(' dim_list ')'
// dim_list
// ::= /*empty*/
// ::= (int64 | '*') (',' (int64 | '*'))*
bool HloParser::ParseTiles(std::vector<Tile>* tiles) {
auto parse_and_add_tile_dimension = [&]() {
tensorflow::int64 i;
if (ParseInt64(&i)) {
tiles->back().add_dimensions(i);
return true;
}
if (lexer_.GetKind() == TokKind::kAsterisk) {
tiles->back().add_dimensions(Tile::kCombineDimension);
lexer_.Lex();
return true;
}
return false;
};
do {
tiles->push_back(Tile());
if (!ParseList(TokKind::kLparen, TokKind::kRparen, TokKind::kComma,
parse_and_add_tile_dimension)) {
return false;
}
} while (lexer_.GetKind() == TokKind::kLparen);
return true;
}
// int_attribute
// ::= /*empty*/
// ::= attr_token '(' attr_value ')'
// attr_token
// ::= 'E' | 'S'
// attr_value
// ::= int64
bool HloParser::ParseLayoutIntAttribute(int64* attr_value,
absl::string_view attr_description) {
if (!ParseToken(TokKind::kLparen,
StrCat("expects ", attr_description, " to start with ",
TokKindToString(TokKind::kLparen)))) {
return false;
}
if (!ParseInt64(attr_value)) {
return false;
}
if (!ParseToken(TokKind::kRparen,
StrCat("expects ", attr_description, " to end with ",
TokKindToString(TokKind::kRparen)))) {
return false;
}
return true;
}
// layout ::= '{' int64_list (':' tiles element_size_in_bits memory_space)? '}'
// element_size_in_bits
// ::= /*empty*/
// ::= 'E' '(' int64 ')'
// memory_space
// ::= /*empty*/
// ::= 'S' '(' int64 ')'
bool HloParser::ParseLayout(Layout* layout) {
std::vector<int64> minor_to_major;
std::vector<Tile> tiles;
tensorflow::int64 element_size_in_bits = 0;
tensorflow::int64 memory_space = 0;
auto parse_and_add_item = [&]() {
int64 i;
if (!ParseInt64(&i)) {
return false;
}
minor_to_major.push_back(i);
return true;
};
if (!ParseToken(TokKind::kLbrace,
StrCat("expects layout to start with ",
TokKindToString(TokKind::kLbrace)))) {
return false;
}
if (lexer_.GetKind() != TokKind::kRbrace) {
if (lexer_.GetKind() == TokKind::kInt) {
// Parse minor to major.
do {
if (!parse_and_add_item()) {
return false;
}
} while (EatIfPresent(TokKind::kComma));
}
if (lexer_.GetKind() == TokKind::kColon) {
lexer_.Lex();
if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "T") {
lexer_.Lex();
ParseTiles(&tiles);
}
if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "E") {
lexer_.Lex();
ParseLayoutIntAttribute(&element_size_in_bits, "element size in bits");
}
if (lexer_.GetKind() == TokKind::kIdent && lexer_.GetStrVal() == "S") {
lexer_.Lex();
ParseLayoutIntAttribute(&memory_space, "memory space");
}
}
}
if (!ParseToken(TokKind::kRbrace,
StrCat("expects layout to end with ",
TokKindToString(TokKind::kRbrace)))) {
return false;
}
std::vector<Tile> vec_tiles(tiles.size());
for (int i = 0; i < tiles.size(); i++) {
vec_tiles[i] = Tile(tiles[i]);
}
*layout = LayoutUtil::MakeLayout(minor_to_major, vec_tiles,
element_size_in_bits, memory_space);
return true;
}
// shape ::= shape_val_
// shape ::= '(' tuple_elements ')'
// tuple_elements
// ::= /*empty*/
// ::= shape (',' shape)*
bool HloParser::ParseShape(Shape* result) {
if (EatIfPresent(TokKind::kLparen)) { // Tuple
std::vector<Shape> shapes;
if (lexer_.GetKind() == TokKind::kRparen) {
/*empty*/
} else {
// shape (',' shape)*
do {
shapes.emplace_back();
if (!ParseShape(&shapes.back())) {
return false;
}
} while (EatIfPresent(TokKind::kComma));
}
*result = ShapeUtil::MakeTupleShape(shapes);
return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple.");
}
if (lexer_.GetKind() != TokKind::kPrimitiveType) {
return TokenError(absl::StrCat("expected primitive type, saw ",
TokKindToString(lexer_.GetKind())));
}
PrimitiveType primitive_type = lexer_.GetPrimitiveTypeVal();
lexer_.Lex();
// Each element contains a dimension size and a bool indicating whether this
// is a dynamic dimension.
std::vector<int64> dimension_sizes;
std::vector<bool> dynamic_dimensions;
if (!ParseDimensionSizes(&dimension_sizes, &dynamic_dimensions)) {
return false;
}
result->set_element_type(primitive_type);
for (int i = 0; i < dimension_sizes.size(); ++i) {
result->add_dimensions(dimension_sizes[i]);
result->set_dynamic_dimension(i, dynamic_dimensions[i]);
}
LayoutUtil::SetToDefaultLayout(result);
if (lexer_.GetKind() == TokKind::kw_sparse) {
lexer_.Lex();
const string message =
"expects a brace-bracketed integer for sparse layout";
int64 max_sparse_elements;
if (!ParseToken(TokKind::kLbrace, message) ||
!ParseInt64(&max_sparse_elements) ||
!ParseToken(TokKind::kRbrace, message)) {
return false;
}
*result->mutable_layout() =
LayoutUtil::MakeSparseLayout(max_sparse_elements);
return true;
}
// We need to lookahead to see if a following open brace is the start of a
// layout. The specific problematic case is:
//
// ENTRY %foo (x: f32[42]) -> f32[123] {
// ...
// }
//
// The open brace could either be the start of a computation or the start of a
// layout for the f32[123] shape. We consider it the start of a layout if the
// next token after the open brace is an integer or a colon.
if (lexer_.GetKind() == TokKind::kLbrace &&
(lexer_.LookAhead() == TokKind::kInt ||
lexer_.LookAhead() == TokKind::kColon)) {
Layout layout;
if (!ParseLayout(&layout)) {
return false;
}
if (layout.minor_to_major_size() != result->rank()) {
return Error(
lexer_.GetLoc(),
StrFormat("Dimensions size is %ld, but minor to major size is %ld.",
result->rank(), layout.minor_to_major_size()));
}
*result->mutable_layout() = layout;
}
return true;
}
bool HloParser::CanBeShape() {
// A non-tuple shape starts with a kPrimitiveType token; a tuple shape starts
// with '('.
return lexer_.GetKind() == TokKind::kPrimitiveType ||
lexer_.GetKind() == TokKind::kLparen;
}
bool HloParser::ParseName(string* result) {
VLOG(3) << "ParseName";
if (lexer_.GetKind() != TokKind::kIdent &&
lexer_.GetKind() != TokKind::kName) {
return TokenError("expects name");
}
*result = lexer_.GetStrVal();
lexer_.Lex();
return true;
}
bool HloParser::ParseAttributeName(string* result) {
if (lexer_.GetKind() != TokKind::kAttributeName) {
return TokenError("expects attribute name");
}
*result = lexer_.GetStrVal();
lexer_.Lex();
return true;
}
bool HloParser::ParseString(string* result) {
VLOG(3) << "ParseString";
if (lexer_.GetKind() != TokKind::kString) {
return TokenError("expects string");
}
*result = lexer_.GetStrVal();
lexer_.Lex();
return true;
}
bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) {
LocTy loc = lexer_.GetLoc();
if (!result->empty()) {
return Error(loc, StrFormat("sub-attribute '%s=' already exists", name));
}
// 1D
if (lexer_.GetKind() == TokKind::kInt) {
int64 number;
if (!ParseInt64(&number)) {
return Error(loc, StrFormat("expects sub-attribute '%s=i'", name));
}
result->push_back(number);
return true;
}
// 2D or higher.
if (lexer_.GetKind() == TokKind::kDxD) {
string str = lexer_.GetStrVal();
if (!SplitToInt64s(str, 'x', result)) {
return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name));
}
lexer_.Lex();
return true;
}
return TokenError("expects token type kInt or kDxD");
}
bool HloParser::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
LocTy loc = lexer_.GetLoc();
if (!pad->empty()) {
return Error(loc, "sub-attribute 'pad=' already exists");
}
if (lexer_.GetKind() != TokKind::kPad) {
return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
}
string str = lexer_.GetStrVal();
for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
std::vector<int64> low_high;
if (!SplitToInt64s(padding_dim_str, '_', &low_high) ||
low_high.size() != 2) {
return Error(loc,
"expects padding_low and padding_high separated by '_'");
}
pad->push_back(low_high);
}
lexer_.Lex();
return true;
}
// This is the inverse xla::ToString(PaddingConfig). The padding config string
// looks like "0_0_0x3_3_1". The string is first separated by 'x', each
// substring represents one PaddingConfigDimension. The substring is 3 (or 2)
// numbers joined by '_'.
bool HloParser::ParsePaddingConfig(PaddingConfig* padding) {
if (lexer_.GetKind() != TokKind::kPad) {
return TokenError("expects padding config, e.g., '0_0_0x3_3_1'");
}
LocTy loc = lexer_.GetLoc();
string str = lexer_.GetStrVal();
for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) {
std::vector<int64> padding_dim;
if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) ||
(padding_dim.size() != 2 && padding_dim.size() != 3)) {
return Error(loc,
"expects padding config pattern like 'low_high_interior' or "
"'low_high'");
}
auto* dim = padding->add_dimensions();
dim->set_edge_padding_low(padding_dim[0]);
dim->set_edge_padding_high(padding_dim[1]);
dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0);
}
lexer_.Lex();
return true;
}
// '{' metadata_string '}'
bool HloParser::ParseMetadata(OpMetadata* metadata) {
std::unordered_map<string, AttrConfig> attrs;
optional<string> op_type;
optional<string> op_name;
optional<string> source_file;
optional<int32> source_line;
attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type};
attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name};
attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file};
attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line};
if (!ParseSubAttributes(attrs)) {
return false;
}
if (op_type) {
metadata->set_op_type(*op_type);
}
if (op_name) {
metadata->set_op_name(*op_name);
}
if (source_file) {
metadata->set_source_file(*source_file);
}
if (source_line) {
metadata->set_source_line(*source_line);
}
return true;
}
bool HloParser::ParseOpcode(HloOpcode* result) {
VLOG(3) << "ParseOpcode";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects opcode");
}
string val = lexer_.GetStrVal();
auto status_or_result = StringToHloOpcode(val);
if (!status_or_result.ok()) {
return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val,
status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
bool HloParser::ParseFftType(FftType* result) {
VLOG(3) << "ParseFftType";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects fft type");
}
string val = lexer_.GetStrVal();
if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) {
return TokenError(StrFormat("expects fft type but sees: %s", val));
}
lexer_.Lex();
return true;
}
bool HloParser::ParseComparisonDirection(ComparisonDirection* result) {
VLOG(1) << "ParseComparisonDirection";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects comparison direction");
}
string val = lexer_.GetStrVal();
auto status_or_result = StringToComparisonDirection(val);
if (!status_or_result.ok()) {
return TokenError(
StrFormat("expects comparison direction but sees: %s", val));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
VLOG(3) << "ParseFusionKind";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects fusion kind");
}
string val = lexer_.GetStrVal();
auto status_or_result = StringToFusionKind(val);
if (!status_or_result.ok()) {
return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s",
val,
status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
bool HloParser::ParseRandomDistribution(RandomDistribution* result) {
VLOG(3) << "ParseRandomDistribution";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects random distribution");
}
string val = lexer_.GetStrVal();
auto status_or_result = StringToRandomDistribution(val);
if (!status_or_result.ok()) {
return TokenError(
StrFormat("expects random distribution but sees: %s, error: %s", val,
status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) {
VLOG(3) << "ParsePrecision";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects random distribution");
}
string val = lexer_.GetStrVal();
auto status_or_result = StringToPrecision(val);
if (!status_or_result.ok()) {
return TokenError(StrFormat("expects precision but sees: %s, error: %s",
val,
status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
bool HloParser::ParseInt64(int64* result) {
VLOG(3) << "ParseInt64";
if (lexer_.GetKind() != TokKind::kInt) {
return TokenError("expects integer");
}
*result = lexer_.GetInt64Val();
lexer_.Lex();
return true;
}
bool HloParser::ParseDouble(double* result) {
switch (lexer_.GetKind()) {
case TokKind::kDecimal: {
double val = lexer_.GetDecimalVal();
// If GetDecimalVal returns +/-inf, that means that we overflowed
// `double`.
if (std::isinf(val)) {
return TokenError(StrCat("Constant is out of range for double (+/-",
std::numeric_limits<double>::max(),
") and so is unparsable."));
}
*result = val;
break;
}
case TokKind::kInt:
*result = static_cast<double>(lexer_.GetInt64Val());
break;
case TokKind::kw_nan:
*result = std::numeric_limits<double>::quiet_NaN();
break;
case TokKind::kw_inf:
*result = std::numeric_limits<double>::infinity();
break;
case TokKind::kNegInf:
*result = -std::numeric_limits<double>::infinity();
break;
default:
return TokenError("expects decimal or integer");
}
lexer_.Lex();
return true;
}
bool HloParser::ParseComplex(std::complex<double>* result) {
if (lexer_.GetKind() != TokKind::kLparen) {
return TokenError("expects '(' before complex number");
}
lexer_.Lex();
double real;
LocTy loc = lexer_.GetLoc();
if (!ParseDouble(&real)) {
return Error(loc,
"expect floating-point value for real part of complex number");
}
if (lexer_.GetKind() != TokKind::kComma) {
return TokenError(
absl::StrFormat("expect comma after real part of complex literal"));
}
lexer_.Lex();
double imag;
loc = lexer_.GetLoc();
if (!ParseDouble(&imag)) {
return Error(
loc,
"expect floating-point value for imaginary part of complex number");
}
if (lexer_.GetKind() != TokKind::kRparen) {
return TokenError(absl::StrFormat("expect ')' after complex number"));
}
*result = std::complex<double>(real, imag);
lexer_.Lex();
return true;
}
bool HloParser::ParseBool(bool* result) {
if (lexer_.GetKind() != TokKind::kw_true &&
lexer_.GetKind() != TokKind::kw_false) {
return TokenError("expects true or false");
}
*result = lexer_.GetKind() == TokKind::kw_true;
lexer_.Lex();
return true;
}
bool HloParser::ParseToken(TokKind kind, const string& msg) {
VLOG(3) << "ParseToken " << TokKindToString(kind) << " " << msg;
if (lexer_.GetKind() != kind) {
return TokenError(msg);
}
lexer_.Lex();
return true;
}
bool HloParser::EatIfPresent(TokKind kind) {
if (lexer_.GetKind() != kind) {
return false;
}
lexer_.Lex();
return true;
}
bool HloParser::AddInstruction(const string& name, HloInstruction* instruction,
LocTy name_loc) {
auto result = current_name_table().insert({name, {instruction, name_loc}});
if (!result.second) {
Error(name_loc, StrCat("instruction already exists: ", name));
return Error(/*loc=*/result.first->second.second,
"instruction previously defined here");
}
return true;
}
bool HloParser::AddComputation(const string& name, HloComputation* computation,
LocTy name_loc) {
auto result = computation_pool_.insert({name, {computation, name_loc}});
if (!result.second) {
Error(name_loc, StrCat("computation already exists: ", name));
return Error(/*loc=*/result.first->second.second,
"computation previously defined here");
}
return true;
}
StatusOr<Shape> HloParser::ParseShapeOnly() {
lexer_.Lex();
Shape shape;
if (!ParseShape(&shape)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument("Syntax error:\nExtra content after shape");
}
return shape;
}
StatusOr<HloSharding> HloParser::ParseShardingOnly() {
lexer_.Lex();
OpSharding op_sharding;
if (!ParseSharding(&op_sharding)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument("Syntax error:\nExtra content after sharding");
}
return HloSharding::FromProto(op_sharding);
}
StatusOr<FrontendAttributes> HloParser::ParseFrontendAttributesOnly() {
lexer_.Lex();
FrontendAttributes attributes;
if (!ParseFrontendAttributes(&attributes)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument(
"Syntax error:\nExtra content after frontend attributes");
}
return attributes;
}
StatusOr<std::vector<bool>> HloParser::ParseParameterReplicationOnly() {
lexer_.Lex();
ParameterReplication parameter_replication;
if (!ParseParameterReplication(&parameter_replication)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument(
"Syntax error:\nExtra content after parameter replication");
}
return std::vector<bool>(
parameter_replication.replicated_at_leaf_buffers().begin(),
parameter_replication.replicated_at_leaf_buffers().end());
}
StatusOr<std::vector<ReplicaGroup>> HloParser::ParseReplicaGroupsOnly() {
lexer_.Lex();
std::vector<ReplicaGroup> replica_groups;
if (!ParseReplicaGroupsOnly(&replica_groups)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument("Syntax error:\nExtra content after replica groups");
}
return replica_groups;
}
StatusOr<Window> HloParser::ParseWindowOnly() {
lexer_.Lex();
Window window;
if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument("Syntax error:\nExtra content after window");
}
return window;
}
StatusOr<ConvolutionDimensionNumbers>
HloParser::ParseConvolutionDimensionNumbersOnly() {
lexer_.Lex();
ConvolutionDimensionNumbers dnums;
if (!ParseConvolutionDimensionNumbers(&dnums)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument(
"Syntax error:\nExtra content after convolution dnums");
}
return dnums;
}
StatusOr<PaddingConfig> HloParser::ParsePaddingConfigOnly() {
lexer_.Lex();
PaddingConfig padding_config;
if (!ParsePaddingConfig(&padding_config)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument("Syntax error:\nExtra content after PaddingConfig");
}
return padding_config;
}
bool HloParser::ParseSingleInstruction(HloModule* module) {
if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) {
LOG(FATAL) << "Parser state is not clean. Please do not call any other "
"methods before calling ParseSingleInstruction.";
}
HloComputation::Builder builder(module->name());
// The missing instruction hook we register creates the shaped instruction on
// the fly as a parameter and returns it.
int64 parameter_count = 0;
create_missing_instruction_ =
[this, &builder, &parameter_count](
const string& name,
const Shape& shape) -> std::pair<HloInstruction*, LocTy>* {
string new_name = name.empty() ? StrCat("_", parameter_count) : name;
HloInstruction* parameter = builder.AddInstruction(
HloInstruction::CreateParameter(parameter_count++, shape, new_name));
current_name_table()[new_name] = {parameter, lexer_.GetLoc()};
return tensorflow::gtl::FindOrNull(current_name_table(), new_name);
};
// Parse the instruction with the registered hook.
Scope scope(&scoped_name_tables_);
if (CanBeShape()) {
// This means that the instruction's left-hand side is probably omitted,
// e.g.
//
// f32[10] fusion(...), calls={...}
if (!ParseInstructionRhs(&builder, module->name(), lexer_.GetLoc())) {
return false;
}
} else {
// This means that the instruction's left-hand side might exist, e.g.
//
// foo = f32[10] fusion(...), calls={...}
string root_name;
if (!ParseInstruction(&builder, &root_name)) {
return false;
}
}
if (lexer_.GetKind() != TokKind::kEof) {
Error(
lexer_.GetLoc(),
"Syntax error:\nExpected eof after parsing single instruction. Did "
"you mean to write an HLO module and forget the \"HloModule\" header?");
return false;
}
module->AddEntryComputation(builder.Build());
for (auto& comp : computations_) {
module->AddEmbeddedComputation(std::move(comp));
}
return true;
}
} // namespace
StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
absl::string_view str, const HloModuleConfig& config) {
auto module = absl::make_unique<HloModule>(/*name=*/"_", config);
HloParser parser(str);
TF_RETURN_IF_ERROR(parser.Run(module.get()));
return std::move(module);
}
StatusOr<std::unique_ptr<HloModule>> ParseAndReturnUnverifiedModule(
absl::string_view str) {
return ParseAndReturnUnverifiedModule(str, HloModuleConfig());
}
Status ParseHloString(absl::string_view str, HloModule* module) {
TF_RET_CHECK(module->computation_count() == 0);
HloParser parser(str);
TF_RETURN_IF_ERROR(parser.Run(module));
return Status::OK();
}
StatusOr<HloSharding> ParseSharding(absl::string_view str) {
HloParser parser(str);
return parser.ParseShardingOnly();
}
StatusOr<FrontendAttributes> ParseFrontendAttributes(absl::string_view str) {
HloParser parser(str);
return parser.ParseFrontendAttributesOnly();
}
StatusOr<std::vector<bool>> ParseParameterReplication(absl::string_view str) {
HloParser parser(str);
return parser.ParseParameterReplicationOnly();
}
StatusOr<std::vector<ReplicaGroup>> ParseReplicaGroupsOnly(
absl::string_view str) {
HloParser parser(str);
return parser.ParseReplicaGroupsOnly();
}
StatusOr<Window> ParseWindow(absl::string_view str) {
HloParser parser(str);
return parser.ParseWindowOnly();
}
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
absl::string_view str) {
HloParser parser(str);
return parser.ParseConvolutionDimensionNumbersOnly();
}
StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
HloParser parser(str);
return parser.ParsePaddingConfigOnly();
}
StatusOr<Shape> ParseShape(absl::string_view str) {
HloParser parser(str);
return parser.ParseShapeOnly();
}
} // namespace xla