diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index 80e2c1132fd..461c12bf548 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -86,7 +86,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) { %arg_tuple.1 = (f32[], f32[]) parameter(0) %get-tuple-element.2 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.1), index=0 %get-tuple-element.3 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.1), index=1 - %add.4 = f32[] add(f32[] %get-tuple-element.2, f32[] %get-tuple-element.3) + %add.4 = f32[] add(f32[] %get-tuple-element.2, f32[] %get-tuple-element.3), metadata={source_file="-" source_line=4} ROOT %tuple.5 = (f32[]) tuple(f32[] %add.4) } @@ -143,7 +143,7 @@ TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) { ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) { %Arg_0.1 = f32[] parameter(0) %Arg_1.2 = f32[] parameter(1) - %add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2) + %add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2), metadata={source_file="-" source_line=4} ROOT %tuple.4 = (f32[]) tuple(f32[] %add.3) } @@ -215,7 +215,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) { %arg_tuple.1 = (f32[10,19]{1,0}, f32[19,10]{1,0}) parameter(0), parameter_replication={false,true} %get-tuple-element.2 = f32[10,19]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %arg_tuple.1), index=0 %get-tuple-element.3 = f32[19,10]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %arg_tuple.1), index=1 - %reshape.4 = f32[10,19]{1,0} reshape(f32[19,10]{1,0} %get-tuple-element.3) + %reshape.4 = f32[10,19]{1,0} reshape(f32[19,10]{1,0} %get-tuple-element.3), metadata={source_file="-" source_line=5} ROOT %tuple.5 = (f32[10,19]{1,0}) tuple(f32[10,19]{1,0} %reshape.4) } diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index ec98d9d29e5..37b16a1d372 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -237,6 +237,7 @@ cc_library( hdrs = ["mlir_hlo_to_hlo.h"], deps = [ ":type_to_shape", + "//tensorflow/compiler/mlir:name_utils", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:error_util", diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 5398cd70777..f4c68de91e6 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -41,6 +41,7 @@ limitations under the License. #include "mlir/IR/UseDefLists.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/utils/name_utils.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" @@ -430,6 +431,27 @@ static xla::FrontendAttributes CreateOpFrontendAttributesFromAttribute( return frontend_attributes; } +// Returns a OpMetadata proto based on the location of the op. If the location +// is unknown, an empty proto is returned. `op_name` are populated with the op +// location (converted). FileLineColLoc locations are populated by taking the +// file name and line number, and populating `source_file` and `source_line` +// respectively. +static xla::OpMetadata CreateOpMetadataFromLocation(mlir::Operation* op) { + xla::OpMetadata metadata; + if (op->getLoc().isa()) return metadata; + + std::string name = mlir::GetNameFromLoc(op->getLoc()); + mlir::LegalizeNodeName(name); + metadata.set_op_name(name); + + if (auto file_line_col_loc = op->getLoc().dyn_cast()) { + metadata.set_source_file(file_line_col_loc.getFilename().str()); + metadata.set_source_line(file_line_col_loc.getLine()); + } + + return metadata; +} + // Checks if all shardings are set. static bool AllOptionalShardingsAreSet( llvm::ArrayRef> shardings) { diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc index 407a7d3da38..801c04496f0 100644 --- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc +++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc @@ -165,6 +165,11 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) { "frontend_attributes(lowering_context.builder, " "CreateOpFrontendAttributesFromAttribute(op));\n\n"; + // Create a scoped object to assign op metadata to generated XLA ops. + os << " xla::XlaScopedOpMetadataAssignment " + "op_metadata(lowering_context.builder, " + "CreateOpMetadataFromLocation(op));\n\n"; + // Retrieve all the definitions derived from HLO_Op and sort by record name. for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) { // Skip operations that have a custom exporter. diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 316eda4c4aa..ff1bcadda7b 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -362,7 +362,9 @@ func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3xf32> // CHECK: [[VAL_1:%.*]] = f32[2,3] parameter(0) // CHECK: [[VAL_2:%.*]] = f32[5,5] parameter(1) // CHECK: ROOT -// CHECK-SAME: f32[1,2,3] custom-call(f32[2,3] [[VAL_1]], f32[5,5] [[VAL_2]]), custom_call_target="foo", backend_config="bar" +// CHECK-SAME: f32[1,2,3] custom-call(f32[2,3] [[VAL_1]], f32[5,5] [[VAL_2]]) +// CHECK-SAME: custom_call_target="foo" +// CHECK-SAME: backend_config="bar" // ----- diff --git a/tensorflow/compiler/mlir/xla/tests/translate/location_to_op_metadata.mlir b/tensorflow/compiler/mlir/xla/tests/translate/location_to_op_metadata.mlir new file mode 100644 index 00000000000..2182ce6106d --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/location_to_op_metadata.mlir @@ -0,0 +1,43 @@ +// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input=always + +// CHECK-LABEL: %main +func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc(unknown) + return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-NOT: metadata + +// ----- + +// CHECK-LABEL: %main +func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc("AfterAll") + return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-SAME: metadata={op_name="AfterAll"} + +// ----- + +// CHECK-LABEL: %main +func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc("name@function") + return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-SAME: metadata={op_name="name"} + +// ----- + +// CHECK-LABEL: %main +func @main(%arg0: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc("file_name":2:8) + return %0 : !mhlo.token +} + +// CHECK: after-all +// CHECK-SAME: metadata={source_file="file_name" source_line=2} diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index f841a1a75a0..cd9809c2a20 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -164,6 +164,15 @@ class XlaBuilder { // OpMetadata attached until a call to ClearOpMetadata. void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); } + // Swaps the passed op metadata with the ones currently set. + // + // Returns the old op metadata. + OpMetadata SwapOpMetadata(OpMetadata metadata) { + OpMetadata old_metadata = std::move(metadata_); + metadata_ = std::move(metadata); + return old_metadata; + } + // Similar to SetOpMetadata, but only set the metadata for the next op. void SetOneShotOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); @@ -1339,6 +1348,25 @@ class XlaScopedFrontendAttributesAssignment { TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment); }; + +// RAII-style object: sets the current op metadata in builder on construction, +// and sets back to the previous assignment on destruction. +class XlaScopedOpMetadataAssignment { + public: + XlaScopedOpMetadataAssignment(xla::XlaBuilder* builder, OpMetadata metadata) + : builder_(builder) { + saved_ = builder_->SwapOpMetadata(metadata); + } + + ~XlaScopedOpMetadataAssignment() { builder_->SwapOpMetadata(saved_); } + + private: + xla::XlaBuilder* const builder_; + OpMetadata saved_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedOpMetadataAssignment); +}; + // Free functions for building XlaOps. The intention is that these will // become the public API for building XlaOps rather than calling methods on // XlaBuilder directly.