Update TPUUpdateEmbeddingEnqueueInput pass to create mode constants rather than

depend on inputs from SelectV2 op.

SelectV2 op may be constant folded away if the conditional value is from a
const op. As so, create mode constant ("train" or "inference") based on
presence of gradient op.

PiperOrigin-RevId: 322151534
Change-Id: I6dbbafe2173af0270e95237fcc30b0e818cbf3ba
This commit is contained in:
A. Unique TensorFlower 2020-07-20 08:43:51 -07:00 committed by TensorFlower Gardener
parent 7feafa7d70
commit cb1119ba71
2 changed files with 39 additions and 27 deletions

View File

@ -9,16 +9,15 @@
// CHECK-SAME: %[[ARG_5:[a-z0-9]*]]: tensor<?xi32>
// CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor<!tf.string>
// CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor<!tf.string>
// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor<i1>
func @check_enqueue_ops_update_for_eval(%arg0: tensor<?x2xi32>, %arg1: tensor<?x2xi32>,
%arg2 :tensor<?x2xi32>, %arg3: tensor<?xi32>, %arg4: tensor<?xi32>, %arg5: tensor<?xi32>,
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>, %arg8: tensor<i1>) -> () {
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>) -> () {
// CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"()
%0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
%1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>) -> tensor<!tf.string>
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_7]])
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
// CHECK: %[[CONST_MODE:[a-z0-9]*]] = "tf.Const"() {value = dense<"inference"> : tensor<!tf.string>} : () -> tensor<!tf.string>
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[CONST_MODE]])
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %arg7) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
%2:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
return
}
@ -34,20 +33,19 @@ func @check_enqueue_ops_update_for_eval(%arg0: tensor<?x2xi32>, %arg1: tensor<?x
// CHECK-SAME: %[[ARG_5:[a-z0-9]*]]: tensor<?xi32>
// CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor<!tf.string>
// CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor<!tf.string>
// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor<i1>
func @check_enqueue_ops_update_for_training(%arg0: tensor<?x2xi32>, %arg1: tensor<?x2xi32>,
%arg2 :tensor<?x2xi32>, %arg3: tensor<?xi32>, %arg4: tensor<?xi32>, %arg5: tensor<?xi32>,
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>, %arg8: tensor<i1>) -> () {
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>) -> () {
// CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"()
%0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
%1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>) -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<0.0> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
%3 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
"tf.SendTPUEmbeddingGradients"(%2, %3) {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D", operand_segment_sizes = dense<[2, 0]> : vector<2xi32>} : (tensor<2x2xf32>, tensor<4x4xf32>) -> ()
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_6]])
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
// CHECK: %[[CONST_MODE:[a-z0-9]*]] = "tf.Const"() {value = dense<"train"> : tensor<!tf.string>} : () -> tensor<!tf.string>
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[CONST_MODE]])
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %arg7) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
%4:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
return
}

View File

@ -13,18 +13,22 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir {
@ -86,7 +90,8 @@ LogicalResult FindTPUEmbeddingOps(
LogicalResult UpdateEmbeddingEnqueueOpInput(
const llvm::StringMap<Operation*>& enqueue_op_map,
const llvm::StringMap<Operation*>& recv_activation_op_map,
const llvm::StringMap<Operation*>& send_gradient_op_map) {
const llvm::StringMap<Operation*>& send_gradient_op_map,
OpBuilder* builder) {
for (const auto& it : enqueue_op_map) {
const auto& embedding_attr = it.getKey();
Operation* embedding_op = it.second;
@ -96,21 +101,29 @@ LogicalResult UpdateEmbeddingEnqueueOpInput(
<< TF::RecvTPUEmbeddingActivationsOp::getOperationName() << "' op";
// TPU Embedding enqueue ops take different inputs depending on whether
// graph is in training mode or in eval/prediction mode. The inputs to the
// enqueue ops are present/listed as operands to SelectV2 op. Then branch
// operand of the SelectV2 op represents input to take during training
// and else branch operand represents input to take during
// prediction/evaluation. If SendTPUEmbeddingGradients op exists in the
// graph, then graph is in training mode, so correctly forward the input
// of SelectV2 op as operand to the TPU embedding enqueue op.
// graph is in training mode or in eval/prediction mode. During training,
// the mode parameter for TPUEmbeddingEnqueue op must be `train` and for
// evaluation or prediction, mode must be set to `inference`.
// If SendTPUEmbeddingGradients op exists in the graph, then graph is
// in training mode, so create a const op with value `train` use the
// output value of the constant as an operand to the TPU embedding
// enqueue op.
bool is_training = send_gradient_op_map.count(embedding_attr);
for (auto enqueue_operand : embedding_op->getOperands()) {
if (auto select = llvm::dyn_cast_or_null<TF::SelectV2Op>(
enqueue_operand.getDefiningOp())) {
enqueue_operand.replaceAllUsesWith(is_training ? select.t()
: select.e());
}
}
// The last operand of TPUEmbeddingEnqueue ops is the mode which
// represents whether graph is in training mode or in evaluation mode.
auto& mode_enqueue_operand =
embedding_op->getOpOperand(embedding_op->getNumOperands() - 1);
llvm::SmallVector<StringRef, 1> mode_string_value;
mode_string_value.emplace_back(is_training ? "train" : "inference");
builder->setInsertionPoint(embedding_op);
auto enqueue_mode = builder->create<TF::ConstOp>(
embedding_op->getLoc(),
DenseStringElementsAttr::get(
RankedTensorType::get({}, builder->getType<TF::StringType>()),
mode_string_value));
mode_enqueue_operand.set(enqueue_mode);
}
return success();
@ -140,8 +153,9 @@ void TPUUpdateEmbeddingEnqueueOpInputs::runOnFunction() {
return signalPassFailure();
}
if (failed(UpdateEmbeddingEnqueueOpInput(
enqueue_op_map, recv_activation_op_map, send_gradient_op_map)))
if (failed(UpdateEmbeddingEnqueueOpInput(enqueue_op_map,
recv_activation_op_map,
send_gradient_op_map, &builder)))
return signalPassFailure();
}