Move GetDeviceCoordinates() function and related constants in tpu_rewrite pass to common utility file.
PiperOrigin-RevId: 311795001 Change-Id: If86babf6656da132fb58b1a2266034f3b341e06d
This commit is contained in:
parent
27ac446be5
commit
340ac1aedb
|
@ -1279,6 +1279,7 @@ cc_library(
|
|||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1293,6 +1294,7 @@ tf_cc_test(
|
|||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -64,19 +64,14 @@ static llvm::cl::opt<bool> tpu_compile_metadata_debug(
|
|||
"'tf._TPUCompileMlir' op as a proto debug string"));
|
||||
|
||||
constexpr char kNumReplicasAttr[] = "num_replicas";
|
||||
constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica";
|
||||
constexpr char kStepMarkerLocationAttr[] = "step_marker_location";
|
||||
constexpr char kPaddingMapAttr[] = "padding_map";
|
||||
constexpr char kTopologyAttr[] = "topology";
|
||||
constexpr char kDeviceAssignmentAttr[] = "device_assignment";
|
||||
constexpr char kDeviceAttr[] = "device";
|
||||
constexpr char kDevicesAttr[] = "devices";
|
||||
constexpr char kVersionsAttr[] = "tf.versions";
|
||||
|
||||
constexpr char kBadStringArrayElementMsg[] =
|
||||
"bad '{0}' attribute at index {1}, not a string";
|
||||
constexpr char kBadIntArrayElementMsg[] =
|
||||
"bad '{0}' attribute at index {1}, not an int";
|
||||
constexpr char kBadArrayElementMsg[] =
|
||||
"bad '{0}' attribute at index {1} with value '{2}': failed to parse to {3}";
|
||||
constexpr char kBadArrayAttrLengthMsg[] =
|
||||
|
@ -163,32 +158,6 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
|
|||
return success();
|
||||
}
|
||||
|
||||
// Extracts device coordinates from a device assignment attribute on an op.
|
||||
LogicalResult GetDeviceCoordinates(
|
||||
tf_device::ClusterFuncOp op,
|
||||
llvm::SmallVectorImpl<int64_t>* device_assignment) {
|
||||
auto device_assignment_attr =
|
||||
op.getAttrOfType<ArrayAttr>(kDeviceAssignmentAttr);
|
||||
if (!device_assignment_attr)
|
||||
return op.emitOpError(CreateMissingAttributeMsg(kDeviceAssignmentAttr));
|
||||
|
||||
device_assignment->reserve(device_assignment_attr.size());
|
||||
|
||||
for (auto device_coordinate_and_idx :
|
||||
llvm::enumerate(device_assignment_attr)) {
|
||||
auto device_coordinate =
|
||||
device_coordinate_and_idx.value().dyn_cast<IntegerAttr>();
|
||||
if (!device_coordinate)
|
||||
return op.emitOpError(llvm::formatv(kBadIntArrayElementMsg,
|
||||
kDeviceAssignmentAttr,
|
||||
device_coordinate_and_idx.index()));
|
||||
|
||||
device_assignment->push_back(device_coordinate.getInt());
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Populates a TPUCompileMetadataProto with StepMarkerLocation from a
|
||||
// `tf_device::ClusterFuncOp`.
|
||||
LogicalResult SetMetadataProtoStepMarkerLocation(
|
||||
|
@ -661,27 +630,41 @@ LogicalResult Rewrite(
|
|||
: nullptr;
|
||||
if (replicate) num_replicas = replicate.n().getLimitedValue();
|
||||
|
||||
auto num_cores_per_replica_attr =
|
||||
cluster_func.getAttrOfType<IntegerAttr>(kNumCoresPerReplicaAttr);
|
||||
auto num_cores_per_replica_attr = cluster_func.getAttrOfType<IntegerAttr>(
|
||||
tensorflow::kNumCoresPerReplicaAttr);
|
||||
if (!num_cores_per_replica_attr)
|
||||
return cluster_func.emitOpError(
|
||||
CreateMissingAttributeMsg(kNumCoresPerReplicaAttr));
|
||||
CreateMissingAttributeMsg(tensorflow::kNumCoresPerReplicaAttr));
|
||||
|
||||
int num_cores_per_replica = num_cores_per_replica_attr.getInt();
|
||||
|
||||
auto topology_attr = cluster_func.getAttrOfType<StringAttr>(kTopologyAttr);
|
||||
auto topology_attr =
|
||||
cluster_func.getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
|
||||
if (!topology_attr)
|
||||
return cluster_func.emitOpError(CreateMissingAttributeMsg(kTopologyAttr));
|
||||
return cluster_func.emitOpError(
|
||||
CreateMissingAttributeMsg(tensorflow::kTopologyAttr));
|
||||
|
||||
llvm::SmallVector<int64_t, 6> device_assignment;
|
||||
if (failed(GetDeviceCoordinates(cluster_func, &device_assignment)))
|
||||
return failure();
|
||||
auto device_assignment_attr = cluster_func.getAttrOfType<mlir::ArrayAttr>(
|
||||
tensorflow::kDeviceAssignmentAttr);
|
||||
if (!device_assignment_attr)
|
||||
return cluster_func.emitOpError(
|
||||
llvm::formatv("requires attribute '{0}'",
|
||||
tensorflow::kDeviceAssignmentAttr)
|
||||
.str());
|
||||
|
||||
auto status_or_device_coodinates =
|
||||
tensorflow::GetDeviceCoordinates(device_assignment_attr);
|
||||
if (!status_or_device_coodinates.ok())
|
||||
return cluster_func.emitError()
|
||||
<< "error in fetching tpu device coordinates: "
|
||||
<< status_or_device_coodinates.status().error_message();
|
||||
|
||||
// Determine compilation and execution devices.
|
||||
auto status_or_tpu_device_assignment =
|
||||
tensorflow::GetTPUCompilationAndExecutionDevices(
|
||||
devices, num_replicas, num_cores_per_replica,
|
||||
topology_attr.getValue(), device_assignment);
|
||||
topology_attr.getValue(),
|
||||
status_or_device_coodinates.ConsumeValueOrDie());
|
||||
if (!status_or_tpu_device_assignment.ok())
|
||||
return cluster_func.emitError()
|
||||
<< "error in fetching TPU compilation/execution devices: "
|
||||
|
|
|
@ -26,9 +26,9 @@ limitations under the License.
|
|||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/array4d.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
@ -39,6 +39,12 @@ limitations under the License.
|
|||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char* const kTPUReplicatedHost = "TPU_REPLICATED_HOST";
|
||||
const char* const kNumCoresPerReplicaAttr = "num_cores_per_replica";
|
||||
const char* const kTopologyAttr = "topology";
|
||||
const char* const kDeviceAssignmentAttr = "device_assignment";
|
||||
|
||||
// Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4
|
||||
// topology.
|
||||
constexpr int kTPUTopologyRank = 4;
|
||||
|
@ -46,8 +52,8 @@ constexpr int kTPUTopologyRank = 4;
|
|||
constexpr char kDeviceTPUSystem[] = "TPU_SYSTEM";
|
||||
constexpr char kDeviceTPU[] = "TPU";
|
||||
constexpr char kTPUReplicatedCore[] = "TPU_REPLICATED_CORE";
|
||||
constexpr char kTopologyAttr[] = "topology";
|
||||
constexpr char kDeviceAssignmentAttr[] = "device_assignment";
|
||||
constexpr char kBadIntArrayElementMsg[] =
|
||||
"bad '{0}' attribute at index {1}, not an int";
|
||||
|
||||
using Device = DeviceNameUtils::ParsedName;
|
||||
using Devices = llvm::ArrayRef<DeviceNameUtils::ParsedName>;
|
||||
|
@ -417,6 +423,27 @@ GetGeneralTPUExecutionDeviceAssignment(
|
|||
|
||||
} // anonymous namespace
|
||||
|
||||
StatusOr<llvm::SmallVector<int64_t, 8>> GetDeviceCoordinates(
|
||||
mlir::ArrayAttr device_assignment_attr) {
|
||||
llvm::SmallVector<int64_t, 8> device_coordinates;
|
||||
device_coordinates.reserve(device_assignment_attr.size());
|
||||
|
||||
for (auto device_coordinate_and_idx :
|
||||
llvm::enumerate(device_assignment_attr)) {
|
||||
auto device_coordinate =
|
||||
device_coordinate_and_idx.value().dyn_cast<mlir::IntegerAttr>();
|
||||
if (!device_coordinate)
|
||||
return errors::InvalidArgument(
|
||||
llvm::formatv(kBadIntArrayElementMsg, kDeviceAssignmentAttr,
|
||||
device_coordinate_and_idx.index())
|
||||
.str());
|
||||
|
||||
device_coordinates.push_back(device_coordinate.getInt());
|
||||
}
|
||||
|
||||
return device_coordinates;
|
||||
}
|
||||
|
||||
StatusOr<TPUDeviceAssignment> GetTPUCompilationAndExecutionDevices(
|
||||
Devices devices, int num_replicas, int num_cores_per_replica,
|
||||
llvm::StringRef topology_attr,
|
||||
|
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
@ -30,6 +31,11 @@ limitations under the License.
|
|||
namespace tensorflow {
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
extern const char* const kTPUReplicatedHost;
|
||||
extern const char* const kNumCoresPerReplicaAttr;
|
||||
extern const char* const kTopologyAttr;
|
||||
extern const char* const kDeviceAssignmentAttr;
|
||||
|
||||
// A TPU device for execution alongside its associated host CPU device.
|
||||
struct TPUDeviceAndHost {
|
||||
TPUDeviceAndHost() {}
|
||||
|
@ -67,6 +73,10 @@ struct TPUDeviceAssignment {
|
|||
llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment;
|
||||
};
|
||||
|
||||
// Extracts device coordinates from a device assignment attribute on an op.
|
||||
StatusOr<llvm::SmallVector<int64_t, 8>> GetDeviceCoordinates(
|
||||
mlir::ArrayAttr device_assignment_attr);
|
||||
|
||||
// Finds the TPU compilation device and execution devices from `devices` for a
|
||||
// TPU computation subgraph. Compilation device is determined from looking up
|
||||
// all TPU_SYSTEM:0 devices and choosing the CPU device associated to the first
|
||||
|
|
|
@ -19,6 +19,8 @@ limitations under the License.
|
|||
#include <tuple>
|
||||
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/tpu/topology.pb.h"
|
||||
|
@ -596,5 +598,29 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) {
|
|||
EXPECT_EQ(computation_device_2.replica_device_ids(1), 3);
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::Builder builder(&context);
|
||||
auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3});
|
||||
auto status_or_device_coodinates =
|
||||
GetDeviceCoordinates(device_assignment_attr);
|
||||
ASSERT_TRUE(status_or_device_coodinates.ok());
|
||||
auto device_coordinates = status_or_device_coodinates.ConsumeValueOrDie();
|
||||
EXPECT_EQ(device_coordinates[0], 1);
|
||||
EXPECT_EQ(device_coordinates[1], 2);
|
||||
EXPECT_EQ(device_coordinates[2], 3);
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::Builder builder(&context);
|
||||
auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0});
|
||||
auto status_or_device_coodinates =
|
||||
GetDeviceCoordinates(device_assignment_attr);
|
||||
ASSERT_TRUE(!status_or_device_coodinates.ok());
|
||||
EXPECT_EQ(status_or_device_coodinates.status().error_message(),
|
||||
"bad 'device_assignment' attribute at index 0, not an int");
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
||||
|
|
Loading…
Reference in New Issue