Update TPUUpdateEmbeddingEnqueueInput pass to create mode constants rather than
depend on inputs from SelectV2 op. SelectV2 op may be constant folded away if the conditional value is from a const op. As so, create mode constant ("train" or "inference") based on presence of gradient op. PiperOrigin-RevId: 322151534 Change-Id: I6dbbafe2173af0270e95237fcc30b0e818cbf3ba
This commit is contained in:
parent
7feafa7d70
commit
cb1119ba71
@ -9,16 +9,15 @@
|
||||
// CHECK-SAME: %[[ARG_5:[a-z0-9]*]]: tensor<?xi32>
|
||||
// CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor<!tf.string>
|
||||
// CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor<!tf.string>
|
||||
// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor<i1>
|
||||
func @check_enqueue_ops_update_for_eval(%arg0: tensor<?x2xi32>, %arg1: tensor<?x2xi32>,
|
||||
%arg2 :tensor<?x2xi32>, %arg3: tensor<?xi32>, %arg4: tensor<?xi32>, %arg5: tensor<?xi32>,
|
||||
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>, %arg8: tensor<i1>) -> () {
|
||||
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>) -> () {
|
||||
// CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"()
|
||||
%0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
|
||||
%1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>) -> tensor<!tf.string>
|
||||
|
||||
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_7]])
|
||||
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
|
||||
// CHECK: %[[CONST_MODE:[a-z0-9]*]] = "tf.Const"() {value = dense<"inference"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[CONST_MODE]])
|
||||
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %arg7) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
|
||||
%2:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
|
||||
return
|
||||
}
|
||||
@ -34,20 +33,19 @@ func @check_enqueue_ops_update_for_eval(%arg0: tensor<?x2xi32>, %arg1: tensor<?x
|
||||
// CHECK-SAME: %[[ARG_5:[a-z0-9]*]]: tensor<?xi32>
|
||||
// CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor<!tf.string>
|
||||
// CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor<!tf.string>
|
||||
// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor<i1>
|
||||
func @check_enqueue_ops_update_for_training(%arg0: tensor<?x2xi32>, %arg1: tensor<?x2xi32>,
|
||||
%arg2 :tensor<?x2xi32>, %arg3: tensor<?xi32>, %arg4: tensor<?xi32>, %arg5: tensor<?xi32>,
|
||||
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>, %arg8: tensor<i1>) -> () {
|
||||
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>) -> () {
|
||||
// CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"()
|
||||
%0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
|
||||
%1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>) -> tensor<!tf.string>
|
||||
|
||||
%2 = "tf.Const"() {value = dense<0.0> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
|
||||
%3 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
|
||||
"tf.SendTPUEmbeddingGradients"(%2, %3) {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D", operand_segment_sizes = dense<[2, 0]> : vector<2xi32>} : (tensor<2x2xf32>, tensor<4x4xf32>) -> ()
|
||||
|
||||
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_6]])
|
||||
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
|
||||
// CHECK: %[[CONST_MODE:[a-z0-9]*]] = "tf.Const"() {value = dense<"train"> : tensor<!tf.string>} : () -> tensor<!tf.string>
|
||||
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[CONST_MODE]])
|
||||
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %arg7) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
|
||||
%4:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
|
||||
return
|
||||
}
|
||||
|
@ -13,18 +13,22 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -86,7 +90,8 @@ LogicalResult FindTPUEmbeddingOps(
|
||||
LogicalResult UpdateEmbeddingEnqueueOpInput(
|
||||
const llvm::StringMap<Operation*>& enqueue_op_map,
|
||||
const llvm::StringMap<Operation*>& recv_activation_op_map,
|
||||
const llvm::StringMap<Operation*>& send_gradient_op_map) {
|
||||
const llvm::StringMap<Operation*>& send_gradient_op_map,
|
||||
OpBuilder* builder) {
|
||||
for (const auto& it : enqueue_op_map) {
|
||||
const auto& embedding_attr = it.getKey();
|
||||
Operation* embedding_op = it.second;
|
||||
@ -96,21 +101,29 @@ LogicalResult UpdateEmbeddingEnqueueOpInput(
|
||||
<< TF::RecvTPUEmbeddingActivationsOp::getOperationName() << "' op";
|
||||
|
||||
// TPU Embedding enqueue ops take different inputs depending on whether
|
||||
// graph is in training mode or in eval/prediction mode. The inputs to the
|
||||
// enqueue ops are present/listed as operands to SelectV2 op. Then branch
|
||||
// operand of the SelectV2 op represents input to take during training
|
||||
// and else branch operand represents input to take during
|
||||
// prediction/evaluation. If SendTPUEmbeddingGradients op exists in the
|
||||
// graph, then graph is in training mode, so correctly forward the input
|
||||
// of SelectV2 op as operand to the TPU embedding enqueue op.
|
||||
// graph is in training mode or in eval/prediction mode. During training,
|
||||
// the mode parameter for TPUEmbeddingEnqueue op must be `train` and for
|
||||
// evaluation or prediction, mode must be set to `inference`.
|
||||
// If SendTPUEmbeddingGradients op exists in the graph, then graph is
|
||||
// in training mode, so create a const op with value `train` use the
|
||||
// output value of the constant as an operand to the TPU embedding
|
||||
// enqueue op.
|
||||
bool is_training = send_gradient_op_map.count(embedding_attr);
|
||||
for (auto enqueue_operand : embedding_op->getOperands()) {
|
||||
if (auto select = llvm::dyn_cast_or_null<TF::SelectV2Op>(
|
||||
enqueue_operand.getDefiningOp())) {
|
||||
enqueue_operand.replaceAllUsesWith(is_training ? select.t()
|
||||
: select.e());
|
||||
}
|
||||
}
|
||||
|
||||
// The last operand of TPUEmbeddingEnqueue ops is the mode which
|
||||
// represents whether graph is in training mode or in evaluation mode.
|
||||
auto& mode_enqueue_operand =
|
||||
embedding_op->getOpOperand(embedding_op->getNumOperands() - 1);
|
||||
|
||||
llvm::SmallVector<StringRef, 1> mode_string_value;
|
||||
mode_string_value.emplace_back(is_training ? "train" : "inference");
|
||||
builder->setInsertionPoint(embedding_op);
|
||||
auto enqueue_mode = builder->create<TF::ConstOp>(
|
||||
embedding_op->getLoc(),
|
||||
DenseStringElementsAttr::get(
|
||||
RankedTensorType::get({}, builder->getType<TF::StringType>()),
|
||||
mode_string_value));
|
||||
mode_enqueue_operand.set(enqueue_mode);
|
||||
}
|
||||
|
||||
return success();
|
||||
@ -140,8 +153,9 @@ void TPUUpdateEmbeddingEnqueueOpInputs::runOnFunction() {
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
if (failed(UpdateEmbeddingEnqueueOpInput(
|
||||
enqueue_op_map, recv_activation_op_map, send_gradient_op_map)))
|
||||
if (failed(UpdateEmbeddingEnqueueOpInput(enqueue_op_map,
|
||||
recv_activation_op_map,
|
||||
send_gradient_op_map, &builder)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user