Use mlir::OpState::operator->() to get to methods of mlir::Operation.

This is a preparation step to remove those methods from OpState.

PiperOrigin-RevId: 348642408
Change-Id: I2982f5516216e4956b1d0f7023445ef7622a545a
This commit is contained in:
Christian Sigg 2020-12-22 09:01:28 -08:00 committed by TensorFlower Gardener
parent 8908b62a48
commit 7d8e31edb0

View File

@ -77,7 +77,7 @@ TEST(DeviceUtilTest, AddDeviceToOp) {
AddDevicesToOp(*module_ref, &device_set);
auto devices_attr =
module_ref->getAttrOfType<mlir::DictionaryAttr>("tf.devices");
(*module_ref)->getAttrOfType<mlir::DictionaryAttr>("tf.devices");
ASSERT_NE(devices_attr, nullptr);
ASSERT_EQ(devices_attr.size(), 3);
@ -105,7 +105,7 @@ TEST(DeviceUtilTest, AddDeviceToOpNullDeviceSet) {
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
AddDevicesToOp(*module_ref, /*device_set=*/nullptr);
EXPECT_EQ(module_ref->getAttr("tf.devices"), nullptr);
EXPECT_EQ((*module_ref)->getAttr("tf.devices"), nullptr);
}
TEST(DeviceUtilTest, GetDevicesFromOpNoDevicesAttribute) {
@ -122,7 +122,7 @@ TEST(DeviceUtilTest, GetDevicesFromOpBadDevicesAttributeType) {
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::Builder builder(*module_ref);
module_ref->setAttr("tf.devices", builder.getBoolAttr(false));
(*module_ref)->setAttr("tf.devices", builder.getBoolAttr(false));
mlir::TF::RuntimeDevices devices;
EXPECT_TRUE(mlir::failed(GetDevicesFromOp(*module_ref, &devices)));
@ -133,7 +133,7 @@ TEST(DeviceUtilTest, GetDevicesFromOpBadDevicesAttributeArraySubtype) {
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::Builder builder(*module_ref);
module_ref->setAttr("tf.devices", builder.getI32ArrayAttr({8}));
(*module_ref)->setAttr("tf.devices", builder.getI32ArrayAttr({8}));
mlir::TF::RuntimeDevices devices;
EXPECT_TRUE(mlir::failed(GetDevicesFromOp(*module_ref, &devices)));
@ -144,9 +144,10 @@ TEST(DeviceUtilTest, GetDevicesFromOpBadDevicesInDevicesAttribute) {
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::Builder builder(*module_ref);
module_ref->setAttr("tf.devices",
builder.getDictionaryAttr(builder.getNamedAttr(
"bad_device", builder.getDictionaryAttr({}))));
(*module_ref)
->setAttr("tf.devices",
builder.getDictionaryAttr(builder.getNamedAttr(
"bad_device", builder.getDictionaryAttr({}))));
mlir::TF::RuntimeDevices devices;
EXPECT_TRUE(mlir::failed(GetDevicesFromOp(*module_ref, &devices)));
@ -161,7 +162,7 @@ TEST(DeviceUtilTest, GetDevicesFromOpValidDeviceInDevicesAttribute) {
auto device_dict = builder.getDictionaryAttr(
{builder.getNamedAttr("/job:worker/replica:0/task:0/device:CPU:0",
builder.getDictionaryAttr({}))});
module_ref->setAttr("tf.devices", device_dict);
(*module_ref)->setAttr("tf.devices", device_dict);
mlir::TF::RuntimeDevices devices;
EXPECT_TRUE(mlir::succeeded(GetDevicesFromOp(*module_ref, &devices)));
@ -188,7 +189,7 @@ TEST(DeviceUtilTest, GetGpuDeviceMetadata) {
builder.getI32IntegerAttr(2),
module_ref->getContext())));
module_ref->setAttr("tf.devices", builder.getDictionaryAttr(metadata));
(*module_ref)->setAttr("tf.devices", builder.getDictionaryAttr(metadata));
mlir::TF::RuntimeDevices devices;
EXPECT_TRUE(mlir::succeeded(GetDevicesFromOp(*module_ref, &devices)));