Add util to extract device strings from an ops tf.device
attribute.
Devices are currently stored as an ArrayAttr<StringAttr> and if valid can be parsed into a DeviceNameUtils::ParsedName. PiperOrigin-RevId: 272992538
This commit is contained in:
parent
c97f235ee8
commit
b14e006f2b
@ -836,6 +836,8 @@ cc_library(
|
||||
"//tensorflow/core/lib/core:status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -849,6 +851,9 @@ tf_cc_test(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:test",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
)
|
||||
|
@ -22,8 +22,13 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
@ -150,6 +155,36 @@ std::string GetTPUCompilationDevice(Device system_device) {
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
mlir::LogicalResult GetDevicesFromAttribute(
|
||||
mlir::Operation* op,
|
||||
llvm::SmallVectorImpl<DeviceNameUtils::ParsedName>* devices) {
|
||||
auto devices_attr = op->getAttr(kDevicesAttr);
|
||||
if (!devices_attr) return mlir::success();
|
||||
|
||||
auto array_attr = devices_attr.dyn_cast<mlir::ArrayAttr>();
|
||||
if (!array_attr)
|
||||
return op->emitOpError(
|
||||
llvm::formatv("bad '{0}' attribute, not an array", kDevicesAttr));
|
||||
|
||||
devices->resize(array_attr.size());
|
||||
for (auto attr_and_idx : llvm::enumerate(array_attr)) {
|
||||
const int idx = attr_and_idx.index();
|
||||
auto string_attr = attr_and_idx.value().dyn_cast<mlir::StringAttr>();
|
||||
if (!string_attr)
|
||||
return op->emitOpError(llvm::formatv(
|
||||
"bad '{0}' attribute at index {1}, not a string", kDevicesAttr, idx));
|
||||
|
||||
if (!DeviceNameUtils::ParseFullName(string_attr.getValue().str(),
|
||||
&(*devices)[idx]))
|
||||
return op->emitOpError(
|
||||
llvm::formatv("bad '{0}' attribute at index {1} with value '{2}', "
|
||||
"not a valid device",
|
||||
kDevicesAttr, idx, string_attr.getValue()));
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
Status GetTPUCompilationAndExecutionDevices(
|
||||
Devices devices, int num_replicas, int num_cores_per_replica,
|
||||
std::string* compilation_device,
|
||||
|
@ -20,11 +20,22 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
constexpr char kDevicesAttr[] = "tf.devices";
|
||||
|
||||
// Collects devices as DeviceNameUtils::ParsedName from an op `tf.devices`
|
||||
// attribute. A failure will be returned if the attribute is not an
|
||||
// ArrayAttr<StringAttr> or the devices are invalid.
|
||||
mlir::LogicalResult GetDevicesFromAttribute(
|
||||
mlir::Operation* op,
|
||||
llvm::SmallVectorImpl<DeviceNameUtils::ParsedName>* devices);
|
||||
|
||||
// Finds the TPU compilation device and execution devices from `devices` for a
|
||||
// replicated TPU computation subgraph. Compilation device is determined from
|
||||
// looking up all TPU_SYSTEM:0 devices and choosing the CPU device associated
|
||||
|
@ -21,6 +21,11 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
@ -142,5 +147,68 @@ TEST(TPURewriteDeviceUtilTest, NumReplicasNumTPUs) {
|
||||
EXPECT_EQ(execution_devices[3], "/job:worker/replica:0/task:1/device:TPU:1");
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, NoDevicesAttribute) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
|
||||
llvm::SmallVector<Device, 8> devices;
|
||||
EXPECT_TRUE(
|
||||
mlir::succeeded(GetDevicesFromAttribute(module_ref.get(), &devices)));
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, BadDevicesAttributeType) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::Builder builder(module_ref.get());
|
||||
module_ref->setAttr("tf.devices", builder.getBoolAttr(false));
|
||||
|
||||
llvm::SmallVector<Device, 8> devices;
|
||||
EXPECT_TRUE(
|
||||
mlir::failed(GetDevicesFromAttribute(module_ref.get(), &devices)));
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, BadDevicesAttributeArraySubtype) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::Builder builder(module_ref.get());
|
||||
module_ref->setAttr("tf.devices", builder.getI32ArrayAttr({8}));
|
||||
|
||||
llvm::SmallVector<Device, 8> devices;
|
||||
EXPECT_TRUE(
|
||||
mlir::failed(GetDevicesFromAttribute(module_ref.get(), &devices)));
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, BadDevicesInDevicesAttribute) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::Builder builder(module_ref.get());
|
||||
module_ref->setAttr("tf.devices", builder.getStrArrayAttr({"bad_device"}));
|
||||
|
||||
llvm::SmallVector<Device, 8> devices;
|
||||
EXPECT_TRUE(
|
||||
mlir::failed(GetDevicesFromAttribute(module_ref.get(), &devices)));
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, ValidDeviceInDevicesAttribute) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::Builder builder(module_ref.get());
|
||||
module_ref->setAttr(
|
||||
"tf.devices",
|
||||
builder.getStrArrayAttr({"/job:worker/replica:0/task:0/device:CPU:0"}));
|
||||
|
||||
llvm::SmallVector<Device, 8> devices;
|
||||
EXPECT_TRUE(
|
||||
mlir::succeeded(GetDevicesFromAttribute(module_ref.get(), &devices)));
|
||||
ASSERT_EQ(devices.size(), 1);
|
||||
EXPECT_EQ(DeviceNameUtils::ParsedNameToString(devices[0]),
|
||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user