Raise error if trying to use Gradients when scope have device or

control dependencies specified(Gradients doesn't currently support them).
This commit is contained in:
Cibifang 2018-11-21 13:34:07 +08:00
parent 07f6ed9896
commit a911ecf5a9
2 changed files with 45 additions and 1 deletions

View File

@ -16,7 +16,12 @@ limitations under the License.
package op
import tf "github.com/tensorflow/tensorflow/tensorflow/go"
import (
"fmt"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
// Gradients adds gradients computation ops to the graph according to scope.
//
@ -27,6 +32,15 @@ import tf "github.com/tensorflow/tensorflow/tensorflow/go"
//
// return the partial derivatives
func Gradients(scope *Scope, y []tf.Output, x []tf.Output, dx ...tf.Output) (output []tf.Output) {
if len(scope.controlDependencies) > 0 {
scope.UpdateErr("Gradients", fmt.Errorf("Gradients does not currently support control dependencies (via Scope.WithControlDependencies)."))
return
}
if scope.device != "" {
scope.UpdateErr("Gradients", fmt.Errorf("Gradients does not currently support device annotations (via Scope.WithDevice)."))
return
}
var err error
if output, err = scope.graph.AddGradients(scope.opName("Gradients"), y, x, dx); err != nil {
scope.UpdateErr("Gradients", err)

View File

@ -214,3 +214,33 @@ func TestValidateGradientsNames(t *testing.T) {
t.Error("Gradients should have failed if executed more than once for scope of the same namespace")
}
}
func TestAddGradientsWithControlDependencies(t *testing.T) {
var (
s = NewScope()
zero = Const(s.SubScope("zero"), int32(0))
x = Placeholder(s.SubScope("x"), tf.Float)
y0 = Square(s.SubScope("y0"), x)
variable = VarHandleOp(s, tf.Int32, tf.ScalarShape())
init = AssignVariableOp(s, variable, zero)
readDeps = []*tf.Operation{init}
)
s = s.WithControlDependencies(readDeps...)
Gradients(s, []tf.Output{y0}, []tf.Output{x})
if err := s.Err(); err == nil {
t.Error("Gradients should have failed when control dependencies are set")
}
}
func TestAddGradientsWithDevice(t *testing.T) {
var (
s = NewScope()
x = Placeholder(s.SubScope("x"), tf.Float)
y0 = Square(s.SubScope("y0"), x)
)
s = s.WithDevice("/device:GPU:0")
Gradients(s, []tf.Output{y0}, []tf.Output{x})
if err := s.Err(); err == nil {
t.Error("Gradients should have failed when device is set")
}
}