Add a pass to parallelize TPU embedding params ops on different shards

This pass moves LoadTPUEmbedding* and corresponding ReadVariable ops to different regions using the parallel_execute op. This parallel_execute op is later broken in different islands by ParallelExecuteToIslands pass and these islands can progress in parallel.

This pass is required to avoid control dependencies between ops on different shards during export to the tf_executor dialect.

Also, added this pass to the pass pipeline.

PiperOrigin-RevId: 323426936
Change-Id: If250a57dfdd137ba25e265581a653dfa104323d3
This commit is contained in:
Smit Hinsu 2020-07-27 13:19:53 -07:00 committed by TensorFlower Gardener
parent c9ec28eebe
commit f27f2de368
5 changed files with 256 additions and 0 deletions

View File

@ -734,6 +734,7 @@ cc_library(
"transforms/materialize_mlir_passthrough_op.cc",
"transforms/optimize.cc",
"transforms/parallel_execute_to_islands.cc",
"transforms/parallelize_embedding_params_ops_pass.cc",
"transforms/promote_resources_to_args.cc",
"transforms/readonly_references_to_resources.cc",
"transforms/region_control_flow_to_functional.cc",
@ -808,6 +809,7 @@ cc_library(
"//tensorflow/core/platform:random",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
"//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",

View File

@ -0,0 +1,96 @@
// RUN: tf-opt %s -tf-parallize-embedding-params-ops -verify-diagnostics -split-input-file | FileCheck %s
// CHECK-LABEL: func @two_shards
func @two_shards(%arg0: tensor<*x!tf.resource<tensor<8xf32>>>, %arg1: tensor<*x!tf.resource<tensor<8xf32>>>, %arg2: tensor<*x!tf.resource<tensor<8xf32>>>, %arg3: tensor<*x!tf.resource<tensor<8xf32>>>) {
tf_executor.graph {
%control = tf_executor.island {
// CHECK: "tf_device.parallel_execute"
// CHECK: "tf.ReadVariableOp"
// CHECK: "tf.ReadVariableOp"
// CHECK: "tf.LoadTPUEmbeddingAdagradParameters"
// CHECK: tf_device.return
// CHECK: "tf.ReadVariableOp"
// CHECK: "tf.ReadVariableOp"
// CHECK: "tf.LoadTPUEmbeddingAdagradParameters"
// CHECK: tf_device.return
%0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
%1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
%2 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
%3 = "tf.ReadVariableOp"(%arg3) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
"tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:0/device:CPU:0", num_shards = 2 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> ()
"tf.LoadTPUEmbeddingAdagradParameters"(%2, %3) {config = "", device = "/job:worker/replica:0/task:1/device:CPU:0", num_shards = 2 : i64, shard_id = 1 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> ()
tf_executor.yield
}
tf_executor.fetch %control : !tf_executor.control
}
return
}
// Verifies that resource reads shared across two shards are kept outside the
// parallel_execute op.
// CHECK-LABEL: func @shared_reads
func @shared_reads(%arg0: tensor<*x!tf.resource<tensor<8xf32>>>, %arg1: tensor<*x!tf.resource<tensor<8xf32>>>) {
tf_executor.graph {
%control = tf_executor.island {
// CHECK: "tf.ReadVariableOp"
%0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
// CHECK: "tf.ReadVariableOp"
%1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
// CHECK: "tf_device.parallel_execute"
// CHECK: "tf.LoadTPUEmbeddingAdagradParameters"
// CHECK: tf_device.return
// CHECK: "tf.LoadTPUEmbeddingAdagradParameters"
// CHECK: tf_device.return
"tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:0/device:CPU:0", num_shards = 2 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> ()
"tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:1/device:CPU:0", num_shards = 2 : i64, shard_id = 1 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> ()
tf_executor.yield
}
tf_executor.fetch %control : !tf_executor.control
}
return
}
// Verifies that if the resource variables are used in ops other than read
// variable op whose semantics are not known then the function is kept
// unchanged.
// CHECK-LABEL: func @update_var
func @update_var(%arg0: tensor<*x!tf.resource<tensor<8xf32>>>, %arg1: tensor<*x!tf.resource<tensor<8xf32>>>, %arg2: tensor<*x!tf.resource<tensor<8xf32>>>) {
tf_executor.graph {
// CHECK-NOT: tf_device.parallel_execute
%control = tf_executor.island {
%0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
%1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
"tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:0/device:CPU:0", num_shards = 2 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> ()
%2 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
%zeros = "tf.Const"() {value = dense<1.0> : tensor<8xf32>} : () -> tensor<8xf32>
"tf.AssignVariableOp"(%arg2, %zeros) : (tensor<*x!tf.resource<tensor<8xf32>>>, tensor<8xf32>) -> ()
%3 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
"tf.LoadTPUEmbeddingAdagradParameters"(%2, %3) {config = "", device = "/job:worker/replica:0/task:1/device:CPU:0", num_shards = 2 : i64, shard_id = 1 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> ()
tf_executor.yield
}
tf_executor.fetch %control : !tf_executor.control
}
return
}
// -----
func @invalid_shard_range(%arg0: tensor<*x!tf.resource<tensor<8xf32>>>, %arg1: tensor<*x!tf.resource<tensor<8xf32>>>) {
tf_executor.graph {
%control = tf_executor.island {
// expected-error @-1 {{require continuous range of shards}}
%0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
%1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
"tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:0/device:CPU:0", num_shards = 3 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> ()
"tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:1/device:CPU:0", num_shards = 3 : i64, shard_id = 3 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> ()
tf_executor.yield
}
tf_executor.fetch %control : !tf_executor.control
}
return
}

View File

@ -41,6 +41,7 @@ namespace TFTPU {
namespace {
void AddGraphExportLoweringPasses(OpPassManager &pm) {
pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass());
pm.addNestedPass<FuncOp>(TFDevice::CreateParallelizeEmbeddingParamsOpsPass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateToIslandPass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());

View File

@ -0,0 +1,152 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation parallelizes TPU embedding params assigned to different
// shards using the parallel execute op. This is useful to avoid introducing
// control dependency between these ops that are known to be independent.
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h"
namespace mlir {
namespace TFDevice {
namespace {
struct ParallelizeEmbeddingParamsOpsPass
: public PassWrapper<ParallelizeEmbeddingParamsOpsPass, FunctionPass> {
void runOnFunction() override;
};
bool IsLoadTPUEmbeddingParmasOp(Operation& op) {
static const auto* algorithms = []() {
auto* algorithms = new llvm::SmallSet<std::string, 16>();
for (tensorflow::tpu::OptimizationAlgorithm alg :
tensorflow::tpu::GetOptimizationAlgorithms()) {
const auto alg_name = tensorflow::tpu::GetOptimizationAlgorithmName(alg);
algorithms->insert(alg_name);
}
return algorithms;
}();
StringRef op_name = op.getName().getStringRef();
return op_name.consume_front("tf.LoadTPUEmbedding") &&
op_name.consume_back("Parameters") &&
algorithms->contains(op_name.str());
}
static LogicalResult RunOnIsland(tf_executor::IslandOp island) {
Block* block = island.getBody();
// Map from op to the id of the shard it is assigned for ops that can execute
// in parallel across shards.
llvm::SmallMapVector<Operation*, int64_t, 4> assigned_shard;
llvm::SmallVector<Value, 8> resources;
llvm::SmallSet<int64_t, 16> shard_ids;
for (Operation& op : llvm::reverse(*block)) {
int64_t shard = -1;
if (IsLoadTPUEmbeddingParmasOp(op)) {
auto shard_id = op.getAttrOfType<mlir::IntegerAttr>("shard_id");
if (!shard_id) {
return op.emitOpError("requires 'shard_id' integer attribute");
}
shard = shard_id.getInt();
shard_ids.insert(shard);
} else if (auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(op)) {
if (assigned_shard.empty()) continue;
for (Operation* user : op.getUsers()) {
auto iter = assigned_shard.find(user);
if (iter == assigned_shard.end() ||
(shard != -1 && shard != iter->second)) {
shard = -1;
break;
}
shard = iter->second;
}
if (shard != -1) resources.push_back(read_op.resource());
}
if (shard != -1) assigned_shard.insert(std::make_pair(&op, shard));
}
// No transformations are required.
int num_shards = shard_ids.size();
if (num_shards <= 1) return success();
// If the resources are used for ops other than read variable op, then moving
// read variable ops to the parallel_execute may not preserve the semantics.
for (Value resource : resources) {
for (Operation* user : resource.getUsers())
if (!llvm::isa<TF::ReadVariableOp>(*user)) return success();
}
// Create parallel_execute op at the end of the block and move operations
// to their corresponding shard.
auto builder = OpBuilder::atBlockTerminator(block);
auto parallel_execute_op = builder.create<tf_device::ParallelExecuteOp>(
island.getLoc(), num_shards, llvm::ArrayRef<Type>());
for (int shard_id = 0; shard_id < num_shards; ++shard_id) {
mlir::Block& b = parallel_execute_op.GetRegionBlockWithIndex(shard_id);
builder.setInsertionPointToStart(&b);
builder.create<tf_device::ReturnOp>(island.getLoc());
}
for (auto op_shard : assigned_shard) {
int64_t shard = op_shard.second;
if (shard >= num_shards) {
return island.emitOpError(
"load tpu embedding ops require continuous range of shards");
}
mlir::Block& b = parallel_execute_op.GetRegionBlockWithIndex(shard);
op_shard.first->moveBefore(&b, b.begin());
}
return success();
}
void ParallelizeEmbeddingParamsOpsPass::runOnFunction() {
getFunction().walk([&](tf_executor::IslandOp island) {
if (failed(RunOnIsland(island))) {
signalPassFailure();
return WalkResult::interrupt();
}
return WalkResult::advance();
});
}
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
CreateParallelizeEmbeddingParamsOpsPass() {
return std::make_unique<ParallelizeEmbeddingParamsOpsPass>();
}
} // namespace TFDevice
} // namespace mlir
static mlir::PassRegistration<mlir::TFDevice::ParallelizeEmbeddingParamsOpsPass>
pass("tf-parallize-embedding-params-ops",
"Parallelizes TPU embedding params assigned to different shards using "
"the parallel_execte op");

View File

@ -242,6 +242,11 @@ std::unique_ptr<OperationPass<FuncOp>> CreateReplicateToIslandPass();
// `tf_device.parallel_execute` island.
std::unique_ptr<OperationPass<FuncOp>> CreateParallelExecuteToIslandsPass();
// Create a pass to parallelize TPU embedding params assigned to different
// shards using the parallel_execte op.
std::unique_ptr<OperationPass<FuncOp>>
CreateParallelizeEmbeddingParamsOpsPass();
// Creates a pass that annotates whether a LaunchFuncOp's parameters have the
// same data across replicas.
std::unique_ptr<OperationPass<ModuleOp>>