Add initial support for populating XLA OpMetadata with operation location when exporting MLIR HLO to HLO proto.
Different location types are handled. Currently, unknown locations are ignored, name locations populate the `op_name` field, and file name line column locations populate the `source_file` and `source_line` fields respectively. A RAII-style object is created for setting OpMetadata, similar to frontend attributes and sharding, to match the rest of the export flow. PiperOrigin-RevId: 329418320 Change-Id: Idd410ee624440c902e65101d41be9eef5d012f63
This commit is contained in:
parent
dbb4fe3fe1
commit
9c9850058d
@ -86,7 +86,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) {
|
||||
%arg_tuple.1 = (f32[], f32[]) parameter(0)
|
||||
%get-tuple-element.2 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.1), index=0
|
||||
%get-tuple-element.3 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.1), index=1
|
||||
%add.4 = f32[] add(f32[] %get-tuple-element.2, f32[] %get-tuple-element.3)
|
||||
%add.4 = f32[] add(f32[] %get-tuple-element.2, f32[] %get-tuple-element.3), metadata={source_file="-" source_line=4}
|
||||
ROOT %tuple.5 = (f32[]) tuple(f32[] %add.4)
|
||||
}
|
||||
|
||||
@ -143,7 +143,7 @@ TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) {
|
||||
ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) {
|
||||
%Arg_0.1 = f32[] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
%add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2)
|
||||
%add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2), metadata={source_file="-" source_line=4}
|
||||
ROOT %tuple.4 = (f32[]) tuple(f32[] %add.3)
|
||||
}
|
||||
|
||||
@ -215,7 +215,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) {
|
||||
%arg_tuple.1 = (f32[10,19]{1,0}, f32[19,10]{1,0}) parameter(0), parameter_replication={false,true}
|
||||
%get-tuple-element.2 = f32[10,19]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %arg_tuple.1), index=0
|
||||
%get-tuple-element.3 = f32[19,10]{1,0} get-tuple-element((f32[10,19]{1,0}, f32[19,10]{1,0}) %arg_tuple.1), index=1
|
||||
%reshape.4 = f32[10,19]{1,0} reshape(f32[19,10]{1,0} %get-tuple-element.3)
|
||||
%reshape.4 = f32[10,19]{1,0} reshape(f32[19,10]{1,0} %get-tuple-element.3), metadata={source_file="-" source_line=5}
|
||||
ROOT %tuple.5 = (f32[10,19]{1,0}) tuple(f32[10,19]{1,0} %reshape.4)
|
||||
}
|
||||
|
||||
|
@ -237,6 +237,7 @@ cc_library(
|
||||
hdrs = ["mlir_hlo_to_hlo.h"],
|
||||
deps = [
|
||||
":type_to_shape",
|
||||
"//tensorflow/compiler/mlir:name_utils",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_type",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
|
@ -41,6 +41,7 @@ limitations under the License.
|
||||
#include "mlir/IR/UseDefLists.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/utils/name_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
@ -430,6 +431,27 @@ static xla::FrontendAttributes CreateOpFrontendAttributesFromAttribute(
|
||||
return frontend_attributes;
|
||||
}
|
||||
|
||||
// Returns a OpMetadata proto based on the location of the op. If the location
|
||||
// is unknown, an empty proto is returned. `op_name` are populated with the op
|
||||
// location (converted). FileLineColLoc locations are populated by taking the
|
||||
// file name and line number, and populating `source_file` and `source_line`
|
||||
// respectively.
|
||||
static xla::OpMetadata CreateOpMetadataFromLocation(mlir::Operation* op) {
|
||||
xla::OpMetadata metadata;
|
||||
if (op->getLoc().isa<mlir::UnknownLoc>()) return metadata;
|
||||
|
||||
std::string name = mlir::GetNameFromLoc(op->getLoc());
|
||||
mlir::LegalizeNodeName(name);
|
||||
metadata.set_op_name(name);
|
||||
|
||||
if (auto file_line_col_loc = op->getLoc().dyn_cast<mlir::FileLineColLoc>()) {
|
||||
metadata.set_source_file(file_line_col_loc.getFilename().str());
|
||||
metadata.set_source_line(file_line_col_loc.getLine());
|
||||
}
|
||||
|
||||
return metadata;
|
||||
}
|
||||
|
||||
// Checks if all shardings are set.
|
||||
static bool AllOptionalShardingsAreSet(
|
||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> shardings) {
|
||||
|
@ -165,6 +165,11 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
|
||||
"frontend_attributes(lowering_context.builder, "
|
||||
"CreateOpFrontendAttributesFromAttribute(op));\n\n";
|
||||
|
||||
// Create a scoped object to assign op metadata to generated XLA ops.
|
||||
os << " xla::XlaScopedOpMetadataAssignment "
|
||||
"op_metadata(lowering_context.builder, "
|
||||
"CreateOpMetadataFromLocation(op));\n\n";
|
||||
|
||||
// Retrieve all the definitions derived from HLO_Op and sort by record name.
|
||||
for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) {
|
||||
// Skip operations that have a custom exporter.
|
||||
|
@ -362,7 +362,9 @@ func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3xf32>
|
||||
// CHECK: [[VAL_1:%.*]] = f32[2,3] parameter(0)
|
||||
// CHECK: [[VAL_2:%.*]] = f32[5,5] parameter(1)
|
||||
// CHECK: ROOT
|
||||
// CHECK-SAME: f32[1,2,3] custom-call(f32[2,3] [[VAL_1]], f32[5,5] [[VAL_2]]), custom_call_target="foo", backend_config="bar"
|
||||
// CHECK-SAME: f32[1,2,3] custom-call(f32[2,3] [[VAL_1]], f32[5,5] [[VAL_2]])
|
||||
// CHECK-SAME: custom_call_target="foo"
|
||||
// CHECK-SAME: backend_config="bar"
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -0,0 +1,43 @@
|
||||
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s --dump-input=always
|
||||
|
||||
// CHECK-LABEL: %main
|
||||
func @main(%arg0: !mhlo.token) -> !mhlo.token {
|
||||
%0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc(unknown)
|
||||
return %0 : !mhlo.token
|
||||
}
|
||||
|
||||
// CHECK: after-all
|
||||
// CHECK-NOT: metadata
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: %main
|
||||
func @main(%arg0: !mhlo.token) -> !mhlo.token {
|
||||
%0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc("AfterAll")
|
||||
return %0 : !mhlo.token
|
||||
}
|
||||
|
||||
// CHECK: after-all
|
||||
// CHECK-SAME: metadata={op_name="AfterAll"}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: %main
|
||||
func @main(%arg0: !mhlo.token) -> !mhlo.token {
|
||||
%0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc("name@function")
|
||||
return %0 : !mhlo.token
|
||||
}
|
||||
|
||||
// CHECK: after-all
|
||||
// CHECK-SAME: metadata={op_name="name"}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: %main
|
||||
func @main(%arg0: !mhlo.token) -> !mhlo.token {
|
||||
%0 = "mhlo.after_all"(%arg0) : (!mhlo.token) -> !mhlo.token loc("file_name":2:8)
|
||||
return %0 : !mhlo.token
|
||||
}
|
||||
|
||||
// CHECK: after-all
|
||||
// CHECK-SAME: metadata={source_file="file_name" source_line=2}
|
@ -164,6 +164,15 @@ class XlaBuilder {
|
||||
// OpMetadata attached until a call to ClearOpMetadata.
|
||||
void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); }
|
||||
|
||||
// Swaps the passed op metadata with the ones currently set.
|
||||
//
|
||||
// Returns the old op metadata.
|
||||
OpMetadata SwapOpMetadata(OpMetadata metadata) {
|
||||
OpMetadata old_metadata = std::move(metadata_);
|
||||
metadata_ = std::move(metadata);
|
||||
return old_metadata;
|
||||
}
|
||||
|
||||
// Similar to SetOpMetadata, but only set the metadata for the next op.
|
||||
void SetOneShotOpMetadata(OpMetadata metadata) {
|
||||
metadata_ = std::move(metadata);
|
||||
@ -1339,6 +1348,25 @@ class XlaScopedFrontendAttributesAssignment {
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment);
|
||||
};
|
||||
|
||||
// RAII-style object: sets the current op metadata in builder on construction,
|
||||
// and sets back to the previous assignment on destruction.
|
||||
class XlaScopedOpMetadataAssignment {
|
||||
public:
|
||||
XlaScopedOpMetadataAssignment(xla::XlaBuilder* builder, OpMetadata metadata)
|
||||
: builder_(builder) {
|
||||
saved_ = builder_->SwapOpMetadata(metadata);
|
||||
}
|
||||
|
||||
~XlaScopedOpMetadataAssignment() { builder_->SwapOpMetadata(saved_); }
|
||||
|
||||
private:
|
||||
xla::XlaBuilder* const builder_;
|
||||
OpMetadata saved_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedOpMetadataAssignment);
|
||||
};
|
||||
|
||||
// Free functions for building XlaOps. The intention is that these will
|
||||
// become the public API for building XlaOps rather than calling methods on
|
||||
// XlaBuilder directly.
|
||||
|
Loading…
x
Reference in New Issue
Block a user