diff --git a/tensorflow/go/op/gradients.go b/tensorflow/go/op/gradients.go index 2eaa7e70ab2..9f892e1da6c 100644 --- a/tensorflow/go/op/gradients.go +++ b/tensorflow/go/op/gradients.go @@ -16,7 +16,12 @@ limitations under the License. package op -import tf "github.com/tensorflow/tensorflow/tensorflow/go" +import ( + "fmt" + + tf "github.com/tensorflow/tensorflow/tensorflow/go" +) + // Gradients adds gradients computation ops to the graph according to scope. // @@ -27,6 +32,15 @@ import tf "github.com/tensorflow/tensorflow/tensorflow/go" // // return the partial derivatives func Gradients(scope *Scope, y []tf.Output, x []tf.Output, dx ...tf.Output) (output []tf.Output) { + if len(scope.controlDependencies) > 0 { + scope.UpdateErr("Gradients", fmt.Errorf("Gradients does not currently support control dependencies (via Scope.WithControlDependencies).")) + return + } + if scope.device != "" { + scope.UpdateErr("Gradients", fmt.Errorf("Gradients does not currently support device annotations (via Scope.WithDevice).")) + return + } + var err error if output, err = scope.graph.AddGradients(scope.opName("Gradients"), y, x, dx); err != nil { scope.UpdateErr("Gradients", err) diff --git a/tensorflow/go/op/gradients_test.go b/tensorflow/go/op/gradients_test.go index 1febd083663..3d1d57b77ea 100644 --- a/tensorflow/go/op/gradients_test.go +++ b/tensorflow/go/op/gradients_test.go @@ -214,3 +214,33 @@ func TestValidateGradientsNames(t *testing.T) { t.Error("Gradients should have failed if executed more than once for scope of the same namespace") } } + +func TestAddGradientsWithControlDependencies(t *testing.T) { + var ( + s = NewScope() + zero = Const(s.SubScope("zero"), int32(0)) + x = Placeholder(s.SubScope("x"), tf.Float) + y0 = Square(s.SubScope("y0"), x) + variable = VarHandleOp(s, tf.Int32, tf.ScalarShape()) + init = AssignVariableOp(s, variable, zero) + readDeps = []*tf.Operation{init} + ) + s = s.WithControlDependencies(readDeps...) + Gradients(s, []tf.Output{y0}, []tf.Output{x}) + if err := s.Err(); err == nil { + t.Error("Gradients should have failed when control dependencies are set") + } +} + +func TestAddGradientsWithDevice(t *testing.T) { + var ( + s = NewScope() + x = Placeholder(s.SubScope("x"), tf.Float) + y0 = Square(s.SubScope("y0"), x) + ) + s = s.WithDevice("/device:GPU:0") + Gradients(s, []tf.Output{y0}, []tf.Output{x}) + if err := s.Err(); err == nil { + t.Error("Gradients should have failed when device is set") + } +}