Create separate general device util for adding and getting device names (from DeviceSet) to and from an op (NFC).

This moves a function from tpu_rewrite_device_util for getting devices, and a function for adding devices from mlir_bridge_pass, into its own util.

PiperOrigin-RevId: 273610226
This commit is contained in:
Andy Ly 2019-10-08 14:48:50 -07:00 committed by TensorFlower Gardener
parent 18af93acca
commit 0aaf71779a
9 changed files with 310 additions and 150 deletions

View File

@ -834,8 +834,6 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Support",
],
)
@ -845,11 +843,39 @@ tf_cc_test(
srcs = ["utils/tpu_rewrite_device_util_test.cc"],
deps = [
":tpu_rewrite_device_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm//:support",
],
)
cc_library(
name = "device_util",
srcs = ["utils/device_util.cc"],
hdrs = ["utils/device_util.h"],
deps = [
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Support",
],
)
tf_cc_test(
name = "device_util_test",
size = "small",
srcs = ["utils/device_util_test.cc"],
deps = [
":device_util",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Support",
],

View File

@ -0,0 +1,82 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include <string>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
constexpr char kDevicesAttr[] = "tf.devices";
void AddDevicesToOp(mlir::Operation* op, const DeviceSet* device_set) {
if (!device_set) return;
// Collect devices as strings in TensorFlow device name form.
llvm::SmallVector<std::string, 8> devices;
devices.reserve(device_set->devices().size());
for (Device* device : device_set->devices())
devices.push_back(
DeviceNameUtils::ParsedNameToString(device->parsed_name()));
llvm::SmallVector<llvm::StringRef, 8> device_refs(devices.begin(),
devices.end());
mlir::Builder builder(op->getContext());
op->setAttr(kDevicesAttr, builder.getStrArrayAttr(device_refs));
}
mlir::LogicalResult GetDevicesFromOp(
mlir::Operation* op,
llvm::SmallVectorImpl<DeviceNameUtils::ParsedName>* devices) {
auto devices_attr = op->getAttr(kDevicesAttr);
if (!devices_attr) return mlir::success();
auto array_attr = devices_attr.dyn_cast<mlir::ArrayAttr>();
if (!array_attr)
return op->emitOpError(
llvm::formatv("bad '{0}' attribute, not an array", kDevicesAttr));
devices->resize(array_attr.size());
for (auto attr_and_idx : llvm::enumerate(array_attr)) {
const int idx = attr_and_idx.index();
auto string_attr = attr_and_idx.value().dyn_cast<mlir::StringAttr>();
if (!string_attr)
return op->emitOpError(llvm::formatv(
"bad '{0}' attribute at index {1}, not a string", kDevicesAttr, idx));
if (!DeviceNameUtils::ParseFullName(string_attr.getValue().str(),
&(*devices)[idx]))
return op->emitOpError(
llvm::formatv("bad '{0}' attribute at index {1} with value '{2}', "
"not a valid device",
kDevicesAttr, idx, string_attr.getValue()));
}
return mlir::success();
}
} // namespace tensorflow

View File

@ -0,0 +1,41 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DEVICE_UTIL_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DEVICE_UTIL_H_
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
// Collects all devices known to the system by name and adds them as a
// `tf.devices` array attribute of string attributes to an op. Device names
// added are in the following form:
// /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
void AddDevicesToOp(mlir::Operation* op, const DeviceSet* device_set);
// Collects devices as DeviceNameUtils::ParsedName from an op `tf.devices`
// attribute. A failure will be returned if the attribute is not an
// ArrayAttr<StringAttr> or the devices are invalid.
mlir::LogicalResult GetDevicesFromOp(
mlir::Operation* op,
llvm::SmallVectorImpl<DeviceNameUtils::ParsedName>* devices);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DEVICE_UTIL_H_

View File

@ -0,0 +1,155 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
namespace {
// A fake device used to populate a DeviceSet.
class FakeDevice : public Device {
public:
explicit FakeDevice(const DeviceAttributes& device_attributes)
: Device(nullptr, device_attributes) {}
Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
static std::unique_ptr<Device> Make(const string& name) {
DeviceNameUtils::ParsedName parsed_name;
DeviceNameUtils::ParseFullName(name, &parsed_name);
DeviceAttributes device_attributes;
device_attributes.set_name(name);
device_attributes.set_device_type(parsed_name.type);
return std::make_unique<FakeDevice>(device_attributes);
}
};
TEST(DeviceUtilTest, AddDeviceToOp) {
mlir::MLIRContext context;
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
DeviceSet device_set;
llvm::SmallVector<std::unique_ptr<Device>, 2> devices;
devices.push_back(
FakeDevice::Make("/job:worker/replica:0/task:0/device:CPU:0"));
devices.push_back(
FakeDevice::Make("/job:worker/replica:1/task:2/device:GPU:3"));
for (auto& device : devices) device_set.AddDevice(device.get());
AddDevicesToOp(*module_ref, &device_set);
auto devices_attr = module_ref->getAttrOfType<mlir::ArrayAttr>("tf.devices");
ASSERT_NE(devices_attr, nullptr);
ASSERT_EQ(devices_attr.size(), 2);
auto device_attr_0 = devices_attr.getValue()[0].dyn_cast<mlir::StringAttr>();
ASSERT_NE(device_attr_0, nullptr);
EXPECT_EQ(device_attr_0.getValue(),
"/job:worker/replica:0/task:0/device:CPU:0");
auto device_attr_1 = devices_attr.getValue()[1].dyn_cast<mlir::StringAttr>();
ASSERT_NE(device_attr_1, nullptr);
EXPECT_EQ(device_attr_1.getValue(),
"/job:worker/replica:1/task:2/device:GPU:3");
}
TEST(DeviceUtilTest, AddDeviceToOpNullDeviceSet) {
mlir::MLIRContext context;
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
AddDevicesToOp(*module_ref, /*device_set=*/nullptr);
EXPECT_EQ(module_ref->getAttr("tf.devices"), nullptr);
}
TEST(DeviceUtilTest, GetDevicesFromOpNoDevicesAttribute) {
mlir::MLIRContext context;
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
llvm::SmallVector<DeviceNameUtils::ParsedName, 8> devices;
EXPECT_TRUE(mlir::succeeded(GetDevicesFromOp(*module_ref, &devices)));
}
TEST(DeviceUtilTest, GetDevicesFromOpBadDevicesAttributeType) {
mlir::MLIRContext context;
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::Builder builder(*module_ref);
module_ref->setAttr("tf.devices", builder.getBoolAttr(false));
llvm::SmallVector<DeviceNameUtils::ParsedName, 8> devices;
EXPECT_TRUE(mlir::failed(GetDevicesFromOp(*module_ref, &devices)));
}
TEST(DeviceUtilTest, GetDevicesFromOpBadDevicesAttributeArraySubtype) {
mlir::MLIRContext context;
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::Builder builder(*module_ref);
module_ref->setAttr("tf.devices", builder.getI32ArrayAttr({8}));
llvm::SmallVector<DeviceNameUtils::ParsedName, 8> devices;
EXPECT_TRUE(mlir::failed(GetDevicesFromOp(*module_ref, &devices)));
}
TEST(DeviceUtilTest, GetDevicesFromOpBadDevicesInDevicesAttribute) {
mlir::MLIRContext context;
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::Builder builder(*module_ref);
module_ref->setAttr("tf.devices", builder.getStrArrayAttr({"bad_device"}));
llvm::SmallVector<DeviceNameUtils::ParsedName, 8> devices;
EXPECT_TRUE(mlir::failed(GetDevicesFromOp(*module_ref, &devices)));
}
TEST(DeviceUtilTest, GetDevicesFromOpValidDeviceInDevicesAttribute) {
mlir::MLIRContext context;
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::Builder builder(*module_ref);
module_ref->setAttr(
"tf.devices",
builder.getStrArrayAttr({"/job:worker/replica:0/task:0/device:CPU:0"}));
llvm::SmallVector<DeviceNameUtils::ParsedName, 8> devices;
EXPECT_TRUE(mlir::succeeded(GetDevicesFromOp(*module_ref, &devices)));
ASSERT_EQ(devices.size(), 1);
EXPECT_EQ(DeviceNameUtils::ParsedNameToString(devices[0]),
"/job:worker/replica:0/task:0/device:CPU:0");
}
} // anonymous namespace
} // namespace tensorflow

View File

@ -22,15 +22,9 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@ -155,36 +149,6 @@ std::string GetTPUCompilationDevice(Device system_device) {
} // anonymous namespace
mlir::LogicalResult GetDevicesFromAttribute(
mlir::Operation* op,
llvm::SmallVectorImpl<DeviceNameUtils::ParsedName>* devices) {
auto devices_attr = op->getAttr(kDevicesAttr);
if (!devices_attr) return mlir::success();
auto array_attr = devices_attr.dyn_cast<mlir::ArrayAttr>();
if (!array_attr)
return op->emitOpError(
llvm::formatv("bad '{0}' attribute, not an array", kDevicesAttr));
devices->resize(array_attr.size());
for (auto attr_and_idx : llvm::enumerate(array_attr)) {
const int idx = attr_and_idx.index();
auto string_attr = attr_and_idx.value().dyn_cast<mlir::StringAttr>();
if (!string_attr)
return op->emitOpError(llvm::formatv(
"bad '{0}' attribute at index {1}, not a string", kDevicesAttr, idx));
if (!DeviceNameUtils::ParseFullName(string_attr.getValue().str(),
&(*devices)[idx]))
return op->emitOpError(
llvm::formatv("bad '{0}' attribute at index {1} with value '{2}', "
"not a valid device",
kDevicesAttr, idx, string_attr.getValue()));
}
return mlir::success();
}
Status GetTPUCompilationAndExecutionDevices(
Devices devices, int num_replicas, int num_cores_per_replica,
std::string* compilation_device,

View File

@ -20,22 +20,10 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
constexpr char kDevicesAttr[] = "tf.devices";
// Collects devices as DeviceNameUtils::ParsedName from an op `tf.devices`
// attribute. A failure will be returned if the attribute is not an
// ArrayAttr<StringAttr> or the devices are invalid.
mlir::LogicalResult GetDevicesFromAttribute(
mlir::Operation* op,
llvm::SmallVectorImpl<DeviceNameUtils::ParsedName>* devices);
// Finds the TPU compilation device and execution devices from `devices` for a
// replicated TPU computation subgraph. Compilation device is determined from
// looking up all TPU_SYSTEM:0 devices and choosing the CPU device associated

View File

@ -21,11 +21,6 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/device_name_utils.h"
@ -147,68 +142,5 @@ TEST(TPURewriteDeviceUtilTest, NumReplicasNumTPUs) {
EXPECT_EQ(execution_devices[3], "/job:worker/replica:0/task:1/device:TPU:1");
}
TEST(TPURewriteDeviceUtilTest, NoDevicesAttribute) {
mlir::MLIRContext context;
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
llvm::SmallVector<Device, 8> devices;
EXPECT_TRUE(
mlir::succeeded(GetDevicesFromAttribute(module_ref.get(), &devices)));
}
TEST(TPURewriteDeviceUtilTest, BadDevicesAttributeType) {
mlir::MLIRContext context;
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::Builder builder(module_ref.get());
module_ref->setAttr("tf.devices", builder.getBoolAttr(false));
llvm::SmallVector<Device, 8> devices;
EXPECT_TRUE(
mlir::failed(GetDevicesFromAttribute(module_ref.get(), &devices)));
}
TEST(TPURewriteDeviceUtilTest, BadDevicesAttributeArraySubtype) {
mlir::MLIRContext context;
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::Builder builder(module_ref.get());
module_ref->setAttr("tf.devices", builder.getI32ArrayAttr({8}));
llvm::SmallVector<Device, 8> devices;
EXPECT_TRUE(
mlir::failed(GetDevicesFromAttribute(module_ref.get(), &devices)));
}
TEST(TPURewriteDeviceUtilTest, BadDevicesInDevicesAttribute) {
mlir::MLIRContext context;
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::Builder builder(module_ref.get());
module_ref->setAttr("tf.devices", builder.getStrArrayAttr({"bad_device"}));
llvm::SmallVector<Device, 8> devices;
EXPECT_TRUE(
mlir::failed(GetDevicesFromAttribute(module_ref.get(), &devices)));
}
TEST(TPURewriteDeviceUtilTest, ValidDeviceInDevicesAttribute) {
mlir::MLIRContext context;
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::Builder builder(module_ref.get());
module_ref->setAttr(
"tf.devices",
builder.getStrArrayAttr({"/job:worker/replica:0/task:0/device:CPU:0"}));
llvm::SmallVector<Device, 8> devices;
EXPECT_TRUE(
mlir::succeeded(GetDevicesFromAttribute(module_ref.get(), &devices)));
ASSERT_EQ(devices.size(), 1);
EXPECT_EQ(DeviceNameUtils::ParsedNameToString(devices[0]),
"/job:worker/replica:0/task:0/device:CPU:0");
}
} // anonymous namespace
} // namespace tensorflow

View File

@ -590,11 +590,10 @@ cc_library(
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:device_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"@llvm//:support",
"@local_config_mlir//:IR",
],
alwayslink = 1,
)

View File

@ -17,44 +17,17 @@ limitations under the License.
#include <string>
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_os_ostream.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
// Collects all devices known to the system by name and adds them as an array
// attribute of string attributes to the module. Device names added are in the
// following form:
// /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
static void AddDevicesToModule(mlir::ModuleOp module,
const DeviceSet* device_set) {
if (!device_set) return;
// Collect devices as strings in TensorFlow device name form.
llvm::SmallVector<std::string, 8> devices;
devices.reserve(device_set->devices().size());
for (Device* device : device_set->devices())
devices.push_back(
DeviceNameUtils::ParsedNameToString(device->parsed_name()));
llvm::SmallVector<llvm::StringRef, 8> device_refs(devices.begin(),
devices.end());
mlir::Builder builder(module);
module.setAttr("tf.devices", builder.getStrArrayAttr(device_refs));
}
// Dumps the MLIR module to disk.
// This require the TF_DUMP_GRAPH_PREFIX to be set to a path that exist (or can
// be created).
@ -125,7 +98,7 @@ Status MlirBridgePass::Run(const GraphOptimizationPassOptions& options) {
ConvertGraphToMlir(**options.graph, debug_info,
*options.flib_def, specs, &context));
AddDevicesToModule(*module, options.device_set);
AddDevicesToOp(*module, options.device_set);
if (VLOG_IS_ON(1)) DumpModule(*module, "mlir_bridge_before_");