diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 1f8df23c18a..f6bf672d6a0 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -203,14 +203,15 @@ cc_library( visibility = [":friends"], deps = [ ":common", + ":frontend_attributes_util", ":host_compute_metadata_proto", + ":rearrange_function_argument", ":sharding_util", ":side_effect_util", ":tf2xla_util", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:shape_inference", "//tensorflow/compiler/jit:xla_cluster_util", - "//tensorflow/compiler/tf2xla:rearrange_function_argument", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -271,6 +272,21 @@ cc_library( ], ) +cc_library( + name = "frontend_attributes_util", + srcs = ["frontend_attributes_util.cc"], + hdrs = ["frontend_attributes_util.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "sharding_util", srcs = ["sharding_util.cc"], @@ -579,6 +595,7 @@ cc_library( "functionalize_while.h", ], deps = [ + ":frontend_attributes_util", ":functionalize_cond", ":functionalize_control_flow_util", ":tf2xla_util", diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.cc b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc new file mode 100644 index 00000000000..e0c70b81771 --- /dev/null +++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc @@ -0,0 +1,41 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +const char kXlaFrontendAttributesAttrName[] = "_XlaFrontendAttributes"; + +xla::StatusOr> +GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs) { + const AttrValue* attr = attrs.Find(kXlaFrontendAttributesAttrName); + if (attr == nullptr) { + return xla::StatusOr>( + absl::nullopt); + } + xla::FrontendAttributes attributes; + if (!attributes.ParseFromString(attr->s())) { + return errors::InvalidArgument( + "Experimental _XlaFrontendAttributes attribute was not a valid encoded " + "xla::FrontendAttributes proto."); + } + return absl::optional(attributes); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.h b/tensorflow/compiler/tf2xla/frontend_attributes_util.h new file mode 100644 index 00000000000..421f21e71d1 --- /dev/null +++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.h @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_ + +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/node_def_util.h" + +namespace tensorflow { + +// Frontend Attributes Id. +extern const char kXlaFrontendAttributesAttrName[]; +// Return the FrontendAttributes stored in the AttrSlice if there are some. +// +// Return an InvalidArgument error if some attributes are present but +// cannot be parsed. +xla::StatusOr> +GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 87c7ea82998..74790f9ee4d 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -494,6 +495,12 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, builder.Attr("cond", cond_name); builder.Attr("body", body_name); string outside_compilation; + string frontend_attributes; + if (GetNodeAttr(frame->loop_cond->def(), kXlaFrontendAttributesAttrName, + &frontend_attributes) + .ok()) { + builder.Attr(kXlaFrontendAttributesAttrName, frontend_attributes); + } if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName, &outside_compilation) .ok()) { diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index c14519c3ade..06423019f23 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -98,6 +99,20 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, absl::optional op_sharding = sharding_parse_result.ValueOrDie(); + auto frontend_attributes_result = + GetFrontendAttributesFromAttrSlice(AttrSlice(op_kernel->def())); + OP_REQUIRES_OK(context, frontend_attributes_result.status()); + absl::optional attributes = + frontend_attributes_result.ValueOrDie(); + + xla::FrontendAttributes merged_attributes = b->frontend_attributes(); + if (attributes.has_value()) { + merged_attributes.mutable_map()->insert(attributes.value().map().begin(), + attributes.value().map().end()); + } + xla::XlaScopedFrontendAttributesAssignment assign_frontend_attributes( + b, std::move(merged_attributes)); + // If no sharding metadata is found, XLA is free to use whatever device it // wants. In practice this usually has the effect of placing things on device // 0. diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 0a0459db8dd..dccdec22fb9 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -289,6 +289,15 @@ Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num, return Status::OK(); } +Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp op, + std::string attribute, + std::string value) { + TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op)); + auto* frontend_attributes = instr_proto->mutable_frontend_attributes(); + (*frontend_attributes->mutable_map())[attribute] = std::move(value); + return Status::OK(); +} + XlaComputation XlaBuilder::BuildAndNoteError() { DCHECK(parent_builder_ != nullptr); auto build_status = Build(); @@ -2626,6 +2635,7 @@ StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, if (sharding_) { *instr.mutable_sharding() = *sharding_; } + *instr.mutable_frontend_attributes() = frontend_attributes_; handle_to_index_[handle] = instructions_.size(); instructions_.push_back(std::move(instr)); @@ -2683,32 +2693,67 @@ void XlaBuilder::AddCalledComputation(const XlaComputation& computation, } } -StatusOr XlaBuilder::LookUpInstruction( - const XlaOp& op) const { - TF_RETURN_IF_ERROR(first_error_); +namespace { - if (op.builder_ == nullptr) { +template +StatusOr LookUpInstructionByHandleInternal( + const absl::flat_hash_map& handle_to_index, + const std::vector& instructions, int64 handle) { + auto it = handle_to_index.find(handle); + if (it == handle_to_index.end()) { + return InvalidArgument("No XlaOp with handle %d", handle); + } + return const_cast(&instructions.at(it->second)); +} + +template +StatusOr LookUpInstructionInternal( + const absl::flat_hash_map& handle_to_index, + const std::vector& instructions, + OpBuilderType op_builder, BuilderType builder, OpType op_handle) { + if (op_builder == nullptr) { return InvalidArgument( "invalid XlaOp with handle %d; the builder of this op is freed", - op.handle()); + op_handle); } - if (op.builder_ != this) { + if (op_builder != builder) { return InvalidArgument( "XlaOp with handle %d is built by builder '%s', but is trying to use " "it in builder '%s'", - op.handle(), op.builder_->name(), this->name()); + op_handle, op_builder->name(), builder->name()); } - return LookUpInstructionByHandle(op.handle()); + return LookUpInstructionByHandleInternal( + handle_to_index, instructions, op_handle); +} + +} // namespace + +StatusOr XlaBuilder::LookUpInstruction( + const XlaOp op) const { + TF_RETURN_IF_ERROR(first_error_); + return LookUpInstructionInternal( + handle_to_index_, instructions_, op.builder_, this, op.handle()); } StatusOr XlaBuilder::LookUpInstructionByHandle( int64 handle) const { - auto it = handle_to_index_.find(handle); - if (it == handle_to_index_.end()) { - return InvalidArgument("No XlaOp with handle %d", handle); - } - return &instructions_[it->second]; + return LookUpInstructionByHandleInternal( + handle_to_index_, instructions_, handle); +} + +StatusOr XlaBuilder::LookUpMutableInstruction( + const XlaOp op) { + TF_RETURN_IF_ERROR(first_error_); + return LookUpInstructionInternal( + handle_to_index_, instructions_, op.builder_, this, op.handle()); +} + +StatusOr XlaBuilder::LookUpMutableInstructionByHandle( + int64 handle) { + return LookUpInstructionByHandleInternal( + handle_to_index_, instructions_, handle); } // Enqueues a "retrieve parameter value" instruction for a parameter that was diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 3279a8bbb64..5c28e8b5150 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -147,8 +147,8 @@ class XlaBuilder { // Sets OpMetadata that will be added to all instructions until cleared. // // OpMetadata is often applied to a series of XLA HLO instructions. As a - // result, OpMetadata is set on the Computation Builder. All subsequent - // instructions generated via this Computation Builder will have the same + // result, OpMetadata is set on the computation builder. All subsequent + // instructions generated via this computation builder will have the same // OpMetadata attached until a call to ClearOpMetadata. void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); } @@ -158,6 +158,35 @@ class XlaBuilder { // Sets an OpSharding that will be attached to all instructions until cleared. void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } + // Sets the FrontendAttributes that will be added to all instructions until + // cleared. + // + // FrontendAttributes are often applied to a series of XLA HLO instructions. + // As a result they are set on the computation builder and all the + // instructions generated via the computation builder will have the same + // frontend attributes attached to them. + void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) { + frontend_attributes_ = frontend_attributes; + } + + // Swap the passed FrontendAttributes with the ones currently set. + // + // Return the old attributes. + FrontendAttributes SwapFrontendAttributes( + const FrontendAttributes& frontend_attributes) { + FrontendAttributes old_attributes = std::move(frontend_attributes_); + frontend_attributes_ = frontend_attributes; + return old_attributes; + } + + // Returns the FrontendAttributes that will be attached to all instructions. + const FrontendAttributes& frontend_attributes() const { + return frontend_attributes_; + } + + // Clears all the frontend attributes. + void ClearFrontendAttributes() { frontend_attributes_.Clear(); } + // Clears the sharding. Ops will be sharded according to the default placement // policy. void ClearSharding() { sharding_ = absl::nullopt; } @@ -314,6 +343,16 @@ class XlaBuilder { ShapeIndex param_index; }; + // Looks up the HloInstruction and sets the frontend attribute "attribute" to + // "value". + // + // If the attribute already existed then its value is updated. + // + // Note: the attribute is only added to the HloInstruction, not to the + // builder. + Status SetInstructionFrontendAttribute(XlaOp op, string attribute, + string value); + private: // Build helper which takes the id of the root operation.. StatusOr Build(int64 root_id, bool remove_dynamic_dimensions); @@ -595,9 +634,11 @@ class XlaBuilder { void AddCalledComputation(const XlaComputation& computation, HloInstructionProto* instr); - StatusOr LookUpInstruction(const XlaOp& op) const; + StatusOr LookUpInstruction(XlaOp op) const; StatusOr LookUpInstructionByHandle( int64 handle) const; + StatusOr LookUpMutableInstruction(XlaOp op); + StatusOr LookUpMutableInstructionByHandle(int64 handle); // Internal helper method that does the building for an arbitrary unary op. XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); @@ -707,6 +748,8 @@ class XlaBuilder { XlaBuilder* parent_builder_{nullptr}; + FrontendAttributes frontend_attributes_; + friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape, const string& name, const std::vector& replicated_at_leaf_buffers); @@ -1034,6 +1077,27 @@ class XlaScopedShardingAssignment { absl::optional prev_sharding_; }; +// RAII-style object: save the current builder's frontend attributes, and merge +// them with the new ones on construction. +// Restore the original attributes on destruction. +class XlaScopedFrontendAttributesAssignment { + public: + XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder, + FrontendAttributes attributes) + : builder_(builder) { + saved_ = builder_->SwapFrontendAttributes(attributes); + } + + ~XlaScopedFrontendAttributesAssignment() { + builder_->SetFrontendAttributes(saved_); + } + + private: + xla::XlaBuilder* const builder_; + FrontendAttributes saved_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment); +}; // Free functions for building XlaOps. The intention is that these will // become the public API for building XlaOps rather than calling methods on // XlaBuilder directly. diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 12656a89943..701729b94f3 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -978,5 +978,151 @@ TEST_F(XlaBuilderTest, CheckInputOutputAlias) { EXPECT_EQ(*alias_p1, ShapeIndex({0})); } +void ExpectAttributesMatch(const FrontendAttributes& attr, + const FrontendAttributes& ref) { + EXPECT_EQ(ref.map_size(), attr.map_size()); + for (auto reference : ref.map()) { + auto other = attr.map().find(reference.first); + EXPECT_NE(other, attr.map().end()); + EXPECT_EQ(other->second, reference.second); + } +} + +void ExpectInstructionsAttributesMatch( + const HloModule& module, const std::vector& expected) { + ASSERT_EQ(module.computation_count(), 1); + auto expected_it = expected.begin(); + for (auto inst : module.entry_computation()->instructions()) { + ASSERT_NE(expected_it, expected.end()); + ExpectAttributesMatch(inst->frontend_attributes(), *expected_it); + expected_it++; + } + EXPECT_EQ(expected_it, expected.end()); +} + +TEST_F(XlaBuilderTest, SimpleSetFrontendAttributes) { + XlaBuilder b(TestName()); + FrontendAttributes attributes; + + ConstantR0(&b, 0); // No attribute set + + (*attributes.mutable_map())["attr_a"] = "a"; + b.SetFrontendAttributes(attributes); + ConstantR0(&b, 0); // One attribute: { "attr_a": "a" } + + b.ClearFrontendAttributes(); + ConstantR0(&b, 0); // No attribute set + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + + std::vector expected{FrontendAttributes(), attributes, + FrontendAttributes()}; + ExpectInstructionsAttributesMatch(*module, expected); +} + +TEST_F(XlaBuilderTest, ComplexSetFrontendAttributes) { + XlaBuilder b(TestName()); + + ConstantR0(&b, 0); // No attribute set. + std::vector expected{FrontendAttributes()}; + + { + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_a"] = "a"; + b.SetFrontendAttributes(attributes); + ConstantR0(&b, 0); // One attribute: { "attr_a": "a" } + expected.push_back(attributes); + } + + { + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_b"] = "b"; + b.SetFrontendAttributes(attributes); + ConstantR0(&b, 0); // One attribute: { "attr_b": "b" } + expected.push_back(attributes); + } + + { + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_b"] = "b"; + (*attributes.mutable_map())["attr_c"] = "c"; + b.SetFrontendAttributes(attributes); + ConstantR0(&b, 0); // Two attributes: { "attr_b": "b", "attr_c": "c" } + expected.push_back(attributes); + } + + b.ClearFrontendAttributes(); + ConstantR0(&b, 0); // No attribute set + expected.push_back(FrontendAttributes()); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + ExpectInstructionsAttributesMatch(*module, expected); +} + +TEST_F(XlaBuilderTest, AddFrontendAttribute) { + XlaBuilder b(TestName()); + + ConstantR0(&b, 0); + std::vector expected{FrontendAttributes()}; + + // One attribute: { "attr_a": "a" } + { + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_a"] = "a"; + b.SetFrontendAttributes(attributes); + ConstantR0(&b, 0); + expected.push_back(attributes); + } + + // Two attributes: {"attra": "a", "attr_c": "c"} + { + auto op = ConstantR0(&b, 0); + EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_c", "c")); + + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_a"] = "a"; + (*attributes.mutable_map())["attr_c"] = "c"; + expected.push_back(attributes); + } + + // Override value of existing "attr_a" + // One attribute: { "attr_a", "a2"} + { + auto op = ConstantR0(&b, 0); + EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_a", "a2")); + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_a"] = "a2"; + expected.push_back(attributes); + } + + // Check "attr_a" is back to its original value + // One attribute: { "attr_a", "a"} + { + auto op = ConstantR0(&b, 0); + (void)op; + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_a"] = "a"; + expected.push_back(attributes); + } + + b.ClearFrontendAttributes(); + ConstantR0(&b, 0); // No attribute set + expected.push_back(FrontendAttributes()); + + // One attribute: { "attr_d", "d"} + { + auto op = ConstantR0(&b, 0); + EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_d", "d")); + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_d"] = "d"; + expected.push_back(attributes); + } + + ConstantR0(&b, 0); // No attribute set + expected.push_back(FrontendAttributes()); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + ExpectInstructionsAttributesMatch(*module, expected); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 4dd6d096750..61e562c7eda 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 68 +// Next ID: 69 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -234,6 +234,9 @@ message HloInstructionProto { // Specifies if the gather/scatter indices are guaranteed to be sorted by the // caller. bool indices_are_sorted = 67; + + // Frontend attributes to pass to the XLA backend. + xla.FrontendAttributes frontend_attributes = 68; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 24928f474f2..cbdada0b46b 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -837,6 +837,10 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, if (new_instruction->metadata().op_name().empty()) { new_instruction->set_metadata(old_instruction->metadata()); } + if (new_instruction->frontend_attributes().map().empty()) { + new_instruction->set_frontend_attributes( + old_instruction->frontend_attributes()); + } // Like the metadata above, if the user didn't specify any sharding // information on the new instruction we should copy the old sharding diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 9ca5fac4524..dabd7ab2836 100755 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -674,6 +674,10 @@ StatusOr> HloInstruction::CreateFromProto( instruction->set_sharding(sharding); } + if (proto.has_frontend_attributes()) { + instruction->set_frontend_attributes(proto.frontend_attributes()); + } + return std::move(instruction); } @@ -1194,6 +1198,7 @@ HloInstruction::CreateBroadcastSequence( if (operand->has_sharding()) { broadcast->set_sharding(operand->sharding()); } + broadcast->set_frontend_attributes(operand->frontend_attributes()); return broadcast; } // Do explicit broadcast for degenerate broadcast. @@ -1219,6 +1224,7 @@ HloInstruction::CreateBroadcastSequence( if (operand->has_sharding()) { reshaped_operand->set_sharding(operand->sharding()); } + reshaped_operand->set_frontend_attributes(operand->frontend_attributes()); // Broadcast 'reshape' up to the larger size. auto broadcast = HloInstruction::CreateBroadcast( broadcast_shape, reshaped_operand, broadcast_dimensions); @@ -1226,6 +1232,7 @@ HloInstruction::CreateBroadcastSequence( if (operand->has_sharding()) { broadcast->set_sharding(operand->sharding()); } + broadcast->set_frontend_attributes(operand->frontend_attributes()); return broadcast; } @@ -1296,6 +1303,7 @@ void HloInstruction::SetupDerivedInstruction( derived_instruction->clear_sharding(); } derived_instruction->set_metadata(metadata_); + derived_instruction->set_frontend_attributes(frontend_attributes_); } bool HloInstruction::HasSideEffectNoRecurse() const { @@ -2483,6 +2491,10 @@ std::vector HloInstruction::ExtraAttributesToString( if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } + if (!frontend_attributes_.map().empty()) { + extra.push_back(StrCat("frontend_attributes=", + FrontendAttributesToString(frontend_attributes_))); + } if (!outer_dimension_partitions_.empty()) { extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}", StrJoin(outer_dimension_partitions_, ","))); @@ -2543,6 +2555,8 @@ HloInstructionProto HloInstruction::ToProto() const { } } + *proto.mutable_frontend_attributes() = frontend_attributes_; + return proto; } @@ -3197,6 +3211,15 @@ StatusOr StringToFusionKind( return InvalidArgument("Unknown fusion kind: %s", kind_name); } +string FrontendAttributesToString( + const FrontendAttributes& frontend_attributes) { + std::vector> sorted_attributes( + frontend_attributes.map().begin(), frontend_attributes.map().end()); + absl::c_sort(sorted_attributes); + return absl::StrFormat( + "{%s}", absl::StrJoin(sorted_attributes, ",", absl::PairFormatter("="))); +} + string PaddingConfigToString(const PaddingConfig& padding) { bool has_interior_padding = absl::c_any_of(padding.dimensions(), diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index c513a95e8a0..3119b52e377 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1385,6 +1385,14 @@ class HloInstruction { } Status set_backend_config(const tensorflow::protobuf::Message& proto); + void set_frontend_attributes(FrontendAttributes frontend_attributes) { + frontend_attributes_ = std::move(frontend_attributes); + } + + const FrontendAttributes& frontend_attributes() const { + return frontend_attributes_; + } + // Getter/setter for raw JSON-encoded backend config. Prefer the // functions above that deal in proto Messages where possible. const string& raw_backend_config_string() const { return backend_config_; } @@ -1879,6 +1887,18 @@ class HloInstruction { // HLO. See the documentation on backend_config(). string backend_config_; + // Attributes passed from the frontend to give hints to the backend about + // how to compile this HLO. + // HLO -> HLO transforms are expected to preserve these attributes on a + // "best effort" basis only. + // For example: + // x = const(10, frontend_attributes={x} + // y = const(10, frontend_attributes={y} + // z = add(x,y), frontend_attributes={y} + // Could be simplified to: + // z' = const(20), frontend_attributes={?} + FrontendAttributes frontend_attributes_; + // This field is assigned to true when backend_config_ is assigned to // a default configuration. bool is_default_config_ = false; @@ -1909,6 +1929,8 @@ StatusOr StringToFusionKind( // Custom (de)stringification functions for protos that live inside // HloInstruction. string PaddingConfigToString(const PaddingConfig& padding); +string FrontendAttributesToString( + const FrontendAttributes& frontend_attributes); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); string PrecisionToString(const PrecisionConfig::Precision& precision); diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index f0f175488e5..c96bfb15187 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -88,6 +88,7 @@ class HloParser { // Stand alone parsing utils for various aggregate data types. StatusOr ParseShapeOnly(); StatusOr ParseShardingOnly(); + StatusOr ParseFrontendAttributesOnly(); StatusOr> ParseParameterReplicationOnly(); StatusOr ParseWindowOnly(); StatusOr ParseConvolutionDimensionNumbersOnly(); @@ -192,6 +193,7 @@ class HloParser { kWindow, kConvolutionDimensionNumbers, kSharding, + kFrontendAttributes, kParameterReplication, kInstructionList, kSliceRanges, @@ -271,6 +273,7 @@ class HloParser { bool ParsePaddingConfig(PaddingConfig* padding); bool ParseMetadata(OpMetadata* metadata); bool ParseSharding(OpSharding* sharding); + bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); bool ParseParameterReplication(ParameterReplication* parameter_replication); bool ParseReplicaGroupsOnly(std::vector* replica_groups); @@ -677,7 +680,10 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, // Add optional attributes. std::unordered_map attrs; optional sharding; + optional frontend_attributes; attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; + attrs["frontend_attributes"] = { + /*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes}; optional parameter_replication; attrs["parameter_replication"] = {/*required=*/false, AttrTy::kParameterReplication, @@ -1845,6 +1851,36 @@ bool HloParser::ParseSharding(OpSharding* sharding) { return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute"); } +// frontend_attributes ::= '{' attributes '}' +// attributes +// ::= /*empty*/ +// ::= attribute '=' value (',' attribute '=' value)* +bool HloParser::ParseFrontendAttributes( + FrontendAttributes* frontend_attributes) { + CHECK(frontend_attributes != nullptr); + if (!ParseToken(TokKind::kLbrace, + "expected '{' to start frontend attributes")) { + return false; + } + if (lexer_.GetKind() == TokKind::kRbrace) { + // empty + } else { + do { + string attribute; + if (!ParseAttributeName(&attribute)) { + return false; + } + if (lexer_.GetKind() != TokKind::kIdent) { + return false; + } + (*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal(); + lexer_.Lex(); + } while (EatIfPresent(TokKind::kComma)); + } + return ParseToken(TokKind::kRbrace, + "expects '}' at the end of frontend attributes"); +} + // ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? // ('devices=' ('[' dims ']')* device_list)? '}' // dims ::= int_list device_list ::= int_list @@ -2864,6 +2900,15 @@ bool HloParser::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(sharding); return true; } + case AttrTy::kFrontendAttributes: { + FrontendAttributes frontend_attributes; + if (!ParseFrontendAttributes(&frontend_attributes)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(frontend_attributes); + return true; + } case AttrTy::kParameterReplication: { ParameterReplication parameter_replication; if (!ParseParameterReplication(¶meter_replication)) { @@ -4120,6 +4165,19 @@ StatusOr HloParser::ParseShardingOnly() { return HloSharding::FromProto(op_sharding); } +StatusOr HloParser::ParseFrontendAttributesOnly() { + lexer_.Lex(); + FrontendAttributes attributes; + if (!ParseFrontendAttributes(&attributes)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument( + "Syntax error:\nExtra content after frontend attributes"); + } + return attributes; +} + StatusOr> HloParser::ParseParameterReplicationOnly() { lexer_.Lex(); ParameterReplication parameter_replication; @@ -4268,6 +4326,11 @@ StatusOr ParseSharding(absl::string_view str) { return parser.ParseShardingOnly(); } +StatusOr ParseFrontendAttributes(absl::string_view str) { + HloParser parser(str); + return parser.ParseFrontendAttributesOnly(); +} + StatusOr> ParseParameterReplication(absl::string_view str) { HloParser parser(str); return parser.ParseParameterReplicationOnly(); diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index e4214c1e6b5..91ce79ec982 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -54,6 +54,12 @@ Status ParseHloString(absl::string_view str, HloModule* module); // "{replicated}". StatusOr ParseSharding(absl::string_view str); +// Parses frontend attributes from str. str is supposed to contain the body of +// the frontend attributes , i.e. just the rhs of the +// "frontend_attributes={...}" attribute string, e.g., +// "{attr_a=a,attr_b=b}". +StatusOr ParseFrontendAttributes(absl::string_view str); + // Parses parameter replication from str. str is supposed to contain the body of // the parameter replication, i.e. just the rhs of the // "parameter_replication={...}" attribute string, e.g., "{true, false}". diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index cb0c7c64b52..c913784cd13 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -2358,6 +2358,13 @@ TEST_F(HloParserTest, ParseSharding) { EXPECT_EQ(sharding.ToString(), original); } +TEST_F(HloParserTest, ParseFrontendAttributes) { + const string original = "{attr_a=test_a,attr_b=b}"; + TF_ASSERT_OK_AND_ASSIGN(FrontendAttributes frontend_attributes, + ParseFrontendAttributes(original)); + EXPECT_EQ(FrontendAttributesToString(frontend_attributes), original); +} + TEST_F(HloParserTest, ParseWindow) { Window original = window_util::MakeWindow({1, 2, 3}); TF_ASSERT_OK_AND_ASSIGN(Window parsed, diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 120be3d86c3..f5218ad4d8c 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -583,6 +583,12 @@ message CholeskyOptions { bool lower = 1; } +// Generic map of attributes used to pass hints / configuration options from +// the Python frontend to the XLA backend. +message FrontendAttributes { + map map = 1; +} + message OpSharding { enum Type { // This sharding is replicated across all devices (implies maximal,