Add legalization for TF.StopGradient to TF.Identity

PiperOrigin-RevId: 261734406
This commit is contained in:
Karim Nosir 2019-08-05 11:50:50 -07:00 committed by TensorFlower Gardener
parent 6edf8c57c7
commit a3672fefae
3 changed files with 46 additions and 0 deletions

View File

@ -287,3 +287,11 @@ func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32>
// CHECK: %7 = "tf.Transpose"(%arg1, %6) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %8 = "tf.MatMul"(%3, %7) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
}
func @stop_gradient(%arg0: tensor<3xi32>) -> tensor<3xi32> {
%0 = "tf.StopGradient"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
// Should be converted to Identity and then from Identity to value
// CHECK-LABEL: stop_gradient
// CHECK: return %arg0 : tensor<3xi32>
}

View File

@ -81,6 +81,8 @@ def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrTrue, $bt),
/*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))), $b,
ConstBoolAttrFalse, $bt)>;
def : Pat<(TF_StopGradientOp $arg), (TF_IdentityOp $arg)>;
//===----------------------------------------------------------------------===//
// Op removal patterns.
//===----------------------------------------------------------------------===//

View File

@ -3138,6 +3138,42 @@ shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_StopGradientOp : TF_Op<"StopGradient", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Stops gradient computation.";
let description = [{
When executed in a graph, this op outputs its input tensor as-is.
When building ops to compute gradients, this op prevents the contribution of
its inputs to be taken into account. Normally, the gradient generator adds ops
to a graph to compute the derivatives of a specified 'loss' by recursively
finding out inputs that contributed to its computation. If you insert this op
in the graph it inputs are masked from the gradient generator. They are not
taken into account for computing gradients.
This is useful any time you want to compute a value with TensorFlow but need
to pretend that the value was a constant. Some examples include:
* The *EM* algorithm where the *M-step* should not involve backpropagation
through the output of the *E-step*.
* Contrastive divergence training of Boltzmann machines where, when
differentiating the energy function, the training must not backpropagate
through the graph that generated the samples from the model.
* Adversarial training, where no backprop should happen through the adversarial
example generation process.
}];
let arguments = (ins
TF_Tensor:$input
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_StridedSliceOp : TF_Op<"StridedSlice", [NoSideEffect]> {
let summary = "Return a strided slice from `input`.";