Add util to extract device strings from an ops tf.device attribute.

Devices are currently stored as an ArrayAttr<StringAttr> and if valid can be parsed into a DeviceNameUtils::ParsedName.

PiperOrigin-RevId: 272992538
This commit is contained in:
Andy Ly 2019-10-04 17:44:41 -07:00 committed by TensorFlower Gardener
parent c97f235ee8
commit b14e006f2b
4 changed files with 119 additions and 0 deletions

View File

@ -836,6 +836,8 @@ cc_library(
"//tensorflow/core/lib/core:status",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Support",
],
)
@ -849,6 +851,9 @@ tf_cc_test(
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:test",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Support",
],
)

View File

@ -22,8 +22,13 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/device_name_utils.h"
@ -150,6 +155,36 @@ std::string GetTPUCompilationDevice(Device system_device) {
} // anonymous namespace
mlir::LogicalResult GetDevicesFromAttribute(
mlir::Operation* op,
llvm::SmallVectorImpl<DeviceNameUtils::ParsedName>* devices) {
auto devices_attr = op->getAttr(kDevicesAttr);
if (!devices_attr) return mlir::success();
auto array_attr = devices_attr.dyn_cast<mlir::ArrayAttr>();
if (!array_attr)
return op->emitOpError(
llvm::formatv("bad '{0}' attribute, not an array", kDevicesAttr));
devices->resize(array_attr.size());
for (auto attr_and_idx : llvm::enumerate(array_attr)) {
const int idx = attr_and_idx.index();
auto string_attr = attr_and_idx.value().dyn_cast<mlir::StringAttr>();
if (!string_attr)
return op->emitOpError(llvm::formatv(
"bad '{0}' attribute at index {1}, not a string", kDevicesAttr, idx));
if (!DeviceNameUtils::ParseFullName(string_attr.getValue().str(),
&(*devices)[idx]))
return op->emitOpError(
llvm::formatv("bad '{0}' attribute at index {1} with value '{2}', "
"not a valid device",
kDevicesAttr, idx, string_attr.getValue()));
}
return mlir::success();
}
Status GetTPUCompilationAndExecutionDevices(
Devices devices, int num_replicas, int num_cores_per_replica,
std::string* compilation_device,

View File

@ -20,11 +20,22 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
constexpr char kDevicesAttr[] = "tf.devices";
// Collects devices as DeviceNameUtils::ParsedName from an op `tf.devices`
// attribute. A failure will be returned if the attribute is not an
// ArrayAttr<StringAttr> or the devices are invalid.
mlir::LogicalResult GetDevicesFromAttribute(
mlir::Operation* op,
llvm::SmallVectorImpl<DeviceNameUtils::ParsedName>* devices);
// Finds the TPU compilation device and execution devices from `devices` for a
// replicated TPU computation subgraph. Compilation device is determined from
// looking up all TPU_SYSTEM:0 devices and choosing the CPU device associated

View File

@ -21,6 +21,11 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/device_name_utils.h"
@ -142,5 +147,68 @@ TEST(TPURewriteDeviceUtilTest, NumReplicasNumTPUs) {
EXPECT_EQ(execution_devices[3], "/job:worker/replica:0/task:1/device:TPU:1");
}
TEST(TPURewriteDeviceUtilTest, NoDevicesAttribute) {
mlir::MLIRContext context;
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
llvm::SmallVector<Device, 8> devices;
EXPECT_TRUE(
mlir::succeeded(GetDevicesFromAttribute(module_ref.get(), &devices)));
}
TEST(TPURewriteDeviceUtilTest, BadDevicesAttributeType) {
mlir::MLIRContext context;
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::Builder builder(module_ref.get());
module_ref->setAttr("tf.devices", builder.getBoolAttr(false));
llvm::SmallVector<Device, 8> devices;
EXPECT_TRUE(
mlir::failed(GetDevicesFromAttribute(module_ref.get(), &devices)));
}
TEST(TPURewriteDeviceUtilTest, BadDevicesAttributeArraySubtype) {
mlir::MLIRContext context;
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::Builder builder(module_ref.get());
module_ref->setAttr("tf.devices", builder.getI32ArrayAttr({8}));
llvm::SmallVector<Device, 8> devices;
EXPECT_TRUE(
mlir::failed(GetDevicesFromAttribute(module_ref.get(), &devices)));
}
TEST(TPURewriteDeviceUtilTest, BadDevicesInDevicesAttribute) {
mlir::MLIRContext context;
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::Builder builder(module_ref.get());
module_ref->setAttr("tf.devices", builder.getStrArrayAttr({"bad_device"}));
llvm::SmallVector<Device, 8> devices;
EXPECT_TRUE(
mlir::failed(GetDevicesFromAttribute(module_ref.get(), &devices)));
}
TEST(TPURewriteDeviceUtilTest, ValidDeviceInDevicesAttribute) {
mlir::MLIRContext context;
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::Builder builder(module_ref.get());
module_ref->setAttr(
"tf.devices",
builder.getStrArrayAttr({"/job:worker/replica:0/task:0/device:CPU:0"}));
llvm::SmallVector<Device, 8> devices;
EXPECT_TRUE(
mlir::succeeded(GetDevicesFromAttribute(module_ref.get(), &devices)));
ASSERT_EQ(devices.size(), 1);
EXPECT_EQ(DeviceNameUtils::ParsedNameToString(devices[0]),
"/job:worker/replica:0/task:0/device:CPU:0");
}
} // anonymous namespace
} // namespace tensorflow