Go: Fix Scope.WithControlDependencies array-copying behavior.

The test fails with the old code, and passes with the new code.

PiperOrigin-RevId: 184021596
This commit is contained in:
Todd Wang 2018-01-31 11:32:35 -08:00 committed by Michael Case
parent d418a14176
commit b79c3b2d1e
2 changed files with 23 additions and 9 deletions

View File

@ -109,15 +109,20 @@ func (s *Scope) SubScope(namespace string) *Scope {
// added to the graph to execute only after all the provided operations have // added to the graph to execute only after all the provided operations have
// executed first (in addition to any other control dependencies in s). // executed first (in addition to any other control dependencies in s).
func (s *Scope) WithControlDependencies(ops ...*tf.Operation) *Scope { func (s *Scope) WithControlDependencies(ops ...*tf.Operation) *Scope {
// Force a copy of the control dependencies into a new underlying array on
// every call. We cannot alias the same underlying array as `ops`, otherwise
// the user could modify that array after calling s.WithControlDependencies,
// which would be confusing. We cannot alias the same underlying array as the
// original `s.controlDependencies`, since Scopes form a logical tree, and
// other calls to s.WithControlDependencies could stomp on each other.
deps := make([]*tf.Operation, 0, len(s.controlDependencies)+len(ops))
deps = append(deps, s.controlDependencies...)
deps = append(deps, ops...)
return &Scope{ return &Scope{
graph: s.graph, graph: s.graph,
namemap: s.namemap, namemap: s.namemap,
namespace: s.namespace, namespace: s.namespace,
// append(ops, s.controlDependencies) and not the other way controlDependencies: deps,
// around so that we end up with a copy of the underlying array
// (and other calls to s.WithControlDependencies() do not stomp
// on each other).
controlDependencies: append(ops, s.controlDependencies...),
err: s.err, err: s.err,
} }
} }

View File

@ -77,8 +77,17 @@ func TestControlDependencies(t *testing.T) {
variable = VarHandleOp(s, tf.Int32, tf.ScalarShape()) variable = VarHandleOp(s, tf.Int32, tf.ScalarShape())
init = AssignVariableOp(s, variable, zero) init = AssignVariableOp(s, variable, zero)
update = AssignAddVariableOp(s, variable, one) update = AssignAddVariableOp(s, variable, one)
read = ReadVariableOp(s.WithControlDependencies(update), variable, tf.Int32) readDeps = []*tf.Operation{update}
) )
// We intend for `read` to have a control dependency on `update`.
s = s.WithControlDependencies(readDeps...)
// Ensure that Scope.WithControlDependencies makes a copy of the underlying
// array, rather than just holding a slice reference to the same user-supplied
// underlying array. If the copy is correctly performed, overwriting
// readDeps[0] should have no effect on control dependencies for `read`.
readDeps[0] = init
read := ReadVariableOp(s, variable, tf.Int32)
graph, err := s.Finalize() graph, err := s.Finalize()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)