Merge pull request #31129 from AnthonyBarbier:frontend_attributes

PiperOrigin-RevId: 265793937
This commit is contained in:
TensorFlower Gardener 2019-08-27 16:48:01 -07:00
commit 47612c19ea
16 changed files with 525 additions and 18 deletions

View File

@ -203,14 +203,15 @@ cc_library(
visibility = [":friends"], visibility = [":friends"],
deps = [ deps = [
":common", ":common",
":frontend_attributes_util",
":host_compute_metadata_proto", ":host_compute_metadata_proto",
":rearrange_function_argument",
":sharding_util", ":sharding_util",
":side_effect_util", ":side_effect_util",
":tf2xla_util", ":tf2xla_util",
"//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:shape_inference", "//tensorflow/compiler/jit:shape_inference",
"//tensorflow/compiler/jit:xla_cluster_util", "//tensorflow/compiler/jit:xla_cluster_util",
"//tensorflow/compiler/tf2xla:rearrange_function_argument",
"//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
@ -271,6 +272,21 @@ cc_library(
], ],
) )
cc_library(
name = "frontend_attributes_util",
srcs = ["frontend_attributes_util.cc"],
hdrs = ["frontend_attributes_util.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:optional",
],
)
cc_library( cc_library(
name = "sharding_util", name = "sharding_util",
srcs = ["sharding_util.cc"], srcs = ["sharding_util.cc"],
@ -579,6 +595,7 @@ cc_library(
"functionalize_while.h", "functionalize_while.h",
], ],
deps = [ deps = [
":frontend_attributes_util",
":functionalize_cond", ":functionalize_cond",
":functionalize_control_flow_util", ":functionalize_control_flow_util",
":tf2xla_util", ":tf2xla_util",

View File

@ -0,0 +1,41 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
const char kXlaFrontendAttributesAttrName[] = "_XlaFrontendAttributes";
xla::StatusOr<absl::optional<xla::FrontendAttributes>>
GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs) {
const AttrValue* attr = attrs.Find(kXlaFrontendAttributesAttrName);
if (attr == nullptr) {
return xla::StatusOr<absl::optional<xla::FrontendAttributes>>(
absl::nullopt);
}
xla::FrontendAttributes attributes;
if (!attributes.ParseFromString(attr->s())) {
return errors::InvalidArgument(
"Experimental _XlaFrontendAttributes attribute was not a valid encoded "
"xla::FrontendAttributes proto.");
}
return absl::optional<xla::FrontendAttributes>(attributes);
}
} // namespace tensorflow

View File

@ -0,0 +1,38 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_
#include <string>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
namespace tensorflow {
// Frontend Attributes Id.
extern const char kXlaFrontendAttributesAttrName[];
// Return the FrontendAttributes stored in the AttrSlice if there are some.
//
// Return an InvalidArgument error if some attributes are present but
// cannot be parsed.
xla::StatusOr<absl::optional<xla::FrontendAttributes>>
GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
#include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@ -494,6 +495,12 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
builder.Attr("cond", cond_name); builder.Attr("cond", cond_name);
builder.Attr("body", body_name); builder.Attr("body", body_name);
string outside_compilation; string outside_compilation;
string frontend_attributes;
if (GetNodeAttr(frame->loop_cond->def(), kXlaFrontendAttributesAttrName,
&frontend_attributes)
.ok()) {
builder.Attr(kXlaFrontendAttributesAttrName, frontend_attributes);
}
if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName, if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName,
&outside_compilation) &outside_compilation)
.ok()) { .ok()) {

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_context.h"
@ -98,6 +99,20 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel,
absl::optional<xla::OpSharding> op_sharding = absl::optional<xla::OpSharding> op_sharding =
sharding_parse_result.ValueOrDie(); sharding_parse_result.ValueOrDie();
auto frontend_attributes_result =
GetFrontendAttributesFromAttrSlice(AttrSlice(op_kernel->def()));
OP_REQUIRES_OK(context, frontend_attributes_result.status());
absl::optional<xla::FrontendAttributes> attributes =
frontend_attributes_result.ValueOrDie();
xla::FrontendAttributes merged_attributes = b->frontend_attributes();
if (attributes.has_value()) {
merged_attributes.mutable_map()->insert(attributes.value().map().begin(),
attributes.value().map().end());
}
xla::XlaScopedFrontendAttributesAssignment assign_frontend_attributes(
b, std::move(merged_attributes));
// If no sharding metadata is found, XLA is free to use whatever device it // If no sharding metadata is found, XLA is free to use whatever device it
// wants. In practice this usually has the effect of placing things on device // wants. In practice this usually has the effect of placing things on device
// 0. // 0.

View File

@ -289,6 +289,15 @@ Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num,
return Status::OK(); return Status::OK();
} }
Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp op,
std::string attribute,
std::string value) {
TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op));
auto* frontend_attributes = instr_proto->mutable_frontend_attributes();
(*frontend_attributes->mutable_map())[attribute] = std::move(value);
return Status::OK();
}
XlaComputation XlaBuilder::BuildAndNoteError() { XlaComputation XlaBuilder::BuildAndNoteError() {
DCHECK(parent_builder_ != nullptr); DCHECK(parent_builder_ != nullptr);
auto build_status = Build(); auto build_status = Build();
@ -2626,6 +2635,7 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
if (sharding_) { if (sharding_) {
*instr.mutable_sharding() = *sharding_; *instr.mutable_sharding() = *sharding_;
} }
*instr.mutable_frontend_attributes() = frontend_attributes_;
handle_to_index_[handle] = instructions_.size(); handle_to_index_[handle] = instructions_.size();
instructions_.push_back(std::move(instr)); instructions_.push_back(std::move(instr));
@ -2683,32 +2693,67 @@ void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
} }
} }
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction( namespace {
const XlaOp& op) const {
TF_RETURN_IF_ERROR(first_error_);
if (op.builder_ == nullptr) { template <typename InstructionType>
StatusOr<InstructionType> LookUpInstructionByHandleInternal(
const absl::flat_hash_map<int64, int64>& handle_to_index,
const std::vector<HloInstructionProto>& instructions, int64 handle) {
auto it = handle_to_index.find(handle);
if (it == handle_to_index.end()) {
return InvalidArgument("No XlaOp with handle %d", handle);
}
return const_cast<InstructionType>(&instructions.at(it->second));
}
template <typename InstructionType, typename OpBuilderType,
typename BuilderType, typename OpType>
StatusOr<InstructionType> LookUpInstructionInternal(
const absl::flat_hash_map<int64, int64>& handle_to_index,
const std::vector<HloInstructionProto>& instructions,
OpBuilderType op_builder, BuilderType builder, OpType op_handle) {
if (op_builder == nullptr) {
return InvalidArgument( return InvalidArgument(
"invalid XlaOp with handle %d; the builder of this op is freed", "invalid XlaOp with handle %d; the builder of this op is freed",
op.handle()); op_handle);
} }
if (op.builder_ != this) { if (op_builder != builder) {
return InvalidArgument( return InvalidArgument(
"XlaOp with handle %d is built by builder '%s', but is trying to use " "XlaOp with handle %d is built by builder '%s', but is trying to use "
"it in builder '%s'", "it in builder '%s'",
op.handle(), op.builder_->name(), this->name()); op_handle, op_builder->name(), builder->name());
} }
return LookUpInstructionByHandle(op.handle()); return LookUpInstructionByHandleInternal<InstructionType>(
handle_to_index, instructions, op_handle);
}
} // namespace
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
const XlaOp op) const {
TF_RETURN_IF_ERROR(first_error_);
return LookUpInstructionInternal<const HloInstructionProto*>(
handle_to_index_, instructions_, op.builder_, this, op.handle());
} }
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle( StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
int64 handle) const { int64 handle) const {
auto it = handle_to_index_.find(handle); return LookUpInstructionByHandleInternal<const HloInstructionProto*>(
if (it == handle_to_index_.end()) { handle_to_index_, instructions_, handle);
return InvalidArgument("No XlaOp with handle %d", handle); }
}
return &instructions_[it->second]; StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstruction(
const XlaOp op) {
TF_RETURN_IF_ERROR(first_error_);
return LookUpInstructionInternal<HloInstructionProto*>(
handle_to_index_, instructions_, op.builder_, this, op.handle());
}
StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstructionByHandle(
int64 handle) {
return LookUpInstructionByHandleInternal<HloInstructionProto*>(
handle_to_index_, instructions_, handle);
} }
// Enqueues a "retrieve parameter value" instruction for a parameter that was // Enqueues a "retrieve parameter value" instruction for a parameter that was

View File

@ -147,8 +147,8 @@ class XlaBuilder {
// Sets OpMetadata that will be added to all instructions until cleared. // Sets OpMetadata that will be added to all instructions until cleared.
// //
// OpMetadata is often applied to a series of XLA HLO instructions. As a // OpMetadata is often applied to a series of XLA HLO instructions. As a
// result, OpMetadata is set on the Computation Builder. All subsequent // result, OpMetadata is set on the computation builder. All subsequent
// instructions generated via this Computation Builder will have the same // instructions generated via this computation builder will have the same
// OpMetadata attached until a call to ClearOpMetadata. // OpMetadata attached until a call to ClearOpMetadata.
void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); } void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); }
@ -158,6 +158,35 @@ class XlaBuilder {
// Sets an OpSharding that will be attached to all instructions until cleared. // Sets an OpSharding that will be attached to all instructions until cleared.
void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
// Sets the FrontendAttributes that will be added to all instructions until
// cleared.
//
// FrontendAttributes are often applied to a series of XLA HLO instructions.
// As a result they are set on the computation builder and all the
// instructions generated via the computation builder will have the same
// frontend attributes attached to them.
void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) {
frontend_attributes_ = frontend_attributes;
}
// Swap the passed FrontendAttributes with the ones currently set.
//
// Return the old attributes.
FrontendAttributes SwapFrontendAttributes(
const FrontendAttributes& frontend_attributes) {
FrontendAttributes old_attributes = std::move(frontend_attributes_);
frontend_attributes_ = frontend_attributes;
return old_attributes;
}
// Returns the FrontendAttributes that will be attached to all instructions.
const FrontendAttributes& frontend_attributes() const {
return frontend_attributes_;
}
// Clears all the frontend attributes.
void ClearFrontendAttributes() { frontend_attributes_.Clear(); }
// Clears the sharding. Ops will be sharded according to the default placement // Clears the sharding. Ops will be sharded according to the default placement
// policy. // policy.
void ClearSharding() { sharding_ = absl::nullopt; } void ClearSharding() { sharding_ = absl::nullopt; }
@ -314,6 +343,16 @@ class XlaBuilder {
ShapeIndex param_index; ShapeIndex param_index;
}; };
// Looks up the HloInstruction and sets the frontend attribute "attribute" to
// "value".
//
// If the attribute already existed then its value is updated.
//
// Note: the attribute is only added to the HloInstruction, not to the
// builder.
Status SetInstructionFrontendAttribute(XlaOp op, string attribute,
string value);
private: private:
// Build helper which takes the id of the root operation.. // Build helper which takes the id of the root operation..
StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions); StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions);
@ -595,9 +634,11 @@ class XlaBuilder {
void AddCalledComputation(const XlaComputation& computation, void AddCalledComputation(const XlaComputation& computation,
HloInstructionProto* instr); HloInstructionProto* instr);
StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const; StatusOr<const HloInstructionProto*> LookUpInstruction(XlaOp op) const;
StatusOr<const HloInstructionProto*> LookUpInstructionByHandle( StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
int64 handle) const; int64 handle) const;
StatusOr<HloInstructionProto*> LookUpMutableInstruction(XlaOp op);
StatusOr<HloInstructionProto*> LookUpMutableInstructionByHandle(int64 handle);
// Internal helper method that does the building for an arbitrary unary op. // Internal helper method that does the building for an arbitrary unary op.
XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
@ -707,6 +748,8 @@ class XlaBuilder {
XlaBuilder* parent_builder_{nullptr}; XlaBuilder* parent_builder_{nullptr};
FrontendAttributes frontend_attributes_;
friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
const Shape& shape, const string& name, const Shape& shape, const string& name,
const std::vector<bool>& replicated_at_leaf_buffers); const std::vector<bool>& replicated_at_leaf_buffers);
@ -1034,6 +1077,27 @@ class XlaScopedShardingAssignment {
absl::optional<OpSharding> prev_sharding_; absl::optional<OpSharding> prev_sharding_;
}; };
// RAII-style object: save the current builder's frontend attributes, and merge
// them with the new ones on construction.
// Restore the original attributes on destruction.
class XlaScopedFrontendAttributesAssignment {
public:
XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder,
FrontendAttributes attributes)
: builder_(builder) {
saved_ = builder_->SwapFrontendAttributes(attributes);
}
~XlaScopedFrontendAttributesAssignment() {
builder_->SetFrontendAttributes(saved_);
}
private:
xla::XlaBuilder* const builder_;
FrontendAttributes saved_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment);
};
// Free functions for building XlaOps. The intention is that these will // Free functions for building XlaOps. The intention is that these will
// become the public API for building XlaOps rather than calling methods on // become the public API for building XlaOps rather than calling methods on
// XlaBuilder directly. // XlaBuilder directly.

View File

@ -978,5 +978,151 @@ TEST_F(XlaBuilderTest, CheckInputOutputAlias) {
EXPECT_EQ(*alias_p1, ShapeIndex({0})); EXPECT_EQ(*alias_p1, ShapeIndex({0}));
} }
void ExpectAttributesMatch(const FrontendAttributes& attr,
const FrontendAttributes& ref) {
EXPECT_EQ(ref.map_size(), attr.map_size());
for (auto reference : ref.map()) {
auto other = attr.map().find(reference.first);
EXPECT_NE(other, attr.map().end());
EXPECT_EQ(other->second, reference.second);
}
}
void ExpectInstructionsAttributesMatch(
const HloModule& module, const std::vector<FrontendAttributes>& expected) {
ASSERT_EQ(module.computation_count(), 1);
auto expected_it = expected.begin();
for (auto inst : module.entry_computation()->instructions()) {
ASSERT_NE(expected_it, expected.end());
ExpectAttributesMatch(inst->frontend_attributes(), *expected_it);
expected_it++;
}
EXPECT_EQ(expected_it, expected.end());
}
TEST_F(XlaBuilderTest, SimpleSetFrontendAttributes) {
XlaBuilder b(TestName());
FrontendAttributes attributes;
ConstantR0(&b, 0); // No attribute set
(*attributes.mutable_map())["attr_a"] = "a";
b.SetFrontendAttributes(attributes);
ConstantR0(&b, 0); // One attribute: { "attr_a": "a" }
b.ClearFrontendAttributes();
ConstantR0(&b, 0); // No attribute set
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
std::vector<FrontendAttributes> expected{FrontendAttributes(), attributes,
FrontendAttributes()};
ExpectInstructionsAttributesMatch(*module, expected);
}
TEST_F(XlaBuilderTest, ComplexSetFrontendAttributes) {
XlaBuilder b(TestName());
ConstantR0(&b, 0); // No attribute set.
std::vector<FrontendAttributes> expected{FrontendAttributes()};
{
FrontendAttributes attributes;
(*attributes.mutable_map())["attr_a"] = "a";
b.SetFrontendAttributes(attributes);
ConstantR0(&b, 0); // One attribute: { "attr_a": "a" }
expected.push_back(attributes);
}
{
FrontendAttributes attributes;
(*attributes.mutable_map())["attr_b"] = "b";
b.SetFrontendAttributes(attributes);
ConstantR0(&b, 0); // One attribute: { "attr_b": "b" }
expected.push_back(attributes);
}
{
FrontendAttributes attributes;
(*attributes.mutable_map())["attr_b"] = "b";
(*attributes.mutable_map())["attr_c"] = "c";
b.SetFrontendAttributes(attributes);
ConstantR0(&b, 0); // Two attributes: { "attr_b": "b", "attr_c": "c" }
expected.push_back(attributes);
}
b.ClearFrontendAttributes();
ConstantR0(&b, 0); // No attribute set
expected.push_back(FrontendAttributes());
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
ExpectInstructionsAttributesMatch(*module, expected);
}
TEST_F(XlaBuilderTest, AddFrontendAttribute) {
XlaBuilder b(TestName());
ConstantR0(&b, 0);
std::vector<FrontendAttributes> expected{FrontendAttributes()};
// One attribute: { "attr_a": "a" }
{
FrontendAttributes attributes;
(*attributes.mutable_map())["attr_a"] = "a";
b.SetFrontendAttributes(attributes);
ConstantR0(&b, 0);
expected.push_back(attributes);
}
// Two attributes: {"attra": "a", "attr_c": "c"}
{
auto op = ConstantR0(&b, 0);
EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_c", "c"));
FrontendAttributes attributes;
(*attributes.mutable_map())["attr_a"] = "a";
(*attributes.mutable_map())["attr_c"] = "c";
expected.push_back(attributes);
}
// Override value of existing "attr_a"
// One attribute: { "attr_a", "a2"}
{
auto op = ConstantR0(&b, 0);
EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_a", "a2"));
FrontendAttributes attributes;
(*attributes.mutable_map())["attr_a"] = "a2";
expected.push_back(attributes);
}
// Check "attr_a" is back to its original value
// One attribute: { "attr_a", "a"}
{
auto op = ConstantR0(&b, 0);
(void)op;
FrontendAttributes attributes;
(*attributes.mutable_map())["attr_a"] = "a";
expected.push_back(attributes);
}
b.ClearFrontendAttributes();
ConstantR0(&b, 0); // No attribute set
expected.push_back(FrontendAttributes());
// One attribute: { "attr_d", "d"}
{
auto op = ConstantR0(&b, 0);
EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_d", "d"));
FrontendAttributes attributes;
(*attributes.mutable_map())["attr_d"] = "d";
expected.push_back(attributes);
}
ConstantR0(&b, 0); // No attribute set
expected.push_back(FrontendAttributes());
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
ExpectInstructionsAttributesMatch(*module, expected);
}
} // namespace } // namespace
} // namespace xla } // namespace xla

View File

@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true; option cc_enable_arenas = true;
// Serialization of HloInstruction. // Serialization of HloInstruction.
// Next ID: 68 // Next ID: 69
message HloInstructionProto { message HloInstructionProto {
reserved 10; reserved 10;
reserved "parameter_name"; reserved "parameter_name";
@ -234,6 +234,9 @@ message HloInstructionProto {
// Specifies if the gather/scatter indices are guaranteed to be sorted by the // Specifies if the gather/scatter indices are guaranteed to be sorted by the
// caller. // caller.
bool indices_are_sorted = 67; bool indices_are_sorted = 67;
// Frontend attributes to pass to the XLA backend.
xla.FrontendAttributes frontend_attributes = 68;
} }
// Serialization of HloComputation. // Serialization of HloComputation.

View File

@ -837,6 +837,10 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
if (new_instruction->metadata().op_name().empty()) { if (new_instruction->metadata().op_name().empty()) {
new_instruction->set_metadata(old_instruction->metadata()); new_instruction->set_metadata(old_instruction->metadata());
} }
if (new_instruction->frontend_attributes().map().empty()) {
new_instruction->set_frontend_attributes(
old_instruction->frontend_attributes());
}
// Like the metadata above, if the user didn't specify any sharding // Like the metadata above, if the user didn't specify any sharding
// information on the new instruction we should copy the old sharding // information on the new instruction we should copy the old sharding

View File

@ -674,6 +674,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->set_sharding(sharding); instruction->set_sharding(sharding);
} }
if (proto.has_frontend_attributes()) {
instruction->set_frontend_attributes(proto.frontend_attributes());
}
return std::move(instruction); return std::move(instruction);
} }
@ -1194,6 +1198,7 @@ HloInstruction::CreateBroadcastSequence(
if (operand->has_sharding()) { if (operand->has_sharding()) {
broadcast->set_sharding(operand->sharding()); broadcast->set_sharding(operand->sharding());
} }
broadcast->set_frontend_attributes(operand->frontend_attributes());
return broadcast; return broadcast;
} }
// Do explicit broadcast for degenerate broadcast. // Do explicit broadcast for degenerate broadcast.
@ -1219,6 +1224,7 @@ HloInstruction::CreateBroadcastSequence(
if (operand->has_sharding()) { if (operand->has_sharding()) {
reshaped_operand->set_sharding(operand->sharding()); reshaped_operand->set_sharding(operand->sharding());
} }
reshaped_operand->set_frontend_attributes(operand->frontend_attributes());
// Broadcast 'reshape' up to the larger size. // Broadcast 'reshape' up to the larger size.
auto broadcast = HloInstruction::CreateBroadcast( auto broadcast = HloInstruction::CreateBroadcast(
broadcast_shape, reshaped_operand, broadcast_dimensions); broadcast_shape, reshaped_operand, broadcast_dimensions);
@ -1226,6 +1232,7 @@ HloInstruction::CreateBroadcastSequence(
if (operand->has_sharding()) { if (operand->has_sharding()) {
broadcast->set_sharding(operand->sharding()); broadcast->set_sharding(operand->sharding());
} }
broadcast->set_frontend_attributes(operand->frontend_attributes());
return broadcast; return broadcast;
} }
@ -1296,6 +1303,7 @@ void HloInstruction::SetupDerivedInstruction(
derived_instruction->clear_sharding(); derived_instruction->clear_sharding();
} }
derived_instruction->set_metadata(metadata_); derived_instruction->set_metadata(metadata_);
derived_instruction->set_frontend_attributes(frontend_attributes_);
} }
bool HloInstruction::HasSideEffectNoRecurse() const { bool HloInstruction::HasSideEffectNoRecurse() const {
@ -2483,6 +2491,10 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
if (has_sharding()) { if (has_sharding()) {
extra.push_back(StrCat("sharding=", sharding().ToString())); extra.push_back(StrCat("sharding=", sharding().ToString()));
} }
if (!frontend_attributes_.map().empty()) {
extra.push_back(StrCat("frontend_attributes=",
FrontendAttributesToString(frontend_attributes_)));
}
if (!outer_dimension_partitions_.empty()) { if (!outer_dimension_partitions_.empty()) {
extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}", extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}",
StrJoin(outer_dimension_partitions_, ","))); StrJoin(outer_dimension_partitions_, ",")));
@ -2543,6 +2555,8 @@ HloInstructionProto HloInstruction::ToProto() const {
} }
} }
*proto.mutable_frontend_attributes() = frontend_attributes_;
return proto; return proto;
} }
@ -3197,6 +3211,15 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
return InvalidArgument("Unknown fusion kind: %s", kind_name); return InvalidArgument("Unknown fusion kind: %s", kind_name);
} }
string FrontendAttributesToString(
const FrontendAttributes& frontend_attributes) {
std::vector<std::pair<string, string>> sorted_attributes(
frontend_attributes.map().begin(), frontend_attributes.map().end());
absl::c_sort(sorted_attributes);
return absl::StrFormat(
"{%s}", absl::StrJoin(sorted_attributes, ",", absl::PairFormatter("=")));
}
string PaddingConfigToString(const PaddingConfig& padding) { string PaddingConfigToString(const PaddingConfig& padding) {
bool has_interior_padding = bool has_interior_padding =
absl::c_any_of(padding.dimensions(), absl::c_any_of(padding.dimensions(),

View File

@ -1385,6 +1385,14 @@ class HloInstruction {
} }
Status set_backend_config(const tensorflow::protobuf::Message& proto); Status set_backend_config(const tensorflow::protobuf::Message& proto);
void set_frontend_attributes(FrontendAttributes frontend_attributes) {
frontend_attributes_ = std::move(frontend_attributes);
}
const FrontendAttributes& frontend_attributes() const {
return frontend_attributes_;
}
// Getter/setter for raw JSON-encoded backend config. Prefer the // Getter/setter for raw JSON-encoded backend config. Prefer the
// functions above that deal in proto Messages where possible. // functions above that deal in proto Messages where possible.
const string& raw_backend_config_string() const { return backend_config_; } const string& raw_backend_config_string() const { return backend_config_; }
@ -1879,6 +1887,18 @@ class HloInstruction {
// HLO. See the documentation on backend_config(). // HLO. See the documentation on backend_config().
string backend_config_; string backend_config_;
// Attributes passed from the frontend to give hints to the backend about
// how to compile this HLO.
// HLO -> HLO transforms are expected to preserve these attributes on a
// "best effort" basis only.
// For example:
// x = const(10, frontend_attributes={x}
// y = const(10, frontend_attributes={y}
// z = add(x,y), frontend_attributes={y}
// Could be simplified to:
// z' = const(20), frontend_attributes={?}
FrontendAttributes frontend_attributes_;
// This field is assigned to true when backend_config_ is assigned to // This field is assigned to true when backend_config_ is assigned to
// a default configuration. // a default configuration.
bool is_default_config_ = false; bool is_default_config_ = false;
@ -1909,6 +1929,8 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
// Custom (de)stringification functions for protos that live inside // Custom (de)stringification functions for protos that live inside
// HloInstruction. // HloInstruction.
string PaddingConfigToString(const PaddingConfig& padding); string PaddingConfigToString(const PaddingConfig& padding);
string FrontendAttributesToString(
const FrontendAttributes& frontend_attributes);
string OpMetadataToString(const OpMetadata& metadata); string OpMetadataToString(const OpMetadata& metadata);
string RandomDistributionToString(const RandomDistribution& distribution); string RandomDistributionToString(const RandomDistribution& distribution);
string PrecisionToString(const PrecisionConfig::Precision& precision); string PrecisionToString(const PrecisionConfig::Precision& precision);

View File

@ -88,6 +88,7 @@ class HloParser {
// Stand alone parsing utils for various aggregate data types. // Stand alone parsing utils for various aggregate data types.
StatusOr<Shape> ParseShapeOnly(); StatusOr<Shape> ParseShapeOnly();
StatusOr<HloSharding> ParseShardingOnly(); StatusOr<HloSharding> ParseShardingOnly();
StatusOr<FrontendAttributes> ParseFrontendAttributesOnly();
StatusOr<std::vector<bool>> ParseParameterReplicationOnly(); StatusOr<std::vector<bool>> ParseParameterReplicationOnly();
StatusOr<Window> ParseWindowOnly(); StatusOr<Window> ParseWindowOnly();
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly(); StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
@ -192,6 +193,7 @@ class HloParser {
kWindow, kWindow,
kConvolutionDimensionNumbers, kConvolutionDimensionNumbers,
kSharding, kSharding,
kFrontendAttributes,
kParameterReplication, kParameterReplication,
kInstructionList, kInstructionList,
kSliceRanges, kSliceRanges,
@ -271,6 +273,7 @@ class HloParser {
bool ParsePaddingConfig(PaddingConfig* padding); bool ParsePaddingConfig(PaddingConfig* padding);
bool ParseMetadata(OpMetadata* metadata); bool ParseMetadata(OpMetadata* metadata);
bool ParseSharding(OpSharding* sharding); bool ParseSharding(OpSharding* sharding);
bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes);
bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
bool ParseParameterReplication(ParameterReplication* parameter_replication); bool ParseParameterReplication(ParameterReplication* parameter_replication);
bool ParseReplicaGroupsOnly(std::vector<ReplicaGroup>* replica_groups); bool ParseReplicaGroupsOnly(std::vector<ReplicaGroup>* replica_groups);
@ -677,7 +680,10 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
// Add optional attributes. // Add optional attributes.
std::unordered_map<string, AttrConfig> attrs; std::unordered_map<string, AttrConfig> attrs;
optional<OpSharding> sharding; optional<OpSharding> sharding;
optional<FrontendAttributes> frontend_attributes;
attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
attrs["frontend_attributes"] = {
/*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes};
optional<ParameterReplication> parameter_replication; optional<ParameterReplication> parameter_replication;
attrs["parameter_replication"] = {/*required=*/false, attrs["parameter_replication"] = {/*required=*/false,
AttrTy::kParameterReplication, AttrTy::kParameterReplication,
@ -1845,6 +1851,36 @@ bool HloParser::ParseSharding(OpSharding* sharding) {
return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute"); return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
} }
// frontend_attributes ::= '{' attributes '}'
// attributes
// ::= /*empty*/
// ::= attribute '=' value (',' attribute '=' value)*
bool HloParser::ParseFrontendAttributes(
FrontendAttributes* frontend_attributes) {
CHECK(frontend_attributes != nullptr);
if (!ParseToken(TokKind::kLbrace,
"expected '{' to start frontend attributes")) {
return false;
}
if (lexer_.GetKind() == TokKind::kRbrace) {
// empty
} else {
do {
string attribute;
if (!ParseAttributeName(&attribute)) {
return false;
}
if (lexer_.GetKind() != TokKind::kIdent) {
return false;
}
(*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal();
lexer_.Lex();
} while (EatIfPresent(TokKind::kComma));
}
return ParseToken(TokKind::kRbrace,
"expects '}' at the end of frontend attributes");
}
// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? // ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape?
// ('devices=' ('[' dims ']')* device_list)? '}' // ('devices=' ('[' dims ']')* device_list)? '}'
// dims ::= int_list device_list ::= int_list // dims ::= int_list device_list ::= int_list
@ -2864,6 +2900,15 @@ bool HloParser::ParseAttributeHelper(
static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding); static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
return true; return true;
} }
case AttrTy::kFrontendAttributes: {
FrontendAttributes frontend_attributes;
if (!ParseFrontendAttributes(&frontend_attributes)) {
return false;
}
static_cast<optional<FrontendAttributes>*>(attr_out_ptr)
->emplace(frontend_attributes);
return true;
}
case AttrTy::kParameterReplication: { case AttrTy::kParameterReplication: {
ParameterReplication parameter_replication; ParameterReplication parameter_replication;
if (!ParseParameterReplication(&parameter_replication)) { if (!ParseParameterReplication(&parameter_replication)) {
@ -4120,6 +4165,19 @@ StatusOr<HloSharding> HloParser::ParseShardingOnly() {
return HloSharding::FromProto(op_sharding); return HloSharding::FromProto(op_sharding);
} }
StatusOr<FrontendAttributes> HloParser::ParseFrontendAttributesOnly() {
lexer_.Lex();
FrontendAttributes attributes;
if (!ParseFrontendAttributes(&attributes)) {
return InvalidArgument("Syntax error:\n%s", GetError());
}
if (lexer_.GetKind() != TokKind::kEof) {
return InvalidArgument(
"Syntax error:\nExtra content after frontend attributes");
}
return attributes;
}
StatusOr<std::vector<bool>> HloParser::ParseParameterReplicationOnly() { StatusOr<std::vector<bool>> HloParser::ParseParameterReplicationOnly() {
lexer_.Lex(); lexer_.Lex();
ParameterReplication parameter_replication; ParameterReplication parameter_replication;
@ -4268,6 +4326,11 @@ StatusOr<HloSharding> ParseSharding(absl::string_view str) {
return parser.ParseShardingOnly(); return parser.ParseShardingOnly();
} }
StatusOr<FrontendAttributes> ParseFrontendAttributes(absl::string_view str) {
HloParser parser(str);
return parser.ParseFrontendAttributesOnly();
}
StatusOr<std::vector<bool>> ParseParameterReplication(absl::string_view str) { StatusOr<std::vector<bool>> ParseParameterReplication(absl::string_view str) {
HloParser parser(str); HloParser parser(str);
return parser.ParseParameterReplicationOnly(); return parser.ParseParameterReplicationOnly();

View File

@ -54,6 +54,12 @@ Status ParseHloString(absl::string_view str, HloModule* module);
// "{replicated}". // "{replicated}".
StatusOr<HloSharding> ParseSharding(absl::string_view str); StatusOr<HloSharding> ParseSharding(absl::string_view str);
// Parses frontend attributes from str. str is supposed to contain the body of
// the frontend attributes , i.e. just the rhs of the
// "frontend_attributes={...}" attribute string, e.g.,
// "{attr_a=a,attr_b=b}".
StatusOr<FrontendAttributes> ParseFrontendAttributes(absl::string_view str);
// Parses parameter replication from str. str is supposed to contain the body of // Parses parameter replication from str. str is supposed to contain the body of
// the parameter replication, i.e. just the rhs of the // the parameter replication, i.e. just the rhs of the
// "parameter_replication={...}" attribute string, e.g., "{true, false}". // "parameter_replication={...}" attribute string, e.g., "{true, false}".

View File

@ -2358,6 +2358,13 @@ TEST_F(HloParserTest, ParseSharding) {
EXPECT_EQ(sharding.ToString(), original); EXPECT_EQ(sharding.ToString(), original);
} }
TEST_F(HloParserTest, ParseFrontendAttributes) {
const string original = "{attr_a=test_a,attr_b=b}";
TF_ASSERT_OK_AND_ASSIGN(FrontendAttributes frontend_attributes,
ParseFrontendAttributes(original));
EXPECT_EQ(FrontendAttributesToString(frontend_attributes), original);
}
TEST_F(HloParserTest, ParseWindow) { TEST_F(HloParserTest, ParseWindow) {
Window original = window_util::MakeWindow({1, 2, 3}); Window original = window_util::MakeWindow({1, 2, 3});
TF_ASSERT_OK_AND_ASSIGN(Window parsed, TF_ASSERT_OK_AND_ASSIGN(Window parsed,

View File

@ -583,6 +583,12 @@ message CholeskyOptions {
bool lower = 1; bool lower = 1;
} }
// Generic map of attributes used to pass hints / configuration options from
// the Python frontend to the XLA backend.
message FrontendAttributes {
map<string, string> map = 1;
}
message OpSharding { message OpSharding {
enum Type { enum Type {
// This sharding is replicated across all devices (implies maximal, // This sharding is replicated across all devices (implies maximal,