From 6ae2b94083a331cf72234e9b2263934163e212fd Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Mon, 29 Jul 2019 11:04:19 +0100 Subject: [PATCH 01/14] Introduce the concept of Frontend Attributes. Summary: Frontend Attributes can be set by the user or the frontend and are passed through to the XLA backend as a dictionary of strings where they can be used to modify the way the HLO instructions are executed. XLA Development discussion: https://groups.google.com/d/msg/xla-dev/9TM0-1N_JlM/Q2R8o2RgBwAJ Test Plan: bazel test returned: INFO: Executed 522 out of 522 tests: 522 tests pass. INFO: There were tests whose specified size is too big. Use the --test_verbose_timeout_warnings command line option to see which ones these are. INFO: Build completed successfully, 4027 total actions SUCCESS! --- tensorflow/compiler/tf2xla/BUILD | 13 ++ .../tf2xla/frontend_attributes_util.cc | 43 ++++++ .../tf2xla/frontend_attributes_util.h | 32 ++++ .../compiler/tf2xla/xla_compilation_device.cc | 10 ++ tensorflow/compiler/xla/client/xla_builder.cc | 72 +++++++-- tensorflow/compiler/xla/client/xla_builder.h | 49 ++++++ .../compiler/xla/client/xla_builder_test.cc | 145 ++++++++++++++++++ tensorflow/compiler/xla/service/hlo.proto | 5 +- .../compiler/xla/service/hlo_instruction.cc | 4 + .../compiler/xla/service/hlo_instruction.h | 12 ++ tensorflow/compiler/xla/xla_data.proto | 6 + 11 files changed, 377 insertions(+), 14 deletions(-) create mode 100644 tensorflow/compiler/tf2xla/frontend_attributes_util.cc create mode 100644 tensorflow/compiler/tf2xla/frontend_attributes_util.h diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 9aea4570cc7..1e4f2e23ef3 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -202,6 +202,7 @@ cc_library( visibility = [":friends"], deps = [ ":common", + ":frontend_attributes_util", ":host_compute_metadata_proto", ":sharding_util", ":side_effect_util", @@ -270,6 +271,18 @@ cc_library( ], ) +cc_library( + name = "frontend_attributes_util", + srcs = ["frontend_attributes_util.cc"], + hdrs = ["frontend_attributes_util.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:framework", + ], +) + cc_library( name = "sharding_util", srcs = ["sharding_util.cc"], diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.cc b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc new file mode 100644 index 00000000000..96e6187fc63 --- /dev/null +++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +namespace { +const char kFrontendAttributesAttribute[] = "_XlaFrontendAttributes"; +} // namespace + +xla::StatusOr> +GetFrontendAttributesFromNodeDef(const NodeDef& node_def) { + if (!HasNodeAttr(node_def, kFrontendAttributesAttribute)) { + return absl::optional(); + } + string value; + xla::FrontendAttributes attributes; + TF_RETURN_IF_ERROR( + GetNodeAttr(node_def, kFrontendAttributesAttribute, &value)); + if (!attributes.ParseFromString(value)) { + return errors::InvalidArgument( + "Experimental _XlaFrontendAttributes attribute was not a valid encoded " + "xla::FrontendAttributes proto."); + } + return absl::optional(attributes); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.h b/tensorflow/compiler/tf2xla/frontend_attributes_util.h new file mode 100644 index 00000000000..fc9df12eeec --- /dev/null +++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.h @@ -0,0 +1,32 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_ + +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace tensorflow { + +xla::StatusOr> +GetFrontendAttributesFromNodeDef(const NodeDef& node_def); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index c14519c3ade..86e3f99afdb 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -98,6 +99,15 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, absl::optional op_sharding = sharding_parse_result.ValueOrDie(); + auto frontend_attributes_result = + GetFrontendAttributesFromNodeDef(op_kernel->def()); + OP_REQUIRES_OK(context, frontend_attributes_result.status()); + absl::optional frontend_attributes = + frontend_attributes_result.ValueOrDie(); + + xla::XlaScopedFrontendAttributesAssignment assign_frontend_attributes( + b, frontend_attributes); + // If no sharding metadata is found, XLA is free to use whatever device it // wants. In practice this usually has the effect of placing things on device // 0. diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 318d5f3be35..5e33984d57f 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -289,6 +289,14 @@ Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num, return Status::OK(); } +Status XlaBuilder::AddFrontendAttribute(const XlaOp& op, std::string attribute, + std::string value) { + TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op)); + auto* frontend_attributes = instr_proto->mutable_frontend_attributes(); + (*frontend_attributes->mutable_map())[attribute] = value; + return Status::OK(); +} + XlaComputation XlaBuilder::BuildAndNoteError() { DCHECK(parent_builder_ != nullptr); auto build_status = Build(); @@ -2662,6 +2670,7 @@ StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, if (sharding_) { *instr.mutable_sharding() = *sharding_; } + instr.mutable_frontend_attributes()->CopyFrom(frontend_attributes_); handle_to_index_[handle] = instructions_.size(); instructions_.push_back(std::move(instr)); @@ -2719,32 +2728,69 @@ void XlaBuilder::AddCalledComputation(const XlaComputation& computation, } } -StatusOr XlaBuilder::LookUpInstruction( - const XlaOp& op) const { - TF_RETURN_IF_ERROR(first_error_); +namespace { - if (op.builder_ == nullptr) { +template +StatusOr LookUpInstructionByHandleInternal( + HandleToIndexType& handle_to_index, + InstructionProtoVectorType& instructions, int64 handle) { + auto it = handle_to_index.find(handle); + if (it == handle_to_index.end()) { + return InvalidArgument("No XlaOp with handle %d", handle); + } + return &instructions[it->second]; +} + +template +StatusOr LookUpInstructionInternal( + HandleToIndexType& handle_to_index, + InstructionProtoVectorType& instructions, OpBuilderType op_builder, + BuilderType builder, OpType op_handle) { + if (op_builder == nullptr) { return InvalidArgument( "invalid XlaOp with handle %d; the builder of this op is freed", - op.handle()); + op_handle); } - if (op.builder_ != this) { + if (op_builder != builder) { return InvalidArgument( "XlaOp with handle %d is built by builder '%s', but is trying to use " "it in builder '%s'", - op.handle(), op.builder_->name(), this->name()); + op_handle, op_builder->name(), builder->name()); } - return LookUpInstructionByHandle(op.handle()); + return LookUpInstructionByHandleInternal( + handle_to_index, instructions, op_handle); +} + +} // namespace + +StatusOr XlaBuilder::LookUpInstruction( + const XlaOp& op) const { + TF_RETURN_IF_ERROR(first_error_); + return LookUpInstructionInternal( + handle_to_index_, instructions_, op.builder_, this, op.handle()); } StatusOr XlaBuilder::LookUpInstructionByHandle( int64 handle) const { - auto it = handle_to_index_.find(handle); - if (it == handle_to_index_.end()) { - return InvalidArgument("No XlaOp with handle %d", handle); - } - return &instructions_[it->second]; + return LookUpInstructionByHandleInternal( + handle_to_index_, instructions_, handle); +} + +StatusOr XlaBuilder::LookUpMutableInstruction( + const XlaOp& op) { + TF_RETURN_IF_ERROR(first_error_); + return LookUpInstructionInternal( + handle_to_index_, instructions_, op.builder_, this, op.handle()); +} + +StatusOr XlaBuilder::LookUpMutableInstructionByHandle( + int64 handle) { + return LookUpInstructionByHandleInternal( + handle_to_index_, instructions_, handle); } // Enqueues a "retrieve parameter value" instruction for a parameter that was diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 89e8be7de1e..cdb31c6ca1c 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -158,6 +158,16 @@ class XlaBuilder { // Sets an OpSharding that will be attached to all instructions until cleared. void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } + void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) { + frontend_attributes_ = frontend_attributes; + } + + const FrontendAttributes& frontend_attributes() const { + return frontend_attributes_; + } + + void ClearFrontendAttributes() { frontend_attributes_.Clear(); } + // Clears the sharding. Ops will be sharded according to the default placement // policy. void ClearSharding() { sharding_ = absl::nullopt; } @@ -314,6 +324,10 @@ class XlaBuilder { ShapeIndex param_index; }; + // Looks up the HloInstruction and sets the frontend attribute "attribute" to + // "value". + Status AddFrontendAttribute(const XlaOp& op, string attribute, string value); + private: // Build helper which takes the id of the root operation.. StatusOr Build(int64 root_id, bool remove_dynamic_dimensions); @@ -596,6 +610,8 @@ class XlaBuilder { StatusOr LookUpInstruction(const XlaOp& op) const; StatusOr LookUpInstructionByHandle( int64 handle) const; + StatusOr LookUpMutableInstruction(const XlaOp& op); + StatusOr LookUpMutableInstructionByHandle(int64 handle); // Internal helper method that does the building for an arbitrary unary op. XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); @@ -713,6 +729,8 @@ class XlaBuilder { XlaBuilder* parent_builder_{nullptr}; + FrontendAttributes frontend_attributes_; + friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape, const string& name, const std::vector& replicated_at_leaf_buffers); @@ -1038,6 +1056,37 @@ class XlaScopedShardingAssignment { absl::optional prev_sharding_; }; +// RAII-style object: sets the current frontend attributes in builder on +// construction, and clears it on destruction. +class XlaScopedFrontendAttributesAssignment { + public: + XlaScopedFrontendAttributesAssignment( + xla::XlaBuilder* builder, absl::optional attributes) + : builder_(builder) { + SetFrontendAttributes(attributes); + } + + XlaScopedFrontendAttributesAssignment( + const XlaScopedFrontendAttributesAssignment&) = delete; + XlaScopedFrontendAttributesAssignment& operator=( + const XlaScopedFrontendAttributesAssignment&) = delete; + + ~XlaScopedFrontendAttributesAssignment() { + SetFrontendAttributes(absl::nullopt); + } + + private: + void SetFrontendAttributes( + const absl::optional& attributes) { + if (attributes.has_value()) { + builder_->SetFrontendAttributes(attributes.value()); + } else { + builder_->ClearFrontendAttributes(); + } + } + + xla::XlaBuilder* const builder_; +}; // Free functions for building XlaOps. The intention is that these will // become the public API for building XlaOps rather than calling methods on // XlaBuilder directly. diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 12656a89943..2bc79f5db66 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -978,5 +978,150 @@ TEST_F(XlaBuilderTest, CheckInputOutputAlias) { EXPECT_EQ(*alias_p1, ShapeIndex({0})); } +void CheckAttributesMatch(const FrontendAttributes& attr, + const FrontendAttributes& ref) { + EXPECT_EQ(ref.map_size(), attr.map_size()); + for (auto reference : ref.map()) { + auto other = attr.map().find(reference.first); + EXPECT_NE(other, attr.map().end()); + EXPECT_EQ(other->second, reference.second); + } +} + +void CheckInstructionsAttributesMatch( + HloModule& module, const std::vector& expected) { + ASSERT_EQ(module.computation_count(), 1); + auto expected_it = expected.begin(); + for (auto inst : module.mutable_computation(0)->instructions()) { + ASSERT_NE(expected_it, expected.end()); + CheckAttributesMatch(inst->frontend_attributes(), *expected_it); + expected_it++; + } + EXPECT_EQ(expected_it, expected.end()); +} + +TEST_F(XlaBuilderTest, SimpleSetFrontendAttributes) { + XlaBuilder b(TestName()); + FrontendAttributes attributes; + + ConstantR0(&b, 0); // No attribute set + + (*attributes.mutable_map())["attr_a"] = "a"; + b.SetFrontendAttributes(attributes); + ConstantR0(&b, 0); // One attribute: { "attr_a": "a" } + + b.ClearFrontendAttributes(); + ConstantR0(&b, 0); // No attribute set + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + + std::vector expected{FrontendAttributes(), attributes, + FrontendAttributes()}; + CheckInstructionsAttributesMatch(*module, expected); +} + +TEST_F(XlaBuilderTest, ComplexSetFrontendAttributes) { + XlaBuilder b(TestName()); + + ConstantR0(&b, 0); // No attribute set. + std::vector expected{FrontendAttributes()}; + + { + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_a"] = "a"; + b.SetFrontendAttributes(attributes); + ConstantR0(&b, 0); // One attribute: { "attr_a": "a" } + expected.push_back(attributes); + } + + { + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_b"] = "b"; + b.SetFrontendAttributes(attributes); + ConstantR0(&b, 0); // One attribute: { "attr_b": "b" } + expected.push_back(attributes); + } + + { + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_b"] = "b"; + (*attributes.mutable_map())["attr_c"] = "c"; + b.SetFrontendAttributes(attributes); + ConstantR0(&b, 0); // Two attributes: { "attr_b": "b", "attr_c": "c" } + expected.push_back(attributes); + } + + b.ClearFrontendAttributes(); + ConstantR0(&b, 0); // No attribute set + expected.push_back(FrontendAttributes()); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + CheckInstructionsAttributesMatch(*module, expected); +} + +TEST_F(XlaBuilderTest, AddFrontendAttribute) { + XlaBuilder b(TestName()); + + ConstantR0(&b, 0); + std::vector expected{FrontendAttributes()}; + + // One attribute: { "attr_a": "a" } + { + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_a"] = "a"; + b.SetFrontendAttributes(attributes); + ConstantR0(&b, 0); + expected.push_back(attributes); + } + + // Two attributes: {"attra": "a", "attr_c": "c"} + { + auto op = ConstantR0(&b, 0); + EXPECT_IS_OK(b.AddFrontendAttribute(op, "attr_c", "c")); + + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_a"] = "a"; + (*attributes.mutable_map())["attr_c"] = "c"; + expected.push_back(attributes); + } + + // Override value of existing "attr_a" + // One attribute: { "attr_a", "a2"} + { + auto op = ConstantR0(&b, 0); + EXPECT_IS_OK(b.AddFrontendAttribute(op, "attr_a", "a2")); + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_a"] = "a2"; + expected.push_back(attributes); + } + + // Check "attr_a" is back to its original value + // One attribute: { "attr_a", "a"} + { + auto op = ConstantR0(&b, 0); + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_a"] = "a"; + expected.push_back(attributes); + } + + b.ClearFrontendAttributes(); + ConstantR0(&b, 0); // No attribute set + expected.push_back(FrontendAttributes()); + + // One attribute: { "attr_d", "d"} + { + auto op = ConstantR0(&b, 0); + EXPECT_IS_OK(b.AddFrontendAttribute(op, "attr_d", "d")); + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_d"] = "d"; + expected.push_back(attributes); + } + + ConstantR0(&b, 0); // No attribute set + expected.push_back(FrontendAttributes()); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + CheckInstructionsAttributesMatch(*module, expected); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 331bbcb7836..ba02b9aed6c 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 67 +// Next ID: 68 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -230,6 +230,9 @@ message HloInstructionProto { // The delta value for kRngGetAndUpdateState. int64 delta = 66; + + // Frontend attributes to pass to the XLA backend. + xla.FrontendAttributes frontend_attributes = 67; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index f7d36fca7b7..236ac143a76 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -672,6 +672,10 @@ StatusOr> HloInstruction::CreateFromProto( instruction->set_sharding(sharding); } + if (proto.has_frontend_attributes()) { + instruction->set_frontend_attributes(proto.frontend_attributes()); + } + return std::move(instruction); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 78128a766b0..467dd292108 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1384,6 +1384,14 @@ class HloInstruction { } Status set_backend_config(const tensorflow::protobuf::Message& proto); + void set_frontend_attributes(FrontendAttributes frontend_attributes) { + frontend_attributes_ = std::move(frontend_attributes); + } + + const FrontendAttributes& frontend_attributes() const { + return frontend_attributes_; + } + // Getter/setter for raw JSON-encoded backend config. Prefer the // functions above that deal in proto Messages where possible. const string& raw_backend_config_string() const { return backend_config_; } @@ -1878,6 +1886,10 @@ class HloInstruction { // HLO. See the documentation on backend_config(). string backend_config_; + // Attributes passed from the frontend to give hints to the backend about + // how to compile this HLO. + FrontendAttributes frontend_attributes_; + // This field is assigned to true when backend_config_ is assigned to // a default configuration. bool is_default_config_ = false; diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 1bd6db2662e..d0c9d10c36f 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -579,6 +579,12 @@ message CholeskyOptions { bool lower = 1; } +// Generic map of attributes used to pass hints / configuration options from +// the Python frontend to the XLA backend. +message FrontendAttributes { + map map = 1; +} + message OpSharding { enum Type { // This sharding is replicated across all devices (implies maximal, From aba481ec721718c40ee3f6d9cabf06766543b261 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Thu, 1 Aug 2019 14:58:46 +0100 Subject: [PATCH 02/14] Addressed sanjoy's comments --- .../tf2xla/frontend_attributes_util.cc | 13 +++--- .../tf2xla/frontend_attributes_util.h | 4 +- .../compiler/tf2xla/xla_compilation_device.cc | 2 +- tensorflow/compiler/xla/client/xla_builder.cc | 7 +-- tensorflow/compiler/xla/client/xla_builder.h | 43 ++++++++++++++++--- .../compiler/xla/service/hlo_instruction.cc | 13 ++++++ .../compiler/xla/service/hlo_instruction.h | 8 ++++ 7 files changed, 70 insertions(+), 20 deletions(-) diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.cc b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc index 96e6187fc63..1dd0d3a4caa 100644 --- a/tensorflow/compiler/tf2xla/frontend_attributes_util.cc +++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc @@ -24,15 +24,14 @@ const char kFrontendAttributesAttribute[] = "_XlaFrontendAttributes"; } // namespace xla::StatusOr> -GetFrontendAttributesFromNodeDef(const NodeDef& node_def) { - if (!HasNodeAttr(node_def, kFrontendAttributesAttribute)) { - return absl::optional(); +GetFrontendAttributesFromNodeDef(const AttrSlice& attrs) { + auto attr = attrs.Find(kFrontendAttributesAttribute); + if (attr == nullptr) { + return xla::StatusOr>( + absl::nullopt); } - string value; xla::FrontendAttributes attributes; - TF_RETURN_IF_ERROR( - GetNodeAttr(node_def, kFrontendAttributesAttribute, &value)); - if (!attributes.ParseFromString(value)) { + if (!attributes.ParseFromString(attr->s())) { return errors::InvalidArgument( "Experimental _XlaFrontendAttributes attribute was not a valid encoded " "xla::FrontendAttributes proto."); diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.h b/tensorflow/compiler/tf2xla/frontend_attributes_util.h index fc9df12eeec..2beaa2fd760 100644 --- a/tensorflow/compiler/tf2xla/frontend_attributes_util.h +++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.h @@ -20,12 +20,12 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" namespace tensorflow { xla::StatusOr> -GetFrontendAttributesFromNodeDef(const NodeDef& node_def); +GetFrontendAttributesFromNodeDef(const AttrSlice& attrs); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index 86e3f99afdb..35a2e63f323 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -100,7 +100,7 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, sharding_parse_result.ValueOrDie(); auto frontend_attributes_result = - GetFrontendAttributesFromNodeDef(op_kernel->def()); + GetFrontendAttributesFromNodeDef(AttrSlice(op_kernel->def())); OP_REQUIRES_OK(context, frontend_attributes_result.status()); absl::optional frontend_attributes = frontend_attributes_result.ValueOrDie(); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 5e33984d57f..b2d375bdf76 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -289,11 +289,12 @@ Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num, return Status::OK(); } -Status XlaBuilder::AddFrontendAttribute(const XlaOp& op, std::string attribute, - std::string value) { +Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp& op, + std::string attribute, + std::string value) { TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op)); auto* frontend_attributes = instr_proto->mutable_frontend_attributes(); - (*frontend_attributes->mutable_map())[attribute] = value; + (*frontend_attributes->mutable_map())[attribute] = std::move(value); return Status::OK(); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index cdb31c6ca1c..8c013da42d3 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -158,14 +158,31 @@ class XlaBuilder { // Sets an OpSharding that will be attached to all instructions until cleared. void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } + // Sets the FrontendAttributes that will be added to all instructions until + // cleared. + // + // FrontendAttributes are often applied to a serie of XLA HLO instructions. + // As a result they are set on the Computation Builder and all the + // instructions generated via the builder will have the same frontend + // attributes attached to them. void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) { frontend_attributes_ = frontend_attributes; } + // Merge the passed FrontendAttributes with the ones already set. + // + // In case of duplicates the new attributes take precedence. + void MergeFrontendAttributes(const FrontendAttributes& frontend_attributes) { + frontend_attributes_.mutable_map()->insert( + frontend_attributes.map().begin(), frontend_attributes.map().end()); + } + + // Returns the FrontendAttributes that will be attached to all instructions. const FrontendAttributes& frontend_attributes() const { return frontend_attributes_; } + // Clears all the frontend attributes. void ClearFrontendAttributes() { frontend_attributes_.Clear(); } // Clears the sharding. Ops will be sharded according to the default placement @@ -326,7 +343,13 @@ class XlaBuilder { // Looks up the HloInstruction and sets the frontend attribute "attribute" to // "value". - Status AddFrontendAttribute(const XlaOp& op, string attribute, string value); + // + // If the attribute already existed then its value is updated. + // + // Note: the attribute is only added to the HloInstruction, not to the + // builder. + Status SetInstructionFrontendAttribute(const XlaOp& op, string attribute, + string value); private: // Build helper which takes the id of the root operation.. @@ -610,8 +633,8 @@ class XlaBuilder { StatusOr LookUpInstruction(const XlaOp& op) const; StatusOr LookUpInstructionByHandle( int64 handle) const; - StatusOr LookUpMutableInstruction(const XlaOp& op); - StatusOr LookUpMutableInstructionByHandle(int64 handle); + StatusOr LookUpMutableInstruction(const XlaOp& op); + StatusOr LookUpMutableInstructionByHandle(int64 handle); // Internal helper method that does the building for an arbitrary unary op. XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); @@ -1056,8 +1079,9 @@ class XlaScopedShardingAssignment { absl::optional prev_sharding_; }; -// RAII-style object: sets the current frontend attributes in builder on -// construction, and clears it on destruction. +// RAII-style object: save the current builder's frontend attributes, and merge +// them with the new ones on construction. +// Restore the original attributes on destruction. class XlaScopedFrontendAttributesAssignment { public: XlaScopedFrontendAttributesAssignment( @@ -1079,13 +1103,18 @@ class XlaScopedFrontendAttributesAssignment { void SetFrontendAttributes( const absl::optional& attributes) { if (attributes.has_value()) { - builder_->SetFrontendAttributes(attributes.value()); + // Save the existing attributes: + saved_ = builder_->frontend_attributes(); + // Merge the existring attributes with the new ones. + builder_->MergeFrontendAttributes(attributes.value()); } else { - builder_->ClearFrontendAttributes(); + builder_->SetFrontendAttributes(saved_); + saved_.Clear(); } } xla::XlaBuilder* const builder_; + FrontendAttributes saved_; }; // Free functions for building XlaOps. The intention is that these will // become the public API for building XlaOps rather than calling methods on diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 236ac143a76..9cd41163c7c 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1196,6 +1196,7 @@ HloInstruction::CreateBroadcastSequence( if (operand->has_sharding()) { broadcast->set_sharding(operand->sharding()); } + broadcast->set_frontend_attributes(operand->frontend_attributes()); return broadcast; } // Do explicit broadcast for degenerate broadcast. @@ -1221,6 +1222,7 @@ HloInstruction::CreateBroadcastSequence( if (operand->has_sharding()) { reshaped_operand->set_sharding(operand->sharding()); } + reshaped_operand->set_frontend_attributes(operand->frontend_attributes()); // Broadcast 'reshape' up to the larger size. auto broadcast = HloInstruction::CreateBroadcast( broadcast_shape, reshaped_operand, broadcast_dimensions); @@ -1228,6 +1230,7 @@ HloInstruction::CreateBroadcastSequence( if (operand->has_sharding()) { broadcast->set_sharding(operand->sharding()); } + broadcast->set_frontend_attributes(operand->frontend_attributes()); return broadcast; } @@ -1298,6 +1301,7 @@ void HloInstruction::SetupDerivedInstruction( derived_instruction->clear_sharding(); } derived_instruction->set_metadata(metadata_); + derived_instruction->set_frontend_attributes(frontend_attributes_); } bool HloInstruction::HasSideEffectNoRecurse() const { @@ -2483,6 +2487,12 @@ std::vector HloInstruction::ExtraAttributesToString( if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } + if (!frontend_attributes_.map().empty()) { + extra.push_back( + absl::StrFormat("frontend_attributes={%s}", + absl::StrJoin(frontend_attributes_.map(), ",", + absl::PairFormatter("=")))); + } if (!outer_dimension_partitions_.empty()) { extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}", StrJoin(outer_dimension_partitions_, ","))); @@ -2542,6 +2552,9 @@ HloInstructionProto HloInstruction::ToProto() const { proto.mutable_outer_dimension_partitions()->Add(idx); } } + if (!frontend_attributes_.map().empty()) { + proto.mutable_frontend_attributes()->CopyFrom(frontend_attributes_); + } return proto; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 467dd292108..cf175024e81 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1888,6 +1888,14 @@ class HloInstruction { // Attributes passed from the frontend to give hints to the backend about // how to compile this HLO. + // HLO -> HLO transforms are expected to preserve these attributes on a + // "best effort" basis only. + // For example: + // x = const(10, frontend_attributes={x} + // y = const(10, frontend_attributes={y} + // z = add(x,y), frontend_attributes={y} + // Could be simplified to: + // z' = const(20), frontend_attributes={?} FrontendAttributes frontend_attributes_; // This field is assigned to true when backend_config_ is assigned to From 4585895c7802909a661e30b842580bb7aa1031b1 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Thu, 1 Aug 2019 15:12:21 +0100 Subject: [PATCH 03/14] Renamed GetFrontendAttributesFromNodeDef -> GetFrontendAttributesFromAttrSlice and added documentation --- tensorflow/compiler/tf2xla/frontend_attributes_util.cc | 2 +- tensorflow/compiler/tf2xla/frontend_attributes_util.h | 6 +++++- tensorflow/compiler/tf2xla/xla_compilation_device.cc | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.cc b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc index 1dd0d3a4caa..54e16a86883 100644 --- a/tensorflow/compiler/tf2xla/frontend_attributes_util.cc +++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc @@ -24,7 +24,7 @@ const char kFrontendAttributesAttribute[] = "_XlaFrontendAttributes"; } // namespace xla::StatusOr> -GetFrontendAttributesFromNodeDef(const AttrSlice& attrs) { +GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs) { auto attr = attrs.Find(kFrontendAttributesAttribute); if (attr == nullptr) { return xla::StatusOr>( diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.h b/tensorflow/compiler/tf2xla/frontend_attributes_util.h index 2beaa2fd760..1c2b1d8c1c5 100644 --- a/tensorflow/compiler/tf2xla/frontend_attributes_util.h +++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.h @@ -24,8 +24,12 @@ limitations under the License. namespace tensorflow { +// Return the FrontendAttributes stored in the AttrSlice if there are some. +// +// Return an InvalidArgument error if some attributes are present but +// cannot be parsed. xla::StatusOr> -GetFrontendAttributesFromNodeDef(const AttrSlice& attrs); +GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index 35a2e63f323..d7e2934cde8 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -100,7 +100,7 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, sharding_parse_result.ValueOrDie(); auto frontend_attributes_result = - GetFrontendAttributesFromNodeDef(AttrSlice(op_kernel->def())); + GetFrontendAttributesFromAttrSlice(AttrSlice(op_kernel->def())); OP_REQUIRES_OK(context, frontend_attributes_result.status()); absl::optional frontend_attributes = frontend_attributes_result.ValueOrDie(); From 522ddb018c86d57785ff0d39c8565bed69183e0f Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Thu, 1 Aug 2019 17:11:49 +0100 Subject: [PATCH 04/14] Rework #2 - Remove auto - Remove templates / keep parameters const - Fix comments --- .../compiler/tf2xla/frontend_attributes_util.cc | 2 +- tensorflow/compiler/xla/client/xla_builder.cc | 17 ++++++++--------- tensorflow/compiler/xla/client/xla_builder.h | 6 +++--- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.cc b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc index 54e16a86883..7c2564ffa99 100644 --- a/tensorflow/compiler/tf2xla/frontend_attributes_util.cc +++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc @@ -25,7 +25,7 @@ const char kFrontendAttributesAttribute[] = "_XlaFrontendAttributes"; xla::StatusOr> GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs) { - auto attr = attrs.Find(kFrontendAttributesAttribute); + const AttrValue *attr = attrs.Find(kFrontendAttributesAttribute); if (attr == nullptr) { return xla::StatusOr>( absl::nullopt); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index b2d375bdf76..a167e258298 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -2731,24 +2731,23 @@ void XlaBuilder::AddCalledComputation(const XlaComputation& computation, namespace { -template +template StatusOr LookUpInstructionByHandleInternal( - HandleToIndexType& handle_to_index, - InstructionProtoVectorType& instructions, int64 handle) { + const absl::flat_hash_map& handle_to_index, + const std::vector& instructions, int64 handle) { auto it = handle_to_index.find(handle); if (it == handle_to_index.end()) { return InvalidArgument("No XlaOp with handle %d", handle); } - return &instructions[it->second]; + return const_cast(&instructions.at(it->second)); } -template StatusOr LookUpInstructionInternal( - HandleToIndexType& handle_to_index, - InstructionProtoVectorType& instructions, OpBuilderType op_builder, + const absl::flat_hash_map& handle_to_index, + const std::vector& instructions, OpBuilderType op_builder, BuilderType builder, OpType op_handle) { if (op_builder == nullptr) { return InvalidArgument( diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 8c013da42d3..33070ee4069 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -161,10 +161,10 @@ class XlaBuilder { // Sets the FrontendAttributes that will be added to all instructions until // cleared. // - // FrontendAttributes are often applied to a serie of XLA HLO instructions. + // FrontendAttributes are often applied to a series of XLA HLO instructions. // As a result they are set on the Computation Builder and all the - // instructions generated via the builder will have the same frontend - // attributes attached to them. + // instructions generated via the Computation Builder will have the same + // frontend attributes attached to them. void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) { frontend_attributes_ = frontend_attributes; } From bee64e95b8ff8bacee280783e823f5eb4d5c9268 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Thu, 1 Aug 2019 17:39:12 +0100 Subject: [PATCH 05/14] Fixed case: Computation Builder -> computation builder --- tensorflow/compiler/xla/client/xla_builder.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 33070ee4069..b1523d9d50f 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -147,8 +147,8 @@ class XlaBuilder { // Sets OpMetadata that will be added to all instructions until cleared. // // OpMetadata is often applied to a series of XLA HLO instructions. As a - // result, OpMetadata is set on the Computation Builder. All subsequent - // instructions generated via this Computation Builder will have the same + // result, OpMetadata is set on the computation builder. All subsequent + // instructions generated via this computation builder will have the same // OpMetadata attached until a call to ClearOpMetadata. void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); } @@ -162,8 +162,8 @@ class XlaBuilder { // cleared. // // FrontendAttributes are often applied to a series of XLA HLO instructions. - // As a result they are set on the Computation Builder and all the - // instructions generated via the Computation Builder will have the same + // As a result they are set on the computation builder and all the + // instructions generated via the computation builder will have the same // frontend attributes attached to them. void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) { frontend_attributes_ = frontend_attributes; From 1e5459afacd7b0578ef970df2bd2f400b90a762c Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Fri, 2 Aug 2019 10:20:35 +0100 Subject: [PATCH 06/14] Moved the attributes merging out of the ScopeAssignment class and changed MergeFrontendAttributes into SwapFrontendAttributes --- .../compiler/tf2xla/xla_compilation_device.cc | 8 +++-- tensorflow/compiler/xla/client/xla_builder.h | 30 ++++++------------- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index d7e2934cde8..d4c62a2a226 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -102,11 +102,15 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, auto frontend_attributes_result = GetFrontendAttributesFromAttrSlice(AttrSlice(op_kernel->def())); OP_REQUIRES_OK(context, frontend_attributes_result.status()); - absl::optional frontend_attributes = + absl::optional attributes = frontend_attributes_result.ValueOrDie(); + xla::FrontendAttributes merged_attributes = b->frontend_attributes(); + if (attributes.has_value()) { + merged_attributes.mutable_map()->insert(attributes.value().map().begin(), attributes.value().map().end()); + } xla::XlaScopedFrontendAttributesAssignment assign_frontend_attributes( - b, frontend_attributes); + b, std::move(merged_attributes)); // If no sharding metadata is found, XLA is free to use whatever device it // wants. In practice this usually has the effect of placing things on device diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index b1523d9d50f..c3c663873e5 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -169,12 +169,13 @@ class XlaBuilder { frontend_attributes_ = frontend_attributes; } - // Merge the passed FrontendAttributes with the ones already set. + // Swap the passed FrontendAttributes with the ones currently set. // - // In case of duplicates the new attributes take precedence. - void MergeFrontendAttributes(const FrontendAttributes& frontend_attributes) { - frontend_attributes_.mutable_map()->insert( - frontend_attributes.map().begin(), frontend_attributes.map().end()); + // Return the old attributes. + FrontendAttributes SwapFrontendAttributes(const FrontendAttributes& frontend_attributes) { + FrontendAttributes old_attributes = std::move(frontend_attributes_); + frontend_attributes_ = std::move(frontend_attributes); + return old_attributes; } // Returns the FrontendAttributes that will be attached to all instructions. @@ -1085,9 +1086,9 @@ class XlaScopedShardingAssignment { class XlaScopedFrontendAttributesAssignment { public: XlaScopedFrontendAttributesAssignment( - xla::XlaBuilder* builder, absl::optional attributes) + xla::XlaBuilder* builder, FrontendAttributes attributes) : builder_(builder) { - SetFrontendAttributes(attributes); + saved_ = builder_->SwapFrontendAttributes(std::move(attributes)); } XlaScopedFrontendAttributesAssignment( @@ -1096,23 +1097,10 @@ class XlaScopedFrontendAttributesAssignment { const XlaScopedFrontendAttributesAssignment&) = delete; ~XlaScopedFrontendAttributesAssignment() { - SetFrontendAttributes(absl::nullopt); + builder_->SetFrontendAttributes(std::move(saved_)); } private: - void SetFrontendAttributes( - const absl::optional& attributes) { - if (attributes.has_value()) { - // Save the existing attributes: - saved_ = builder_->frontend_attributes(); - // Merge the existring attributes with the new ones. - builder_->MergeFrontendAttributes(attributes.value()); - } else { - builder_->SetFrontendAttributes(saved_); - saved_.Clear(); - } - } - xla::XlaBuilder* const builder_; FrontendAttributes saved_; }; From 3d2ad7f6c87e1b2eae0d7616f6dfde6b90955048 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Fri, 2 Aug 2019 16:48:15 +0100 Subject: [PATCH 07/14] Added serializers / parsers and test --- .../compiler/xla/service/hlo_instruction.cc | 11 ++-- .../compiler/xla/service/hlo_instruction.h | 1 + tensorflow/compiler/xla/service/hlo_parser.cc | 60 +++++++++++++++++++ tensorflow/compiler/xla/service/hlo_parser.h | 5 ++ .../compiler/xla/service/hlo_parser_test.cc | 6 ++ 5 files changed, 79 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 9cd41163c7c..a0ac73323a5 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2488,10 +2488,7 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back(StrCat("sharding=", sharding().ToString())); } if (!frontend_attributes_.map().empty()) { - extra.push_back( - absl::StrFormat("frontend_attributes={%s}", - absl::StrJoin(frontend_attributes_.map(), ",", - absl::PairFormatter("=")))); + extra.push_back(StrCat("frontend_attributes=", FrontendAttributesToString(frontend_attributes_))); } if (!outer_dimension_partitions_.empty()) { extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}", @@ -3207,6 +3204,12 @@ StatusOr StringToFusionKind( return InvalidArgument("Unknown fusion kind: %s", kind_name); } +string FrontendAttributesToString(const FrontendAttributes& frontend_attributes){ + return absl::StrFormat("{%s}", + absl::StrJoin(frontend_attributes.map(), ",", + absl::PairFormatter("="))); +} + string PaddingConfigToString(const PaddingConfig& padding) { bool has_interior_padding = absl::c_any_of(padding.dimensions(), diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index cf175024e81..e5f22aa3146 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1928,6 +1928,7 @@ StatusOr StringToFusionKind( // Custom (de)stringification functions for protos that live inside // HloInstruction. string PaddingConfigToString(const PaddingConfig& padding); +string FrontendAttributesToString(const FrontendAttributes& frontend_attributes); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); string PrecisionToString(const PrecisionConfig::Precision& precision); diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 2589de633d0..834fa69499f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -88,6 +88,7 @@ class HloParser { // Stand alone parsing utils for various aggregate data types. StatusOr ParseShapeOnly(); StatusOr ParseShardingOnly(); + StatusOr ParseFrontendAttributesOnly(); StatusOr> ParseParameterReplicationOnly(); StatusOr ParseWindowOnly(); StatusOr ParseConvolutionDimensionNumbersOnly(); @@ -192,6 +193,7 @@ class HloParser { kWindow, kConvolutionDimensionNumbers, kSharding, + kFrontendAttributes, kParameterReplication, kInstructionList, kSliceRanges, @@ -271,6 +273,7 @@ class HloParser { bool ParsePaddingConfig(PaddingConfig* padding); bool ParseMetadata(OpMetadata* metadata); bool ParseSharding(OpSharding* sharding); + bool ParseFrontendAttributes(FrontendAttributes *frontend_attributes); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); bool ParseParameterReplication(ParameterReplication* parameter_replication); bool ParseReplicaGroupsOnly(std::vector* replica_groups); @@ -677,7 +680,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, // Add optional attributes. std::unordered_map attrs; optional sharding; + optional frontend_attributes; attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; + attrs["frontend_attributes"] = {/*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes}; optional parameter_replication; attrs["parameter_replication"] = {/*required=*/false, AttrTy::kParameterReplication, @@ -1838,6 +1843,36 @@ bool HloParser::ParseSharding(OpSharding* sharding) { return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute"); } +// frontend_attributes ::= '{' attributes '}' +// attributes +// ::= /*empty*/ +// ::= attribute '=' value (',' attribute '=' value)* +bool HloParser::ParseFrontendAttributes(FrontendAttributes *frontend_attributes) +{ + CHECK(frontend_attributes!= nullptr); + if (!ParseToken(TokKind::kLbrace, + "expected '{' to start frontend attributes")) { + return false; + } + if (lexer_.GetKind() == TokKind::kRbrace) { + // empty + } else { + do { + LocTy loc = lexer_.GetLoc(); + string attribute; + if(!ParseAttributeName(&attribute)){ + return false; + } + if (lexer_.GetKind() != TokKind::kIdent) { + return false; + } + (*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal(); + lexer_.Lex(); + } while (EatIfPresent(TokKind::kComma)); + } + return ParseToken(TokKind::kRbrace, "expects '}' at the end of frontend attributes"); +} + // ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? // ('devices=' ('[' dims ']')* device_list)? '}' // dims ::= int_list device_list ::= int_list @@ -2857,6 +2892,14 @@ bool HloParser::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(sharding); return true; } + case AttrTy::kFrontendAttributes: { + FrontendAttributes frontend_attributes; + if(!ParseFrontendAttributes(&frontend_attributes)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(frontend_attributes); + return true; + } case AttrTy::kParameterReplication: { ParameterReplication parameter_replication; if (!ParseParameterReplication(¶meter_replication)) { @@ -4113,6 +4156,18 @@ StatusOr HloParser::ParseShardingOnly() { return HloSharding::FromProto(op_sharding); } +StatusOr HloParser::ParseFrontendAttributesOnly() { + lexer_.Lex(); + FrontendAttributes attributes; + if (!ParseFrontendAttributes(&attributes)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after frontend attributes"); + } + return attributes; +} + StatusOr> HloParser::ParseParameterReplicationOnly() { lexer_.Lex(); ParameterReplication parameter_replication; @@ -4261,6 +4316,11 @@ StatusOr ParseSharding(absl::string_view str) { return parser.ParseShardingOnly(); } +StatusOr ParseFrontendAttributes(absl::string_view str) { + HloParser parser(str); + return parser.ParseFrontendAttributesOnly(); +} + StatusOr> ParseParameterReplication(absl::string_view str) { HloParser parser(str); return parser.ParseParameterReplicationOnly(); diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index e4214c1e6b5..e643d9d4c0b 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -54,6 +54,11 @@ Status ParseHloString(absl::string_view str, HloModule* module); // "{replicated}". StatusOr ParseSharding(absl::string_view str); +// Parses frontend attributes from str. str is supposed to contain the body of the +// frontend attributes , i.e. just the rhs of the "frontend_attributes={...}" attribute string, e.g., +// "{attr_a=a,attr_b=b}". +StatusOr ParseFrontendAttributes(absl::string_view str); + // Parses parameter replication from str. str is supposed to contain the body of // the parameter replication, i.e. just the rhs of the // "parameter_replication={...}" attribute string, e.g., "{true, false}". diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index b9a017ada43..36ffafcc338 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -2327,6 +2327,12 @@ TEST_F(HloParserTest, ParseSharding) { EXPECT_EQ(sharding.ToString(), original); } +TEST_F(HloParserTest, ParseFrontendAttributes) { + const string original = "{attr_a=test_a,attr_b=b}"; + TF_ASSERT_OK_AND_ASSIGN(FrontendAttributes frontend_attributes, ParseFrontendAttributes(original)); + EXPECT_EQ(FrontendAttributesToString(frontend_attributes), original); +} + TEST_F(HloParserTest, ParseWindow) { Window original = window_util::MakeWindow({1, 2, 3}); TF_ASSERT_OK_AND_ASSIGN(Window parsed, From 4b463f04e50463b5e05df5ab1c3fcc8534a0009d Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Fri, 2 Aug 2019 17:50:02 +0100 Subject: [PATCH 08/14] Apply linter --- .../compiler/xla/service/hlo_instruction.cc | 11 +++++---- .../compiler/xla/service/hlo_instruction.h | 3 ++- tensorflow/compiler/xla/service/hlo_parser.cc | 24 +++++++++++-------- tensorflow/compiler/xla/service/hlo_parser.h | 5 ++-- .../compiler/xla/service/hlo_parser_test.cc | 3 ++- 5 files changed, 27 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a0ac73323a5..3325069cad8 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2488,7 +2488,8 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back(StrCat("sharding=", sharding().ToString())); } if (!frontend_attributes_.map().empty()) { - extra.push_back(StrCat("frontend_attributes=", FrontendAttributesToString(frontend_attributes_))); + extra.push_back(StrCat("frontend_attributes=", + FrontendAttributesToString(frontend_attributes_))); } if (!outer_dimension_partitions_.empty()) { extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}", @@ -3204,10 +3205,10 @@ StatusOr StringToFusionKind( return InvalidArgument("Unknown fusion kind: %s", kind_name); } -string FrontendAttributesToString(const FrontendAttributes& frontend_attributes){ - return absl::StrFormat("{%s}", - absl::StrJoin(frontend_attributes.map(), ",", - absl::PairFormatter("="))); +string FrontendAttributesToString( + const FrontendAttributes& frontend_attributes) { + return absl::StrFormat("{%s}", absl::StrJoin(frontend_attributes.map(), ",", + absl::PairFormatter("="))); } string PaddingConfigToString(const PaddingConfig& padding) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index e5f22aa3146..2d4235f6a0b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1928,7 +1928,8 @@ StatusOr StringToFusionKind( // Custom (de)stringification functions for protos that live inside // HloInstruction. string PaddingConfigToString(const PaddingConfig& padding); -string FrontendAttributesToString(const FrontendAttributes& frontend_attributes); +string FrontendAttributesToString( + const FrontendAttributes& frontend_attributes); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); string PrecisionToString(const PrecisionConfig::Precision& precision); diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 834fa69499f..e0dcb2ce9d1 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -273,7 +273,7 @@ class HloParser { bool ParsePaddingConfig(PaddingConfig* padding); bool ParseMetadata(OpMetadata* metadata); bool ParseSharding(OpSharding* sharding); - bool ParseFrontendAttributes(FrontendAttributes *frontend_attributes); + bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); bool ParseParameterReplication(ParameterReplication* parameter_replication); bool ParseReplicaGroupsOnly(std::vector* replica_groups); @@ -682,7 +682,8 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional sharding; optional frontend_attributes; attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; - attrs["frontend_attributes"] = {/*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes}; + attrs["frontend_attributes"] = { + /*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes}; optional parameter_replication; attrs["parameter_replication"] = {/*required=*/false, AttrTy::kParameterReplication, @@ -1847,9 +1848,9 @@ bool HloParser::ParseSharding(OpSharding* sharding) { // attributes // ::= /*empty*/ // ::= attribute '=' value (',' attribute '=' value)* -bool HloParser::ParseFrontendAttributes(FrontendAttributes *frontend_attributes) -{ - CHECK(frontend_attributes!= nullptr); +bool HloParser::ParseFrontendAttributes( + FrontendAttributes* frontend_attributes) { + CHECK(frontend_attributes != nullptr); if (!ParseToken(TokKind::kLbrace, "expected '{' to start frontend attributes")) { return false; @@ -1860,7 +1861,7 @@ bool HloParser::ParseFrontendAttributes(FrontendAttributes *frontend_attributes) do { LocTy loc = lexer_.GetLoc(); string attribute; - if(!ParseAttributeName(&attribute)){ + if (!ParseAttributeName(&attribute)) { return false; } if (lexer_.GetKind() != TokKind::kIdent) { @@ -1870,7 +1871,8 @@ bool HloParser::ParseFrontendAttributes(FrontendAttributes *frontend_attributes) lexer_.Lex(); } while (EatIfPresent(TokKind::kComma)); } - return ParseToken(TokKind::kRbrace, "expects '}' at the end of frontend attributes"); + return ParseToken(TokKind::kRbrace, + "expects '}' at the end of frontend attributes"); } // ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? @@ -2894,10 +2896,11 @@ bool HloParser::ParseAttributeHelper( } case AttrTy::kFrontendAttributes: { FrontendAttributes frontend_attributes; - if(!ParseFrontendAttributes(&frontend_attributes)) { + if (!ParseFrontendAttributes(&frontend_attributes)) { return false; } - static_cast*>(attr_out_ptr)->emplace(frontend_attributes); + static_cast*>(attr_out_ptr) + ->emplace(frontend_attributes); return true; } case AttrTy::kParameterReplication: { @@ -4163,7 +4166,8 @@ StatusOr HloParser::ParseFrontendAttributesOnly() { return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { - return InvalidArgument("Syntax error:\nExtra content after frontend attributes"); + return InvalidArgument( + "Syntax error:\nExtra content after frontend attributes"); } return attributes; } diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index e643d9d4c0b..91ce79ec982 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -54,8 +54,9 @@ Status ParseHloString(absl::string_view str, HloModule* module); // "{replicated}". StatusOr ParseSharding(absl::string_view str); -// Parses frontend attributes from str. str is supposed to contain the body of the -// frontend attributes , i.e. just the rhs of the "frontend_attributes={...}" attribute string, e.g., +// Parses frontend attributes from str. str is supposed to contain the body of +// the frontend attributes , i.e. just the rhs of the +// "frontend_attributes={...}" attribute string, e.g., // "{attr_a=a,attr_b=b}". StatusOr ParseFrontendAttributes(absl::string_view str); diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 36ffafcc338..a5ed289721c 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -2329,7 +2329,8 @@ TEST_F(HloParserTest, ParseSharding) { TEST_F(HloParserTest, ParseFrontendAttributes) { const string original = "{attr_a=test_a,attr_b=b}"; - TF_ASSERT_OK_AND_ASSIGN(FrontendAttributes frontend_attributes, ParseFrontendAttributes(original)); + TF_ASSERT_OK_AND_ASSIGN(FrontendAttributes frontend_attributes, + ParseFrontendAttributes(original)); EXPECT_EQ(FrontendAttributesToString(frontend_attributes), original); } From 53e79b073eeb7813e3cfb497b98ece7a7cb71dad Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Mon, 5 Aug 2019 11:46:54 +0100 Subject: [PATCH 09/14] Renamed Check -> Expect, used macro to disable copy / assignment --- tensorflow/compiler/xla/client/xla_builder.h | 16 +++++++--------- .../compiler/xla/client/xla_builder_test.cc | 14 +++++++------- .../compiler/xla/service/hlo_instruction.cc | 5 ++--- 3 files changed, 16 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index c3c663873e5..34716f6ed78 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -172,7 +172,8 @@ class XlaBuilder { // Swap the passed FrontendAttributes with the ones currently set. // // Return the old attributes. - FrontendAttributes SwapFrontendAttributes(const FrontendAttributes& frontend_attributes) { + FrontendAttributes SwapFrontendAttributes( + const FrontendAttributes& frontend_attributes) { FrontendAttributes old_attributes = std::move(frontend_attributes_); frontend_attributes_ = std::move(frontend_attributes); return old_attributes; @@ -1085,17 +1086,12 @@ class XlaScopedShardingAssignment { // Restore the original attributes on destruction. class XlaScopedFrontendAttributesAssignment { public: - XlaScopedFrontendAttributesAssignment( - xla::XlaBuilder* builder, FrontendAttributes attributes) + XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder, + FrontendAttributes attributes) : builder_(builder) { - saved_ = builder_->SwapFrontendAttributes(std::move(attributes)); + saved_ = builder_->SwapFrontendAttributes(std::move(attributes)); } - XlaScopedFrontendAttributesAssignment( - const XlaScopedFrontendAttributesAssignment&) = delete; - XlaScopedFrontendAttributesAssignment& operator=( - const XlaScopedFrontendAttributesAssignment&) = delete; - ~XlaScopedFrontendAttributesAssignment() { builder_->SetFrontendAttributes(std::move(saved_)); } @@ -1103,6 +1099,8 @@ class XlaScopedFrontendAttributesAssignment { private: xla::XlaBuilder* const builder_; FrontendAttributes saved_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment); }; // Free functions for building XlaOps. The intention is that these will // become the public API for building XlaOps rather than calling methods on diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 2bc79f5db66..08b28c051f5 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -978,8 +978,8 @@ TEST_F(XlaBuilderTest, CheckInputOutputAlias) { EXPECT_EQ(*alias_p1, ShapeIndex({0})); } -void CheckAttributesMatch(const FrontendAttributes& attr, - const FrontendAttributes& ref) { +void ExpectAttributesMatch(const FrontendAttributes& attr, + const FrontendAttributes& ref) { EXPECT_EQ(ref.map_size(), attr.map_size()); for (auto reference : ref.map()) { auto other = attr.map().find(reference.first); @@ -988,13 +988,13 @@ void CheckAttributesMatch(const FrontendAttributes& attr, } } -void CheckInstructionsAttributesMatch( +void ExpectInstructionsAttributesMatch( HloModule& module, const std::vector& expected) { ASSERT_EQ(module.computation_count(), 1); auto expected_it = expected.begin(); for (auto inst : module.mutable_computation(0)->instructions()) { ASSERT_NE(expected_it, expected.end()); - CheckAttributesMatch(inst->frontend_attributes(), *expected_it); + ExpectAttributesMatch(inst->frontend_attributes(), *expected_it); expected_it++; } EXPECT_EQ(expected_it, expected.end()); @@ -1017,7 +1017,7 @@ TEST_F(XlaBuilderTest, SimpleSetFrontendAttributes) { std::vector expected{FrontendAttributes(), attributes, FrontendAttributes()}; - CheckInstructionsAttributesMatch(*module, expected); + ExpectInstructionsAttributesMatch(*module, expected); } TEST_F(XlaBuilderTest, ComplexSetFrontendAttributes) { @@ -1056,7 +1056,7 @@ TEST_F(XlaBuilderTest, ComplexSetFrontendAttributes) { expected.push_back(FrontendAttributes()); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - CheckInstructionsAttributesMatch(*module, expected); + ExpectInstructionsAttributesMatch(*module, expected); } TEST_F(XlaBuilderTest, AddFrontendAttribute) { @@ -1121,7 +1121,7 @@ TEST_F(XlaBuilderTest, AddFrontendAttribute) { expected.push_back(FrontendAttributes()); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); - CheckInstructionsAttributesMatch(*module, expected); + ExpectInstructionsAttributesMatch(*module, expected); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 3325069cad8..e946e56c82c 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2550,9 +2550,8 @@ HloInstructionProto HloInstruction::ToProto() const { proto.mutable_outer_dimension_partitions()->Add(idx); } } - if (!frontend_attributes_.map().empty()) { - proto.mutable_frontend_attributes()->CopyFrom(frontend_attributes_); - } + + proto.mutable_frontend_attributes()->CopyFrom(frontend_attributes_); return proto; } From d333176f5a615f23f225be7bc906a2c5d9f56b51 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Mon, 19 Aug 2019 11:40:47 +0100 Subject: [PATCH 10/14] - Fix: Save frontend attributes in while loop - Fix: save backend / frontend attributes in ReplaceInstruction --- tensorflow/compiler/tf2xla/BUILD | 1 + tensorflow/compiler/tf2xla/frontend_attributes_util.cc | 6 ++---- tensorflow/compiler/tf2xla/frontend_attributes_util.h | 2 ++ tensorflow/compiler/tf2xla/functionalize_while.cc | 7 +++++++ tensorflow/compiler/xla/service/hlo_computation.cc | 6 ++++++ 5 files changed, 18 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 1e4f2e23ef3..329c706c763 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -594,6 +594,7 @@ cc_library( ":functionalize_cond", ":functionalize_control_flow_util", ":tf2xla_util", + ":frontend_attributes_util", "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.cc b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc index 7c2564ffa99..b088001f287 100644 --- a/tensorflow/compiler/tf2xla/frontend_attributes_util.cc +++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc @@ -19,13 +19,11 @@ limitations under the License. namespace tensorflow { -namespace { -const char kFrontendAttributesAttribute[] = "_XlaFrontendAttributes"; -} // namespace +const char kXlaFrontendAttributesAttrName[] = "_XlaFrontendAttributes"; xla::StatusOr> GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs) { - const AttrValue *attr = attrs.Find(kFrontendAttributesAttribute); + const AttrValue* attr = attrs.Find(kXlaFrontendAttributesAttrName); if (attr == nullptr) { return xla::StatusOr>( absl::nullopt); diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.h b/tensorflow/compiler/tf2xla/frontend_attributes_util.h index 1c2b1d8c1c5..421f21e71d1 100644 --- a/tensorflow/compiler/tf2xla/frontend_attributes_util.h +++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.h @@ -24,6 +24,8 @@ limitations under the License. namespace tensorflow { +// Frontend Attributes Id. +extern const char kXlaFrontendAttributesAttrName[]; // Return the FrontendAttributes stored in the AttrSlice if there are some. // // Return an InvalidArgument error if some attributes are present but diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index e4a21f90598..d3d2f2ff79a 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -530,6 +531,12 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, builder.Attr("cond", cond_name); builder.Attr("body", body_name); string outside_compilation; + string frontend_attributes; + if (GetNodeAttr(frame->loop_cond->def(), kXlaFrontendAttributesAttrName, + &frontend_attributes) + .ok()) { + builder.Attr(kXlaFrontendAttributesAttrName, frontend_attributes); + } if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName, &outside_compilation) .ok()) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 6fe91e492ed..fce60bc430e 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -837,6 +837,12 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, if (new_instruction->metadata().op_name().empty()) { new_instruction->set_metadata(old_instruction->metadata()); } + new_instruction->set_raw_backend_config_string( + old_instruction->raw_backend_config_string()); + if (new_instruction->frontend_attributes().map().empty()) { + new_instruction->set_frontend_attributes( + old_instruction->frontend_attributes()); + } // Like the metadata above, if the user didn't specify any sharding // information on the new instruction we should copy the old sharding From b2b2424f09e4540afe0a6086c8c833303bbed39e Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Mon, 19 Aug 2019 11:05:11 +0100 Subject: [PATCH 11/14] Sort map before serializing it --- tensorflow/compiler/xla/service/hlo_instruction.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index e946e56c82c..1ba87a6da6a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -3206,8 +3206,11 @@ StatusOr StringToFusionKind( string FrontendAttributesToString( const FrontendAttributes& frontend_attributes) { - return absl::StrFormat("{%s}", absl::StrJoin(frontend_attributes.map(), ",", - absl::PairFormatter("="))); + std::vector> sorted_attributes( + frontend_attributes.map().begin(), frontend_attributes.map().end()); + std::sort(sorted_attributes.begin(), sorted_attributes.end()); + return absl::StrFormat( + "{%s}", absl::StrJoin(sorted_attributes, ",", absl::PairFormatter("="))); } string PaddingConfigToString(const PaddingConfig& padding) { From e65f01242bb51e258305357765473f2f9c6e624b Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Tue, 20 Aug 2019 11:17:38 +0100 Subject: [PATCH 12/14] Use absl::c_sort instead of std::c_sort. Best-effort to copy the backend_config over when replacing HloInstructions (But only if the new one is empty) --- tensorflow/compiler/xla/service/hlo_computation.cc | 6 ++++-- tensorflow/compiler/xla/service/hlo_instruction.cc | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index fce60bc430e..c27a0a86ef8 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -837,8 +837,10 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, if (new_instruction->metadata().op_name().empty()) { new_instruction->set_metadata(old_instruction->metadata()); } - new_instruction->set_raw_backend_config_string( - old_instruction->raw_backend_config_string()); + if (new_instruction->raw_backend_config_string().empty()) { + new_instruction->set_raw_backend_config_string( + old_instruction->raw_backend_config_string()); + } if (new_instruction->frontend_attributes().map().empty()) { new_instruction->set_frontend_attributes( old_instruction->frontend_attributes()); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 1ba87a6da6a..47b76331d76 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -3208,7 +3208,7 @@ string FrontendAttributesToString( const FrontendAttributes& frontend_attributes) { std::vector> sorted_attributes( frontend_attributes.map().begin(), frontend_attributes.map().end()); - std::sort(sorted_attributes.begin(), sorted_attributes.end()); + absl::c_sort(sorted_attributes); return absl::StrFormat( "{%s}", absl::StrJoin(sorted_attributes, ",", absl::PairFormatter("="))); } From 4db6d9a5f061386377c4d02d17cdda608f8d84ad Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Tue, 20 Aug 2019 14:36:59 +0100 Subject: [PATCH 13/14] Actually there is no need to copy over the backend config at the Hlo level --- tensorflow/compiler/xla/service/hlo_computation.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index c27a0a86ef8..0ff103cf5fc 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -837,10 +837,6 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, if (new_instruction->metadata().op_name().empty()) { new_instruction->set_metadata(old_instruction->metadata()); } - if (new_instruction->raw_backend_config_string().empty()) { - new_instruction->set_raw_backend_config_string( - old_instruction->raw_backend_config_string()); - } if (new_instruction->frontend_attributes().map().empty()) { new_instruction->set_frontend_attributes( old_instruction->frontend_attributes()); From 782d3d05ea62848dc46ce8b5e5336ce8f0d81f3f Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Tue, 20 Aug 2019 17:28:26 +0100 Subject: [PATCH 14/14] Ran buildifier to fix formatting issues in BUILD files --- tensorflow/compiler/tf2xla/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 329c706c763..877c00d1115 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -591,10 +591,10 @@ cc_library( "functionalize_while.h", ], deps = [ + ":frontend_attributes_util", ":functionalize_cond", ":functionalize_control_flow_util", ":tf2xla_util", - ":frontend_attributes_util", "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros",