From 340ac1aedb082dbf3092608354c8f5a1d2d276d9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 15 May 2020 13:49:39 -0700 Subject: [PATCH] Move GetDeviceCoordinates() function and related constants in tpu_rewrite pass to common utility file. PiperOrigin-RevId: 311795001 Change-Id: If86babf6656da132fb58b1a2266034f3b341e06d --- tensorflow/compiler/mlir/tensorflow/BUILD | 2 + .../tensorflow/transforms/tpu_rewrite_pass.cc | 63 +++++++------------ .../utils/tpu_rewrite_device_util.cc | 33 +++++++++- .../utils/tpu_rewrite_device_util.h | 10 +++ .../utils/tpu_rewrite_device_util_test.cc | 26 ++++++++ 5 files changed, 91 insertions(+), 43 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index eb220a31f80..2bbdbb383a1 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1279,6 +1279,7 @@ cc_library( "//tensorflow/stream_executor/lib", "@com_google_absl//absl/strings", "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", ], ) @@ -1293,6 +1294,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/protobuf/tpu:topology_proto_cc", "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index f5e9da915c8..986736a9502 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -64,19 +64,14 @@ static llvm::cl::opt tpu_compile_metadata_debug( "'tf._TPUCompileMlir' op as a proto debug string")); constexpr char kNumReplicasAttr[] = "num_replicas"; -constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica"; constexpr char kStepMarkerLocationAttr[] = "step_marker_location"; constexpr char kPaddingMapAttr[] = "padding_map"; -constexpr char kTopologyAttr[] = "topology"; -constexpr char kDeviceAssignmentAttr[] = "device_assignment"; constexpr char kDeviceAttr[] = "device"; constexpr char kDevicesAttr[] = "devices"; constexpr char kVersionsAttr[] = "tf.versions"; constexpr char kBadStringArrayElementMsg[] = "bad '{0}' attribute at index {1}, not a string"; -constexpr char kBadIntArrayElementMsg[] = - "bad '{0}' attribute at index {1}, not an int"; constexpr char kBadArrayElementMsg[] = "bad '{0}' attribute at index {1} with value '{2}': failed to parse to {3}"; constexpr char kBadArrayAttrLengthMsg[] = @@ -163,32 +158,6 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func, return success(); } -// Extracts device coordinates from a device assignment attribute on an op. -LogicalResult GetDeviceCoordinates( - tf_device::ClusterFuncOp op, - llvm::SmallVectorImpl* device_assignment) { - auto device_assignment_attr = - op.getAttrOfType(kDeviceAssignmentAttr); - if (!device_assignment_attr) - return op.emitOpError(CreateMissingAttributeMsg(kDeviceAssignmentAttr)); - - device_assignment->reserve(device_assignment_attr.size()); - - for (auto device_coordinate_and_idx : - llvm::enumerate(device_assignment_attr)) { - auto device_coordinate = - device_coordinate_and_idx.value().dyn_cast(); - if (!device_coordinate) - return op.emitOpError(llvm::formatv(kBadIntArrayElementMsg, - kDeviceAssignmentAttr, - device_coordinate_and_idx.index())); - - device_assignment->push_back(device_coordinate.getInt()); - } - - return success(); -} - // Populates a TPUCompileMetadataProto with StepMarkerLocation from a // `tf_device::ClusterFuncOp`. LogicalResult SetMetadataProtoStepMarkerLocation( @@ -661,27 +630,41 @@ LogicalResult Rewrite( : nullptr; if (replicate) num_replicas = replicate.n().getLimitedValue(); - auto num_cores_per_replica_attr = - cluster_func.getAttrOfType(kNumCoresPerReplicaAttr); + auto num_cores_per_replica_attr = cluster_func.getAttrOfType( + tensorflow::kNumCoresPerReplicaAttr); if (!num_cores_per_replica_attr) return cluster_func.emitOpError( - CreateMissingAttributeMsg(kNumCoresPerReplicaAttr)); + CreateMissingAttributeMsg(tensorflow::kNumCoresPerReplicaAttr)); int num_cores_per_replica = num_cores_per_replica_attr.getInt(); - auto topology_attr = cluster_func.getAttrOfType(kTopologyAttr); + auto topology_attr = + cluster_func.getAttrOfType(tensorflow::kTopologyAttr); if (!topology_attr) - return cluster_func.emitOpError(CreateMissingAttributeMsg(kTopologyAttr)); + return cluster_func.emitOpError( + CreateMissingAttributeMsg(tensorflow::kTopologyAttr)); - llvm::SmallVector device_assignment; - if (failed(GetDeviceCoordinates(cluster_func, &device_assignment))) - return failure(); + auto device_assignment_attr = cluster_func.getAttrOfType( + tensorflow::kDeviceAssignmentAttr); + if (!device_assignment_attr) + return cluster_func.emitOpError( + llvm::formatv("requires attribute '{0}'", + tensorflow::kDeviceAssignmentAttr) + .str()); + + auto status_or_device_coodinates = + tensorflow::GetDeviceCoordinates(device_assignment_attr); + if (!status_or_device_coodinates.ok()) + return cluster_func.emitError() + << "error in fetching tpu device coordinates: " + << status_or_device_coodinates.status().error_message(); // Determine compilation and execution devices. auto status_or_tpu_device_assignment = tensorflow::GetTPUCompilationAndExecutionDevices( devices, num_replicas, num_cores_per_replica, - topology_attr.getValue(), device_assignment); + topology_attr.getValue(), + status_or_device_coodinates.ConsumeValueOrDie()); if (!status_or_tpu_device_assignment.ok()) return cluster_func.emitError() << "error in fetching TPU compilation/execution devices: " diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index 06c10c26835..282b7ad3139 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -26,9 +26,9 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Attributes.h" // from @llvm-project #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -39,6 +39,12 @@ limitations under the License. #include "tensorflow/stream_executor/lib/statusor.h" namespace tensorflow { + +const char* const kTPUReplicatedHost = "TPU_REPLICATED_HOST"; +const char* const kNumCoresPerReplicaAttr = "num_cores_per_replica"; +const char* const kTopologyAttr = "topology"; +const char* const kDeviceAssignmentAttr = "device_assignment"; + // Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4 // topology. constexpr int kTPUTopologyRank = 4; @@ -46,8 +52,8 @@ constexpr int kTPUTopologyRank = 4; constexpr char kDeviceTPUSystem[] = "TPU_SYSTEM"; constexpr char kDeviceTPU[] = "TPU"; constexpr char kTPUReplicatedCore[] = "TPU_REPLICATED_CORE"; -constexpr char kTopologyAttr[] = "topology"; -constexpr char kDeviceAssignmentAttr[] = "device_assignment"; +constexpr char kBadIntArrayElementMsg[] = + "bad '{0}' attribute at index {1}, not an int"; using Device = DeviceNameUtils::ParsedName; using Devices = llvm::ArrayRef; @@ -417,6 +423,27 @@ GetGeneralTPUExecutionDeviceAssignment( } // anonymous namespace +StatusOr> GetDeviceCoordinates( + mlir::ArrayAttr device_assignment_attr) { + llvm::SmallVector device_coordinates; + device_coordinates.reserve(device_assignment_attr.size()); + + for (auto device_coordinate_and_idx : + llvm::enumerate(device_assignment_attr)) { + auto device_coordinate = + device_coordinate_and_idx.value().dyn_cast(); + if (!device_coordinate) + return errors::InvalidArgument( + llvm::formatv(kBadIntArrayElementMsg, kDeviceAssignmentAttr, + device_coordinate_and_idx.index()) + .str()); + + device_coordinates.push_back(device_coordinate.getInt()); + } + + return device_coordinates; +} + StatusOr GetTPUCompilationAndExecutionDevices( Devices devices, int num_replicas, int num_cores_per_replica, llvm::StringRef topology_attr, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h index 5fdb6b8768b..6bb541ab683 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/device_name_utils.h" @@ -30,6 +31,11 @@ limitations under the License. namespace tensorflow { using stream_executor::port::StatusOr; +extern const char* const kTPUReplicatedHost; +extern const char* const kNumCoresPerReplicaAttr; +extern const char* const kTopologyAttr; +extern const char* const kDeviceAssignmentAttr; + // A TPU device for execution alongside its associated host CPU device. struct TPUDeviceAndHost { TPUDeviceAndHost() {} @@ -67,6 +73,10 @@ struct TPUDeviceAssignment { llvm::Optional xla_device_assignment; }; +// Extracts device coordinates from a device assignment attribute on an op. +StatusOr> GetDeviceCoordinates( + mlir::ArrayAttr device_assignment_attr); + // Finds the TPU compilation device and execution devices from `devices` for a // TPU computation subgraph. Compilation device is determined from looking up // all TPU_SYSTEM:0 devices and choosing the CPU device associated to the first diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc index 7ac5635a6e4..a70e93a0195 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/tpu/topology.pb.h" @@ -596,5 +598,29 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) { EXPECT_EQ(computation_device_2.replica_device_ids(1), 3); } +TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) { + mlir::MLIRContext context; + mlir::Builder builder(&context); + auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3}); + auto status_or_device_coodinates = + GetDeviceCoordinates(device_assignment_attr); + ASSERT_TRUE(status_or_device_coodinates.ok()); + auto device_coordinates = status_or_device_coodinates.ConsumeValueOrDie(); + EXPECT_EQ(device_coordinates[0], 1); + EXPECT_EQ(device_coordinates[1], 2); + EXPECT_EQ(device_coordinates[2], 3); +} + +TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) { + mlir::MLIRContext context; + mlir::Builder builder(&context); + auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0}); + auto status_or_device_coodinates = + GetDeviceCoordinates(device_assignment_attr); + ASSERT_TRUE(!status_or_device_coodinates.ok()); + EXPECT_EQ(status_or_device_coodinates.status().error_message(), + "bad 'device_assignment' attribute at index 0, not an int"); +} + } // anonymous namespace } // namespace tensorflow