Create separate general device util for adding and getting device names (from DeviceSet) to and from an op (NFC).
This moves a function from tpu_rewrite_device_util for getting devices, and a function for adding devices from mlir_bridge_pass, into its own util. PiperOrigin-RevId: 273610226
This commit is contained in:
parent
18af93acca
commit
0aaf71779a
@ -834,8 +834,6 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -845,11 +843,39 @@ tf_cc_test(
|
||||
srcs = ["utils/tpu_rewrite_device_util_test.cc"],
|
||||
deps = [
|
||||
":tpu_rewrite_device_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@llvm//:support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "device_util",
|
||||
srcs = ["utils/device_util.cc"],
|
||||
hdrs = ["utils/device_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "device_util_test",
|
||||
size = "small",
|
||||
srcs = ["utils/device_util_test.cc"],
|
||||
deps = [
|
||||
":device_util",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
|
82
tensorflow/compiler/mlir/tensorflow/utils/device_util.cc
Normal file
82
tensorflow/compiler/mlir/tensorflow/utils/device_util.cc
Normal file
@ -0,0 +1,82 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
constexpr char kDevicesAttr[] = "tf.devices";
|
||||
|
||||
void AddDevicesToOp(mlir::Operation* op, const DeviceSet* device_set) {
|
||||
if (!device_set) return;
|
||||
|
||||
// Collect devices as strings in TensorFlow device name form.
|
||||
llvm::SmallVector<std::string, 8> devices;
|
||||
devices.reserve(device_set->devices().size());
|
||||
for (Device* device : device_set->devices())
|
||||
devices.push_back(
|
||||
DeviceNameUtils::ParsedNameToString(device->parsed_name()));
|
||||
|
||||
llvm::SmallVector<llvm::StringRef, 8> device_refs(devices.begin(),
|
||||
devices.end());
|
||||
mlir::Builder builder(op->getContext());
|
||||
op->setAttr(kDevicesAttr, builder.getStrArrayAttr(device_refs));
|
||||
}
|
||||
|
||||
mlir::LogicalResult GetDevicesFromOp(
|
||||
mlir::Operation* op,
|
||||
llvm::SmallVectorImpl<DeviceNameUtils::ParsedName>* devices) {
|
||||
auto devices_attr = op->getAttr(kDevicesAttr);
|
||||
if (!devices_attr) return mlir::success();
|
||||
|
||||
auto array_attr = devices_attr.dyn_cast<mlir::ArrayAttr>();
|
||||
if (!array_attr)
|
||||
return op->emitOpError(
|
||||
llvm::formatv("bad '{0}' attribute, not an array", kDevicesAttr));
|
||||
|
||||
devices->resize(array_attr.size());
|
||||
for (auto attr_and_idx : llvm::enumerate(array_attr)) {
|
||||
const int idx = attr_and_idx.index();
|
||||
auto string_attr = attr_and_idx.value().dyn_cast<mlir::StringAttr>();
|
||||
if (!string_attr)
|
||||
return op->emitOpError(llvm::formatv(
|
||||
"bad '{0}' attribute at index {1}, not a string", kDevicesAttr, idx));
|
||||
|
||||
if (!DeviceNameUtils::ParseFullName(string_attr.getValue().str(),
|
||||
&(*devices)[idx]))
|
||||
return op->emitOpError(
|
||||
llvm::formatv("bad '{0}' attribute at index {1} with value '{2}', "
|
||||
"not a valid device",
|
||||
kDevicesAttr, idx, string_attr.getValue()));
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
41
tensorflow/compiler/mlir/tensorflow/utils/device_util.h
Normal file
41
tensorflow/compiler/mlir/tensorflow/utils/device_util.h
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DEVICE_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DEVICE_UTIL_H_
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// Collects all devices known to the system by name and adds them as a
|
||||
// `tf.devices` array attribute of string attributes to an op. Device names
|
||||
// added are in the following form:
|
||||
// /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
|
||||
void AddDevicesToOp(mlir::Operation* op, const DeviceSet* device_set);
|
||||
|
||||
// Collects devices as DeviceNameUtils::ParsedName from an op `tf.devices`
|
||||
// attribute. A failure will be returned if the attribute is not an
|
||||
// ArrayAttr<StringAttr> or the devices are invalid.
|
||||
mlir::LogicalResult GetDevicesFromOp(
|
||||
mlir::Operation* op,
|
||||
llvm::SmallVectorImpl<DeviceNameUtils::ParsedName>* devices);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DEVICE_UTIL_H_
|
155
tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc
Normal file
155
tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc
Normal file
@ -0,0 +1,155 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// A fake device used to populate a DeviceSet.
|
||||
class FakeDevice : public Device {
|
||||
public:
|
||||
explicit FakeDevice(const DeviceAttributes& device_attributes)
|
||||
: Device(nullptr, device_attributes) {}
|
||||
|
||||
Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
|
||||
|
||||
static std::unique_ptr<Device> Make(const string& name) {
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
DeviceNameUtils::ParseFullName(name, &parsed_name);
|
||||
|
||||
DeviceAttributes device_attributes;
|
||||
device_attributes.set_name(name);
|
||||
device_attributes.set_device_type(parsed_name.type);
|
||||
return std::make_unique<FakeDevice>(device_attributes);
|
||||
}
|
||||
};
|
||||
|
||||
TEST(DeviceUtilTest, AddDeviceToOp) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
|
||||
DeviceSet device_set;
|
||||
llvm::SmallVector<std::unique_ptr<Device>, 2> devices;
|
||||
devices.push_back(
|
||||
FakeDevice::Make("/job:worker/replica:0/task:0/device:CPU:0"));
|
||||
devices.push_back(
|
||||
FakeDevice::Make("/job:worker/replica:1/task:2/device:GPU:3"));
|
||||
for (auto& device : devices) device_set.AddDevice(device.get());
|
||||
|
||||
AddDevicesToOp(*module_ref, &device_set);
|
||||
auto devices_attr = module_ref->getAttrOfType<mlir::ArrayAttr>("tf.devices");
|
||||
ASSERT_NE(devices_attr, nullptr);
|
||||
ASSERT_EQ(devices_attr.size(), 2);
|
||||
auto device_attr_0 = devices_attr.getValue()[0].dyn_cast<mlir::StringAttr>();
|
||||
ASSERT_NE(device_attr_0, nullptr);
|
||||
EXPECT_EQ(device_attr_0.getValue(),
|
||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||
auto device_attr_1 = devices_attr.getValue()[1].dyn_cast<mlir::StringAttr>();
|
||||
ASSERT_NE(device_attr_1, nullptr);
|
||||
EXPECT_EQ(device_attr_1.getValue(),
|
||||
"/job:worker/replica:1/task:2/device:GPU:3");
|
||||
}
|
||||
|
||||
TEST(DeviceUtilTest, AddDeviceToOpNullDeviceSet) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
|
||||
AddDevicesToOp(*module_ref, /*device_set=*/nullptr);
|
||||
EXPECT_EQ(module_ref->getAttr("tf.devices"), nullptr);
|
||||
}
|
||||
|
||||
TEST(DeviceUtilTest, GetDevicesFromOpNoDevicesAttribute) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
|
||||
llvm::SmallVector<DeviceNameUtils::ParsedName, 8> devices;
|
||||
EXPECT_TRUE(mlir::succeeded(GetDevicesFromOp(*module_ref, &devices)));
|
||||
}
|
||||
|
||||
TEST(DeviceUtilTest, GetDevicesFromOpBadDevicesAttributeType) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::Builder builder(*module_ref);
|
||||
module_ref->setAttr("tf.devices", builder.getBoolAttr(false));
|
||||
|
||||
llvm::SmallVector<DeviceNameUtils::ParsedName, 8> devices;
|
||||
EXPECT_TRUE(mlir::failed(GetDevicesFromOp(*module_ref, &devices)));
|
||||
}
|
||||
|
||||
TEST(DeviceUtilTest, GetDevicesFromOpBadDevicesAttributeArraySubtype) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::Builder builder(*module_ref);
|
||||
module_ref->setAttr("tf.devices", builder.getI32ArrayAttr({8}));
|
||||
|
||||
llvm::SmallVector<DeviceNameUtils::ParsedName, 8> devices;
|
||||
EXPECT_TRUE(mlir::failed(GetDevicesFromOp(*module_ref, &devices)));
|
||||
}
|
||||
|
||||
TEST(DeviceUtilTest, GetDevicesFromOpBadDevicesInDevicesAttribute) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::Builder builder(*module_ref);
|
||||
module_ref->setAttr("tf.devices", builder.getStrArrayAttr({"bad_device"}));
|
||||
|
||||
llvm::SmallVector<DeviceNameUtils::ParsedName, 8> devices;
|
||||
EXPECT_TRUE(mlir::failed(GetDevicesFromOp(*module_ref, &devices)));
|
||||
}
|
||||
|
||||
TEST(DeviceUtilTest, GetDevicesFromOpValidDeviceInDevicesAttribute) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::Builder builder(*module_ref);
|
||||
module_ref->setAttr(
|
||||
"tf.devices",
|
||||
builder.getStrArrayAttr({"/job:worker/replica:0/task:0/device:CPU:0"}));
|
||||
|
||||
llvm::SmallVector<DeviceNameUtils::ParsedName, 8> devices;
|
||||
EXPECT_TRUE(mlir::succeeded(GetDevicesFromOp(*module_ref, &devices)));
|
||||
ASSERT_EQ(devices.size(), 1);
|
||||
EXPECT_EQ(DeviceNameUtils::ParsedNameToString(devices[0]),
|
||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
@ -22,15 +22,9 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -155,36 +149,6 @@ std::string GetTPUCompilationDevice(Device system_device) {
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
mlir::LogicalResult GetDevicesFromAttribute(
|
||||
mlir::Operation* op,
|
||||
llvm::SmallVectorImpl<DeviceNameUtils::ParsedName>* devices) {
|
||||
auto devices_attr = op->getAttr(kDevicesAttr);
|
||||
if (!devices_attr) return mlir::success();
|
||||
|
||||
auto array_attr = devices_attr.dyn_cast<mlir::ArrayAttr>();
|
||||
if (!array_attr)
|
||||
return op->emitOpError(
|
||||
llvm::formatv("bad '{0}' attribute, not an array", kDevicesAttr));
|
||||
|
||||
devices->resize(array_attr.size());
|
||||
for (auto attr_and_idx : llvm::enumerate(array_attr)) {
|
||||
const int idx = attr_and_idx.index();
|
||||
auto string_attr = attr_and_idx.value().dyn_cast<mlir::StringAttr>();
|
||||
if (!string_attr)
|
||||
return op->emitOpError(llvm::formatv(
|
||||
"bad '{0}' attribute at index {1}, not a string", kDevicesAttr, idx));
|
||||
|
||||
if (!DeviceNameUtils::ParseFullName(string_attr.getValue().str(),
|
||||
&(*devices)[idx]))
|
||||
return op->emitOpError(
|
||||
llvm::formatv("bad '{0}' attribute at index {1} with value '{2}', "
|
||||
"not a valid device",
|
||||
kDevicesAttr, idx, string_attr.getValue()));
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
Status GetTPUCompilationAndExecutionDevices(
|
||||
Devices devices, int num_replicas, int num_cores_per_replica,
|
||||
std::string* compilation_device,
|
||||
|
@ -20,22 +20,10 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
constexpr char kDevicesAttr[] = "tf.devices";
|
||||
|
||||
// Collects devices as DeviceNameUtils::ParsedName from an op `tf.devices`
|
||||
// attribute. A failure will be returned if the attribute is not an
|
||||
// ArrayAttr<StringAttr> or the devices are invalid.
|
||||
mlir::LogicalResult GetDevicesFromAttribute(
|
||||
mlir::Operation* op,
|
||||
llvm::SmallVectorImpl<DeviceNameUtils::ParsedName>* devices);
|
||||
|
||||
// Finds the TPU compilation device and execution devices from `devices` for a
|
||||
// replicated TPU computation subgraph. Compilation device is determined from
|
||||
// looking up all TPU_SYSTEM:0 devices and choosing the CPU device associated
|
||||
|
@ -21,11 +21,6 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
@ -147,68 +142,5 @@ TEST(TPURewriteDeviceUtilTest, NumReplicasNumTPUs) {
|
||||
EXPECT_EQ(execution_devices[3], "/job:worker/replica:0/task:1/device:TPU:1");
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, NoDevicesAttribute) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
|
||||
llvm::SmallVector<Device, 8> devices;
|
||||
EXPECT_TRUE(
|
||||
mlir::succeeded(GetDevicesFromAttribute(module_ref.get(), &devices)));
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, BadDevicesAttributeType) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::Builder builder(module_ref.get());
|
||||
module_ref->setAttr("tf.devices", builder.getBoolAttr(false));
|
||||
|
||||
llvm::SmallVector<Device, 8> devices;
|
||||
EXPECT_TRUE(
|
||||
mlir::failed(GetDevicesFromAttribute(module_ref.get(), &devices)));
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, BadDevicesAttributeArraySubtype) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::Builder builder(module_ref.get());
|
||||
module_ref->setAttr("tf.devices", builder.getI32ArrayAttr({8}));
|
||||
|
||||
llvm::SmallVector<Device, 8> devices;
|
||||
EXPECT_TRUE(
|
||||
mlir::failed(GetDevicesFromAttribute(module_ref.get(), &devices)));
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, BadDevicesInDevicesAttribute) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::Builder builder(module_ref.get());
|
||||
module_ref->setAttr("tf.devices", builder.getStrArrayAttr({"bad_device"}));
|
||||
|
||||
llvm::SmallVector<Device, 8> devices;
|
||||
EXPECT_TRUE(
|
||||
mlir::failed(GetDevicesFromAttribute(module_ref.get(), &devices)));
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, ValidDeviceInDevicesAttribute) {
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::Builder builder(module_ref.get());
|
||||
module_ref->setAttr(
|
||||
"tf.devices",
|
||||
builder.getStrArrayAttr({"/job:worker/replica:0/task:0/device:CPU:0"}));
|
||||
|
||||
llvm::SmallVector<Device, 8> devices;
|
||||
EXPECT_TRUE(
|
||||
mlir::succeeded(GetDevicesFromAttribute(module_ref.get(), &devices)));
|
||||
ASSERT_EQ(devices.size(), 1);
|
||||
EXPECT_EQ(DeviceNameUtils::ParsedNameToString(devices[0]),
|
||||
"/job:worker/replica:0/task:0/device:CPU:0");
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -590,11 +590,10 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:device_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -17,44 +17,17 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/bridge.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Collects all devices known to the system by name and adds them as an array
|
||||
// attribute of string attributes to the module. Device names added are in the
|
||||
// following form:
|
||||
// /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
|
||||
static void AddDevicesToModule(mlir::ModuleOp module,
|
||||
const DeviceSet* device_set) {
|
||||
if (!device_set) return;
|
||||
|
||||
// Collect devices as strings in TensorFlow device name form.
|
||||
llvm::SmallVector<std::string, 8> devices;
|
||||
devices.reserve(device_set->devices().size());
|
||||
for (Device* device : device_set->devices())
|
||||
devices.push_back(
|
||||
DeviceNameUtils::ParsedNameToString(device->parsed_name()));
|
||||
|
||||
llvm::SmallVector<llvm::StringRef, 8> device_refs(devices.begin(),
|
||||
devices.end());
|
||||
mlir::Builder builder(module);
|
||||
module.setAttr("tf.devices", builder.getStrArrayAttr(device_refs));
|
||||
}
|
||||
|
||||
// Dumps the MLIR module to disk.
|
||||
// This require the TF_DUMP_GRAPH_PREFIX to be set to a path that exist (or can
|
||||
// be created).
|
||||
@ -125,7 +98,7 @@ Status MlirBridgePass::Run(const GraphOptimizationPassOptions& options) {
|
||||
ConvertGraphToMlir(**options.graph, debug_info,
|
||||
*options.flib_def, specs, &context));
|
||||
|
||||
AddDevicesToModule(*module, options.device_set);
|
||||
AddDevicesToOp(*module, options.device_set);
|
||||
|
||||
if (VLOG_IS_ON(1)) DumpModule(*module, "mlir_bridge_before_");
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user