Add legalization for TF.StopGradient to TF.Identity
PiperOrigin-RevId: 261734406
This commit is contained in:
parent
6edf8c57c7
commit
a3672fefae
@ -287,3 +287,11 @@ func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32>
|
|||||||
// CHECK: %7 = "tf.Transpose"(%arg1, %6) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
|
// CHECK: %7 = "tf.Transpose"(%arg1, %6) : (tensor<1280x1000xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||||
// CHECK: %8 = "tf.MatMul"(%3, %7) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
|
// CHECK: %8 = "tf.MatMul"(%3, %7) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @stop_gradient(%arg0: tensor<3xi32>) -> tensor<3xi32> {
|
||||||
|
%0 = "tf.StopGradient"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
|
||||||
|
return %0 : tensor<3xi32>
|
||||||
|
// Should be converted to Identity and then from Identity to value
|
||||||
|
// CHECK-LABEL: stop_gradient
|
||||||
|
// CHECK: return %arg0 : tensor<3xi32>
|
||||||
|
}
|
||||||
|
@ -81,6 +81,8 @@ def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrTrue, $bt),
|
|||||||
/*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))), $b,
|
/*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))), $b,
|
||||||
ConstBoolAttrFalse, $bt)>;
|
ConstBoolAttrFalse, $bt)>;
|
||||||
|
|
||||||
|
def : Pat<(TF_StopGradientOp $arg), (TF_IdentityOp $arg)>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Op removal patterns.
|
// Op removal patterns.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -3138,6 +3138,42 @@ shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_StopGradientOp : TF_Op<"StopGradient", [NoSideEffect, SameOperandsAndResultType]> {
|
||||||
|
let summary = "Stops gradient computation.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
When executed in a graph, this op outputs its input tensor as-is.
|
||||||
|
|
||||||
|
When building ops to compute gradients, this op prevents the contribution of
|
||||||
|
its inputs to be taken into account. Normally, the gradient generator adds ops
|
||||||
|
to a graph to compute the derivatives of a specified 'loss' by recursively
|
||||||
|
finding out inputs that contributed to its computation. If you insert this op
|
||||||
|
in the graph it inputs are masked from the gradient generator. They are not
|
||||||
|
taken into account for computing gradients.
|
||||||
|
|
||||||
|
This is useful any time you want to compute a value with TensorFlow but need
|
||||||
|
to pretend that the value was a constant. Some examples include:
|
||||||
|
|
||||||
|
* The *EM* algorithm where the *M-step* should not involve backpropagation
|
||||||
|
through the output of the *E-step*.
|
||||||
|
* Contrastive divergence training of Boltzmann machines where, when
|
||||||
|
differentiating the energy function, the training must not backpropagate
|
||||||
|
through the graph that generated the samples from the model.
|
||||||
|
* Adversarial training, where no backprop should happen through the adversarial
|
||||||
|
example generation process.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_Tensor:$input
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_Tensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_StridedSliceOp : TF_Op<"StridedSlice", [NoSideEffect]> {
|
def TF_StridedSliceOp : TF_Op<"StridedSlice", [NoSideEffect]> {
|
||||||
let summary = "Return a strided slice from `input`.";
|
let summary = "Return a strided slice from `input`.";
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user