Move GetDeviceCoordinates() function and related constants in tpu_rewrite pass to common utility file.
PiperOrigin-RevId: 311795001 Change-Id: If86babf6656da132fb58b1a2266034f3b341e06d
This commit is contained in:
parent
27ac446be5
commit
340ac1aedb
|
@ -1279,6 +1279,7 @@ cc_library(
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1293,6 +1294,7 @@ tf_cc_test(
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
|
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -64,19 +64,14 @@ static llvm::cl::opt<bool> tpu_compile_metadata_debug(
|
||||||
"'tf._TPUCompileMlir' op as a proto debug string"));
|
"'tf._TPUCompileMlir' op as a proto debug string"));
|
||||||
|
|
||||||
constexpr char kNumReplicasAttr[] = "num_replicas";
|
constexpr char kNumReplicasAttr[] = "num_replicas";
|
||||||
constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica";
|
|
||||||
constexpr char kStepMarkerLocationAttr[] = "step_marker_location";
|
constexpr char kStepMarkerLocationAttr[] = "step_marker_location";
|
||||||
constexpr char kPaddingMapAttr[] = "padding_map";
|
constexpr char kPaddingMapAttr[] = "padding_map";
|
||||||
constexpr char kTopologyAttr[] = "topology";
|
|
||||||
constexpr char kDeviceAssignmentAttr[] = "device_assignment";
|
|
||||||
constexpr char kDeviceAttr[] = "device";
|
constexpr char kDeviceAttr[] = "device";
|
||||||
constexpr char kDevicesAttr[] = "devices";
|
constexpr char kDevicesAttr[] = "devices";
|
||||||
constexpr char kVersionsAttr[] = "tf.versions";
|
constexpr char kVersionsAttr[] = "tf.versions";
|
||||||
|
|
||||||
constexpr char kBadStringArrayElementMsg[] =
|
constexpr char kBadStringArrayElementMsg[] =
|
||||||
"bad '{0}' attribute at index {1}, not a string";
|
"bad '{0}' attribute at index {1}, not a string";
|
||||||
constexpr char kBadIntArrayElementMsg[] =
|
|
||||||
"bad '{0}' attribute at index {1}, not an int";
|
|
||||||
constexpr char kBadArrayElementMsg[] =
|
constexpr char kBadArrayElementMsg[] =
|
||||||
"bad '{0}' attribute at index {1} with value '{2}': failed to parse to {3}";
|
"bad '{0}' attribute at index {1} with value '{2}': failed to parse to {3}";
|
||||||
constexpr char kBadArrayAttrLengthMsg[] =
|
constexpr char kBadArrayAttrLengthMsg[] =
|
||||||
|
@ -163,32 +158,6 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extracts device coordinates from a device assignment attribute on an op.
|
|
||||||
LogicalResult GetDeviceCoordinates(
|
|
||||||
tf_device::ClusterFuncOp op,
|
|
||||||
llvm::SmallVectorImpl<int64_t>* device_assignment) {
|
|
||||||
auto device_assignment_attr =
|
|
||||||
op.getAttrOfType<ArrayAttr>(kDeviceAssignmentAttr);
|
|
||||||
if (!device_assignment_attr)
|
|
||||||
return op.emitOpError(CreateMissingAttributeMsg(kDeviceAssignmentAttr));
|
|
||||||
|
|
||||||
device_assignment->reserve(device_assignment_attr.size());
|
|
||||||
|
|
||||||
for (auto device_coordinate_and_idx :
|
|
||||||
llvm::enumerate(device_assignment_attr)) {
|
|
||||||
auto device_coordinate =
|
|
||||||
device_coordinate_and_idx.value().dyn_cast<IntegerAttr>();
|
|
||||||
if (!device_coordinate)
|
|
||||||
return op.emitOpError(llvm::formatv(kBadIntArrayElementMsg,
|
|
||||||
kDeviceAssignmentAttr,
|
|
||||||
device_coordinate_and_idx.index()));
|
|
||||||
|
|
||||||
device_assignment->push_back(device_coordinate.getInt());
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Populates a TPUCompileMetadataProto with StepMarkerLocation from a
|
// Populates a TPUCompileMetadataProto with StepMarkerLocation from a
|
||||||
// `tf_device::ClusterFuncOp`.
|
// `tf_device::ClusterFuncOp`.
|
||||||
LogicalResult SetMetadataProtoStepMarkerLocation(
|
LogicalResult SetMetadataProtoStepMarkerLocation(
|
||||||
|
@ -661,27 +630,41 @@ LogicalResult Rewrite(
|
||||||
: nullptr;
|
: nullptr;
|
||||||
if (replicate) num_replicas = replicate.n().getLimitedValue();
|
if (replicate) num_replicas = replicate.n().getLimitedValue();
|
||||||
|
|
||||||
auto num_cores_per_replica_attr =
|
auto num_cores_per_replica_attr = cluster_func.getAttrOfType<IntegerAttr>(
|
||||||
cluster_func.getAttrOfType<IntegerAttr>(kNumCoresPerReplicaAttr);
|
tensorflow::kNumCoresPerReplicaAttr);
|
||||||
if (!num_cores_per_replica_attr)
|
if (!num_cores_per_replica_attr)
|
||||||
return cluster_func.emitOpError(
|
return cluster_func.emitOpError(
|
||||||
CreateMissingAttributeMsg(kNumCoresPerReplicaAttr));
|
CreateMissingAttributeMsg(tensorflow::kNumCoresPerReplicaAttr));
|
||||||
|
|
||||||
int num_cores_per_replica = num_cores_per_replica_attr.getInt();
|
int num_cores_per_replica = num_cores_per_replica_attr.getInt();
|
||||||
|
|
||||||
auto topology_attr = cluster_func.getAttrOfType<StringAttr>(kTopologyAttr);
|
auto topology_attr =
|
||||||
|
cluster_func.getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
|
||||||
if (!topology_attr)
|
if (!topology_attr)
|
||||||
return cluster_func.emitOpError(CreateMissingAttributeMsg(kTopologyAttr));
|
return cluster_func.emitOpError(
|
||||||
|
CreateMissingAttributeMsg(tensorflow::kTopologyAttr));
|
||||||
|
|
||||||
llvm::SmallVector<int64_t, 6> device_assignment;
|
auto device_assignment_attr = cluster_func.getAttrOfType<mlir::ArrayAttr>(
|
||||||
if (failed(GetDeviceCoordinates(cluster_func, &device_assignment)))
|
tensorflow::kDeviceAssignmentAttr);
|
||||||
return failure();
|
if (!device_assignment_attr)
|
||||||
|
return cluster_func.emitOpError(
|
||||||
|
llvm::formatv("requires attribute '{0}'",
|
||||||
|
tensorflow::kDeviceAssignmentAttr)
|
||||||
|
.str());
|
||||||
|
|
||||||
|
auto status_or_device_coodinates =
|
||||||
|
tensorflow::GetDeviceCoordinates(device_assignment_attr);
|
||||||
|
if (!status_or_device_coodinates.ok())
|
||||||
|
return cluster_func.emitError()
|
||||||
|
<< "error in fetching tpu device coordinates: "
|
||||||
|
<< status_or_device_coodinates.status().error_message();
|
||||||
|
|
||||||
// Determine compilation and execution devices.
|
// Determine compilation and execution devices.
|
||||||
auto status_or_tpu_device_assignment =
|
auto status_or_tpu_device_assignment =
|
||||||
tensorflow::GetTPUCompilationAndExecutionDevices(
|
tensorflow::GetTPUCompilationAndExecutionDevices(
|
||||||
devices, num_replicas, num_cores_per_replica,
|
devices, num_replicas, num_cores_per_replica,
|
||||||
topology_attr.getValue(), device_assignment);
|
topology_attr.getValue(),
|
||||||
|
status_or_device_coodinates.ConsumeValueOrDie());
|
||||||
if (!status_or_tpu_device_assignment.ok())
|
if (!status_or_tpu_device_assignment.ok())
|
||||||
return cluster_func.emitError()
|
return cluster_func.emitError()
|
||||||
<< "error in fetching TPU compilation/execution devices: "
|
<< "error in fetching TPU compilation/execution devices: "
|
||||||
|
|
|
@ -26,9 +26,9 @@ limitations under the License.
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
|
||||||
#include "llvm/ADT/iterator_range.h"
|
#include "llvm/ADT/iterator_range.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/xla/array4d.h"
|
#include "tensorflow/compiler/xla/array4d.h"
|
||||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
@ -39,6 +39,12 @@ limitations under the License.
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
const char* const kTPUReplicatedHost = "TPU_REPLICATED_HOST";
|
||||||
|
const char* const kNumCoresPerReplicaAttr = "num_cores_per_replica";
|
||||||
|
const char* const kTopologyAttr = "topology";
|
||||||
|
const char* const kDeviceAssignmentAttr = "device_assignment";
|
||||||
|
|
||||||
// Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4
|
// Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4
|
||||||
// topology.
|
// topology.
|
||||||
constexpr int kTPUTopologyRank = 4;
|
constexpr int kTPUTopologyRank = 4;
|
||||||
|
@ -46,8 +52,8 @@ constexpr int kTPUTopologyRank = 4;
|
||||||
constexpr char kDeviceTPUSystem[] = "TPU_SYSTEM";
|
constexpr char kDeviceTPUSystem[] = "TPU_SYSTEM";
|
||||||
constexpr char kDeviceTPU[] = "TPU";
|
constexpr char kDeviceTPU[] = "TPU";
|
||||||
constexpr char kTPUReplicatedCore[] = "TPU_REPLICATED_CORE";
|
constexpr char kTPUReplicatedCore[] = "TPU_REPLICATED_CORE";
|
||||||
constexpr char kTopologyAttr[] = "topology";
|
constexpr char kBadIntArrayElementMsg[] =
|
||||||
constexpr char kDeviceAssignmentAttr[] = "device_assignment";
|
"bad '{0}' attribute at index {1}, not an int";
|
||||||
|
|
||||||
using Device = DeviceNameUtils::ParsedName;
|
using Device = DeviceNameUtils::ParsedName;
|
||||||
using Devices = llvm::ArrayRef<DeviceNameUtils::ParsedName>;
|
using Devices = llvm::ArrayRef<DeviceNameUtils::ParsedName>;
|
||||||
|
@ -417,6 +423,27 @@ GetGeneralTPUExecutionDeviceAssignment(
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
|
StatusOr<llvm::SmallVector<int64_t, 8>> GetDeviceCoordinates(
|
||||||
|
mlir::ArrayAttr device_assignment_attr) {
|
||||||
|
llvm::SmallVector<int64_t, 8> device_coordinates;
|
||||||
|
device_coordinates.reserve(device_assignment_attr.size());
|
||||||
|
|
||||||
|
for (auto device_coordinate_and_idx :
|
||||||
|
llvm::enumerate(device_assignment_attr)) {
|
||||||
|
auto device_coordinate =
|
||||||
|
device_coordinate_and_idx.value().dyn_cast<mlir::IntegerAttr>();
|
||||||
|
if (!device_coordinate)
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
llvm::formatv(kBadIntArrayElementMsg, kDeviceAssignmentAttr,
|
||||||
|
device_coordinate_and_idx.index())
|
||||||
|
.str());
|
||||||
|
|
||||||
|
device_coordinates.push_back(device_coordinate.getInt());
|
||||||
|
}
|
||||||
|
|
||||||
|
return device_coordinates;
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<TPUDeviceAssignment> GetTPUCompilationAndExecutionDevices(
|
StatusOr<TPUDeviceAssignment> GetTPUCompilationAndExecutionDevices(
|
||||||
Devices devices, int num_replicas, int num_cores_per_replica,
|
Devices devices, int num_replicas, int num_cores_per_replica,
|
||||||
llvm::StringRef topology_attr,
|
llvm::StringRef topology_attr,
|
||||||
|
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||||
#include "llvm/ADT/Optional.h"
|
#include "llvm/ADT/Optional.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/util/device_name_utils.h"
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
|
@ -30,6 +31,11 @@ limitations under the License.
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
using stream_executor::port::StatusOr;
|
using stream_executor::port::StatusOr;
|
||||||
|
|
||||||
|
extern const char* const kTPUReplicatedHost;
|
||||||
|
extern const char* const kNumCoresPerReplicaAttr;
|
||||||
|
extern const char* const kTopologyAttr;
|
||||||
|
extern const char* const kDeviceAssignmentAttr;
|
||||||
|
|
||||||
// A TPU device for execution alongside its associated host CPU device.
|
// A TPU device for execution alongside its associated host CPU device.
|
||||||
struct TPUDeviceAndHost {
|
struct TPUDeviceAndHost {
|
||||||
TPUDeviceAndHost() {}
|
TPUDeviceAndHost() {}
|
||||||
|
@ -67,6 +73,10 @@ struct TPUDeviceAssignment {
|
||||||
llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment;
|
llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Extracts device coordinates from a device assignment attribute on an op.
|
||||||
|
StatusOr<llvm::SmallVector<int64_t, 8>> GetDeviceCoordinates(
|
||||||
|
mlir::ArrayAttr device_assignment_attr);
|
||||||
|
|
||||||
// Finds the TPU compilation device and execution devices from `devices` for a
|
// Finds the TPU compilation device and execution devices from `devices` for a
|
||||||
// TPU computation subgraph. Compilation device is determined from looking up
|
// TPU computation subgraph. Compilation device is determined from looking up
|
||||||
// all TPU_SYSTEM:0 devices and choosing the CPU device associated to the first
|
// all TPU_SYSTEM:0 devices and choosing the CPU device associated to the first
|
||||||
|
|
|
@ -19,6 +19,8 @@ limitations under the License.
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/protobuf/tpu/topology.pb.h"
|
#include "tensorflow/core/protobuf/tpu/topology.pb.h"
|
||||||
|
@ -596,5 +598,29 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) {
|
||||||
EXPECT_EQ(computation_device_2.replica_device_ids(1), 3);
|
EXPECT_EQ(computation_device_2.replica_device_ids(1), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
mlir::Builder builder(&context);
|
||||||
|
auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3});
|
||||||
|
auto status_or_device_coodinates =
|
||||||
|
GetDeviceCoordinates(device_assignment_attr);
|
||||||
|
ASSERT_TRUE(status_or_device_coodinates.ok());
|
||||||
|
auto device_coordinates = status_or_device_coodinates.ConsumeValueOrDie();
|
||||||
|
EXPECT_EQ(device_coordinates[0], 1);
|
||||||
|
EXPECT_EQ(device_coordinates[1], 2);
|
||||||
|
EXPECT_EQ(device_coordinates[2], 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
mlir::Builder builder(&context);
|
||||||
|
auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0});
|
||||||
|
auto status_or_device_coodinates =
|
||||||
|
GetDeviceCoordinates(device_assignment_attr);
|
||||||
|
ASSERT_TRUE(!status_or_device_coodinates.ok());
|
||||||
|
EXPECT_EQ(status_or_device_coodinates.status().error_message(),
|
||||||
|
"bad 'device_assignment' attribute at index 0, not an int");
|
||||||
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
Loading…
Reference in New Issue