Use mlir::OpState::operator->() to get to methods of mlir::Operation.
This is a preparation step to remove those methods from OpState. PiperOrigin-RevId: 351850594 Change-Id: I5ba53e37f283b8f5e028be043119859c487097e4
This commit is contained in:
parent
07592d94b6
commit
cfa9e61e4d
@ -44,7 +44,7 @@ FuncOp createMaxUnpoolingFunc(
|
|||||||
|
|
||||||
func.addEntryBlock();
|
func.addEntryBlock();
|
||||||
mlir::StringAttr attr_value = builder->getStringAttr("MaxUnpooling2D");
|
mlir::StringAttr attr_value = builder->getStringAttr("MaxUnpooling2D");
|
||||||
func.setAttr("tf._implements", attr_value);
|
func->setAttr("tf._implements", attr_value);
|
||||||
return func;
|
return func;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,8 +73,8 @@ Type GetResourceSubtype(Value resource) {
|
|||||||
// resource reads and writes per associated partitioned resource handle.
|
// resource reads and writes per associated partitioned resource handle.
|
||||||
void PartitionResourceReadsWrites(tf_device::ClusterFuncOp cluster_func) {
|
void PartitionResourceReadsWrites(tf_device::ClusterFuncOp cluster_func) {
|
||||||
bool use_spmd = false;
|
bool use_spmd = false;
|
||||||
if (auto use_spmd_attr =
|
if (auto use_spmd_attr = cluster_func->getAttrOfType<BoolAttr>(
|
||||||
cluster_func.getAttrOfType<BoolAttr>("use_spmd_for_xla_partitioning"))
|
"use_spmd_for_xla_partitioning"))
|
||||||
use_spmd = use_spmd_attr.getValue();
|
use_spmd = use_spmd_attr.getValue();
|
||||||
|
|
||||||
if (!use_spmd) return;
|
if (!use_spmd) return;
|
||||||
|
@ -691,9 +691,9 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingDeviceAssignment) {
|
|||||||
llvm::SmallVector<mlir::Type, 8> result_types;
|
llvm::SmallVector<mlir::Type, 8> result_types;
|
||||||
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
||||||
mlir::UnknownLoc::get(&context), result_types);
|
mlir::UnknownLoc::get(&context), result_types);
|
||||||
cluster.setAttr(kNumCoresPerReplicaAttr,
|
cluster->setAttr(kNumCoresPerReplicaAttr,
|
||||||
builder.getIntegerAttr(builder.getIntegerType(64), 1));
|
builder.getIntegerAttr(builder.getIntegerType(64), 1));
|
||||||
cluster.setAttr(kTopologyAttr, builder.getStringAttr(""));
|
cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
|
||||||
|
|
||||||
mlir::TF::RuntimeDevices runtime_devices;
|
mlir::TF::RuntimeDevices runtime_devices;
|
||||||
std::string host_device;
|
std::string host_device;
|
||||||
@ -711,12 +711,12 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceAssignment) {
|
|||||||
llvm::SmallVector<mlir::Type, 8> result_types;
|
llvm::SmallVector<mlir::Type, 8> result_types;
|
||||||
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
||||||
mlir::UnknownLoc::get(&context), result_types);
|
mlir::UnknownLoc::get(&context), result_types);
|
||||||
cluster.setAttr(kNumCoresPerReplicaAttr,
|
cluster->setAttr(kNumCoresPerReplicaAttr,
|
||||||
builder.getIntegerAttr(builder.getIntegerType(64), 1));
|
builder.getIntegerAttr(builder.getIntegerType(64), 1));
|
||||||
cluster.setAttr(kTopologyAttr, builder.getStringAttr(""));
|
cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
|
||||||
cluster.setAttr(kDeviceAssignmentAttr,
|
cluster->setAttr(kDeviceAssignmentAttr,
|
||||||
builder.getStrArrayAttr(llvm::ArrayRef<llvm::StringRef>(
|
builder.getStrArrayAttr(llvm::ArrayRef<llvm::StringRef>(
|
||||||
{"bad_device_assigment"})));
|
{"bad_device_assigment"})));
|
||||||
|
|
||||||
mlir::TF::RuntimeDevices runtime_devices;
|
mlir::TF::RuntimeDevices runtime_devices;
|
||||||
std::string host_device;
|
std::string host_device;
|
||||||
@ -737,10 +737,10 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceName) {
|
|||||||
llvm::SmallVector<mlir::Type, 8> result_types;
|
llvm::SmallVector<mlir::Type, 8> result_types;
|
||||||
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
||||||
mlir::UnknownLoc::get(&context), result_types);
|
mlir::UnknownLoc::get(&context), result_types);
|
||||||
cluster.setAttr(kNumCoresPerReplicaAttr,
|
cluster->setAttr(kNumCoresPerReplicaAttr,
|
||||||
builder.getIntegerAttr(builder.getIntegerType(64), 1));
|
builder.getIntegerAttr(builder.getIntegerType(64), 1));
|
||||||
cluster.setAttr(kTopologyAttr, builder.getStringAttr(""));
|
cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
|
||||||
cluster.setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
|
cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
|
||||||
|
|
||||||
mlir::TF::RuntimeDevices runtime_devices;
|
mlir::TF::RuntimeDevices runtime_devices;
|
||||||
GetDevicesFromOp(*module_ref, &runtime_devices);
|
GetDevicesFromOp(*module_ref, &runtime_devices);
|
||||||
@ -791,10 +791,10 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceNotReplicated) {
|
|||||||
llvm::SmallVector<mlir::Type, 8> result_types;
|
llvm::SmallVector<mlir::Type, 8> result_types;
|
||||||
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
||||||
mlir::UnknownLoc::get(&context), result_types);
|
mlir::UnknownLoc::get(&context), result_types);
|
||||||
cluster.setAttr(kNumCoresPerReplicaAttr,
|
cluster->setAttr(kNumCoresPerReplicaAttr,
|
||||||
builder.getIntegerAttr(builder.getIntegerType(64), 1));
|
builder.getIntegerAttr(builder.getIntegerType(64), 1));
|
||||||
cluster.setAttr(kTopologyAttr, builder.getStringAttr(""));
|
cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
|
||||||
cluster.setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
|
cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
|
||||||
|
|
||||||
mlir::TF::RuntimeDevices runtime_devices;
|
mlir::TF::RuntimeDevices runtime_devices;
|
||||||
GetDevicesFromOp(*module_ref, &runtime_devices);
|
GetDevicesFromOp(*module_ref, &runtime_devices);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user