Merge pull request #31129 from AnthonyBarbier:frontend_attributes
PiperOrigin-RevId: 265793937
This commit is contained in:
commit
47612c19ea
@ -203,14 +203,15 @@ cc_library(
|
|||||||
visibility = [":friends"],
|
visibility = [":friends"],
|
||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
|
":frontend_attributes_util",
|
||||||
":host_compute_metadata_proto",
|
":host_compute_metadata_proto",
|
||||||
|
":rearrange_function_argument",
|
||||||
":sharding_util",
|
":sharding_util",
|
||||||
":side_effect_util",
|
":side_effect_util",
|
||||||
":tf2xla_util",
|
":tf2xla_util",
|
||||||
"//tensorflow/compiler/jit:flags",
|
"//tensorflow/compiler/jit:flags",
|
||||||
"//tensorflow/compiler/jit:shape_inference",
|
"//tensorflow/compiler/jit:shape_inference",
|
||||||
"//tensorflow/compiler/jit:xla_cluster_util",
|
"//tensorflow/compiler/jit:xla_cluster_util",
|
||||||
"//tensorflow/compiler/tf2xla:rearrange_function_argument",
|
|
||||||
"//tensorflow/compiler/tf2xla/lib:util",
|
"//tensorflow/compiler/tf2xla/lib:util",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
@ -271,6 +272,21 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "frontend_attributes_util",
|
||||||
|
srcs = ["frontend_attributes_util.cc"],
|
||||||
|
hdrs = ["frontend_attributes_util.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "sharding_util",
|
name = "sharding_util",
|
||||||
srcs = ["sharding_util.cc"],
|
srcs = ["sharding_util.cc"],
|
||||||
@ -579,6 +595,7 @@ cc_library(
|
|||||||
"functionalize_while.h",
|
"functionalize_while.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":frontend_attributes_util",
|
||||||
":functionalize_cond",
|
":functionalize_cond",
|
||||||
":functionalize_control_flow_util",
|
":functionalize_control_flow_util",
|
||||||
":tf2xla_util",
|
":tf2xla_util",
|
||||||
|
41
tensorflow/compiler/tf2xla/frontend_attributes_util.cc
Normal file
41
tensorflow/compiler/tf2xla/frontend_attributes_util.cc
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
const char kXlaFrontendAttributesAttrName[] = "_XlaFrontendAttributes";
|
||||||
|
|
||||||
|
xla::StatusOr<absl::optional<xla::FrontendAttributes>>
|
||||||
|
GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs) {
|
||||||
|
const AttrValue* attr = attrs.Find(kXlaFrontendAttributesAttrName);
|
||||||
|
if (attr == nullptr) {
|
||||||
|
return xla::StatusOr<absl::optional<xla::FrontendAttributes>>(
|
||||||
|
absl::nullopt);
|
||||||
|
}
|
||||||
|
xla::FrontendAttributes attributes;
|
||||||
|
if (!attributes.ParseFromString(attr->s())) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Experimental _XlaFrontendAttributes attribute was not a valid encoded "
|
||||||
|
"xla::FrontendAttributes proto.");
|
||||||
|
}
|
||||||
|
return absl::optional<xla::FrontendAttributes>(attributes);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
38
tensorflow/compiler/tf2xla/frontend_attributes_util.h
Normal file
38
tensorflow/compiler/tf2xla/frontend_attributes_util.h
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_
|
||||||
|
#define TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Frontend Attributes Id.
|
||||||
|
extern const char kXlaFrontendAttributesAttrName[];
|
||||||
|
// Return the FrontendAttributes stored in the AttrSlice if there are some.
|
||||||
|
//
|
||||||
|
// Return an InvalidArgument error if some attributes are present but
|
||||||
|
// cannot be parsed.
|
||||||
|
xla::StatusOr<absl::optional<xla::FrontendAttributes>>
|
||||||
|
GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/compiler/jit/union_find.h"
|
#include "tensorflow/compiler/jit/union_find.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
|
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
|
||||||
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
|
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||||
@ -494,6 +495,12 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
|
|||||||
builder.Attr("cond", cond_name);
|
builder.Attr("cond", cond_name);
|
||||||
builder.Attr("body", body_name);
|
builder.Attr("body", body_name);
|
||||||
string outside_compilation;
|
string outside_compilation;
|
||||||
|
string frontend_attributes;
|
||||||
|
if (GetNodeAttr(frame->loop_cond->def(), kXlaFrontendAttributesAttrName,
|
||||||
|
&frontend_attributes)
|
||||||
|
.ok()) {
|
||||||
|
builder.Attr(kXlaFrontendAttributesAttrName, frontend_attributes);
|
||||||
|
}
|
||||||
if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName,
|
if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName,
|
||||||
&outside_compilation)
|
&outside_compilation)
|
||||||
.ok()) {
|
.ok()) {
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||||
@ -98,6 +99,20 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel,
|
|||||||
absl::optional<xla::OpSharding> op_sharding =
|
absl::optional<xla::OpSharding> op_sharding =
|
||||||
sharding_parse_result.ValueOrDie();
|
sharding_parse_result.ValueOrDie();
|
||||||
|
|
||||||
|
auto frontend_attributes_result =
|
||||||
|
GetFrontendAttributesFromAttrSlice(AttrSlice(op_kernel->def()));
|
||||||
|
OP_REQUIRES_OK(context, frontend_attributes_result.status());
|
||||||
|
absl::optional<xla::FrontendAttributes> attributes =
|
||||||
|
frontend_attributes_result.ValueOrDie();
|
||||||
|
|
||||||
|
xla::FrontendAttributes merged_attributes = b->frontend_attributes();
|
||||||
|
if (attributes.has_value()) {
|
||||||
|
merged_attributes.mutable_map()->insert(attributes.value().map().begin(),
|
||||||
|
attributes.value().map().end());
|
||||||
|
}
|
||||||
|
xla::XlaScopedFrontendAttributesAssignment assign_frontend_attributes(
|
||||||
|
b, std::move(merged_attributes));
|
||||||
|
|
||||||
// If no sharding metadata is found, XLA is free to use whatever device it
|
// If no sharding metadata is found, XLA is free to use whatever device it
|
||||||
// wants. In practice this usually has the effect of placing things on device
|
// wants. In practice this usually has the effect of placing things on device
|
||||||
// 0.
|
// 0.
|
||||||
|
@ -289,6 +289,15 @@ Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp op,
|
||||||
|
std::string attribute,
|
||||||
|
std::string value) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op));
|
||||||
|
auto* frontend_attributes = instr_proto->mutable_frontend_attributes();
|
||||||
|
(*frontend_attributes->mutable_map())[attribute] = std::move(value);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
XlaComputation XlaBuilder::BuildAndNoteError() {
|
XlaComputation XlaBuilder::BuildAndNoteError() {
|
||||||
DCHECK(parent_builder_ != nullptr);
|
DCHECK(parent_builder_ != nullptr);
|
||||||
auto build_status = Build();
|
auto build_status = Build();
|
||||||
@ -2626,6 +2635,7 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
|
|||||||
if (sharding_) {
|
if (sharding_) {
|
||||||
*instr.mutable_sharding() = *sharding_;
|
*instr.mutable_sharding() = *sharding_;
|
||||||
}
|
}
|
||||||
|
*instr.mutable_frontend_attributes() = frontend_attributes_;
|
||||||
|
|
||||||
handle_to_index_[handle] = instructions_.size();
|
handle_to_index_[handle] = instructions_.size();
|
||||||
instructions_.push_back(std::move(instr));
|
instructions_.push_back(std::move(instr));
|
||||||
@ -2683,32 +2693,67 @@ void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
|
namespace {
|
||||||
const XlaOp& op) const {
|
|
||||||
TF_RETURN_IF_ERROR(first_error_);
|
|
||||||
|
|
||||||
if (op.builder_ == nullptr) {
|
template <typename InstructionType>
|
||||||
|
StatusOr<InstructionType> LookUpInstructionByHandleInternal(
|
||||||
|
const absl::flat_hash_map<int64, int64>& handle_to_index,
|
||||||
|
const std::vector<HloInstructionProto>& instructions, int64 handle) {
|
||||||
|
auto it = handle_to_index.find(handle);
|
||||||
|
if (it == handle_to_index.end()) {
|
||||||
|
return InvalidArgument("No XlaOp with handle %d", handle);
|
||||||
|
}
|
||||||
|
return const_cast<InstructionType>(&instructions.at(it->second));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename InstructionType, typename OpBuilderType,
|
||||||
|
typename BuilderType, typename OpType>
|
||||||
|
StatusOr<InstructionType> LookUpInstructionInternal(
|
||||||
|
const absl::flat_hash_map<int64, int64>& handle_to_index,
|
||||||
|
const std::vector<HloInstructionProto>& instructions,
|
||||||
|
OpBuilderType op_builder, BuilderType builder, OpType op_handle) {
|
||||||
|
if (op_builder == nullptr) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"invalid XlaOp with handle %d; the builder of this op is freed",
|
"invalid XlaOp with handle %d; the builder of this op is freed",
|
||||||
op.handle());
|
op_handle);
|
||||||
}
|
}
|
||||||
if (op.builder_ != this) {
|
if (op_builder != builder) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"XlaOp with handle %d is built by builder '%s', but is trying to use "
|
"XlaOp with handle %d is built by builder '%s', but is trying to use "
|
||||||
"it in builder '%s'",
|
"it in builder '%s'",
|
||||||
op.handle(), op.builder_->name(), this->name());
|
op_handle, op_builder->name(), builder->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
return LookUpInstructionByHandle(op.handle());
|
return LookUpInstructionByHandleInternal<InstructionType>(
|
||||||
|
handle_to_index, instructions, op_handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
|
||||||
|
const XlaOp op) const {
|
||||||
|
TF_RETURN_IF_ERROR(first_error_);
|
||||||
|
return LookUpInstructionInternal<const HloInstructionProto*>(
|
||||||
|
handle_to_index_, instructions_, op.builder_, this, op.handle());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
|
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
|
||||||
int64 handle) const {
|
int64 handle) const {
|
||||||
auto it = handle_to_index_.find(handle);
|
return LookUpInstructionByHandleInternal<const HloInstructionProto*>(
|
||||||
if (it == handle_to_index_.end()) {
|
handle_to_index_, instructions_, handle);
|
||||||
return InvalidArgument("No XlaOp with handle %d", handle);
|
|
||||||
}
|
}
|
||||||
return &instructions_[it->second];
|
|
||||||
|
StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstruction(
|
||||||
|
const XlaOp op) {
|
||||||
|
TF_RETURN_IF_ERROR(first_error_);
|
||||||
|
return LookUpInstructionInternal<HloInstructionProto*>(
|
||||||
|
handle_to_index_, instructions_, op.builder_, this, op.handle());
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstructionByHandle(
|
||||||
|
int64 handle) {
|
||||||
|
return LookUpInstructionByHandleInternal<HloInstructionProto*>(
|
||||||
|
handle_to_index_, instructions_, handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enqueues a "retrieve parameter value" instruction for a parameter that was
|
// Enqueues a "retrieve parameter value" instruction for a parameter that was
|
||||||
|
@ -147,8 +147,8 @@ class XlaBuilder {
|
|||||||
// Sets OpMetadata that will be added to all instructions until cleared.
|
// Sets OpMetadata that will be added to all instructions until cleared.
|
||||||
//
|
//
|
||||||
// OpMetadata is often applied to a series of XLA HLO instructions. As a
|
// OpMetadata is often applied to a series of XLA HLO instructions. As a
|
||||||
// result, OpMetadata is set on the Computation Builder. All subsequent
|
// result, OpMetadata is set on the computation builder. All subsequent
|
||||||
// instructions generated via this Computation Builder will have the same
|
// instructions generated via this computation builder will have the same
|
||||||
// OpMetadata attached until a call to ClearOpMetadata.
|
// OpMetadata attached until a call to ClearOpMetadata.
|
||||||
void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); }
|
void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); }
|
||||||
|
|
||||||
@ -158,6 +158,35 @@ class XlaBuilder {
|
|||||||
// Sets an OpSharding that will be attached to all instructions until cleared.
|
// Sets an OpSharding that will be attached to all instructions until cleared.
|
||||||
void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
|
void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
|
||||||
|
|
||||||
|
// Sets the FrontendAttributes that will be added to all instructions until
|
||||||
|
// cleared.
|
||||||
|
//
|
||||||
|
// FrontendAttributes are often applied to a series of XLA HLO instructions.
|
||||||
|
// As a result they are set on the computation builder and all the
|
||||||
|
// instructions generated via the computation builder will have the same
|
||||||
|
// frontend attributes attached to them.
|
||||||
|
void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) {
|
||||||
|
frontend_attributes_ = frontend_attributes;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap the passed FrontendAttributes with the ones currently set.
|
||||||
|
//
|
||||||
|
// Return the old attributes.
|
||||||
|
FrontendAttributes SwapFrontendAttributes(
|
||||||
|
const FrontendAttributes& frontend_attributes) {
|
||||||
|
FrontendAttributes old_attributes = std::move(frontend_attributes_);
|
||||||
|
frontend_attributes_ = frontend_attributes;
|
||||||
|
return old_attributes;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the FrontendAttributes that will be attached to all instructions.
|
||||||
|
const FrontendAttributes& frontend_attributes() const {
|
||||||
|
return frontend_attributes_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clears all the frontend attributes.
|
||||||
|
void ClearFrontendAttributes() { frontend_attributes_.Clear(); }
|
||||||
|
|
||||||
// Clears the sharding. Ops will be sharded according to the default placement
|
// Clears the sharding. Ops will be sharded according to the default placement
|
||||||
// policy.
|
// policy.
|
||||||
void ClearSharding() { sharding_ = absl::nullopt; }
|
void ClearSharding() { sharding_ = absl::nullopt; }
|
||||||
@ -314,6 +343,16 @@ class XlaBuilder {
|
|||||||
ShapeIndex param_index;
|
ShapeIndex param_index;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Looks up the HloInstruction and sets the frontend attribute "attribute" to
|
||||||
|
// "value".
|
||||||
|
//
|
||||||
|
// If the attribute already existed then its value is updated.
|
||||||
|
//
|
||||||
|
// Note: the attribute is only added to the HloInstruction, not to the
|
||||||
|
// builder.
|
||||||
|
Status SetInstructionFrontendAttribute(XlaOp op, string attribute,
|
||||||
|
string value);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Build helper which takes the id of the root operation..
|
// Build helper which takes the id of the root operation..
|
||||||
StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions);
|
StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions);
|
||||||
@ -595,9 +634,11 @@ class XlaBuilder {
|
|||||||
void AddCalledComputation(const XlaComputation& computation,
|
void AddCalledComputation(const XlaComputation& computation,
|
||||||
HloInstructionProto* instr);
|
HloInstructionProto* instr);
|
||||||
|
|
||||||
StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
|
StatusOr<const HloInstructionProto*> LookUpInstruction(XlaOp op) const;
|
||||||
StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
|
StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
|
||||||
int64 handle) const;
|
int64 handle) const;
|
||||||
|
StatusOr<HloInstructionProto*> LookUpMutableInstruction(XlaOp op);
|
||||||
|
StatusOr<HloInstructionProto*> LookUpMutableInstructionByHandle(int64 handle);
|
||||||
|
|
||||||
// Internal helper method that does the building for an arbitrary unary op.
|
// Internal helper method that does the building for an arbitrary unary op.
|
||||||
XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
|
XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
|
||||||
@ -707,6 +748,8 @@ class XlaBuilder {
|
|||||||
|
|
||||||
XlaBuilder* parent_builder_{nullptr};
|
XlaBuilder* parent_builder_{nullptr};
|
||||||
|
|
||||||
|
FrontendAttributes frontend_attributes_;
|
||||||
|
|
||||||
friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
|
friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
|
||||||
const Shape& shape, const string& name,
|
const Shape& shape, const string& name,
|
||||||
const std::vector<bool>& replicated_at_leaf_buffers);
|
const std::vector<bool>& replicated_at_leaf_buffers);
|
||||||
@ -1034,6 +1077,27 @@ class XlaScopedShardingAssignment {
|
|||||||
absl::optional<OpSharding> prev_sharding_;
|
absl::optional<OpSharding> prev_sharding_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// RAII-style object: save the current builder's frontend attributes, and merge
|
||||||
|
// them with the new ones on construction.
|
||||||
|
// Restore the original attributes on destruction.
|
||||||
|
class XlaScopedFrontendAttributesAssignment {
|
||||||
|
public:
|
||||||
|
XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder,
|
||||||
|
FrontendAttributes attributes)
|
||||||
|
: builder_(builder) {
|
||||||
|
saved_ = builder_->SwapFrontendAttributes(attributes);
|
||||||
|
}
|
||||||
|
|
||||||
|
~XlaScopedFrontendAttributesAssignment() {
|
||||||
|
builder_->SetFrontendAttributes(saved_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
xla::XlaBuilder* const builder_;
|
||||||
|
FrontendAttributes saved_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment);
|
||||||
|
};
|
||||||
// Free functions for building XlaOps. The intention is that these will
|
// Free functions for building XlaOps. The intention is that these will
|
||||||
// become the public API for building XlaOps rather than calling methods on
|
// become the public API for building XlaOps rather than calling methods on
|
||||||
// XlaBuilder directly.
|
// XlaBuilder directly.
|
||||||
|
@ -978,5 +978,151 @@ TEST_F(XlaBuilderTest, CheckInputOutputAlias) {
|
|||||||
EXPECT_EQ(*alias_p1, ShapeIndex({0}));
|
EXPECT_EQ(*alias_p1, ShapeIndex({0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ExpectAttributesMatch(const FrontendAttributes& attr,
|
||||||
|
const FrontendAttributes& ref) {
|
||||||
|
EXPECT_EQ(ref.map_size(), attr.map_size());
|
||||||
|
for (auto reference : ref.map()) {
|
||||||
|
auto other = attr.map().find(reference.first);
|
||||||
|
EXPECT_NE(other, attr.map().end());
|
||||||
|
EXPECT_EQ(other->second, reference.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExpectInstructionsAttributesMatch(
|
||||||
|
const HloModule& module, const std::vector<FrontendAttributes>& expected) {
|
||||||
|
ASSERT_EQ(module.computation_count(), 1);
|
||||||
|
auto expected_it = expected.begin();
|
||||||
|
for (auto inst : module.entry_computation()->instructions()) {
|
||||||
|
ASSERT_NE(expected_it, expected.end());
|
||||||
|
ExpectAttributesMatch(inst->frontend_attributes(), *expected_it);
|
||||||
|
expected_it++;
|
||||||
|
}
|
||||||
|
EXPECT_EQ(expected_it, expected.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(XlaBuilderTest, SimpleSetFrontendAttributes) {
|
||||||
|
XlaBuilder b(TestName());
|
||||||
|
FrontendAttributes attributes;
|
||||||
|
|
||||||
|
ConstantR0(&b, 0); // No attribute set
|
||||||
|
|
||||||
|
(*attributes.mutable_map())["attr_a"] = "a";
|
||||||
|
b.SetFrontendAttributes(attributes);
|
||||||
|
ConstantR0(&b, 0); // One attribute: { "attr_a": "a" }
|
||||||
|
|
||||||
|
b.ClearFrontendAttributes();
|
||||||
|
ConstantR0(&b, 0); // No attribute set
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||||
|
|
||||||
|
std::vector<FrontendAttributes> expected{FrontendAttributes(), attributes,
|
||||||
|
FrontendAttributes()};
|
||||||
|
ExpectInstructionsAttributesMatch(*module, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(XlaBuilderTest, ComplexSetFrontendAttributes) {
|
||||||
|
XlaBuilder b(TestName());
|
||||||
|
|
||||||
|
ConstantR0(&b, 0); // No attribute set.
|
||||||
|
std::vector<FrontendAttributes> expected{FrontendAttributes()};
|
||||||
|
|
||||||
|
{
|
||||||
|
FrontendAttributes attributes;
|
||||||
|
(*attributes.mutable_map())["attr_a"] = "a";
|
||||||
|
b.SetFrontendAttributes(attributes);
|
||||||
|
ConstantR0(&b, 0); // One attribute: { "attr_a": "a" }
|
||||||
|
expected.push_back(attributes);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
FrontendAttributes attributes;
|
||||||
|
(*attributes.mutable_map())["attr_b"] = "b";
|
||||||
|
b.SetFrontendAttributes(attributes);
|
||||||
|
ConstantR0(&b, 0); // One attribute: { "attr_b": "b" }
|
||||||
|
expected.push_back(attributes);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
FrontendAttributes attributes;
|
||||||
|
(*attributes.mutable_map())["attr_b"] = "b";
|
||||||
|
(*attributes.mutable_map())["attr_c"] = "c";
|
||||||
|
b.SetFrontendAttributes(attributes);
|
||||||
|
ConstantR0(&b, 0); // Two attributes: { "attr_b": "b", "attr_c": "c" }
|
||||||
|
expected.push_back(attributes);
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ClearFrontendAttributes();
|
||||||
|
ConstantR0(&b, 0); // No attribute set
|
||||||
|
expected.push_back(FrontendAttributes());
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||||
|
ExpectInstructionsAttributesMatch(*module, expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(XlaBuilderTest, AddFrontendAttribute) {
|
||||||
|
XlaBuilder b(TestName());
|
||||||
|
|
||||||
|
ConstantR0(&b, 0);
|
||||||
|
std::vector<FrontendAttributes> expected{FrontendAttributes()};
|
||||||
|
|
||||||
|
// One attribute: { "attr_a": "a" }
|
||||||
|
{
|
||||||
|
FrontendAttributes attributes;
|
||||||
|
(*attributes.mutable_map())["attr_a"] = "a";
|
||||||
|
b.SetFrontendAttributes(attributes);
|
||||||
|
ConstantR0(&b, 0);
|
||||||
|
expected.push_back(attributes);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Two attributes: {"attra": "a", "attr_c": "c"}
|
||||||
|
{
|
||||||
|
auto op = ConstantR0(&b, 0);
|
||||||
|
EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_c", "c"));
|
||||||
|
|
||||||
|
FrontendAttributes attributes;
|
||||||
|
(*attributes.mutable_map())["attr_a"] = "a";
|
||||||
|
(*attributes.mutable_map())["attr_c"] = "c";
|
||||||
|
expected.push_back(attributes);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override value of existing "attr_a"
|
||||||
|
// One attribute: { "attr_a", "a2"}
|
||||||
|
{
|
||||||
|
auto op = ConstantR0(&b, 0);
|
||||||
|
EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_a", "a2"));
|
||||||
|
FrontendAttributes attributes;
|
||||||
|
(*attributes.mutable_map())["attr_a"] = "a2";
|
||||||
|
expected.push_back(attributes);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check "attr_a" is back to its original value
|
||||||
|
// One attribute: { "attr_a", "a"}
|
||||||
|
{
|
||||||
|
auto op = ConstantR0(&b, 0);
|
||||||
|
(void)op;
|
||||||
|
FrontendAttributes attributes;
|
||||||
|
(*attributes.mutable_map())["attr_a"] = "a";
|
||||||
|
expected.push_back(attributes);
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ClearFrontendAttributes();
|
||||||
|
ConstantR0(&b, 0); // No attribute set
|
||||||
|
expected.push_back(FrontendAttributes());
|
||||||
|
|
||||||
|
// One attribute: { "attr_d", "d"}
|
||||||
|
{
|
||||||
|
auto op = ConstantR0(&b, 0);
|
||||||
|
EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_d", "d"));
|
||||||
|
FrontendAttributes attributes;
|
||||||
|
(*attributes.mutable_map())["attr_d"] = "d";
|
||||||
|
expected.push_back(attributes);
|
||||||
|
}
|
||||||
|
|
||||||
|
ConstantR0(&b, 0); // No attribute set
|
||||||
|
expected.push_back(FrontendAttributes());
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||||
|
ExpectInstructionsAttributesMatch(*module, expected);
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
|
|||||||
option cc_enable_arenas = true;
|
option cc_enable_arenas = true;
|
||||||
|
|
||||||
// Serialization of HloInstruction.
|
// Serialization of HloInstruction.
|
||||||
// Next ID: 68
|
// Next ID: 69
|
||||||
message HloInstructionProto {
|
message HloInstructionProto {
|
||||||
reserved 10;
|
reserved 10;
|
||||||
reserved "parameter_name";
|
reserved "parameter_name";
|
||||||
@ -234,6 +234,9 @@ message HloInstructionProto {
|
|||||||
// Specifies if the gather/scatter indices are guaranteed to be sorted by the
|
// Specifies if the gather/scatter indices are guaranteed to be sorted by the
|
||||||
// caller.
|
// caller.
|
||||||
bool indices_are_sorted = 67;
|
bool indices_are_sorted = 67;
|
||||||
|
|
||||||
|
// Frontend attributes to pass to the XLA backend.
|
||||||
|
xla.FrontendAttributes frontend_attributes = 68;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serialization of HloComputation.
|
// Serialization of HloComputation.
|
||||||
|
@ -837,6 +837,10 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
|
|||||||
if (new_instruction->metadata().op_name().empty()) {
|
if (new_instruction->metadata().op_name().empty()) {
|
||||||
new_instruction->set_metadata(old_instruction->metadata());
|
new_instruction->set_metadata(old_instruction->metadata());
|
||||||
}
|
}
|
||||||
|
if (new_instruction->frontend_attributes().map().empty()) {
|
||||||
|
new_instruction->set_frontend_attributes(
|
||||||
|
old_instruction->frontend_attributes());
|
||||||
|
}
|
||||||
|
|
||||||
// Like the metadata above, if the user didn't specify any sharding
|
// Like the metadata above, if the user didn't specify any sharding
|
||||||
// information on the new instruction we should copy the old sharding
|
// information on the new instruction we should copy the old sharding
|
||||||
|
@ -674,6 +674,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
instruction->set_sharding(sharding);
|
instruction->set_sharding(sharding);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (proto.has_frontend_attributes()) {
|
||||||
|
instruction->set_frontend_attributes(proto.frontend_attributes());
|
||||||
|
}
|
||||||
|
|
||||||
return std::move(instruction);
|
return std::move(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1194,6 +1198,7 @@ HloInstruction::CreateBroadcastSequence(
|
|||||||
if (operand->has_sharding()) {
|
if (operand->has_sharding()) {
|
||||||
broadcast->set_sharding(operand->sharding());
|
broadcast->set_sharding(operand->sharding());
|
||||||
}
|
}
|
||||||
|
broadcast->set_frontend_attributes(operand->frontend_attributes());
|
||||||
return broadcast;
|
return broadcast;
|
||||||
}
|
}
|
||||||
// Do explicit broadcast for degenerate broadcast.
|
// Do explicit broadcast for degenerate broadcast.
|
||||||
@ -1219,6 +1224,7 @@ HloInstruction::CreateBroadcastSequence(
|
|||||||
if (operand->has_sharding()) {
|
if (operand->has_sharding()) {
|
||||||
reshaped_operand->set_sharding(operand->sharding());
|
reshaped_operand->set_sharding(operand->sharding());
|
||||||
}
|
}
|
||||||
|
reshaped_operand->set_frontend_attributes(operand->frontend_attributes());
|
||||||
// Broadcast 'reshape' up to the larger size.
|
// Broadcast 'reshape' up to the larger size.
|
||||||
auto broadcast = HloInstruction::CreateBroadcast(
|
auto broadcast = HloInstruction::CreateBroadcast(
|
||||||
broadcast_shape, reshaped_operand, broadcast_dimensions);
|
broadcast_shape, reshaped_operand, broadcast_dimensions);
|
||||||
@ -1226,6 +1232,7 @@ HloInstruction::CreateBroadcastSequence(
|
|||||||
if (operand->has_sharding()) {
|
if (operand->has_sharding()) {
|
||||||
broadcast->set_sharding(operand->sharding());
|
broadcast->set_sharding(operand->sharding());
|
||||||
}
|
}
|
||||||
|
broadcast->set_frontend_attributes(operand->frontend_attributes());
|
||||||
return broadcast;
|
return broadcast;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1296,6 +1303,7 @@ void HloInstruction::SetupDerivedInstruction(
|
|||||||
derived_instruction->clear_sharding();
|
derived_instruction->clear_sharding();
|
||||||
}
|
}
|
||||||
derived_instruction->set_metadata(metadata_);
|
derived_instruction->set_metadata(metadata_);
|
||||||
|
derived_instruction->set_frontend_attributes(frontend_attributes_);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HloInstruction::HasSideEffectNoRecurse() const {
|
bool HloInstruction::HasSideEffectNoRecurse() const {
|
||||||
@ -2483,6 +2491,10 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
|
|||||||
if (has_sharding()) {
|
if (has_sharding()) {
|
||||||
extra.push_back(StrCat("sharding=", sharding().ToString()));
|
extra.push_back(StrCat("sharding=", sharding().ToString()));
|
||||||
}
|
}
|
||||||
|
if (!frontend_attributes_.map().empty()) {
|
||||||
|
extra.push_back(StrCat("frontend_attributes=",
|
||||||
|
FrontendAttributesToString(frontend_attributes_)));
|
||||||
|
}
|
||||||
if (!outer_dimension_partitions_.empty()) {
|
if (!outer_dimension_partitions_.empty()) {
|
||||||
extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}",
|
extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}",
|
||||||
StrJoin(outer_dimension_partitions_, ",")));
|
StrJoin(outer_dimension_partitions_, ",")));
|
||||||
@ -2543,6 +2555,8 @@ HloInstructionProto HloInstruction::ToProto() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
*proto.mutable_frontend_attributes() = frontend_attributes_;
|
||||||
|
|
||||||
return proto;
|
return proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3197,6 +3211,15 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
|
|||||||
return InvalidArgument("Unknown fusion kind: %s", kind_name);
|
return InvalidArgument("Unknown fusion kind: %s", kind_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string FrontendAttributesToString(
|
||||||
|
const FrontendAttributes& frontend_attributes) {
|
||||||
|
std::vector<std::pair<string, string>> sorted_attributes(
|
||||||
|
frontend_attributes.map().begin(), frontend_attributes.map().end());
|
||||||
|
absl::c_sort(sorted_attributes);
|
||||||
|
return absl::StrFormat(
|
||||||
|
"{%s}", absl::StrJoin(sorted_attributes, ",", absl::PairFormatter("=")));
|
||||||
|
}
|
||||||
|
|
||||||
string PaddingConfigToString(const PaddingConfig& padding) {
|
string PaddingConfigToString(const PaddingConfig& padding) {
|
||||||
bool has_interior_padding =
|
bool has_interior_padding =
|
||||||
absl::c_any_of(padding.dimensions(),
|
absl::c_any_of(padding.dimensions(),
|
||||||
|
@ -1385,6 +1385,14 @@ class HloInstruction {
|
|||||||
}
|
}
|
||||||
Status set_backend_config(const tensorflow::protobuf::Message& proto);
|
Status set_backend_config(const tensorflow::protobuf::Message& proto);
|
||||||
|
|
||||||
|
void set_frontend_attributes(FrontendAttributes frontend_attributes) {
|
||||||
|
frontend_attributes_ = std::move(frontend_attributes);
|
||||||
|
}
|
||||||
|
|
||||||
|
const FrontendAttributes& frontend_attributes() const {
|
||||||
|
return frontend_attributes_;
|
||||||
|
}
|
||||||
|
|
||||||
// Getter/setter for raw JSON-encoded backend config. Prefer the
|
// Getter/setter for raw JSON-encoded backend config. Prefer the
|
||||||
// functions above that deal in proto Messages where possible.
|
// functions above that deal in proto Messages where possible.
|
||||||
const string& raw_backend_config_string() const { return backend_config_; }
|
const string& raw_backend_config_string() const { return backend_config_; }
|
||||||
@ -1879,6 +1887,18 @@ class HloInstruction {
|
|||||||
// HLO. See the documentation on backend_config().
|
// HLO. See the documentation on backend_config().
|
||||||
string backend_config_;
|
string backend_config_;
|
||||||
|
|
||||||
|
// Attributes passed from the frontend to give hints to the backend about
|
||||||
|
// how to compile this HLO.
|
||||||
|
// HLO -> HLO transforms are expected to preserve these attributes on a
|
||||||
|
// "best effort" basis only.
|
||||||
|
// For example:
|
||||||
|
// x = const(10, frontend_attributes={x}
|
||||||
|
// y = const(10, frontend_attributes={y}
|
||||||
|
// z = add(x,y), frontend_attributes={y}
|
||||||
|
// Could be simplified to:
|
||||||
|
// z' = const(20), frontend_attributes={?}
|
||||||
|
FrontendAttributes frontend_attributes_;
|
||||||
|
|
||||||
// This field is assigned to true when backend_config_ is assigned to
|
// This field is assigned to true when backend_config_ is assigned to
|
||||||
// a default configuration.
|
// a default configuration.
|
||||||
bool is_default_config_ = false;
|
bool is_default_config_ = false;
|
||||||
@ -1909,6 +1929,8 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
|
|||||||
// Custom (de)stringification functions for protos that live inside
|
// Custom (de)stringification functions for protos that live inside
|
||||||
// HloInstruction.
|
// HloInstruction.
|
||||||
string PaddingConfigToString(const PaddingConfig& padding);
|
string PaddingConfigToString(const PaddingConfig& padding);
|
||||||
|
string FrontendAttributesToString(
|
||||||
|
const FrontendAttributes& frontend_attributes);
|
||||||
string OpMetadataToString(const OpMetadata& metadata);
|
string OpMetadataToString(const OpMetadata& metadata);
|
||||||
string RandomDistributionToString(const RandomDistribution& distribution);
|
string RandomDistributionToString(const RandomDistribution& distribution);
|
||||||
string PrecisionToString(const PrecisionConfig::Precision& precision);
|
string PrecisionToString(const PrecisionConfig::Precision& precision);
|
||||||
|
@ -88,6 +88,7 @@ class HloParser {
|
|||||||
// Stand alone parsing utils for various aggregate data types.
|
// Stand alone parsing utils for various aggregate data types.
|
||||||
StatusOr<Shape> ParseShapeOnly();
|
StatusOr<Shape> ParseShapeOnly();
|
||||||
StatusOr<HloSharding> ParseShardingOnly();
|
StatusOr<HloSharding> ParseShardingOnly();
|
||||||
|
StatusOr<FrontendAttributes> ParseFrontendAttributesOnly();
|
||||||
StatusOr<std::vector<bool>> ParseParameterReplicationOnly();
|
StatusOr<std::vector<bool>> ParseParameterReplicationOnly();
|
||||||
StatusOr<Window> ParseWindowOnly();
|
StatusOr<Window> ParseWindowOnly();
|
||||||
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
|
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
|
||||||
@ -192,6 +193,7 @@ class HloParser {
|
|||||||
kWindow,
|
kWindow,
|
||||||
kConvolutionDimensionNumbers,
|
kConvolutionDimensionNumbers,
|
||||||
kSharding,
|
kSharding,
|
||||||
|
kFrontendAttributes,
|
||||||
kParameterReplication,
|
kParameterReplication,
|
||||||
kInstructionList,
|
kInstructionList,
|
||||||
kSliceRanges,
|
kSliceRanges,
|
||||||
@ -271,6 +273,7 @@ class HloParser {
|
|||||||
bool ParsePaddingConfig(PaddingConfig* padding);
|
bool ParsePaddingConfig(PaddingConfig* padding);
|
||||||
bool ParseMetadata(OpMetadata* metadata);
|
bool ParseMetadata(OpMetadata* metadata);
|
||||||
bool ParseSharding(OpSharding* sharding);
|
bool ParseSharding(OpSharding* sharding);
|
||||||
|
bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes);
|
||||||
bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
|
bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
|
||||||
bool ParseParameterReplication(ParameterReplication* parameter_replication);
|
bool ParseParameterReplication(ParameterReplication* parameter_replication);
|
||||||
bool ParseReplicaGroupsOnly(std::vector<ReplicaGroup>* replica_groups);
|
bool ParseReplicaGroupsOnly(std::vector<ReplicaGroup>* replica_groups);
|
||||||
@ -677,7 +680,10 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
// Add optional attributes.
|
// Add optional attributes.
|
||||||
std::unordered_map<string, AttrConfig> attrs;
|
std::unordered_map<string, AttrConfig> attrs;
|
||||||
optional<OpSharding> sharding;
|
optional<OpSharding> sharding;
|
||||||
|
optional<FrontendAttributes> frontend_attributes;
|
||||||
attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
|
attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
|
||||||
|
attrs["frontend_attributes"] = {
|
||||||
|
/*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes};
|
||||||
optional<ParameterReplication> parameter_replication;
|
optional<ParameterReplication> parameter_replication;
|
||||||
attrs["parameter_replication"] = {/*required=*/false,
|
attrs["parameter_replication"] = {/*required=*/false,
|
||||||
AttrTy::kParameterReplication,
|
AttrTy::kParameterReplication,
|
||||||
@ -1845,6 +1851,36 @@ bool HloParser::ParseSharding(OpSharding* sharding) {
|
|||||||
return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
|
return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// frontend_attributes ::= '{' attributes '}'
|
||||||
|
// attributes
|
||||||
|
// ::= /*empty*/
|
||||||
|
// ::= attribute '=' value (',' attribute '=' value)*
|
||||||
|
bool HloParser::ParseFrontendAttributes(
|
||||||
|
FrontendAttributes* frontend_attributes) {
|
||||||
|
CHECK(frontend_attributes != nullptr);
|
||||||
|
if (!ParseToken(TokKind::kLbrace,
|
||||||
|
"expected '{' to start frontend attributes")) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (lexer_.GetKind() == TokKind::kRbrace) {
|
||||||
|
// empty
|
||||||
|
} else {
|
||||||
|
do {
|
||||||
|
string attribute;
|
||||||
|
if (!ParseAttributeName(&attribute)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
(*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal();
|
||||||
|
lexer_.Lex();
|
||||||
|
} while (EatIfPresent(TokKind::kComma));
|
||||||
|
}
|
||||||
|
return ParseToken(TokKind::kRbrace,
|
||||||
|
"expects '}' at the end of frontend attributes");
|
||||||
|
}
|
||||||
|
|
||||||
// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape?
|
// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape?
|
||||||
// ('devices=' ('[' dims ']')* device_list)? '}'
|
// ('devices=' ('[' dims ']')* device_list)? '}'
|
||||||
// dims ::= int_list device_list ::= int_list
|
// dims ::= int_list device_list ::= int_list
|
||||||
@ -2864,6 +2900,15 @@ bool HloParser::ParseAttributeHelper(
|
|||||||
static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
|
static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
case AttrTy::kFrontendAttributes: {
|
||||||
|
FrontendAttributes frontend_attributes;
|
||||||
|
if (!ParseFrontendAttributes(&frontend_attributes)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
static_cast<optional<FrontendAttributes>*>(attr_out_ptr)
|
||||||
|
->emplace(frontend_attributes);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
case AttrTy::kParameterReplication: {
|
case AttrTy::kParameterReplication: {
|
||||||
ParameterReplication parameter_replication;
|
ParameterReplication parameter_replication;
|
||||||
if (!ParseParameterReplication(¶meter_replication)) {
|
if (!ParseParameterReplication(¶meter_replication)) {
|
||||||
@ -4120,6 +4165,19 @@ StatusOr<HloSharding> HloParser::ParseShardingOnly() {
|
|||||||
return HloSharding::FromProto(op_sharding);
|
return HloSharding::FromProto(op_sharding);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<FrontendAttributes> HloParser::ParseFrontendAttributesOnly() {
|
||||||
|
lexer_.Lex();
|
||||||
|
FrontendAttributes attributes;
|
||||||
|
if (!ParseFrontendAttributes(&attributes)) {
|
||||||
|
return InvalidArgument("Syntax error:\n%s", GetError());
|
||||||
|
}
|
||||||
|
if (lexer_.GetKind() != TokKind::kEof) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"Syntax error:\nExtra content after frontend attributes");
|
||||||
|
}
|
||||||
|
return attributes;
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<bool>> HloParser::ParseParameterReplicationOnly() {
|
StatusOr<std::vector<bool>> HloParser::ParseParameterReplicationOnly() {
|
||||||
lexer_.Lex();
|
lexer_.Lex();
|
||||||
ParameterReplication parameter_replication;
|
ParameterReplication parameter_replication;
|
||||||
@ -4268,6 +4326,11 @@ StatusOr<HloSharding> ParseSharding(absl::string_view str) {
|
|||||||
return parser.ParseShardingOnly();
|
return parser.ParseShardingOnly();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<FrontendAttributes> ParseFrontendAttributes(absl::string_view str) {
|
||||||
|
HloParser parser(str);
|
||||||
|
return parser.ParseFrontendAttributesOnly();
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<bool>> ParseParameterReplication(absl::string_view str) {
|
StatusOr<std::vector<bool>> ParseParameterReplication(absl::string_view str) {
|
||||||
HloParser parser(str);
|
HloParser parser(str);
|
||||||
return parser.ParseParameterReplicationOnly();
|
return parser.ParseParameterReplicationOnly();
|
||||||
|
@ -54,6 +54,12 @@ Status ParseHloString(absl::string_view str, HloModule* module);
|
|||||||
// "{replicated}".
|
// "{replicated}".
|
||||||
StatusOr<HloSharding> ParseSharding(absl::string_view str);
|
StatusOr<HloSharding> ParseSharding(absl::string_view str);
|
||||||
|
|
||||||
|
// Parses frontend attributes from str. str is supposed to contain the body of
|
||||||
|
// the frontend attributes , i.e. just the rhs of the
|
||||||
|
// "frontend_attributes={...}" attribute string, e.g.,
|
||||||
|
// "{attr_a=a,attr_b=b}".
|
||||||
|
StatusOr<FrontendAttributes> ParseFrontendAttributes(absl::string_view str);
|
||||||
|
|
||||||
// Parses parameter replication from str. str is supposed to contain the body of
|
// Parses parameter replication from str. str is supposed to contain the body of
|
||||||
// the parameter replication, i.e. just the rhs of the
|
// the parameter replication, i.e. just the rhs of the
|
||||||
// "parameter_replication={...}" attribute string, e.g., "{true, false}".
|
// "parameter_replication={...}" attribute string, e.g., "{true, false}".
|
||||||
|
@ -2358,6 +2358,13 @@ TEST_F(HloParserTest, ParseSharding) {
|
|||||||
EXPECT_EQ(sharding.ToString(), original);
|
EXPECT_EQ(sharding.ToString(), original);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(HloParserTest, ParseFrontendAttributes) {
|
||||||
|
const string original = "{attr_a=test_a,attr_b=b}";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(FrontendAttributes frontend_attributes,
|
||||||
|
ParseFrontendAttributes(original));
|
||||||
|
EXPECT_EQ(FrontendAttributesToString(frontend_attributes), original);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(HloParserTest, ParseWindow) {
|
TEST_F(HloParserTest, ParseWindow) {
|
||||||
Window original = window_util::MakeWindow({1, 2, 3});
|
Window original = window_util::MakeWindow({1, 2, 3});
|
||||||
TF_ASSERT_OK_AND_ASSIGN(Window parsed,
|
TF_ASSERT_OK_AND_ASSIGN(Window parsed,
|
||||||
|
@ -583,6 +583,12 @@ message CholeskyOptions {
|
|||||||
bool lower = 1;
|
bool lower = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generic map of attributes used to pass hints / configuration options from
|
||||||
|
// the Python frontend to the XLA backend.
|
||||||
|
message FrontendAttributes {
|
||||||
|
map<string, string> map = 1;
|
||||||
|
}
|
||||||
|
|
||||||
message OpSharding {
|
message OpSharding {
|
||||||
enum Type {
|
enum Type {
|
||||||
// This sharding is replicated across all devices (implies maximal,
|
// This sharding is replicated across all devices (implies maximal,
|
||||||
|
Loading…
Reference in New Issue
Block a user