Merge pull request #31129 from AnthonyBarbier:frontend_attributes
PiperOrigin-RevId: 265793937
This commit is contained in:
commit
47612c19ea
@ -203,14 +203,15 @@ cc_library(
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":common",
|
||||
":frontend_attributes_util",
|
||||
":host_compute_metadata_proto",
|
||||
":rearrange_function_argument",
|
||||
":sharding_util",
|
||||
":side_effect_util",
|
||||
":tf2xla_util",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/jit:shape_inference",
|
||||
"//tensorflow/compiler/jit:xla_cluster_util",
|
||||
"//tensorflow/compiler/tf2xla:rearrange_function_argument",
|
||||
"//tensorflow/compiler/tf2xla/lib:util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -271,6 +272,21 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "frontend_attributes_util",
|
||||
srcs = ["frontend_attributes_util.cc"],
|
||||
hdrs = ["frontend_attributes_util.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sharding_util",
|
||||
srcs = ["sharding_util.cc"],
|
||||
@ -579,6 +595,7 @@ cc_library(
|
||||
"functionalize_while.h",
|
||||
],
|
||||
deps = [
|
||||
":frontend_attributes_util",
|
||||
":functionalize_cond",
|
||||
":functionalize_control_flow_util",
|
||||
":tf2xla_util",
|
||||
|
41
tensorflow/compiler/tf2xla/frontend_attributes_util.cc
Normal file
41
tensorflow/compiler/tf2xla/frontend_attributes_util.cc
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char kXlaFrontendAttributesAttrName[] = "_XlaFrontendAttributes";
|
||||
|
||||
xla::StatusOr<absl::optional<xla::FrontendAttributes>>
|
||||
GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs) {
|
||||
const AttrValue* attr = attrs.Find(kXlaFrontendAttributesAttrName);
|
||||
if (attr == nullptr) {
|
||||
return xla::StatusOr<absl::optional<xla::FrontendAttributes>>(
|
||||
absl::nullopt);
|
||||
}
|
||||
xla::FrontendAttributes attributes;
|
||||
if (!attributes.ParseFromString(attr->s())) {
|
||||
return errors::InvalidArgument(
|
||||
"Experimental _XlaFrontendAttributes attribute was not a valid encoded "
|
||||
"xla::FrontendAttributes proto.");
|
||||
}
|
||||
return absl::optional<xla::FrontendAttributes>(attributes);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
38
tensorflow/compiler/tf2xla/frontend_attributes_util.h
Normal file
38
tensorflow/compiler/tf2xla/frontend_attributes_util.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Frontend Attributes Id.
|
||||
extern const char kXlaFrontendAttributesAttrName[];
|
||||
// Return the FrontendAttributes stored in the AttrSlice if there are some.
|
||||
//
|
||||
// Return an InvalidArgument error if some attributes are present but
|
||||
// cannot be parsed.
|
||||
xla::StatusOr<absl::optional<xla::FrontendAttributes>>
|
||||
GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
@ -494,6 +495,12 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
||||
builder.Attr("cond", cond_name);
|
||||
builder.Attr("body", body_name);
|
||||
string outside_compilation;
|
||||
string frontend_attributes;
|
||||
if (GetNodeAttr(frame->loop_cond->def(), kXlaFrontendAttributesAttrName,
|
||||
&frontend_attributes)
|
||||
.ok()) {
|
||||
builder.Attr(kXlaFrontendAttributesAttrName, frontend_attributes);
|
||||
}
|
||||
if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName,
|
||||
&outside_compilation)
|
||||
.ok()) {
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
@ -98,6 +99,20 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel,
|
||||
absl::optional<xla::OpSharding> op_sharding =
|
||||
sharding_parse_result.ValueOrDie();
|
||||
|
||||
auto frontend_attributes_result =
|
||||
GetFrontendAttributesFromAttrSlice(AttrSlice(op_kernel->def()));
|
||||
OP_REQUIRES_OK(context, frontend_attributes_result.status());
|
||||
absl::optional<xla::FrontendAttributes> attributes =
|
||||
frontend_attributes_result.ValueOrDie();
|
||||
|
||||
xla::FrontendAttributes merged_attributes = b->frontend_attributes();
|
||||
if (attributes.has_value()) {
|
||||
merged_attributes.mutable_map()->insert(attributes.value().map().begin(),
|
||||
attributes.value().map().end());
|
||||
}
|
||||
xla::XlaScopedFrontendAttributesAssignment assign_frontend_attributes(
|
||||
b, std::move(merged_attributes));
|
||||
|
||||
// If no sharding metadata is found, XLA is free to use whatever device it
|
||||
// wants. In practice this usually has the effect of placing things on device
|
||||
// 0.
|
||||
|
@ -289,6 +289,15 @@ Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp op,
|
||||
std::string attribute,
|
||||
std::string value) {
|
||||
TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op));
|
||||
auto* frontend_attributes = instr_proto->mutable_frontend_attributes();
|
||||
(*frontend_attributes->mutable_map())[attribute] = std::move(value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaComputation XlaBuilder::BuildAndNoteError() {
|
||||
DCHECK(parent_builder_ != nullptr);
|
||||
auto build_status = Build();
|
||||
@ -2626,6 +2635,7 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
|
||||
if (sharding_) {
|
||||
*instr.mutable_sharding() = *sharding_;
|
||||
}
|
||||
*instr.mutable_frontend_attributes() = frontend_attributes_;
|
||||
|
||||
handle_to_index_[handle] = instructions_.size();
|
||||
instructions_.push_back(std::move(instr));
|
||||
@ -2683,32 +2693,67 @@ void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
|
||||
const XlaOp& op) const {
|
||||
TF_RETURN_IF_ERROR(first_error_);
|
||||
namespace {
|
||||
|
||||
if (op.builder_ == nullptr) {
|
||||
template <typename InstructionType>
|
||||
StatusOr<InstructionType> LookUpInstructionByHandleInternal(
|
||||
const absl::flat_hash_map<int64, int64>& handle_to_index,
|
||||
const std::vector<HloInstructionProto>& instructions, int64 handle) {
|
||||
auto it = handle_to_index.find(handle);
|
||||
if (it == handle_to_index.end()) {
|
||||
return InvalidArgument("No XlaOp with handle %d", handle);
|
||||
}
|
||||
return const_cast<InstructionType>(&instructions.at(it->second));
|
||||
}
|
||||
|
||||
template <typename InstructionType, typename OpBuilderType,
|
||||
typename BuilderType, typename OpType>
|
||||
StatusOr<InstructionType> LookUpInstructionInternal(
|
||||
const absl::flat_hash_map<int64, int64>& handle_to_index,
|
||||
const std::vector<HloInstructionProto>& instructions,
|
||||
OpBuilderType op_builder, BuilderType builder, OpType op_handle) {
|
||||
if (op_builder == nullptr) {
|
||||
return InvalidArgument(
|
||||
"invalid XlaOp with handle %d; the builder of this op is freed",
|
||||
op.handle());
|
||||
op_handle);
|
||||
}
|
||||
if (op.builder_ != this) {
|
||||
if (op_builder != builder) {
|
||||
return InvalidArgument(
|
||||
"XlaOp with handle %d is built by builder '%s', but is trying to use "
|
||||
"it in builder '%s'",
|
||||
op.handle(), op.builder_->name(), this->name());
|
||||
op_handle, op_builder->name(), builder->name());
|
||||
}
|
||||
|
||||
return LookUpInstructionByHandle(op.handle());
|
||||
return LookUpInstructionByHandleInternal<InstructionType>(
|
||||
handle_to_index, instructions, op_handle);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
|
||||
const XlaOp op) const {
|
||||
TF_RETURN_IF_ERROR(first_error_);
|
||||
return LookUpInstructionInternal<const HloInstructionProto*>(
|
||||
handle_to_index_, instructions_, op.builder_, this, op.handle());
|
||||
}
|
||||
|
||||
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
|
||||
int64 handle) const {
|
||||
auto it = handle_to_index_.find(handle);
|
||||
if (it == handle_to_index_.end()) {
|
||||
return InvalidArgument("No XlaOp with handle %d", handle);
|
||||
}
|
||||
return &instructions_[it->second];
|
||||
return LookUpInstructionByHandleInternal<const HloInstructionProto*>(
|
||||
handle_to_index_, instructions_, handle);
|
||||
}
|
||||
|
||||
StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstruction(
|
||||
const XlaOp op) {
|
||||
TF_RETURN_IF_ERROR(first_error_);
|
||||
return LookUpInstructionInternal<HloInstructionProto*>(
|
||||
handle_to_index_, instructions_, op.builder_, this, op.handle());
|
||||
}
|
||||
|
||||
StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstructionByHandle(
|
||||
int64 handle) {
|
||||
return LookUpInstructionByHandleInternal<HloInstructionProto*>(
|
||||
handle_to_index_, instructions_, handle);
|
||||
}
|
||||
|
||||
// Enqueues a "retrieve parameter value" instruction for a parameter that was
|
||||
|
@ -147,8 +147,8 @@ class XlaBuilder {
|
||||
// Sets OpMetadata that will be added to all instructions until cleared.
|
||||
//
|
||||
// OpMetadata is often applied to a series of XLA HLO instructions. As a
|
||||
// result, OpMetadata is set on the Computation Builder. All subsequent
|
||||
// instructions generated via this Computation Builder will have the same
|
||||
// result, OpMetadata is set on the computation builder. All subsequent
|
||||
// instructions generated via this computation builder will have the same
|
||||
// OpMetadata attached until a call to ClearOpMetadata.
|
||||
void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); }
|
||||
|
||||
@ -158,6 +158,35 @@ class XlaBuilder {
|
||||
// Sets an OpSharding that will be attached to all instructions until cleared.
|
||||
void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
|
||||
|
||||
// Sets the FrontendAttributes that will be added to all instructions until
|
||||
// cleared.
|
||||
//
|
||||
// FrontendAttributes are often applied to a series of XLA HLO instructions.
|
||||
// As a result they are set on the computation builder and all the
|
||||
// instructions generated via the computation builder will have the same
|
||||
// frontend attributes attached to them.
|
||||
void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) {
|
||||
frontend_attributes_ = frontend_attributes;
|
||||
}
|
||||
|
||||
// Swap the passed FrontendAttributes with the ones currently set.
|
||||
//
|
||||
// Return the old attributes.
|
||||
FrontendAttributes SwapFrontendAttributes(
|
||||
const FrontendAttributes& frontend_attributes) {
|
||||
FrontendAttributes old_attributes = std::move(frontend_attributes_);
|
||||
frontend_attributes_ = frontend_attributes;
|
||||
return old_attributes;
|
||||
}
|
||||
|
||||
// Returns the FrontendAttributes that will be attached to all instructions.
|
||||
const FrontendAttributes& frontend_attributes() const {
|
||||
return frontend_attributes_;
|
||||
}
|
||||
|
||||
// Clears all the frontend attributes.
|
||||
void ClearFrontendAttributes() { frontend_attributes_.Clear(); }
|
||||
|
||||
// Clears the sharding. Ops will be sharded according to the default placement
|
||||
// policy.
|
||||
void ClearSharding() { sharding_ = absl::nullopt; }
|
||||
@ -314,6 +343,16 @@ class XlaBuilder {
|
||||
ShapeIndex param_index;
|
||||
};
|
||||
|
||||
// Looks up the HloInstruction and sets the frontend attribute "attribute" to
|
||||
// "value".
|
||||
//
|
||||
// If the attribute already existed then its value is updated.
|
||||
//
|
||||
// Note: the attribute is only added to the HloInstruction, not to the
|
||||
// builder.
|
||||
Status SetInstructionFrontendAttribute(XlaOp op, string attribute,
|
||||
string value);
|
||||
|
||||
private:
|
||||
// Build helper which takes the id of the root operation..
|
||||
StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions);
|
||||
@ -595,9 +634,11 @@ class XlaBuilder {
|
||||
void AddCalledComputation(const XlaComputation& computation,
|
||||
HloInstructionProto* instr);
|
||||
|
||||
StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
|
||||
StatusOr<const HloInstructionProto*> LookUpInstruction(XlaOp op) const;
|
||||
StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
|
||||
int64 handle) const;
|
||||
StatusOr<HloInstructionProto*> LookUpMutableInstruction(XlaOp op);
|
||||
StatusOr<HloInstructionProto*> LookUpMutableInstructionByHandle(int64 handle);
|
||||
|
||||
// Internal helper method that does the building for an arbitrary unary op.
|
||||
XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
|
||||
@ -707,6 +748,8 @@ class XlaBuilder {
|
||||
|
||||
XlaBuilder* parent_builder_{nullptr};
|
||||
|
||||
FrontendAttributes frontend_attributes_;
|
||||
|
||||
friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
|
||||
const Shape& shape, const string& name,
|
||||
const std::vector<bool>& replicated_at_leaf_buffers);
|
||||
@ -1034,6 +1077,27 @@ class XlaScopedShardingAssignment {
|
||||
absl::optional<OpSharding> prev_sharding_;
|
||||
};
|
||||
|
||||
// RAII-style object: save the current builder's frontend attributes, and merge
|
||||
// them with the new ones on construction.
|
||||
// Restore the original attributes on destruction.
|
||||
class XlaScopedFrontendAttributesAssignment {
|
||||
public:
|
||||
XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder,
|
||||
FrontendAttributes attributes)
|
||||
: builder_(builder) {
|
||||
saved_ = builder_->SwapFrontendAttributes(attributes);
|
||||
}
|
||||
|
||||
~XlaScopedFrontendAttributesAssignment() {
|
||||
builder_->SetFrontendAttributes(saved_);
|
||||
}
|
||||
|
||||
private:
|
||||
xla::XlaBuilder* const builder_;
|
||||
FrontendAttributes saved_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment);
|
||||
};
|
||||
// Free functions for building XlaOps. The intention is that these will
|
||||
// become the public API for building XlaOps rather than calling methods on
|
||||
// XlaBuilder directly.
|
||||
|
@ -978,5 +978,151 @@ TEST_F(XlaBuilderTest, CheckInputOutputAlias) {
|
||||
EXPECT_EQ(*alias_p1, ShapeIndex({0}));
|
||||
}
|
||||
|
||||
void ExpectAttributesMatch(const FrontendAttributes& attr,
|
||||
const FrontendAttributes& ref) {
|
||||
EXPECT_EQ(ref.map_size(), attr.map_size());
|
||||
for (auto reference : ref.map()) {
|
||||
auto other = attr.map().find(reference.first);
|
||||
EXPECT_NE(other, attr.map().end());
|
||||
EXPECT_EQ(other->second, reference.second);
|
||||
}
|
||||
}
|
||||
|
||||
void ExpectInstructionsAttributesMatch(
|
||||
const HloModule& module, const std::vector<FrontendAttributes>& expected) {
|
||||
ASSERT_EQ(module.computation_count(), 1);
|
||||
auto expected_it = expected.begin();
|
||||
for (auto inst : module.entry_computation()->instructions()) {
|
||||
ASSERT_NE(expected_it, expected.end());
|
||||
ExpectAttributesMatch(inst->frontend_attributes(), *expected_it);
|
||||
expected_it++;
|
||||
}
|
||||
EXPECT_EQ(expected_it, expected.end());
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, SimpleSetFrontendAttributes) {
|
||||
XlaBuilder b(TestName());
|
||||
FrontendAttributes attributes;
|
||||
|
||||
ConstantR0(&b, 0); // No attribute set
|
||||
|
||||
(*attributes.mutable_map())["attr_a"] = "a";
|
||||
b.SetFrontendAttributes(attributes);
|
||||
ConstantR0(&b, 0); // One attribute: { "attr_a": "a" }
|
||||
|
||||
b.ClearFrontendAttributes();
|
||||
ConstantR0(&b, 0); // No attribute set
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||
|
||||
std::vector<FrontendAttributes> expected{FrontendAttributes(), attributes,
|
||||
FrontendAttributes()};
|
||||
ExpectInstructionsAttributesMatch(*module, expected);
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, ComplexSetFrontendAttributes) {
|
||||
XlaBuilder b(TestName());
|
||||
|
||||
ConstantR0(&b, 0); // No attribute set.
|
||||
std::vector<FrontendAttributes> expected{FrontendAttributes()};
|
||||
|
||||
{
|
||||
FrontendAttributes attributes;
|
||||
(*attributes.mutable_map())["attr_a"] = "a";
|
||||
b.SetFrontendAttributes(attributes);
|
||||
ConstantR0(&b, 0); // One attribute: { "attr_a": "a" }
|
||||
expected.push_back(attributes);
|
||||
}
|
||||
|
||||
{
|
||||
FrontendAttributes attributes;
|
||||
(*attributes.mutable_map())["attr_b"] = "b";
|
||||
b.SetFrontendAttributes(attributes);
|
||||
ConstantR0(&b, 0); // One attribute: { "attr_b": "b" }
|
||||
expected.push_back(attributes);
|
||||
}
|
||||
|
||||
{
|
||||
FrontendAttributes attributes;
|
||||
(*attributes.mutable_map())["attr_b"] = "b";
|
||||
(*attributes.mutable_map())["attr_c"] = "c";
|
||||
b.SetFrontendAttributes(attributes);
|
||||
ConstantR0(&b, 0); // Two attributes: { "attr_b": "b", "attr_c": "c" }
|
||||
expected.push_back(attributes);
|
||||
}
|
||||
|
||||
b.ClearFrontendAttributes();
|
||||
ConstantR0(&b, 0); // No attribute set
|
||||
expected.push_back(FrontendAttributes());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||
ExpectInstructionsAttributesMatch(*module, expected);
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, AddFrontendAttribute) {
|
||||
XlaBuilder b(TestName());
|
||||
|
||||
ConstantR0(&b, 0);
|
||||
std::vector<FrontendAttributes> expected{FrontendAttributes()};
|
||||
|
||||
// One attribute: { "attr_a": "a" }
|
||||
{
|
||||
FrontendAttributes attributes;
|
||||
(*attributes.mutable_map())["attr_a"] = "a";
|
||||
b.SetFrontendAttributes(attributes);
|
||||
ConstantR0(&b, 0);
|
||||
expected.push_back(attributes);
|
||||
}
|
||||
|
||||
// Two attributes: {"attra": "a", "attr_c": "c"}
|
||||
{
|
||||
auto op = ConstantR0(&b, 0);
|
||||
EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_c", "c"));
|
||||
|
||||
FrontendAttributes attributes;
|
||||
(*attributes.mutable_map())["attr_a"] = "a";
|
||||
(*attributes.mutable_map())["attr_c"] = "c";
|
||||
expected.push_back(attributes);
|
||||
}
|
||||
|
||||
// Override value of existing "attr_a"
|
||||
// One attribute: { "attr_a", "a2"}
|
||||
{
|
||||
auto op = ConstantR0(&b, 0);
|
||||
EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_a", "a2"));
|
||||
FrontendAttributes attributes;
|
||||
(*attributes.mutable_map())["attr_a"] = "a2";
|
||||
expected.push_back(attributes);
|
||||
}
|
||||
|
||||
// Check "attr_a" is back to its original value
|
||||
// One attribute: { "attr_a", "a"}
|
||||
{
|
||||
auto op = ConstantR0(&b, 0);
|
||||
(void)op;
|
||||
FrontendAttributes attributes;
|
||||
(*attributes.mutable_map())["attr_a"] = "a";
|
||||
expected.push_back(attributes);
|
||||
}
|
||||
|
||||
b.ClearFrontendAttributes();
|
||||
ConstantR0(&b, 0); // No attribute set
|
||||
expected.push_back(FrontendAttributes());
|
||||
|
||||
// One attribute: { "attr_d", "d"}
|
||||
{
|
||||
auto op = ConstantR0(&b, 0);
|
||||
EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_d", "d"));
|
||||
FrontendAttributes attributes;
|
||||
(*attributes.mutable_map())["attr_d"] = "d";
|
||||
expected.push_back(attributes);
|
||||
}
|
||||
|
||||
ConstantR0(&b, 0); // No attribute set
|
||||
expected.push_back(FrontendAttributes());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||
ExpectInstructionsAttributesMatch(*module, expected);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
// Serialization of HloInstruction.
|
||||
// Next ID: 68
|
||||
// Next ID: 69
|
||||
message HloInstructionProto {
|
||||
reserved 10;
|
||||
reserved "parameter_name";
|
||||
@ -234,6 +234,9 @@ message HloInstructionProto {
|
||||
// Specifies if the gather/scatter indices are guaranteed to be sorted by the
|
||||
// caller.
|
||||
bool indices_are_sorted = 67;
|
||||
|
||||
// Frontend attributes to pass to the XLA backend.
|
||||
xla.FrontendAttributes frontend_attributes = 68;
|
||||
}
|
||||
|
||||
// Serialization of HloComputation.
|
||||
|
@ -837,6 +837,10 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
|
||||
if (new_instruction->metadata().op_name().empty()) {
|
||||
new_instruction->set_metadata(old_instruction->metadata());
|
||||
}
|
||||
if (new_instruction->frontend_attributes().map().empty()) {
|
||||
new_instruction->set_frontend_attributes(
|
||||
old_instruction->frontend_attributes());
|
||||
}
|
||||
|
||||
// Like the metadata above, if the user didn't specify any sharding
|
||||
// information on the new instruction we should copy the old sharding
|
||||
|
@ -674,6 +674,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
instruction->set_sharding(sharding);
|
||||
}
|
||||
|
||||
if (proto.has_frontend_attributes()) {
|
||||
instruction->set_frontend_attributes(proto.frontend_attributes());
|
||||
}
|
||||
|
||||
return std::move(instruction);
|
||||
}
|
||||
|
||||
@ -1194,6 +1198,7 @@ HloInstruction::CreateBroadcastSequence(
|
||||
if (operand->has_sharding()) {
|
||||
broadcast->set_sharding(operand->sharding());
|
||||
}
|
||||
broadcast->set_frontend_attributes(operand->frontend_attributes());
|
||||
return broadcast;
|
||||
}
|
||||
// Do explicit broadcast for degenerate broadcast.
|
||||
@ -1219,6 +1224,7 @@ HloInstruction::CreateBroadcastSequence(
|
||||
if (operand->has_sharding()) {
|
||||
reshaped_operand->set_sharding(operand->sharding());
|
||||
}
|
||||
reshaped_operand->set_frontend_attributes(operand->frontend_attributes());
|
||||
// Broadcast 'reshape' up to the larger size.
|
||||
auto broadcast = HloInstruction::CreateBroadcast(
|
||||
broadcast_shape, reshaped_operand, broadcast_dimensions);
|
||||
@ -1226,6 +1232,7 @@ HloInstruction::CreateBroadcastSequence(
|
||||
if (operand->has_sharding()) {
|
||||
broadcast->set_sharding(operand->sharding());
|
||||
}
|
||||
broadcast->set_frontend_attributes(operand->frontend_attributes());
|
||||
return broadcast;
|
||||
}
|
||||
|
||||
@ -1296,6 +1303,7 @@ void HloInstruction::SetupDerivedInstruction(
|
||||
derived_instruction->clear_sharding();
|
||||
}
|
||||
derived_instruction->set_metadata(metadata_);
|
||||
derived_instruction->set_frontend_attributes(frontend_attributes_);
|
||||
}
|
||||
|
||||
bool HloInstruction::HasSideEffectNoRecurse() const {
|
||||
@ -2483,6 +2491,10 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
|
||||
if (has_sharding()) {
|
||||
extra.push_back(StrCat("sharding=", sharding().ToString()));
|
||||
}
|
||||
if (!frontend_attributes_.map().empty()) {
|
||||
extra.push_back(StrCat("frontend_attributes=",
|
||||
FrontendAttributesToString(frontend_attributes_)));
|
||||
}
|
||||
if (!outer_dimension_partitions_.empty()) {
|
||||
extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}",
|
||||
StrJoin(outer_dimension_partitions_, ",")));
|
||||
@ -2543,6 +2555,8 @@ HloInstructionProto HloInstruction::ToProto() const {
|
||||
}
|
||||
}
|
||||
|
||||
*proto.mutable_frontend_attributes() = frontend_attributes_;
|
||||
|
||||
return proto;
|
||||
}
|
||||
|
||||
@ -3197,6 +3211,15 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
|
||||
return InvalidArgument("Unknown fusion kind: %s", kind_name);
|
||||
}
|
||||
|
||||
string FrontendAttributesToString(
|
||||
const FrontendAttributes& frontend_attributes) {
|
||||
std::vector<std::pair<string, string>> sorted_attributes(
|
||||
frontend_attributes.map().begin(), frontend_attributes.map().end());
|
||||
absl::c_sort(sorted_attributes);
|
||||
return absl::StrFormat(
|
||||
"{%s}", absl::StrJoin(sorted_attributes, ",", absl::PairFormatter("=")));
|
||||
}
|
||||
|
||||
string PaddingConfigToString(const PaddingConfig& padding) {
|
||||
bool has_interior_padding =
|
||||
absl::c_any_of(padding.dimensions(),
|
||||
|
@ -1385,6 +1385,14 @@ class HloInstruction {
|
||||
}
|
||||
Status set_backend_config(const tensorflow::protobuf::Message& proto);
|
||||
|
||||
void set_frontend_attributes(FrontendAttributes frontend_attributes) {
|
||||
frontend_attributes_ = std::move(frontend_attributes);
|
||||
}
|
||||
|
||||
const FrontendAttributes& frontend_attributes() const {
|
||||
return frontend_attributes_;
|
||||
}
|
||||
|
||||
// Getter/setter for raw JSON-encoded backend config. Prefer the
|
||||
// functions above that deal in proto Messages where possible.
|
||||
const string& raw_backend_config_string() const { return backend_config_; }
|
||||
@ -1879,6 +1887,18 @@ class HloInstruction {
|
||||
// HLO. See the documentation on backend_config().
|
||||
string backend_config_;
|
||||
|
||||
// Attributes passed from the frontend to give hints to the backend about
|
||||
// how to compile this HLO.
|
||||
// HLO -> HLO transforms are expected to preserve these attributes on a
|
||||
// "best effort" basis only.
|
||||
// For example:
|
||||
// x = const(10, frontend_attributes={x}
|
||||
// y = const(10, frontend_attributes={y}
|
||||
// z = add(x,y), frontend_attributes={y}
|
||||
// Could be simplified to:
|
||||
// z' = const(20), frontend_attributes={?}
|
||||
FrontendAttributes frontend_attributes_;
|
||||
|
||||
// This field is assigned to true when backend_config_ is assigned to
|
||||
// a default configuration.
|
||||
bool is_default_config_ = false;
|
||||
@ -1909,6 +1929,8 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
|
||||
// Custom (de)stringification functions for protos that live inside
|
||||
// HloInstruction.
|
||||
string PaddingConfigToString(const PaddingConfig& padding);
|
||||
string FrontendAttributesToString(
|
||||
const FrontendAttributes& frontend_attributes);
|
||||
string OpMetadataToString(const OpMetadata& metadata);
|
||||
string RandomDistributionToString(const RandomDistribution& distribution);
|
||||
string PrecisionToString(const PrecisionConfig::Precision& precision);
|
||||
|
@ -88,6 +88,7 @@ class HloParser {
|
||||
// Stand alone parsing utils for various aggregate data types.
|
||||
StatusOr<Shape> ParseShapeOnly();
|
||||
StatusOr<HloSharding> ParseShardingOnly();
|
||||
StatusOr<FrontendAttributes> ParseFrontendAttributesOnly();
|
||||
StatusOr<std::vector<bool>> ParseParameterReplicationOnly();
|
||||
StatusOr<Window> ParseWindowOnly();
|
||||
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
|
||||
@ -192,6 +193,7 @@ class HloParser {
|
||||
kWindow,
|
||||
kConvolutionDimensionNumbers,
|
||||
kSharding,
|
||||
kFrontendAttributes,
|
||||
kParameterReplication,
|
||||
kInstructionList,
|
||||
kSliceRanges,
|
||||
@ -271,6 +273,7 @@ class HloParser {
|
||||
bool ParsePaddingConfig(PaddingConfig* padding);
|
||||
bool ParseMetadata(OpMetadata* metadata);
|
||||
bool ParseSharding(OpSharding* sharding);
|
||||
bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes);
|
||||
bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
|
||||
bool ParseParameterReplication(ParameterReplication* parameter_replication);
|
||||
bool ParseReplicaGroupsOnly(std::vector<ReplicaGroup>* replica_groups);
|
||||
@ -677,7 +680,10 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
// Add optional attributes.
|
||||
std::unordered_map<string, AttrConfig> attrs;
|
||||
optional<OpSharding> sharding;
|
||||
optional<FrontendAttributes> frontend_attributes;
|
||||
attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
|
||||
attrs["frontend_attributes"] = {
|
||||
/*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes};
|
||||
optional<ParameterReplication> parameter_replication;
|
||||
attrs["parameter_replication"] = {/*required=*/false,
|
||||
AttrTy::kParameterReplication,
|
||||
@ -1845,6 +1851,36 @@ bool HloParser::ParseSharding(OpSharding* sharding) {
|
||||
return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
|
||||
}
|
||||
|
||||
// frontend_attributes ::= '{' attributes '}'
|
||||
// attributes
|
||||
// ::= /*empty*/
|
||||
// ::= attribute '=' value (',' attribute '=' value)*
|
||||
bool HloParser::ParseFrontendAttributes(
|
||||
FrontendAttributes* frontend_attributes) {
|
||||
CHECK(frontend_attributes != nullptr);
|
||||
if (!ParseToken(TokKind::kLbrace,
|
||||
"expected '{' to start frontend attributes")) {
|
||||
return false;
|
||||
}
|
||||
if (lexer_.GetKind() == TokKind::kRbrace) {
|
||||
// empty
|
||||
} else {
|
||||
do {
|
||||
string attribute;
|
||||
if (!ParseAttributeName(&attribute)) {
|
||||
return false;
|
||||
}
|
||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||
return false;
|
||||
}
|
||||
(*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal();
|
||||
lexer_.Lex();
|
||||
} while (EatIfPresent(TokKind::kComma));
|
||||
}
|
||||
return ParseToken(TokKind::kRbrace,
|
||||
"expects '}' at the end of frontend attributes");
|
||||
}
|
||||
|
||||
// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape?
|
||||
// ('devices=' ('[' dims ']')* device_list)? '}'
|
||||
// dims ::= int_list device_list ::= int_list
|
||||
@ -2864,6 +2900,15 @@ bool HloParser::ParseAttributeHelper(
|
||||
static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
|
||||
return true;
|
||||
}
|
||||
case AttrTy::kFrontendAttributes: {
|
||||
FrontendAttributes frontend_attributes;
|
||||
if (!ParseFrontendAttributes(&frontend_attributes)) {
|
||||
return false;
|
||||
}
|
||||
static_cast<optional<FrontendAttributes>*>(attr_out_ptr)
|
||||
->emplace(frontend_attributes);
|
||||
return true;
|
||||
}
|
||||
case AttrTy::kParameterReplication: {
|
||||
ParameterReplication parameter_replication;
|
||||
if (!ParseParameterReplication(¶meter_replication)) {
|
||||
@ -4120,6 +4165,19 @@ StatusOr<HloSharding> HloParser::ParseShardingOnly() {
|
||||
return HloSharding::FromProto(op_sharding);
|
||||
}
|
||||
|
||||
StatusOr<FrontendAttributes> HloParser::ParseFrontendAttributesOnly() {
|
||||
lexer_.Lex();
|
||||
FrontendAttributes attributes;
|
||||
if (!ParseFrontendAttributes(&attributes)) {
|
||||
return InvalidArgument("Syntax error:\n%s", GetError());
|
||||
}
|
||||
if (lexer_.GetKind() != TokKind::kEof) {
|
||||
return InvalidArgument(
|
||||
"Syntax error:\nExtra content after frontend attributes");
|
||||
}
|
||||
return attributes;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<bool>> HloParser::ParseParameterReplicationOnly() {
|
||||
lexer_.Lex();
|
||||
ParameterReplication parameter_replication;
|
||||
@ -4268,6 +4326,11 @@ StatusOr<HloSharding> ParseSharding(absl::string_view str) {
|
||||
return parser.ParseShardingOnly();
|
||||
}
|
||||
|
||||
StatusOr<FrontendAttributes> ParseFrontendAttributes(absl::string_view str) {
|
||||
HloParser parser(str);
|
||||
return parser.ParseFrontendAttributesOnly();
|
||||
}
|
||||
|
||||
StatusOr<std::vector<bool>> ParseParameterReplication(absl::string_view str) {
|
||||
HloParser parser(str);
|
||||
return parser.ParseParameterReplicationOnly();
|
||||
|
@ -54,6 +54,12 @@ Status ParseHloString(absl::string_view str, HloModule* module);
|
||||
// "{replicated}".
|
||||
StatusOr<HloSharding> ParseSharding(absl::string_view str);
|
||||
|
||||
// Parses frontend attributes from str. str is supposed to contain the body of
|
||||
// the frontend attributes , i.e. just the rhs of the
|
||||
// "frontend_attributes={...}" attribute string, e.g.,
|
||||
// "{attr_a=a,attr_b=b}".
|
||||
StatusOr<FrontendAttributes> ParseFrontendAttributes(absl::string_view str);
|
||||
|
||||
// Parses parameter replication from str. str is supposed to contain the body of
|
||||
// the parameter replication, i.e. just the rhs of the
|
||||
// "parameter_replication={...}" attribute string, e.g., "{true, false}".
|
||||
|
@ -2358,6 +2358,13 @@ TEST_F(HloParserTest, ParseSharding) {
|
||||
EXPECT_EQ(sharding.ToString(), original);
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, ParseFrontendAttributes) {
|
||||
const string original = "{attr_a=test_a,attr_b=b}";
|
||||
TF_ASSERT_OK_AND_ASSIGN(FrontendAttributes frontend_attributes,
|
||||
ParseFrontendAttributes(original));
|
||||
EXPECT_EQ(FrontendAttributesToString(frontend_attributes), original);
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, ParseWindow) {
|
||||
Window original = window_util::MakeWindow({1, 2, 3});
|
||||
TF_ASSERT_OK_AND_ASSIGN(Window parsed,
|
||||
|
@ -583,6 +583,12 @@ message CholeskyOptions {
|
||||
bool lower = 1;
|
||||
}
|
||||
|
||||
// Generic map of attributes used to pass hints / configuration options from
|
||||
// the Python frontend to the XLA backend.
|
||||
message FrontendAttributes {
|
||||
map<string, string> map = 1;
|
||||
}
|
||||
|
||||
message OpSharding {
|
||||
enum Type {
|
||||
// This sharding is replicated across all devices (implies maximal,
|
||||
|
Loading…
Reference in New Issue
Block a user