Add a pass to parallelize TPU embedding params ops on different shards
This pass moves LoadTPUEmbedding* and corresponding ReadVariable ops to different regions using the parallel_execute op. This parallel_execute op is later broken in different islands by ParallelExecuteToIslands pass and these islands can progress in parallel. This pass is required to avoid control dependencies between ops on different shards during export to the tf_executor dialect. Also, added this pass to the pass pipeline. PiperOrigin-RevId: 323426936 Change-Id: If250a57dfdd137ba25e265581a653dfa104323d3
This commit is contained in:
parent
c9ec28eebe
commit
f27f2de368
@ -734,6 +734,7 @@ cc_library(
|
||||
"transforms/materialize_mlir_passthrough_op.cc",
|
||||
"transforms/optimize.cc",
|
||||
"transforms/parallel_execute_to_islands.cc",
|
||||
"transforms/parallelize_embedding_params_ops_pass.cc",
|
||||
"transforms/promote_resources_to_args.cc",
|
||||
"transforms/readonly_references_to_resources.cc",
|
||||
"transforms/region_control_flow_to_functional.cc",
|
||||
@ -808,6 +809,7 @@ cc_library(
|
||||
"//tensorflow/core/platform:random",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
|
||||
"//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
|
@ -0,0 +1,96 @@
|
||||
// RUN: tf-opt %s -tf-parallize-embedding-params-ops -verify-diagnostics -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @two_shards
|
||||
func @two_shards(%arg0: tensor<*x!tf.resource<tensor<8xf32>>>, %arg1: tensor<*x!tf.resource<tensor<8xf32>>>, %arg2: tensor<*x!tf.resource<tensor<8xf32>>>, %arg3: tensor<*x!tf.resource<tensor<8xf32>>>) {
|
||||
tf_executor.graph {
|
||||
%control = tf_executor.island {
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK: "tf.ReadVariableOp"
|
||||
// CHECK: "tf.ReadVariableOp"
|
||||
// CHECK: "tf.LoadTPUEmbeddingAdagradParameters"
|
||||
// CHECK: tf_device.return
|
||||
// CHECK: "tf.ReadVariableOp"
|
||||
// CHECK: "tf.ReadVariableOp"
|
||||
// CHECK: "tf.LoadTPUEmbeddingAdagradParameters"
|
||||
// CHECK: tf_device.return
|
||||
%0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
|
||||
%1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
|
||||
%2 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
|
||||
%3 = "tf.ReadVariableOp"(%arg3) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
|
||||
"tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:0/device:CPU:0", num_shards = 2 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> ()
|
||||
"tf.LoadTPUEmbeddingAdagradParameters"(%2, %3) {config = "", device = "/job:worker/replica:0/task:1/device:CPU:0", num_shards = 2 : i64, shard_id = 1 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> ()
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch %control : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Verifies that resource reads shared across two shards are kept outside the
|
||||
// parallel_execute op.
|
||||
|
||||
// CHECK-LABEL: func @shared_reads
|
||||
func @shared_reads(%arg0: tensor<*x!tf.resource<tensor<8xf32>>>, %arg1: tensor<*x!tf.resource<tensor<8xf32>>>) {
|
||||
tf_executor.graph {
|
||||
%control = tf_executor.island {
|
||||
// CHECK: "tf.ReadVariableOp"
|
||||
%0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
|
||||
// CHECK: "tf.ReadVariableOp"
|
||||
%1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
|
||||
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK: "tf.LoadTPUEmbeddingAdagradParameters"
|
||||
// CHECK: tf_device.return
|
||||
// CHECK: "tf.LoadTPUEmbeddingAdagradParameters"
|
||||
// CHECK: tf_device.return
|
||||
"tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:0/device:CPU:0", num_shards = 2 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> ()
|
||||
"tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:1/device:CPU:0", num_shards = 2 : i64, shard_id = 1 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> ()
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch %control : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Verifies that if the resource variables are used in ops other than read
|
||||
// variable op whose semantics are not known then the function is kept
|
||||
// unchanged.
|
||||
|
||||
// CHECK-LABEL: func @update_var
|
||||
func @update_var(%arg0: tensor<*x!tf.resource<tensor<8xf32>>>, %arg1: tensor<*x!tf.resource<tensor<8xf32>>>, %arg2: tensor<*x!tf.resource<tensor<8xf32>>>) {
|
||||
tf_executor.graph {
|
||||
// CHECK-NOT: tf_device.parallel_execute
|
||||
%control = tf_executor.island {
|
||||
%0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
|
||||
%1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
|
||||
"tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:0/device:CPU:0", num_shards = 2 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> ()
|
||||
|
||||
%2 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
|
||||
%zeros = "tf.Const"() {value = dense<1.0> : tensor<8xf32>} : () -> tensor<8xf32>
|
||||
"tf.AssignVariableOp"(%arg2, %zeros) : (tensor<*x!tf.resource<tensor<8xf32>>>, tensor<8xf32>) -> ()
|
||||
%3 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
|
||||
"tf.LoadTPUEmbeddingAdagradParameters"(%2, %3) {config = "", device = "/job:worker/replica:0/task:1/device:CPU:0", num_shards = 2 : i64, shard_id = 1 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> ()
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch %control : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_shard_range(%arg0: tensor<*x!tf.resource<tensor<8xf32>>>, %arg1: tensor<*x!tf.resource<tensor<8xf32>>>) {
|
||||
tf_executor.graph {
|
||||
%control = tf_executor.island {
|
||||
// expected-error @-1 {{require continuous range of shards}}
|
||||
%0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
|
||||
%1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
|
||||
|
||||
"tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:0/device:CPU:0", num_shards = 3 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> ()
|
||||
"tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:1/device:CPU:0", num_shards = 3 : i64, shard_id = 3 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> ()
|
||||
tf_executor.yield
|
||||
}
|
||||
tf_executor.fetch %control : !tf_executor.control
|
||||
}
|
||||
return
|
||||
}
|
@ -41,6 +41,7 @@ namespace TFTPU {
|
||||
namespace {
|
||||
void AddGraphExportLoweringPasses(OpPassManager &pm) {
|
||||
pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
|
||||
pm.addNestedPass<FuncOp>(TFDevice::CreateParallelizeEmbeddingParamsOpsPass());
|
||||
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
|
||||
pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateToIslandPass());
|
||||
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
|
||||
|
@ -0,0 +1,152 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation parallelizes TPU embedding params assigned to different
|
||||
// shards using the parallel execute op. This is useful to avoid introducing
|
||||
// control dependency between these ops that are known to be independent.
|
||||
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFDevice {
|
||||
|
||||
namespace {
|
||||
|
||||
struct ParallelizeEmbeddingParamsOpsPass
|
||||
: public PassWrapper<ParallelizeEmbeddingParamsOpsPass, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
bool IsLoadTPUEmbeddingParmasOp(Operation& op) {
|
||||
static const auto* algorithms = []() {
|
||||
auto* algorithms = new llvm::SmallSet<std::string, 16>();
|
||||
for (tensorflow::tpu::OptimizationAlgorithm alg :
|
||||
tensorflow::tpu::GetOptimizationAlgorithms()) {
|
||||
const auto alg_name = tensorflow::tpu::GetOptimizationAlgorithmName(alg);
|
||||
algorithms->insert(alg_name);
|
||||
}
|
||||
return algorithms;
|
||||
}();
|
||||
StringRef op_name = op.getName().getStringRef();
|
||||
return op_name.consume_front("tf.LoadTPUEmbedding") &&
|
||||
op_name.consume_back("Parameters") &&
|
||||
algorithms->contains(op_name.str());
|
||||
}
|
||||
|
||||
static LogicalResult RunOnIsland(tf_executor::IslandOp island) {
|
||||
Block* block = island.getBody();
|
||||
|
||||
// Map from op to the id of the shard it is assigned for ops that can execute
|
||||
// in parallel across shards.
|
||||
llvm::SmallMapVector<Operation*, int64_t, 4> assigned_shard;
|
||||
llvm::SmallVector<Value, 8> resources;
|
||||
llvm::SmallSet<int64_t, 16> shard_ids;
|
||||
for (Operation& op : llvm::reverse(*block)) {
|
||||
int64_t shard = -1;
|
||||
if (IsLoadTPUEmbeddingParmasOp(op)) {
|
||||
auto shard_id = op.getAttrOfType<mlir::IntegerAttr>("shard_id");
|
||||
if (!shard_id) {
|
||||
return op.emitOpError("requires 'shard_id' integer attribute");
|
||||
}
|
||||
shard = shard_id.getInt();
|
||||
shard_ids.insert(shard);
|
||||
} else if (auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(op)) {
|
||||
if (assigned_shard.empty()) continue;
|
||||
|
||||
for (Operation* user : op.getUsers()) {
|
||||
auto iter = assigned_shard.find(user);
|
||||
if (iter == assigned_shard.end() ||
|
||||
(shard != -1 && shard != iter->second)) {
|
||||
shard = -1;
|
||||
break;
|
||||
}
|
||||
shard = iter->second;
|
||||
}
|
||||
if (shard != -1) resources.push_back(read_op.resource());
|
||||
}
|
||||
|
||||
if (shard != -1) assigned_shard.insert(std::make_pair(&op, shard));
|
||||
}
|
||||
|
||||
// No transformations are required.
|
||||
int num_shards = shard_ids.size();
|
||||
if (num_shards <= 1) return success();
|
||||
|
||||
// If the resources are used for ops other than read variable op, then moving
|
||||
// read variable ops to the parallel_execute may not preserve the semantics.
|
||||
for (Value resource : resources) {
|
||||
for (Operation* user : resource.getUsers())
|
||||
if (!llvm::isa<TF::ReadVariableOp>(*user)) return success();
|
||||
}
|
||||
|
||||
// Create parallel_execute op at the end of the block and move operations
|
||||
// to their corresponding shard.
|
||||
auto builder = OpBuilder::atBlockTerminator(block);
|
||||
auto parallel_execute_op = builder.create<tf_device::ParallelExecuteOp>(
|
||||
island.getLoc(), num_shards, llvm::ArrayRef<Type>());
|
||||
for (int shard_id = 0; shard_id < num_shards; ++shard_id) {
|
||||
mlir::Block& b = parallel_execute_op.GetRegionBlockWithIndex(shard_id);
|
||||
builder.setInsertionPointToStart(&b);
|
||||
builder.create<tf_device::ReturnOp>(island.getLoc());
|
||||
}
|
||||
|
||||
for (auto op_shard : assigned_shard) {
|
||||
int64_t shard = op_shard.second;
|
||||
if (shard >= num_shards) {
|
||||
return island.emitOpError(
|
||||
"load tpu embedding ops require continuous range of shards");
|
||||
}
|
||||
mlir::Block& b = parallel_execute_op.GetRegionBlockWithIndex(shard);
|
||||
op_shard.first->moveBefore(&b, b.begin());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void ParallelizeEmbeddingParamsOpsPass::runOnFunction() {
|
||||
getFunction().walk([&](tf_executor::IslandOp island) {
|
||||
if (failed(RunOnIsland(island))) {
|
||||
signalPassFailure();
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateParallelizeEmbeddingParamsOpsPass() {
|
||||
return std::make_unique<ParallelizeEmbeddingParamsOpsPass>();
|
||||
}
|
||||
} // namespace TFDevice
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::TFDevice::ParallelizeEmbeddingParamsOpsPass>
|
||||
pass("tf-parallize-embedding-params-ops",
|
||||
"Parallelizes TPU embedding params assigned to different shards using "
|
||||
"the parallel_execte op");
|
@ -242,6 +242,11 @@ std::unique_ptr<OperationPass<FuncOp>> CreateReplicateToIslandPass();
|
||||
// `tf_device.parallel_execute` island.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateParallelExecuteToIslandsPass();
|
||||
|
||||
// Create a pass to parallelize TPU embedding params assigned to different
|
||||
// shards using the parallel_execte op.
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateParallelizeEmbeddingParamsOpsPass();
|
||||
|
||||
// Creates a pass that annotates whether a LaunchFuncOp's parameters have the
|
||||
// same data across replicas.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
|
Loading…
x
Reference in New Issue
Block a user