diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 0bc5c808ceb..d4171f46202 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -836,6 +836,8 @@ cc_library( "//tensorflow/core/lib/core:status", "@com_google_absl//absl/strings", "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Support", ], ) @@ -849,6 +851,9 @@ tf_cc_test( "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:test", "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Support", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index 5805a0fd6e4..be3012415d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -22,8 +22,13 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/device_name_utils.h" @@ -150,6 +155,36 @@ std::string GetTPUCompilationDevice(Device system_device) { } // anonymous namespace +mlir::LogicalResult GetDevicesFromAttribute( + mlir::Operation* op, + llvm::SmallVectorImpl* devices) { + auto devices_attr = op->getAttr(kDevicesAttr); + if (!devices_attr) return mlir::success(); + + auto array_attr = devices_attr.dyn_cast(); + if (!array_attr) + return op->emitOpError( + llvm::formatv("bad '{0}' attribute, not an array", kDevicesAttr)); + + devices->resize(array_attr.size()); + for (auto attr_and_idx : llvm::enumerate(array_attr)) { + const int idx = attr_and_idx.index(); + auto string_attr = attr_and_idx.value().dyn_cast(); + if (!string_attr) + return op->emitOpError(llvm::formatv( + "bad '{0}' attribute at index {1}, not a string", kDevicesAttr, idx)); + + if (!DeviceNameUtils::ParseFullName(string_attr.getValue().str(), + &(*devices)[idx])) + return op->emitOpError( + llvm::formatv("bad '{0}' attribute at index {1} with value '{2}', " + "not a valid device", + kDevicesAttr, idx, string_attr.getValue())); + } + + return mlir::success(); +} + Status GetTPUCompilationAndExecutionDevices( Devices devices, int num_replicas, int num_cores_per_replica, std::string* compilation_device, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h index 66ae02b2c97..6aaf3fb3414 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h @@ -20,11 +20,22 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { +constexpr char kDevicesAttr[] = "tf.devices"; + +// Collects devices as DeviceNameUtils::ParsedName from an op `tf.devices` +// attribute. A failure will be returned if the attribute is not an +// ArrayAttr or the devices are invalid. +mlir::LogicalResult GetDevicesFromAttribute( + mlir::Operation* op, + llvm::SmallVectorImpl* devices); + // Finds the TPU compilation device and execution devices from `devices` for a // replicated TPU computation subgraph. Compilation device is determined from // looking up all TPU_SYSTEM:0 devices and choosing the CPU device associated diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc index bda8c860f8a..c7d68d71786 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc @@ -21,6 +21,11 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/device_name_utils.h" @@ -142,5 +147,68 @@ TEST(TPURewriteDeviceUtilTest, NumReplicasNumTPUs) { EXPECT_EQ(execution_devices[3], "/job:worker/replica:0/task:1/device:TPU:1"); } +TEST(TPURewriteDeviceUtilTest, NoDevicesAttribute) { + mlir::MLIRContext context; + mlir::OwningModuleRef module_ref = + mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + + llvm::SmallVector devices; + EXPECT_TRUE( + mlir::succeeded(GetDevicesFromAttribute(module_ref.get(), &devices))); +} + +TEST(TPURewriteDeviceUtilTest, BadDevicesAttributeType) { + mlir::MLIRContext context; + mlir::OwningModuleRef module_ref = + mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + mlir::Builder builder(module_ref.get()); + module_ref->setAttr("tf.devices", builder.getBoolAttr(false)); + + llvm::SmallVector devices; + EXPECT_TRUE( + mlir::failed(GetDevicesFromAttribute(module_ref.get(), &devices))); +} + +TEST(TPURewriteDeviceUtilTest, BadDevicesAttributeArraySubtype) { + mlir::MLIRContext context; + mlir::OwningModuleRef module_ref = + mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + mlir::Builder builder(module_ref.get()); + module_ref->setAttr("tf.devices", builder.getI32ArrayAttr({8})); + + llvm::SmallVector devices; + EXPECT_TRUE( + mlir::failed(GetDevicesFromAttribute(module_ref.get(), &devices))); +} + +TEST(TPURewriteDeviceUtilTest, BadDevicesInDevicesAttribute) { + mlir::MLIRContext context; + mlir::OwningModuleRef module_ref = + mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + mlir::Builder builder(module_ref.get()); + module_ref->setAttr("tf.devices", builder.getStrArrayAttr({"bad_device"})); + + llvm::SmallVector devices; + EXPECT_TRUE( + mlir::failed(GetDevicesFromAttribute(module_ref.get(), &devices))); +} + +TEST(TPURewriteDeviceUtilTest, ValidDeviceInDevicesAttribute) { + mlir::MLIRContext context; + mlir::OwningModuleRef module_ref = + mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + mlir::Builder builder(module_ref.get()); + module_ref->setAttr( + "tf.devices", + builder.getStrArrayAttr({"/job:worker/replica:0/task:0/device:CPU:0"})); + + llvm::SmallVector devices; + EXPECT_TRUE( + mlir::succeeded(GetDevicesFromAttribute(module_ref.get(), &devices))); + ASSERT_EQ(devices.size(), 1); + EXPECT_EQ(DeviceNameUtils::ParsedNameToString(devices[0]), + "/job:worker/replica:0/task:0/device:CPU:0"); +} + } // anonymous namespace } // namespace tensorflow