Use mlir::OpState::operator->() to get to methods of mlir::Operation.
This is a preparation step to remove those methods from OpState. PiperOrigin-RevId: 348209062 Change-Id: I72c68635d6b47d74a16932385cadf1fa3fe5b517
This commit is contained in:
parent
9a68efbad9
commit
80e5ad4c92
@ -128,13 +128,13 @@ void CrossHostTransferPass::runOnFunction() {
|
||||
std::string key = GetNextKey();
|
||||
auto send_op =
|
||||
builder.create<tf_device::SendOp>(op->getLoc(), arg, key, dst_host);
|
||||
send_op.setAttr(kOpDeviceAttr,
|
||||
builder.getStringAttr(src_host + kCPUDevice));
|
||||
send_op->setAttr(kOpDeviceAttr,
|
||||
builder.getStringAttr(src_host + kCPUDevice));
|
||||
|
||||
auto receive_op = builder.create<tf_device::ReceiveOp>(
|
||||
op->getLoc(), arg.getType(), key, src_host);
|
||||
receive_op.setAttr(kOpDeviceAttr,
|
||||
builder.getStringAttr(dst_host + kCPUDevice));
|
||||
receive_op->setAttr(kOpDeviceAttr,
|
||||
builder.getStringAttr(dst_host + kCPUDevice));
|
||||
|
||||
transferred_value_by_host[dst_host] = receive_op.getResult();
|
||||
op->replaceUsesOfWith(arg, receive_op.getResult());
|
||||
|
@ -650,10 +650,10 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailModelParallelism) {
|
||||
llvm::SmallVector<mlir::Type, 8> result_types;
|
||||
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
||||
mlir::UnknownLoc::get(&context), result_types);
|
||||
cluster.setAttr(kNumCoresPerReplicaAttr,
|
||||
builder.getIntegerAttr(builder.getIntegerType(64), 5));
|
||||
cluster.setAttr(kTopologyAttr, builder.getStringAttr(""));
|
||||
cluster.setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
|
||||
cluster->setAttr(kNumCoresPerReplicaAttr,
|
||||
builder.getIntegerAttr(builder.getIntegerType(64), 5));
|
||||
cluster->setAttr(kTopologyAttr, builder.getStringAttr(""));
|
||||
cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
|
||||
|
||||
mlir::TF::RuntimeDevices runtime_devices;
|
||||
std::string host_device;
|
||||
@ -671,9 +671,9 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingTopology) {
|
||||
llvm::SmallVector<mlir::Type, 8> result_types;
|
||||
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
||||
mlir::UnknownLoc::get(&context), result_types);
|
||||
cluster.setAttr(kNumCoresPerReplicaAttr,
|
||||
builder.getIntegerAttr(builder.getIntegerType(64), 1));
|
||||
cluster.setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
|
||||
cluster->setAttr(kNumCoresPerReplicaAttr,
|
||||
builder.getIntegerAttr(builder.getIntegerType(64), 1));
|
||||
cluster->setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
|
||||
|
||||
mlir::TF::RuntimeDevices runtime_devices;
|
||||
std::string host_device;
|
||||
|
@ -69,8 +69,8 @@ class GpuKernelToBlobPass
|
||||
if (blob_or.ok()) {
|
||||
const auto& blob = blob_or.ValueOrDie();
|
||||
std::string blob_string(blob.begin(), blob.end());
|
||||
gpu_module.setAttr(blob_annotation_,
|
||||
mlir::StringAttr::get(blob_string, &getContext()));
|
||||
gpu_module->setAttr(blob_annotation_,
|
||||
mlir::StringAttr::get(blob_string, &getContext()));
|
||||
return;
|
||||
}
|
||||
// Forward the error by attaching the message to the gpu module.
|
||||
|
Loading…
x
Reference in New Issue
Block a user