diff --git a/tensorflow/go/op/scope.go b/tensorflow/go/op/scope.go index 2cf1f30187c..13de4294dc2 100644 --- a/tensorflow/go/op/scope.go +++ b/tensorflow/go/op/scope.go @@ -109,15 +109,20 @@ func (s *Scope) SubScope(namespace string) *Scope { // added to the graph to execute only after all the provided operations have // executed first (in addition to any other control dependencies in s). func (s *Scope) WithControlDependencies(ops ...*tf.Operation) *Scope { + // Force a copy of the control dependencies into a new underlying array on + // every call. We cannot alias the same underlying array as `ops`, otherwise + // the user could modify that array after calling s.WithControlDependencies, + // which would be confusing. We cannot alias the same underlying array as the + // original `s.controlDependencies`, since Scopes form a logical tree, and + // other calls to s.WithControlDependencies could stomp on each other. + deps := make([]*tf.Operation, 0, len(s.controlDependencies)+len(ops)) + deps = append(deps, s.controlDependencies...) + deps = append(deps, ops...) return &Scope{ - graph: s.graph, - namemap: s.namemap, - namespace: s.namespace, - // append(ops, s.controlDependencies) and not the other way - // around so that we end up with a copy of the underlying array - // (and other calls to s.WithControlDependencies() do not stomp - // on each other). - controlDependencies: append(ops, s.controlDependencies...), + graph: s.graph, + namemap: s.namemap, + namespace: s.namespace, + controlDependencies: deps, err: s.err, } } diff --git a/tensorflow/go/op/scope_test.go b/tensorflow/go/op/scope_test.go index 4f533881d06..b58a61de98b 100644 --- a/tensorflow/go/op/scope_test.go +++ b/tensorflow/go/op/scope_test.go @@ -77,8 +77,17 @@ func TestControlDependencies(t *testing.T) { variable = VarHandleOp(s, tf.Int32, tf.ScalarShape()) init = AssignVariableOp(s, variable, zero) update = AssignAddVariableOp(s, variable, one) - read = ReadVariableOp(s.WithControlDependencies(update), variable, tf.Int32) + readDeps = []*tf.Operation{update} ) + // We intend for `read` to have a control dependency on `update`. + s = s.WithControlDependencies(readDeps...) + // Ensure that Scope.WithControlDependencies makes a copy of the underlying + // array, rather than just holding a slice reference to the same user-supplied + // underlying array. If the copy is correctly performed, overwriting + // readDeps[0] should have no effect on control dependencies for `read`. + readDeps[0] = init + read := ReadVariableOp(s, variable, tf.Int32) + graph, err := s.Finalize() if err != nil { t.Fatal(err)