From 9c9850058de839a05b4832e3ce21489527024c2a Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Mon, 31 Aug 2020 18:22:05 -0700 Subject: [PATCH] Add initial support for populating XLA OpMetadata with operation location when exporting MLIR HLO to HLO proto. Different location types are handled. Currently, unknown locations are ignored, name locations populate the `op_name` field, and file name line column locations populate the `source_file` and `source_line` fields respectively. A RAII-style object is created for setting OpMetadata, similar to frontend attributes and sharding, to match the rest of the export flow. PiperOrigin-RevId: 329418320 Change-Id: Idd410ee624440c902e65101d41be9eef5d012f63 --- .../utils/compile_mlir_util_test.cc | 6 +-- tensorflow/compiler/mlir/xla/BUILD | 1 + .../compiler/mlir/xla/mlir_hlo_to_hlo.cc | 22 ++++++++++ .../compiler/mlir/xla/operator_writer_gen.cc | 5 +++ .../mlir/xla/tests/translate/export.mlir | 4 +- .../translate/location_to_op_metadata.mlir | 43 +++++++++++++++++++ tensorflow/compiler/xla/client/xla_builder.h | 28 ++++++++++++ 7 files changed, 105 insertions(+), 4 deletions(-) create mode 100644 tensorflow/compiler/mlir/xla/tests/translate/location_to_op_metadata.mlir diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index 80e2c1132fd..461c12bf548 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -86,7 +86,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) { %arg_tuple.1 = (f32[], f32[]) parameter(0) %get-tuple-element.2 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.1), index=0 %get-tuple-element.3 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.1), index=1 - %add.4 = f32[] add(f32[] %get-tuple-element.2, f32[] %get-tuple-element.3) + %add.4 = f32[] add(f32[] %get-tuple-element.2, f32[] %get-tuple-element.3), metadata={source_file="-" source_line=4} ROOT %tuple.5 = (f32[]) tuple(f32[] %add.4) } @@ -143,7 +143,7 @@ TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) { ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) { %Arg_0.1 = f32[] parameter(0) %Arg_1.2 = f32[] parameter(1) - %add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2) + %add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2), metadata={source_file="-" source_line=4} ROOT %tuple.4 = (f32[]) tuple(f32[] %add.3) } @@ -215,7 +215,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) { %arg_tuple.1 = (f32[10,19]{1,0}, f32[19,10]{1,0}) parameter(0), parameter_replication={false,true} %get-tuple-element.2 = f32[10,19]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %arg_tuple.1), index=0 %get-tuple-element.3 = f32[19,10]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %arg_tuple.1), index=1 - %reshape.4 = f32[10,19]{1,0} reshape(f32[19,10]{1,0} %get-tuple-element.3) + %reshape.4 = f32[10,19]{1,0} reshape(f32[19,10]{1,0} %get-tuple-element.3), metadata={source_file="-" source_line=5} ROOT %tuple.5 = (f32[10,19]{1,0}) tuple(f32[10,19]{1,0} %reshape.4) } diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index ec98d9d29e5..37b16a1d372 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -237,6 +237,7 @@ cc_library( hdrs = ["mlir_hlo_to_hlo.h"], deps = [ ":type_to_shape", + "//tensorflow/compiler/mlir:name_utils", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:error_util", diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 5398cd70777..f4c68de91e6 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -41,6 +41,7 @@ limitations under the License. #include "mlir/IR/UseDefLists.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/utils/name_utils.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" @@ -430,6 +431,27 @@ static xla::FrontendAttributes CreateOpFrontendAttributesFromAttribute( return frontend_attributes; } +// Returns a OpMetadata proto based on the location of the op. If the location +// is unknown, an empty proto is returned. `op_name` are populated with the op +// location (converted). FileLineColLoc locations are populated by taking the +// file name and line number, and populating `source_file` and `source_line` +// respectively. +static xla::OpMetadata CreateOpMetadataFromLocation(mlir::Operation* op) { + xla::OpMetadata metadata; + if (op->getLoc().isa()) return metadata; + + std::string name = mlir::GetNameFromLoc(op->getLoc()); + mlir::LegalizeNodeName(name); + metadata.set_op_name(name); + + if (auto file_line_col_loc = op->getLoc().dyn_cast()) { + metadata.set_source_file(file_line_col_loc.getFilename().str()); + metadata.set_source_line(file_line_col_loc.getLine()); + } + + return metadata; +} + // Checks if all shardings are set. static bool AllOptionalShardingsAreSet( llvm::ArrayRef> shardings) { diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc index 407a7d3da38..801c04496f0 100644 --- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc +++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc @@ -165,6 +165,11 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) { "frontend_attributes(lowering_context.builder, " "CreateOpFrontendAttributesFromAttribute(op));\n\n"; + // Create a scoped object to assign op metadata to generated XLA ops. + os << " xla::XlaScopedOpMetadataAssignment " + "op_metadata(lowering_context.builder, " + "CreateOpMetadataFromLocation(op));\n\n"; + // Retrieve all the definitions derived from HLO_Op and sort by record name. for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) { // Skip operations that have a custom exporter. diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 316eda4c4aa..ff1bcadda7b 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -362,7 +362,9 @@ func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3xf32> // CHECK: [[VAL_1:%.*]] = f32[2,3] parameter(0) // CHECK: [[VAL_2:%.*]] = f32[5,5] parameter(1) // CHECK: ROOT -// CHECK-SAME: f32[1,2,3] custom-call(f32[2,3] [[VAL_1]], f32[5,5] [[VAL_2]]), custom_call_target="foo", backend_config="bar" +// CHECK-SAME: f32[1,2,3] custom-call(f32[2,3] [[VAL_1]], f32[5,5] [[VAL_2]]) +// CHECK-SAME: custom_call_target="foo" +// CHECK-SAME: backend_config="bar" // ----- diff --git a/tensorflow/compiler/mlir/xla/tests/translate/location_to_op_metadata.mlir b/tensorflow/compiler/mlir/xla/tests/translate/location_to_op_metadata.mlir new file mode 100644 index 00000000000..2182ce6106d --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/location_to_op_metadata.mlir @@ -0,0 +1,43 @@ +// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input=always + +// CHECK-LABEL: %main +func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc(unknown) + return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-NOT: metadata + +// ----- + +// CHECK-LABEL: %main +func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc("AfterAll") + return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-SAME: metadata={op_name="AfterAll"} + +// ----- + +// CHECK-LABEL: %main +func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc("name@function") + return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-SAME: metadata={op_name="name"} + +// ----- + +// CHECK-LABEL: %main +func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc("file_name":2:8) + return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-SAME: metadata={source_file="file_name" source_line=2} diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index f841a1a75a0..cd9809c2a20 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -164,6 +164,15 @@ class XlaBuilder { // OpMetadata attached until a call to ClearOpMetadata. void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); } + // Swaps the passed op metadata with the ones currently set. + // + // Returns the old op metadata. + OpMetadata SwapOpMetadata(OpMetadata metadata) { + OpMetadata old_metadata = std::move(metadata_); + metadata_ = std::move(metadata); + return old_metadata; + } + // Similar to SetOpMetadata, but only set the metadata for the next op. void SetOneShotOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); @@ -1339,6 +1348,25 @@ class XlaScopedFrontendAttributesAssignment { TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment); }; + +// RAII-style object: sets the current op metadata in builder on construction, +// and sets back to the previous assignment on destruction. +class XlaScopedOpMetadataAssignment { + public: + XlaScopedOpMetadataAssignment(xla::XlaBuilder* builder, OpMetadata metadata) + : builder_(builder) { + saved_ = builder_->SwapOpMetadata(metadata); + } + + ~XlaScopedOpMetadataAssignment() { builder_->SwapOpMetadata(saved_); } + + private: + xla::XlaBuilder* const builder_; + OpMetadata saved_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedOpMetadataAssignment); +}; + // Free functions for building XlaOps. The intention is that these will // become the public API for building XlaOps rather than calling methods on // XlaBuilder directly.