From f27f2de368f56136035c3d3630616c8a82bc2b21 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Mon, 27 Jul 2020 13:19:53 -0700 Subject: [PATCH] Add a pass to parallelize TPU embedding params ops on different shards This pass moves LoadTPUEmbedding* and corresponding ReadVariable ops to different regions using the parallel_execute op. This parallel_execute op is later broken in different islands by ParallelExecuteToIslands pass and these islands can progress in parallel. This pass is required to avoid control dependencies between ops on different shards during export to the tf_executor dialect. Also, added this pass to the pass pipeline. PiperOrigin-RevId: 323426936 Change-Id: If250a57dfdd137ba25e265581a653dfa104323d3 --- tensorflow/compiler/mlir/tensorflow/BUILD | 2 + ...parallelize_embedding_params_ops_pass.mlir | 96 +++++++++++ .../mlir/tensorflow/transforms/bridge.cc | 1 + .../parallelize_embedding_params_ops_pass.cc | 152 ++++++++++++++++++ .../mlir/tensorflow/transforms/passes.h | 5 + 5 files changed, 256 insertions(+) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/parallelize_embedding_params_ops_pass.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 26c47e580e8..2a800cfc8c4 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -734,6 +734,7 @@ cc_library( "transforms/materialize_mlir_passthrough_op.cc", "transforms/optimize.cc", "transforms/parallel_execute_to_islands.cc", + "transforms/parallelize_embedding_params_ops_pass.cc", "transforms/promote_resources_to_args.cc", "transforms/readonly_references_to_resources.cc", "transforms/region_control_flow_to_functional.cc", @@ -808,6 +809,7 @@ cc_library( "//tensorflow/core/platform:random", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc", + "//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/parallelize_embedding_params_ops_pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/parallelize_embedding_params_ops_pass.mlir new file mode 100644 index 00000000000..e1cfaba5dcc --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/parallelize_embedding_params_ops_pass.mlir @@ -0,0 +1,96 @@ +// RUN: tf-opt %s -tf-parallize-embedding-params-ops -verify-diagnostics -split-input-file | FileCheck %s + +// CHECK-LABEL: func @two_shards +func @two_shards(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>, %arg2: tensor<*x!tf.resource>>, %arg3: tensor<*x!tf.resource>>) { + tf_executor.graph { + %control = tf_executor.island { + // CHECK: "tf_device.parallel_execute" + // CHECK: "tf.ReadVariableOp" + // CHECK: "tf.ReadVariableOp" + // CHECK: "tf.LoadTPUEmbeddingAdagradParameters" + // CHECK: tf_device.return + // CHECK: "tf.ReadVariableOp" + // CHECK: "tf.ReadVariableOp" + // CHECK: "tf.LoadTPUEmbeddingAdagradParameters" + // CHECK: tf_device.return + %0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + %1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + %2 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + %3 = "tf.ReadVariableOp"(%arg3) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + "tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:0/device:CPU:0", num_shards = 2 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> () + "tf.LoadTPUEmbeddingAdagradParameters"(%2, %3) {config = "", device = "/job:worker/replica:0/task:1/device:CPU:0", num_shards = 2 : i64, shard_id = 1 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> () + tf_executor.yield + } + tf_executor.fetch %control : !tf_executor.control + } + return +} + +// Verifies that resource reads shared across two shards are kept outside the +// parallel_execute op. + +// CHECK-LABEL: func @shared_reads +func @shared_reads(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>) { + tf_executor.graph { + %control = tf_executor.island { + // CHECK: "tf.ReadVariableOp" + %0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + // CHECK: "tf.ReadVariableOp" + %1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + + // CHECK: "tf_device.parallel_execute" + // CHECK: "tf.LoadTPUEmbeddingAdagradParameters" + // CHECK: tf_device.return + // CHECK: "tf.LoadTPUEmbeddingAdagradParameters" + // CHECK: tf_device.return + "tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:0/device:CPU:0", num_shards = 2 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> () + "tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:1/device:CPU:0", num_shards = 2 : i64, shard_id = 1 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> () + tf_executor.yield + } + tf_executor.fetch %control : !tf_executor.control + } + return +} + +// Verifies that if the resource variables are used in ops other than read +// variable op whose semantics are not known then the function is kept +// unchanged. + +// CHECK-LABEL: func @update_var +func @update_var(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>, %arg2: tensor<*x!tf.resource>>) { + tf_executor.graph { + // CHECK-NOT: tf_device.parallel_execute + %control = tf_executor.island { + %0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + %1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + "tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:0/device:CPU:0", num_shards = 2 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> () + + %2 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + %zeros = "tf.Const"() {value = dense<1.0> : tensor<8xf32>} : () -> tensor<8xf32> + "tf.AssignVariableOp"(%arg2, %zeros) : (tensor<*x!tf.resource>>, tensor<8xf32>) -> () + %3 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + "tf.LoadTPUEmbeddingAdagradParameters"(%2, %3) {config = "", device = "/job:worker/replica:0/task:1/device:CPU:0", num_shards = 2 : i64, shard_id = 1 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> () + tf_executor.yield + } + tf_executor.fetch %control : !tf_executor.control + } + return +} + +// ----- + +func @invalid_shard_range(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>) { + tf_executor.graph { + %control = tf_executor.island { + // expected-error @-1 {{require continuous range of shards}} + %0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + %1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + + "tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:0/device:CPU:0", num_shards = 3 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> () + "tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:1/device:CPU:0", num_shards = 3 : i64, shard_id = 3 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> () + tf_executor.yield + } + tf_executor.fetch %control : !tf_executor.control + } + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 35ffabb9131..783664960bc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -41,6 +41,7 @@ namespace TFTPU { namespace { void AddGraphExportLoweringPasses(OpPassManager &pm) { pm.addNestedPass(CreateFunctionalToExecutorDialectConversionPass()); + pm.addNestedPass(TFDevice::CreateParallelizeEmbeddingParamsOpsPass()); pm.addNestedPass(CreateBreakUpIslandsPass()); pm.addNestedPass(TFDevice::CreateReplicateToIslandPass()); pm.addNestedPass(CreateBreakUpIslandsPass()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc new file mode 100644 index 00000000000..527af0934ea --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc @@ -0,0 +1,152 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This transformation parallelizes TPU embedding params assigned to different +// shards using the parallel execute op. This is useful to avoid introducing +// control dependency between these ops that are known to be independent. + +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h" + +namespace mlir { +namespace TFDevice { + +namespace { + +struct ParallelizeEmbeddingParamsOpsPass + : public PassWrapper { + void runOnFunction() override; +}; + +bool IsLoadTPUEmbeddingParmasOp(Operation& op) { + static const auto* algorithms = []() { + auto* algorithms = new llvm::SmallSet(); + for (tensorflow::tpu::OptimizationAlgorithm alg : + tensorflow::tpu::GetOptimizationAlgorithms()) { + const auto alg_name = tensorflow::tpu::GetOptimizationAlgorithmName(alg); + algorithms->insert(alg_name); + } + return algorithms; + }(); + StringRef op_name = op.getName().getStringRef(); + return op_name.consume_front("tf.LoadTPUEmbedding") && + op_name.consume_back("Parameters") && + algorithms->contains(op_name.str()); +} + +static LogicalResult RunOnIsland(tf_executor::IslandOp island) { + Block* block = island.getBody(); + + // Map from op to the id of the shard it is assigned for ops that can execute + // in parallel across shards. + llvm::SmallMapVector assigned_shard; + llvm::SmallVector resources; + llvm::SmallSet shard_ids; + for (Operation& op : llvm::reverse(*block)) { + int64_t shard = -1; + if (IsLoadTPUEmbeddingParmasOp(op)) { + auto shard_id = op.getAttrOfType("shard_id"); + if (!shard_id) { + return op.emitOpError("requires 'shard_id' integer attribute"); + } + shard = shard_id.getInt(); + shard_ids.insert(shard); + } else if (auto read_op = llvm::dyn_cast(op)) { + if (assigned_shard.empty()) continue; + + for (Operation* user : op.getUsers()) { + auto iter = assigned_shard.find(user); + if (iter == assigned_shard.end() || + (shard != -1 && shard != iter->second)) { + shard = -1; + break; + } + shard = iter->second; + } + if (shard != -1) resources.push_back(read_op.resource()); + } + + if (shard != -1) assigned_shard.insert(std::make_pair(&op, shard)); + } + + // No transformations are required. + int num_shards = shard_ids.size(); + if (num_shards <= 1) return success(); + + // If the resources are used for ops other than read variable op, then moving + // read variable ops to the parallel_execute may not preserve the semantics. + for (Value resource : resources) { + for (Operation* user : resource.getUsers()) + if (!llvm::isa(*user)) return success(); + } + + // Create parallel_execute op at the end of the block and move operations + // to their corresponding shard. + auto builder = OpBuilder::atBlockTerminator(block); + auto parallel_execute_op = builder.create( + island.getLoc(), num_shards, llvm::ArrayRef()); + for (int shard_id = 0; shard_id < num_shards; ++shard_id) { + mlir::Block& b = parallel_execute_op.GetRegionBlockWithIndex(shard_id); + builder.setInsertionPointToStart(&b); + builder.create(island.getLoc()); + } + + for (auto op_shard : assigned_shard) { + int64_t shard = op_shard.second; + if (shard >= num_shards) { + return island.emitOpError( + "load tpu embedding ops require continuous range of shards"); + } + mlir::Block& b = parallel_execute_op.GetRegionBlockWithIndex(shard); + op_shard.first->moveBefore(&b, b.begin()); + } + return success(); +} + +void ParallelizeEmbeddingParamsOpsPass::runOnFunction() { + getFunction().walk([&](tf_executor::IslandOp island) { + if (failed(RunOnIsland(island))) { + signalPassFailure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); +} + +} // namespace + +std::unique_ptr> +CreateParallelizeEmbeddingParamsOpsPass() { + return std::make_unique(); +} +} // namespace TFDevice +} // namespace mlir + +static mlir::PassRegistration + pass("tf-parallize-embedding-params-ops", + "Parallelizes TPU embedding params assigned to different shards using " + "the parallel_execte op"); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 68bc9d09e91..9c8790afa1d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -242,6 +242,11 @@ std::unique_ptr> CreateReplicateToIslandPass(); // `tf_device.parallel_execute` island. std::unique_ptr> CreateParallelExecuteToIslandsPass(); +// Create a pass to parallelize TPU embedding params assigned to different +// shards using the parallel_execte op. +std::unique_ptr> +CreateParallelizeEmbeddingParamsOpsPass(); + // Creates a pass that annotates whether a LaunchFuncOp's parameters have the // same data across replicas. std::unique_ptr>