From a3672fefaeca40ab0f05a1eded1349b26ee100f7 Mon Sep 17 00:00:00 2001 From: Karim Nosir Date: Mon, 5 Aug 2019 11:50:50 -0700 Subject: [PATCH] Add legalization for TF.StopGradient to TF.Identity PiperOrigin-RevId: 261734406 --- .../compiler/mlir/lite/tests/prepare-tf.mlir | 8 +++++ .../mlir/lite/transforms/prepare_patterns.td | 2 ++ .../mlir/tensorflow/ir/tf_generated_ops.td | 36 +++++++++++++++++++ 3 files changed, 46 insertions(+) diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 8b0baaf5804..324e37d7f81 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -287,3 +287,11 @@ func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32> // CHECK: %7 = "tf.Transpose"(%arg1, %6) : (tensor<1280x1000xf32>, tensor) -> tensor<*xf32> // CHECK: %8 = "tf.MatMul"(%3, %7) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32> } + +func @stop_gradient(%arg0: tensor<3xi32>) -> tensor<3xi32> { + %0 = "tf.StopGradient"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> + return %0 : tensor<3xi32> + // Should be converted to Identity and then from Identity to value + // CHECK-LABEL: stop_gradient + // CHECK: return %arg0 : tensor<3xi32> +} diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td index 98248f2de90..a9263df9e79 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td @@ -81,6 +81,8 @@ def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrTrue, $bt), /*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))), $b, ConstBoolAttrFalse, $bt)>; +def : Pat<(TF_StopGradientOp $arg), (TF_IdentityOp $arg)>; + //===----------------------------------------------------------------------===// // Op removal patterns. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 40157b17429..4d6ad6ad19e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -3138,6 +3138,42 @@ shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_StopGradientOp : TF_Op<"StopGradient", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Stops gradient computation."; + + let description = [{ +When executed in a graph, this op outputs its input tensor as-is. + +When building ops to compute gradients, this op prevents the contribution of +its inputs to be taken into account. Normally, the gradient generator adds ops +to a graph to compute the derivatives of a specified 'loss' by recursively +finding out inputs that contributed to its computation. If you insert this op +in the graph it inputs are masked from the gradient generator. They are not +taken into account for computing gradients. + +This is useful any time you want to compute a value with TensorFlow but need +to pretend that the value was a constant. Some examples include: + +* The *EM* algorithm where the *M-step* should not involve backpropagation + through the output of the *E-step*. +* Contrastive divergence training of Boltzmann machines where, when + differentiating the energy function, the training must not backpropagate + through the graph that generated the samples from the model. +* Adversarial training, where no backprop should happen through the adversarial + example generation process. + }]; + + let arguments = (ins + TF_Tensor:$input + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_StridedSliceOp : TF_Op<"StridedSlice", [NoSideEffect]> { let summary = "Return a strided slice from `input`.";