Raise error if trying to use Gradients when scope have device or
control dependencies specified(Gradients doesn't currently support them).
This commit is contained in:
parent
07f6ed9896
commit
a911ecf5a9
@ -16,7 +16,12 @@ limitations under the License.
|
||||
|
||||
package op
|
||||
|
||||
import tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
||||
)
|
||||
|
||||
|
||||
// Gradients adds gradients computation ops to the graph according to scope.
|
||||
//
|
||||
@ -27,6 +32,15 @@ import tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
||||
//
|
||||
// return the partial derivatives
|
||||
func Gradients(scope *Scope, y []tf.Output, x []tf.Output, dx ...tf.Output) (output []tf.Output) {
|
||||
if len(scope.controlDependencies) > 0 {
|
||||
scope.UpdateErr("Gradients", fmt.Errorf("Gradients does not currently support control dependencies (via Scope.WithControlDependencies)."))
|
||||
return
|
||||
}
|
||||
if scope.device != "" {
|
||||
scope.UpdateErr("Gradients", fmt.Errorf("Gradients does not currently support device annotations (via Scope.WithDevice)."))
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
if output, err = scope.graph.AddGradients(scope.opName("Gradients"), y, x, dx); err != nil {
|
||||
scope.UpdateErr("Gradients", err)
|
||||
|
@ -214,3 +214,33 @@ func TestValidateGradientsNames(t *testing.T) {
|
||||
t.Error("Gradients should have failed if executed more than once for scope of the same namespace")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddGradientsWithControlDependencies(t *testing.T) {
|
||||
var (
|
||||
s = NewScope()
|
||||
zero = Const(s.SubScope("zero"), int32(0))
|
||||
x = Placeholder(s.SubScope("x"), tf.Float)
|
||||
y0 = Square(s.SubScope("y0"), x)
|
||||
variable = VarHandleOp(s, tf.Int32, tf.ScalarShape())
|
||||
init = AssignVariableOp(s, variable, zero)
|
||||
readDeps = []*tf.Operation{init}
|
||||
)
|
||||
s = s.WithControlDependencies(readDeps...)
|
||||
Gradients(s, []tf.Output{y0}, []tf.Output{x})
|
||||
if err := s.Err(); err == nil {
|
||||
t.Error("Gradients should have failed when control dependencies are set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddGradientsWithDevice(t *testing.T) {
|
||||
var (
|
||||
s = NewScope()
|
||||
x = Placeholder(s.SubScope("x"), tf.Float)
|
||||
y0 = Square(s.SubScope("y0"), x)
|
||||
)
|
||||
s = s.WithDevice("/device:GPU:0")
|
||||
Gradients(s, []tf.Output{y0}, []tf.Output{x})
|
||||
if err := s.Err(); err == nil {
|
||||
t.Error("Gradients should have failed when device is set")
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user