Move GetDeviceCoordinates() function and related constants in tpu_rewrite pass to common utility file.

PiperOrigin-RevId: 311795001
Change-Id: If86babf6656da132fb58b1a2266034f3b341e06d
This commit is contained in:
A. Unique TensorFlower 2020-05-15 13:49:39 -07:00 committed by TensorFlower Gardener
parent 27ac446be5
commit 340ac1aedb
5 changed files with 91 additions and 43 deletions

View File

@ -1279,6 +1279,7 @@ cc_library(
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
], ],
) )
@ -1293,6 +1294,7 @@ tf_cc_test(
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/protobuf/tpu:topology_proto_cc", "//tensorflow/core/protobuf/tpu:topology_proto_cc",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
], ],
) )

View File

@ -64,19 +64,14 @@ static llvm::cl::opt<bool> tpu_compile_metadata_debug(
"'tf._TPUCompileMlir' op as a proto debug string")); "'tf._TPUCompileMlir' op as a proto debug string"));
constexpr char kNumReplicasAttr[] = "num_replicas"; constexpr char kNumReplicasAttr[] = "num_replicas";
constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica";
constexpr char kStepMarkerLocationAttr[] = "step_marker_location"; constexpr char kStepMarkerLocationAttr[] = "step_marker_location";
constexpr char kPaddingMapAttr[] = "padding_map"; constexpr char kPaddingMapAttr[] = "padding_map";
constexpr char kTopologyAttr[] = "topology";
constexpr char kDeviceAssignmentAttr[] = "device_assignment";
constexpr char kDeviceAttr[] = "device"; constexpr char kDeviceAttr[] = "device";
constexpr char kDevicesAttr[] = "devices"; constexpr char kDevicesAttr[] = "devices";
constexpr char kVersionsAttr[] = "tf.versions"; constexpr char kVersionsAttr[] = "tf.versions";
constexpr char kBadStringArrayElementMsg[] = constexpr char kBadStringArrayElementMsg[] =
"bad '{0}' attribute at index {1}, not a string"; "bad '{0}' attribute at index {1}, not a string";
constexpr char kBadIntArrayElementMsg[] =
"bad '{0}' attribute at index {1}, not an int";
constexpr char kBadArrayElementMsg[] = constexpr char kBadArrayElementMsg[] =
"bad '{0}' attribute at index {1} with value '{2}': failed to parse to {3}"; "bad '{0}' attribute at index {1} with value '{2}': failed to parse to {3}";
constexpr char kBadArrayAttrLengthMsg[] = constexpr char kBadArrayAttrLengthMsg[] =
@ -163,32 +158,6 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
return success(); return success();
} }
// Extracts device coordinates from a device assignment attribute on an op.
LogicalResult GetDeviceCoordinates(
tf_device::ClusterFuncOp op,
llvm::SmallVectorImpl<int64_t>* device_assignment) {
auto device_assignment_attr =
op.getAttrOfType<ArrayAttr>(kDeviceAssignmentAttr);
if (!device_assignment_attr)
return op.emitOpError(CreateMissingAttributeMsg(kDeviceAssignmentAttr));
device_assignment->reserve(device_assignment_attr.size());
for (auto device_coordinate_and_idx :
llvm::enumerate(device_assignment_attr)) {
auto device_coordinate =
device_coordinate_and_idx.value().dyn_cast<IntegerAttr>();
if (!device_coordinate)
return op.emitOpError(llvm::formatv(kBadIntArrayElementMsg,
kDeviceAssignmentAttr,
device_coordinate_and_idx.index()));
device_assignment->push_back(device_coordinate.getInt());
}
return success();
}
// Populates a TPUCompileMetadataProto with StepMarkerLocation from a // Populates a TPUCompileMetadataProto with StepMarkerLocation from a
// `tf_device::ClusterFuncOp`. // `tf_device::ClusterFuncOp`.
LogicalResult SetMetadataProtoStepMarkerLocation( LogicalResult SetMetadataProtoStepMarkerLocation(
@ -661,27 +630,41 @@ LogicalResult Rewrite(
: nullptr; : nullptr;
if (replicate) num_replicas = replicate.n().getLimitedValue(); if (replicate) num_replicas = replicate.n().getLimitedValue();
auto num_cores_per_replica_attr = auto num_cores_per_replica_attr = cluster_func.getAttrOfType<IntegerAttr>(
cluster_func.getAttrOfType<IntegerAttr>(kNumCoresPerReplicaAttr); tensorflow::kNumCoresPerReplicaAttr);
if (!num_cores_per_replica_attr) if (!num_cores_per_replica_attr)
return cluster_func.emitOpError( return cluster_func.emitOpError(
CreateMissingAttributeMsg(kNumCoresPerReplicaAttr)); CreateMissingAttributeMsg(tensorflow::kNumCoresPerReplicaAttr));
int num_cores_per_replica = num_cores_per_replica_attr.getInt(); int num_cores_per_replica = num_cores_per_replica_attr.getInt();
auto topology_attr = cluster_func.getAttrOfType<StringAttr>(kTopologyAttr); auto topology_attr =
cluster_func.getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
if (!topology_attr) if (!topology_attr)
return cluster_func.emitOpError(CreateMissingAttributeMsg(kTopologyAttr)); return cluster_func.emitOpError(
CreateMissingAttributeMsg(tensorflow::kTopologyAttr));
llvm::SmallVector<int64_t, 6> device_assignment; auto device_assignment_attr = cluster_func.getAttrOfType<mlir::ArrayAttr>(
if (failed(GetDeviceCoordinates(cluster_func, &device_assignment))) tensorflow::kDeviceAssignmentAttr);
return failure(); if (!device_assignment_attr)
return cluster_func.emitOpError(
llvm::formatv("requires attribute '{0}'",
tensorflow::kDeviceAssignmentAttr)
.str());
auto status_or_device_coodinates =
tensorflow::GetDeviceCoordinates(device_assignment_attr);
if (!status_or_device_coodinates.ok())
return cluster_func.emitError()
<< "error in fetching tpu device coordinates: "
<< status_or_device_coodinates.status().error_message();
// Determine compilation and execution devices. // Determine compilation and execution devices.
auto status_or_tpu_device_assignment = auto status_or_tpu_device_assignment =
tensorflow::GetTPUCompilationAndExecutionDevices( tensorflow::GetTPUCompilationAndExecutionDevices(
devices, num_replicas, num_cores_per_replica, devices, num_replicas, num_cores_per_replica,
topology_attr.getValue(), device_assignment); topology_attr.getValue(),
status_or_device_coodinates.ConsumeValueOrDie());
if (!status_or_tpu_device_assignment.ok()) if (!status_or_tpu_device_assignment.ok())
return cluster_func.emitError() return cluster_func.emitError()
<< "error in fetching TPU compilation/execution devices: " << "error in fetching TPU compilation/execution devices: "

View File

@ -26,9 +26,9 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/iterator_range.h" #include "llvm/ADT/iterator_range.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
@ -39,6 +39,12 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow { namespace tensorflow {
const char* const kTPUReplicatedHost = "TPU_REPLICATED_HOST";
const char* const kNumCoresPerReplicaAttr = "num_cores_per_replica";
const char* const kTopologyAttr = "topology";
const char* const kDeviceAssignmentAttr = "device_assignment";
// Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4 // Device coordinates are defined as (x, y, z, core), thus resulting in a rank 4
// topology. // topology.
constexpr int kTPUTopologyRank = 4; constexpr int kTPUTopologyRank = 4;
@ -46,8 +52,8 @@ constexpr int kTPUTopologyRank = 4;
constexpr char kDeviceTPUSystem[] = "TPU_SYSTEM"; constexpr char kDeviceTPUSystem[] = "TPU_SYSTEM";
constexpr char kDeviceTPU[] = "TPU"; constexpr char kDeviceTPU[] = "TPU";
constexpr char kTPUReplicatedCore[] = "TPU_REPLICATED_CORE"; constexpr char kTPUReplicatedCore[] = "TPU_REPLICATED_CORE";
constexpr char kTopologyAttr[] = "topology"; constexpr char kBadIntArrayElementMsg[] =
constexpr char kDeviceAssignmentAttr[] = "device_assignment"; "bad '{0}' attribute at index {1}, not an int";
using Device = DeviceNameUtils::ParsedName; using Device = DeviceNameUtils::ParsedName;
using Devices = llvm::ArrayRef<DeviceNameUtils::ParsedName>; using Devices = llvm::ArrayRef<DeviceNameUtils::ParsedName>;
@ -417,6 +423,27 @@ GetGeneralTPUExecutionDeviceAssignment(
} // anonymous namespace } // anonymous namespace
StatusOr<llvm::SmallVector<int64_t, 8>> GetDeviceCoordinates(
mlir::ArrayAttr device_assignment_attr) {
llvm::SmallVector<int64_t, 8> device_coordinates;
device_coordinates.reserve(device_assignment_attr.size());
for (auto device_coordinate_and_idx :
llvm::enumerate(device_assignment_attr)) {
auto device_coordinate =
device_coordinate_and_idx.value().dyn_cast<mlir::IntegerAttr>();
if (!device_coordinate)
return errors::InvalidArgument(
llvm::formatv(kBadIntArrayElementMsg, kDeviceAssignmentAttr,
device_coordinate_and_idx.index())
.str());
device_coordinates.push_back(device_coordinate.getInt());
}
return device_coordinates;
}
StatusOr<TPUDeviceAssignment> GetTPUCompilationAndExecutionDevices( StatusOr<TPUDeviceAssignment> GetTPUCompilationAndExecutionDevices(
Devices devices, int num_replicas, int num_cores_per_replica, Devices devices, int num_replicas, int num_cores_per_replica,
llvm::StringRef topology_attr, llvm::StringRef topology_attr,

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "llvm/ADT/Optional.h" #include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/device_name_utils.h"
@ -30,6 +31,11 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
using stream_executor::port::StatusOr; using stream_executor::port::StatusOr;
extern const char* const kTPUReplicatedHost;
extern const char* const kNumCoresPerReplicaAttr;
extern const char* const kTopologyAttr;
extern const char* const kDeviceAssignmentAttr;
// A TPU device for execution alongside its associated host CPU device. // A TPU device for execution alongside its associated host CPU device.
struct TPUDeviceAndHost { struct TPUDeviceAndHost {
TPUDeviceAndHost() {} TPUDeviceAndHost() {}
@ -67,6 +73,10 @@ struct TPUDeviceAssignment {
llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment; llvm::Optional<xla::DeviceAssignmentProto> xla_device_assignment;
}; };
// Extracts device coordinates from a device assignment attribute on an op.
StatusOr<llvm::SmallVector<int64_t, 8>> GetDeviceCoordinates(
mlir::ArrayAttr device_assignment_attr);
// Finds the TPU compilation device and execution devices from `devices` for a // Finds the TPU compilation device and execution devices from `devices` for a
// TPU computation subgraph. Compilation device is determined from looking up // TPU computation subgraph. Compilation device is determined from looking up
// all TPU_SYSTEM:0 devices and choosing the CPU device associated to the first // all TPU_SYSTEM:0 devices and choosing the CPU device associated to the first

View File

@ -19,6 +19,8 @@ limitations under the License.
#include <tuple> #include <tuple>
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/tpu/topology.pb.h" #include "tensorflow/core/protobuf/tpu/topology.pb.h"
@ -596,5 +598,29 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) {
EXPECT_EQ(computation_device_2.replica_device_ids(1), 3); EXPECT_EQ(computation_device_2.replica_device_ids(1), 3);
} }
TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
mlir::MLIRContext context;
mlir::Builder builder(&context);
auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3});
auto status_or_device_coodinates =
GetDeviceCoordinates(device_assignment_attr);
ASSERT_TRUE(status_or_device_coodinates.ok());
auto device_coordinates = status_or_device_coodinates.ConsumeValueOrDie();
EXPECT_EQ(device_coordinates[0], 1);
EXPECT_EQ(device_coordinates[1], 2);
EXPECT_EQ(device_coordinates[2], 3);
}
TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
mlir::MLIRContext context;
mlir::Builder builder(&context);
auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0});
auto status_or_device_coodinates =
GetDeviceCoordinates(device_assignment_attr);
ASSERT_TRUE(!status_or_device_coodinates.ok());
EXPECT_EQ(status_or_device_coodinates.status().error_message(),
"bad 'device_assignment' attribute at index 0, not an int");
}
} // anonymous namespace } // anonymous namespace
} // namespace tensorflow } // namespace tensorflow