Merge pull request #31129 from AnthonyBarbier:frontend_attributes

PiperOrigin-RevId: 265793937
This commit is contained in:
TensorFlower Gardener 2019-08-27 16:48:01 -07:00
commit 47612c19ea
16 changed files with 525 additions and 18 deletions

View File

@ -203,14 +203,15 @@ cc_library(
visibility = [":friends"],
deps = [
":common",
":frontend_attributes_util",
":host_compute_metadata_proto",
":rearrange_function_argument",
":sharding_util",
":side_effect_util",
":tf2xla_util",
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:shape_inference",
"//tensorflow/compiler/jit:xla_cluster_util",
"//tensorflow/compiler/tf2xla:rearrange_function_argument",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@ -271,6 +272,21 @@ cc_library(
],
)
cc_library(
name = "frontend_attributes_util",
srcs = ["frontend_attributes_util.cc"],
hdrs = ["frontend_attributes_util.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
name = "sharding_util",
srcs = ["sharding_util.cc"],
@ -579,6 +595,7 @@ cc_library(
"functionalize_while.h",
],
deps = [
":frontend_attributes_util",
":functionalize_cond",
":functionalize_control_flow_util",
":tf2xla_util",

View File

@ -0,0 +1,41 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
const char kXlaFrontendAttributesAttrName[] = "_XlaFrontendAttributes";
xla::StatusOr<absl::optional<xla::FrontendAttributes>>
GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs) {
const AttrValue* attr = attrs.Find(kXlaFrontendAttributesAttrName);
if (attr == nullptr) {
return xla::StatusOr<absl::optional<xla::FrontendAttributes>>(
absl::nullopt);
}
xla::FrontendAttributes attributes;
if (!attributes.ParseFromString(attr->s())) {
return errors::InvalidArgument(
"Experimental _XlaFrontendAttributes attribute was not a valid encoded "
"xla::FrontendAttributes proto.");
}
return absl::optional<xla::FrontendAttributes>(attributes);
}
} // namespace tensorflow

View File

@ -0,0 +1,38 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_
#include <string>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
namespace tensorflow {
// Frontend Attributes Id.
extern const char kXlaFrontendAttributesAttrName[];
// Return the FrontendAttributes stored in the AttrSlice if there are some.
//
// Return an InvalidArgument error if some attributes are present but
// cannot be parsed.
xla::StatusOr<absl::optional<xla::FrontendAttributes>>
GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@ -494,6 +495,12 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
builder.Attr("cond", cond_name);
builder.Attr("body", body_name);
string outside_compilation;
string frontend_attributes;
if (GetNodeAttr(frame->loop_cond->def(), kXlaFrontendAttributesAttrName,
&frontend_attributes)
.ok()) {
builder.Attr(kXlaFrontendAttributesAttrName, frontend_attributes);
}
if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName,
&outside_compilation)
.ok()) {

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <functional>
#include <memory>
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
@ -98,6 +99,20 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel,
absl::optional<xla::OpSharding> op_sharding =
sharding_parse_result.ValueOrDie();
auto frontend_attributes_result =
GetFrontendAttributesFromAttrSlice(AttrSlice(op_kernel->def()));
OP_REQUIRES_OK(context, frontend_attributes_result.status());
absl::optional<xla::FrontendAttributes> attributes =
frontend_attributes_result.ValueOrDie();
xla::FrontendAttributes merged_attributes = b->frontend_attributes();
if (attributes.has_value()) {
merged_attributes.mutable_map()->insert(attributes.value().map().begin(),
attributes.value().map().end());
}
xla::XlaScopedFrontendAttributesAssignment assign_frontend_attributes(
b, std::move(merged_attributes));
// If no sharding metadata is found, XLA is free to use whatever device it
// wants. In practice this usually has the effect of placing things on device
// 0.

View File

@ -289,6 +289,15 @@ Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num,
return Status::OK();
}
Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp op,
std::string attribute,
std::string value) {
TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op));
auto* frontend_attributes = instr_proto->mutable_frontend_attributes();
(*frontend_attributes->mutable_map())[attribute] = std::move(value);
return Status::OK();
}
XlaComputation XlaBuilder::BuildAndNoteError() {
DCHECK(parent_builder_ != nullptr);
auto build_status = Build();
@ -2626,6 +2635,7 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
if (sharding_) {
*instr.mutable_sharding() = *sharding_;
}
*instr.mutable_frontend_attributes() = frontend_attributes_;
handle_to_index_[handle] = instructions_.size();
instructions_.push_back(std::move(instr));
@ -2683,32 +2693,67 @@ void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
}
}
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
const XlaOp& op) const {
TF_RETURN_IF_ERROR(first_error_);
namespace {
if (op.builder_ == nullptr) {
template <typename InstructionType>
StatusOr<InstructionType> LookUpInstructionByHandleInternal(
const absl::flat_hash_map<int64, int64>& handle_to_index,
const std::vector<HloInstructionProto>& instructions, int64 handle) {
auto it = handle_to_index.find(handle);
if (it == handle_to_index.end()) {
return InvalidArgument("No XlaOp with handle %d", handle);
}
return const_cast<InstructionType>(&instructions.at(it->second));
}
template <typename InstructionType, typename OpBuilderType,
typename BuilderType, typename OpType>
StatusOr<InstructionType> LookUpInstructionInternal(
const absl::flat_hash_map<int64, int64>& handle_to_index,
const std::vector<HloInstructionProto>& instructions,
OpBuilderType op_builder, BuilderType builder, OpType op_handle) {
if (op_builder == nullptr) {
return InvalidArgument(
"invalid XlaOp with handle %d; the builder of this op is freed",
op.handle());
op_handle);
}
if (op.builder_ != this) {
if (op_builder != builder) {
return InvalidArgument(
"XlaOp with handle %d is built by builder '%s', but is trying to use "
"it in builder '%s'",
op.handle(), op.builder_->name(), this->name());
op_handle, op_builder->name(), builder->name());
}
return LookUpInstructionByHandle(op.handle());
return LookUpInstructionByHandleInternal<InstructionType>(
handle_to_index, instructions, op_handle);
}
} // namespace
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
const XlaOp op) const {
TF_RETURN_IF_ERROR(first_error_);
return LookUpInstructionInternal<const HloInstructionProto*>(
handle_to_index_, instructions_, op.builder_, this, op.handle());
}
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
int64 handle) const {
auto it = handle_to_index_.find(handle);
if (it == handle_to_index_.end()) {
return InvalidArgument("No XlaOp with handle %d", handle);
}
return &instructions_[it->second];
return LookUpInstructionByHandleInternal<const HloInstructionProto*>(
handle_to_index_, instructions_, handle);
}
StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstruction(
const XlaOp op) {
TF_RETURN_IF_ERROR(first_error_);
return LookUpInstructionInternal<HloInstructionProto*>(
handle_to_index_, instructions_, op.builder_, this, op.handle());
}
StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstructionByHandle(
int64 handle) {
return LookUpInstructionByHandleInternal<HloInstructionProto*>(
handle_to_index_, instructions_, handle);
}
// Enqueues a "retrieve parameter value" instruction for a parameter that was

View File

@ -147,8 +147,8 @@ class XlaBuilder {
// Sets OpMetadata that will be added to all instructions until cleared.
//
// OpMetadata is often applied to a series of XLA HLO instructions. As a
// result, OpMetadata is set on the Computation Builder. All subsequent
// instructions generated via this Computation Builder will have the same
// result, OpMetadata is set on the computation builder. All subsequent
// instructions generated via this computation builder will have the same
// OpMetadata attached until a call to ClearOpMetadata.
void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); }
@ -158,6 +158,35 @@ class XlaBuilder {
// Sets an OpSharding that will be attached to all instructions until cleared.
void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
// Sets the FrontendAttributes that will be added to all instructions until
// cleared.
//
// FrontendAttributes are often applied to a series of XLA HLO instructions.
// As a result they are set on the computation builder and all the
// instructions generated via the computation builder will have the same
// frontend attributes attached to them.
void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) {
frontend_attributes_ = frontend_attributes;
}
// Swap the passed FrontendAttributes with the ones currently set.
//
// Return the old attributes.
FrontendAttributes SwapFrontendAttributes(
const FrontendAttributes& frontend_attributes) {
FrontendAttributes old_attributes = std::move(frontend_attributes_);
frontend_attributes_ = frontend_attributes;
return old_attributes;
}
// Returns the FrontendAttributes that will be attached to all instructions.
const FrontendAttributes& frontend_attributes() const {
return frontend_attributes_;
}
// Clears all the frontend attributes.
void ClearFrontendAttributes() { frontend_attributes_.Clear(); }
// Clears the sharding. Ops will be sharded according to the default placement
// policy.
void ClearSharding() { sharding_ = absl::nullopt; }
@ -314,6 +343,16 @@ class XlaBuilder {
ShapeIndex param_index;
};
// Looks up the HloInstruction and sets the frontend attribute "attribute" to
// "value".
//
// If the attribute already existed then its value is updated.
//
// Note: the attribute is only added to the HloInstruction, not to the
// builder.
Status SetInstructionFrontendAttribute(XlaOp op, string attribute,
string value);
private:
// Build helper which takes the id of the root operation..
StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions);
@ -595,9 +634,11 @@ class XlaBuilder {
void AddCalledComputation(const XlaComputation& computation,
HloInstructionProto* instr);
StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
StatusOr<const HloInstructionProto*> LookUpInstruction(XlaOp op) const;
StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
int64 handle) const;
StatusOr<HloInstructionProto*> LookUpMutableInstruction(XlaOp op);
StatusOr<HloInstructionProto*> LookUpMutableInstructionByHandle(int64 handle);
// Internal helper method that does the building for an arbitrary unary op.
XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
@ -707,6 +748,8 @@ class XlaBuilder {
XlaBuilder* parent_builder_{nullptr};
FrontendAttributes frontend_attributes_;
friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
const Shape& shape, const string& name,
const std::vector<bool>& replicated_at_leaf_buffers);
@ -1034,6 +1077,27 @@ class XlaScopedShardingAssignment {
absl::optional<OpSharding> prev_sharding_;
};
// RAII-style object: save the current builder's frontend attributes, and merge
// them with the new ones on construction.
// Restore the original attributes on destruction.
class XlaScopedFrontendAttributesAssignment {
public:
XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder,
FrontendAttributes attributes)
: builder_(builder) {
saved_ = builder_->SwapFrontendAttributes(attributes);
}
~XlaScopedFrontendAttributesAssignment() {
builder_->SetFrontendAttributes(saved_);
}
private:
xla::XlaBuilder* const builder_;
FrontendAttributes saved_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment);
};
// Free functions for building XlaOps. The intention is that these will
// become the public API for building XlaOps rather than calling methods on
// XlaBuilder directly.

View File

@ -978,5 +978,151 @@ TEST_F(XlaBuilderTest, CheckInputOutputAlias) {
EXPECT_EQ(*alias_p1, ShapeIndex({0}));
}
void ExpectAttributesMatch(const FrontendAttributes& attr,
const FrontendAttributes& ref) {
EXPECT_EQ(ref.map_size(), attr.map_size());
for (auto reference : ref.map()) {
auto other = attr.map().find(reference.first);
EXPECT_NE(other, attr.map().end());
EXPECT_EQ(other->second, reference.second);
}
}
void ExpectInstructionsAttributesMatch(
const HloModule& module, const std::vector<FrontendAttributes>& expected) {
ASSERT_EQ(module.computation_count(), 1);
auto expected_it = expected.begin();
for (auto inst : module.entry_computation()->instructions()) {
ASSERT_NE(expected_it, expected.end());
ExpectAttributesMatch(inst->frontend_attributes(), *expected_it);
expected_it++;
}
EXPECT_EQ(expected_it, expected.end());
}
TEST_F(XlaBuilderTest, SimpleSetFrontendAttributes) {
XlaBuilder b(TestName());
FrontendAttributes attributes;
ConstantR0(&b, 0); // No attribute set
(*attributes.mutable_map())["attr_a"] = "a";
b.SetFrontendAttributes(attributes);
ConstantR0(&b, 0); // One attribute: { "attr_a": "a" }
b.ClearFrontendAttributes();
ConstantR0(&b, 0); // No attribute set
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
std::vector<FrontendAttributes> expected{FrontendAttributes(), attributes,
FrontendAttributes()};
ExpectInstructionsAttributesMatch(*module, expected);
}
TEST_F(XlaBuilderTest, ComplexSetFrontendAttributes) {
XlaBuilder b(TestName());
ConstantR0(&b, 0); // No attribute set.
std::vector<FrontendAttributes> expected{FrontendAttributes()};
{
FrontendAttributes attributes;
(*attributes.mutable_map())["attr_a"] = "a";
b.SetFrontendAttributes(attributes);
ConstantR0(&b, 0); // One attribute: { "attr_a": "a" }
expected.push_back(attributes);
}
{
FrontendAttributes attributes;
(*attributes.mutable_map())["attr_b"] = "b";
b.SetFrontendAttributes(attributes);
ConstantR0(&b, 0); // One attribute: { "attr_b": "b" }
expected.push_back(attributes);
}
{
FrontendAttributes attributes;
(*attributes.mutable_map())["attr_b"] = "b";
(*attributes.mutable_map())["attr_c"] = "c";
b.SetFrontendAttributes(attributes);
ConstantR0(&b, 0); // Two attributes: { "attr_b": "b", "attr_c": "c" }
expected.push_back(attributes);
}
b.ClearFrontendAttributes();
ConstantR0(&b, 0); // No attribute set
expected.push_back(FrontendAttributes());
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
ExpectInstructionsAttributesMatch(*module, expected);
}
TEST_F(XlaBuilderTest, AddFrontendAttribute) {
XlaBuilder b(TestName());
ConstantR0(&b, 0);
std::vector<FrontendAttributes> expected{FrontendAttributes()};
// One attribute: { "attr_a": "a" }
{
FrontendAttributes attributes;
(*attributes.mutable_map())["attr_a"] = "a";
b.SetFrontendAttributes(attributes);
ConstantR0(&b, 0);
expected.push_back(attributes);
}
// Two attributes: {"attra": "a", "attr_c": "c"}
{
auto op = ConstantR0(&b, 0);
EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_c", "c"));
FrontendAttributes attributes;
(*attributes.mutable_map())["attr_a"] = "a";
(*attributes.mutable_map())["attr_c"] = "c";
expected.push_back(attributes);
}
// Override value of existing "attr_a"
// One attribute: { "attr_a", "a2"}
{
auto op = ConstantR0(&b, 0);
EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_a", "a2"));
FrontendAttributes attributes;
(*attributes.mutable_map())["attr_a"] = "a2";
expected.push_back(attributes);
}
// Check "attr_a" is back to its original value
// One attribute: { "attr_a", "a"}
{
auto op = ConstantR0(&b, 0);
(void)op;
FrontendAttributes attributes;
(*attributes.mutable_map())["attr_a"] = "a";
expected.push_back(attributes);
}
b.ClearFrontendAttributes();
ConstantR0(&b, 0); // No attribute set
expected.push_back(FrontendAttributes());
// One attribute: { "attr_d", "d"}
{
auto op = ConstantR0(&b, 0);
EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_d", "d"));
FrontendAttributes attributes;
(*attributes.mutable_map())["attr_d"] = "d";
expected.push_back(attributes);
}
ConstantR0(&b, 0); // No attribute set
expected.push_back(FrontendAttributes());
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
ExpectInstructionsAttributesMatch(*module, expected);
}
} // namespace
} // namespace xla

View File

@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
// Next ID: 68
// Next ID: 69
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@ -234,6 +234,9 @@ message HloInstructionProto {
// Specifies if the gather/scatter indices are guaranteed to be sorted by the
// caller.
bool indices_are_sorted = 67;
// Frontend attributes to pass to the XLA backend.
xla.FrontendAttributes frontend_attributes = 68;
}
// Serialization of HloComputation.

View File

@ -837,6 +837,10 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
if (new_instruction->metadata().op_name().empty()) {
new_instruction->set_metadata(old_instruction->metadata());
}
if (new_instruction->frontend_attributes().map().empty()) {
new_instruction->set_frontend_attributes(
old_instruction->frontend_attributes());
}
// Like the metadata above, if the user didn't specify any sharding
// information on the new instruction we should copy the old sharding

View File

@ -674,6 +674,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->set_sharding(sharding);
}
if (proto.has_frontend_attributes()) {
instruction->set_frontend_attributes(proto.frontend_attributes());
}
return std::move(instruction);
}
@ -1194,6 +1198,7 @@ HloInstruction::CreateBroadcastSequence(
if (operand->has_sharding()) {
broadcast->set_sharding(operand->sharding());
}
broadcast->set_frontend_attributes(operand->frontend_attributes());
return broadcast;
}
// Do explicit broadcast for degenerate broadcast.
@ -1219,6 +1224,7 @@ HloInstruction::CreateBroadcastSequence(
if (operand->has_sharding()) {
reshaped_operand->set_sharding(operand->sharding());
}
reshaped_operand->set_frontend_attributes(operand->frontend_attributes());
// Broadcast 'reshape' up to the larger size.
auto broadcast = HloInstruction::CreateBroadcast(
broadcast_shape, reshaped_operand, broadcast_dimensions);
@ -1226,6 +1232,7 @@ HloInstruction::CreateBroadcastSequence(
if (operand->has_sharding()) {
broadcast->set_sharding(operand->sharding());
}
broadcast->set_frontend_attributes(operand->frontend_attributes());
return broadcast;
}
@ -1296,6 +1303,7 @@ void HloInstruction::SetupDerivedInstruction(
derived_instruction->clear_sharding();
}
derived_instruction->set_metadata(metadata_);
derived_instruction->set_frontend_attributes(frontend_attributes_);
}
bool HloInstruction::HasSideEffectNoRecurse() const {
@ -2483,6 +2491,10 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
if (has_sharding()) {
extra.push_back(StrCat("sharding=", sharding().ToString()));
}
if (!frontend_attributes_.map().empty()) {
extra.push_back(StrCat("frontend_attributes=",
FrontendAttributesToString(frontend_attributes_)));
}
if (!outer_dimension_partitions_.empty()) {
extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}",
StrJoin(outer_dimension_partitions_, ",")));
@ -2543,6 +2555,8 @@ HloInstructionProto HloInstruction::ToProto() const {
}
}
*proto.mutable_frontend_attributes() = frontend_attributes_;
return proto;
}
@ -3197,6 +3211,15 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
return InvalidArgument("Unknown fusion kind: %s", kind_name);
}
string FrontendAttributesToString(
const FrontendAttributes& frontend_attributes) {
std::vector<std::pair<string, string>> sorted_attributes(
frontend_attributes.map().begin(), frontend_attributes.map().end());
absl::c_sort(sorted_attributes);
return absl::StrFormat(
"{%s}", absl::StrJoin(sorted_attributes, ",", absl::PairFormatter("=")));
}
string PaddingConfigToString(const PaddingConfig& padding) {
bool has_interior_padding =
absl::c_any_of(padding.dimensions(),

View File

@ -1385,6 +1385,14 @@ class HloInstruction {
}
Status set_backend_config(const tensorflow::protobuf::Message& proto);
void set_frontend_attributes(FrontendAttributes frontend_attributes) {
frontend_attributes_ = std::move(frontend_attributes);
}
const FrontendAttributes& frontend_attributes() const {
return frontend_attributes_;
}
// Getter/setter for raw JSON-encoded backend config. Prefer the
// functions above that deal in proto Messages where possible.
const string& raw_backend_config_string() const { return backend_config_; }
@ -1879,6 +1887,18 @@ class HloInstruction {
// HLO. See the documentation on backend_config().
string backend_config_;
// Attributes passed from the frontend to give hints to the backend about
// how to compile this HLO.
// HLO -> HLO transforms are expected to preserve these attributes on a
// "best effort" basis only.
// For example:
// x = const(10, frontend_attributes={x}
// y = const(10, frontend_attributes={y}
// z = add(x,y), frontend_attributes={y}
// Could be simplified to:
// z' = const(20), frontend_attributes={?}
FrontendAttributes frontend_attributes_;
// This field is assigned to true when backend_config_ is assigned to
// a default configuration.
bool is_default_config_ = false;
@ -1909,6 +1929,8 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
// Custom (de)stringification functions for protos that live inside
// HloInstruction.
string PaddingConfigToString(const PaddingConfig& padding);
string FrontendAttributesToString(
const FrontendAttributes& frontend_attributes);
string OpMetadataToString(const OpMetadata& metadata);
string RandomDistributionToString(const RandomDistribution& distribution);
string PrecisionToString(const PrecisionConfig::Precision& precision);

View File

@ -88,6 +88,7 @@ class HloParser {
// Stand alone parsing utils for various aggregate data types.
StatusOr<Shape> ParseShapeOnly();
StatusOr<HloSharding> ParseShardingOnly();
StatusOr<FrontendAttributes> ParseFrontendAttributesOnly();
StatusOr<std::vector<bool>> ParseParameterReplicationOnly();
StatusOr<Window> ParseWindowOnly();
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
@ -192,6 +193,7 @@ class HloParser {
kWindow,
kConvolutionDimensionNumbers,
kSharding,
kFrontendAttributes,
kParameterReplication,
kInstructionList,
kSliceRanges,
@ -271,6 +273,7 @@ class HloParser {
bool ParsePaddingConfig(PaddingConfig* padding);
bool ParseMetadata(OpMetadata* metadata);
bool ParseSharding(OpSharding* sharding);
bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes);
bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
bool ParseParameterReplication(ParameterReplication* parameter_replication);
bool ParseReplicaGroupsOnly(std::vector<ReplicaGroup>* replica_groups);
@ -677,7 +680,10 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
// Add optional attributes.
std::unordered_map<string, AttrConfig> attrs;
optional<OpSharding> sharding;
optional<FrontendAttributes> frontend_attributes;
attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
attrs["frontend_attributes"] = {
/*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes};
optional<ParameterReplication> parameter_replication;
attrs["parameter_replication"] = {/*required=*/false,
AttrTy::kParameterReplication,
@ -1845,6 +1851,36 @@ bool HloParser::ParseSharding(OpSharding* sharding) {
return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
}
// frontend_attributes ::= '{' attributes '}'
// attributes
// ::= /*empty*/
// ::= attribute '=' value (',' attribute '=' value)*
bool HloParser::ParseFrontendAttributes(
FrontendAttributes* frontend_attributes) {
CHECK(frontend_attributes != nullptr);
if (!ParseToken(TokKind::kLbrace,
"expected '{' to start frontend attributes")) {
return false;
}
if (lexer_.GetKind() == TokKind::kRbrace) {
// empty
} else {
do {
string attribute;
if (!ParseAttributeName(&attribute)) {
return false;
}
if (lexer_.GetKind() != TokKind::kIdent) {
return false;
}
(*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal();
lexer_.Lex();
} while (EatIfPresent(TokKind::kComma));
}
return ParseToken(TokKind::kRbrace,
"expects '}' at the end of frontend attributes");
}
// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape?
// ('devices=' ('[' dims ']')* device_list)? '}'
// dims ::= int_list device_list ::= int_list
@ -2864,6 +2900,15 @@ bool HloParser::ParseAttributeHelper(
static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
return true;
}
case AttrTy::kFrontendAttributes: {
FrontendAttributes frontend_attributes;
if (!ParseFrontendAttributes(&frontend_attributes)) {
return false;
}
static_cast<optional<FrontendAttributes>*>(attr_out_ptr)
->emplace(frontend_attributes);
return true;
}
case AttrTy::kParameterReplication: {
ParameterReplication parameter_replication;
if (!ParseParameterReplication(&parameter_replication)) {
@ -4120,6 +4165,19 @@ StatusOr<HloSharding> HloParser::ParseShardingOnly() {
return HloSharding::FromProto(op_sharding);
}
StatusOr<FrontendAttributes> HloParser::ParseFrontendAttributesOnly() {
lexer_.Lex();
FrontendAttributes attributes;
if (!ParseFrontendAttributes(&attributes)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument(
"Syntax error:\nExtra content after frontend attributes");
}
return attributes;
}
StatusOr<std::vector<bool>> HloParser::ParseParameterReplicationOnly() {
lexer_.Lex();
ParameterReplication parameter_replication;
@ -4268,6 +4326,11 @@ StatusOr<HloSharding> ParseSharding(absl::string_view str) {
return parser.ParseShardingOnly();
}
StatusOr<FrontendAttributes> ParseFrontendAttributes(absl::string_view str) {
HloParser parser(str);
return parser.ParseFrontendAttributesOnly();
}
StatusOr<std::vector<bool>> ParseParameterReplication(absl::string_view str) {
HloParser parser(str);
return parser.ParseParameterReplicationOnly();

View File

@ -54,6 +54,12 @@ Status ParseHloString(absl::string_view str, HloModule* module);
// "{replicated}".
StatusOr<HloSharding> ParseSharding(absl::string_view str);
// Parses frontend attributes from str. str is supposed to contain the body of
// the frontend attributes , i.e. just the rhs of the
// "frontend_attributes={...}" attribute string, e.g.,
// "{attr_a=a,attr_b=b}".
StatusOr<FrontendAttributes> ParseFrontendAttributes(absl::string_view str);
// Parses parameter replication from str. str is supposed to contain the body of
// the parameter replication, i.e. just the rhs of the
// "parameter_replication={...}" attribute string, e.g., "{true, false}".

View File

@ -2358,6 +2358,13 @@ TEST_F(HloParserTest, ParseSharding) {
EXPECT_EQ(sharding.ToString(), original);
}
TEST_F(HloParserTest, ParseFrontendAttributes) {
const string original = "{attr_a=test_a,attr_b=b}";
TF_ASSERT_OK_AND_ASSIGN(FrontendAttributes frontend_attributes,
ParseFrontendAttributes(original));
EXPECT_EQ(FrontendAttributesToString(frontend_attributes), original);
}
TEST_F(HloParserTest, ParseWindow) {
Window original = window_util::MakeWindow({1, 2, 3});
TF_ASSERT_OK_AND_ASSIGN(Window parsed,

View File

@ -583,6 +583,12 @@ message CholeskyOptions {
bool lower = 1;
}
// Generic map of attributes used to pass hints / configuration options from
// the Python frontend to the XLA backend.
message FrontendAttributes {
map<string, string> map = 1;
}
message OpSharding {
enum Type {
// This sharding is replicated across all devices (implies maximal,