3620 lines
132 KiB
Python
3620 lines
132 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for tensorflow.python.framework.ops."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from absl.testing import parameterized
|
|
import gc
|
|
import numpy as np
|
|
import os
|
|
import threading
|
|
import weakref
|
|
|
|
from tensorflow.core.framework import attr_value_pb2
|
|
from tensorflow.core.protobuf import config_pb2
|
|
from tensorflow.python.autograph.core import ag_ctx
|
|
from tensorflow.python.client import session
|
|
from tensorflow.python.eager import backprop
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.eager import function as eager_function
|
|
from tensorflow.python.eager import wrap_function
|
|
from tensorflow.python.framework import config
|
|
from tensorflow.python.framework import composite_tensor
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import device as pydev
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import function
|
|
from tensorflow.python.framework import indexed_slices
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import sparse_tensor
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.framework import tensor_spec
|
|
from tensorflow.python.framework import tensor_util
|
|
from tensorflow.python.framework import test_ops
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.framework import type_spec
|
|
from tensorflow.python.framework import versions
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.ops import resources
|
|
from tensorflow.python.ops import special_math_ops
|
|
from tensorflow.python.ops import variable_scope
|
|
from tensorflow.python.ops import variables
|
|
import tensorflow.python.ops.gradients # pylint: disable=unused-import
|
|
from tensorflow.python.platform import googletest
|
|
from tensorflow.python.util import compat
|
|
|
|
|
|
class ResourceTest(test_util.TensorFlowTestCase):
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testBuildGraph(self):
|
|
with self.cached_session():
|
|
pt = test_ops.stub_resource_handle_op(container="a", shared_name="b")
|
|
test_ops.resource_create_op(pt).run()
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testInitialize(self):
|
|
with self.cached_session():
|
|
handle = test_ops.stub_resource_handle_op(container="a", shared_name="b")
|
|
resources.register_resource(
|
|
handle=handle,
|
|
create_op=test_ops.resource_create_op(handle),
|
|
is_initialized_op=test_ops.resource_initialized_op(handle))
|
|
self.assertEqual(
|
|
len(
|
|
resources.report_uninitialized_resources(
|
|
resources.shared_resources()).eval()), 1)
|
|
resources.initialize_resources(resources.shared_resources()).run()
|
|
self.assertEqual(
|
|
len(
|
|
resources.report_uninitialized_resources(
|
|
resources.shared_resources()).eval()), 0)
|
|
|
|
|
|
class TensorAndShapeTest(test_util.TensorFlowTestCase):
|
|
|
|
def testShape(self):
|
|
op = ops.Operation(
|
|
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
|
|
t = op.outputs[0]
|
|
self.assertEqual(tensor_shape.unknown_shape(), t.get_shape())
|
|
t.set_shape([1, 2, 3])
|
|
self.assertEqual([1, 2, 3], t.get_shape())
|
|
|
|
def testIterable(self):
|
|
if not context.executing_eagerly():
|
|
self.skipTest("Eager-mode test")
|
|
op = ops.Operation(
|
|
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
|
|
t = op.outputs[0]
|
|
with self.assertRaisesRegex(TypeError, "Cannot iterate"):
|
|
iter(t)
|
|
|
|
def testIterableGraph(self):
|
|
if context.executing_eagerly():
|
|
self.skipTest("Graph-mode test")
|
|
|
|
op = ops.Operation(
|
|
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
|
|
t = op.outputs[0]
|
|
with self.assertRaisesRegex(TypeError, "iterating.*not allowed in Graph"):
|
|
next(iter(t))
|
|
with self.assertRaisesRegex(TypeError, "iterating.*AutoGraph did convert"):
|
|
with ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED):
|
|
next(iter(t))
|
|
with self.assertRaisesRegex(TypeError, "iterating.*AutoGraph is disabled"):
|
|
with ag_ctx.ControlStatusCtx(ag_ctx.Status.DISABLED):
|
|
next(iter(t))
|
|
|
|
def testImplicitBool(self):
|
|
op = ops.Operation(
|
|
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.bool])
|
|
t = op.outputs[0]
|
|
with self.assertRaisesRegex(TypeError,
|
|
"using.*as a.*bool.*not allowed in Graph"):
|
|
bool(t)
|
|
with self.assertRaisesRegex(TypeError,
|
|
"using.*as a.*bool.*AutoGraph did convert"):
|
|
with ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED):
|
|
bool(t)
|
|
with self.assertRaisesRegex(TypeError,
|
|
"using.*as a.*bool.*AutoGraph is disabled"):
|
|
with ag_ctx.ControlStatusCtx(ag_ctx.Status.DISABLED):
|
|
bool(t)
|
|
|
|
def testAddShape(self):
|
|
with self.cached_session():
|
|
a = array_ops.zeros([2, 3])
|
|
b = array_ops.ones([1, 3])
|
|
c = a + b
|
|
self.assertEqual([2, 3], c.shape)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testUnknownDim(self):
|
|
with self.cached_session():
|
|
a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
|
|
b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
|
|
c = a + b
|
|
self.assertEqual([2, None, 3], c.shape.as_list())
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testUnknownShape(self):
|
|
with self.cached_session():
|
|
a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
|
|
b = array_ops.ones([1, 3])
|
|
c = a + b
|
|
self.assertEqual(tensor_shape.unknown_shape(), c.shape)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testScalarShape(self):
|
|
with self.cached_session():
|
|
a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
|
|
b = array_ops.ones([])
|
|
c = a + b
|
|
self.assertEqual(tensor_shape.TensorShape([]), c.shape)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testShapeFunctionError(self):
|
|
with self.cached_session():
|
|
a = array_ops.ones([1, 2, 3])
|
|
b = array_ops.ones([4, 5, 6])
|
|
with self.assertRaisesRegex(
|
|
ValueError, r"Dimensions must be equal, but are 2 and 5 for .*add"
|
|
r".*Add(V2)?.* with input shapes: \[1,2,3\], \[4,5,6\]."):
|
|
_ = a + b
|
|
|
|
def testNumpyArray(self):
|
|
with ops.Graph().as_default():
|
|
x = array_ops.ones((3, 4), name="test_ones")
|
|
|
|
with self.assertRaisesRegex(NotImplementedError,
|
|
r"Cannot convert a symbolic.+test_ones"):
|
|
np.array(x)
|
|
|
|
with self.assertRaisesRegex(TypeError, "not well defined.+test_ones"):
|
|
len(x)
|
|
|
|
# EagerTensors should still behave as numpy arrays.
|
|
with context.eager_mode():
|
|
x = array_ops.ones((3, 4))
|
|
|
|
self.assertAllEqual(x, np.ones((3, 4)))
|
|
self.assertAllEqual(np.array(x), np.ones((3, 4)))
|
|
self.assertEqual(len(x), 3)
|
|
|
|
def testRef(self):
|
|
x1 = constant_op.constant(3)
|
|
x2 = x1
|
|
y = constant_op.constant(3)
|
|
z = constant_op.constant([6, 10])
|
|
w = variables.Variable(5)
|
|
|
|
self.assertEqual(x1.ref(), x1.ref())
|
|
self.assertEqual(x2.ref(), x2.ref())
|
|
self.assertEqual(x1.ref(), x2.ref())
|
|
self.assertEqual(y.ref(), y.ref())
|
|
self.assertEqual(z.ref(), z.ref())
|
|
self.assertEqual(w.ref(), w.ref())
|
|
|
|
self.assertNotEqual(x1.ref(), y.ref())
|
|
self.assertNotEqual(x1.ref(), z.ref())
|
|
self.assertNotEqual(x1.ref(), w.ref())
|
|
self.assertNotEqual(y.ref(), z.ref())
|
|
self.assertNotEqual(y.ref(), w.ref())
|
|
self.assertNotEqual(z.ref(), w.ref())
|
|
|
|
def testRefDeref(self):
|
|
x1 = constant_op.constant(3)
|
|
x2 = x1
|
|
y = constant_op.constant(3)
|
|
z = constant_op.constant([6, 10])
|
|
w = variables.Variable(5)
|
|
|
|
self.assertIs(x1, x1.ref().deref())
|
|
self.assertIs(x2, x2.ref().deref())
|
|
self.assertIs(x1, x2.ref().deref())
|
|
self.assertIs(x2, x1.ref().deref())
|
|
self.assertIs(y, y.ref().deref())
|
|
self.assertIs(z, z.ref().deref())
|
|
|
|
self.assertIsNot(x1, y.ref().deref())
|
|
self.assertIsNot(x1, z.ref().deref())
|
|
self.assertIsNot(x1, w.ref().deref())
|
|
self.assertIsNot(y, z.ref().deref())
|
|
self.assertIsNot(y, w.ref().deref())
|
|
self.assertIsNot(z, w.ref().deref())
|
|
|
|
def testRefInSet(self):
|
|
x1 = constant_op.constant(3)
|
|
x2 = x1
|
|
y = constant_op.constant(3)
|
|
z = constant_op.constant([6, 10])
|
|
w = variables.Variable(5)
|
|
|
|
self.assertEqual(x1.ref(), x2.ref())
|
|
|
|
tensor_set = {
|
|
x1.ref(),
|
|
x2.ref(),
|
|
y.ref(),
|
|
z.ref(),
|
|
w.ref(),
|
|
}
|
|
|
|
self.assertEqual(len(tensor_set), 4)
|
|
self.assertIn(x1.ref(), tensor_set)
|
|
self.assertIn(x2.ref(), tensor_set)
|
|
self.assertIn(y.ref(), tensor_set)
|
|
self.assertIn(z.ref(), tensor_set)
|
|
self.assertIn(w.ref(), tensor_set)
|
|
|
|
def testRefInDict(self):
|
|
x1 = constant_op.constant(3)
|
|
x2 = x1
|
|
y = constant_op.constant(3)
|
|
z = constant_op.constant([6, 10])
|
|
w = variables.Variable(5)
|
|
|
|
self.assertEqual(x1.ref(), x2.ref())
|
|
|
|
tensor_dict = {
|
|
x1.ref(): "x1",
|
|
y.ref(): "y",
|
|
z.ref(): "z",
|
|
w.ref(): "w",
|
|
}
|
|
|
|
self.assertEqual(len(tensor_dict), 4)
|
|
|
|
# Overwriting x1
|
|
tensor_dict[x2.ref()] = "x2"
|
|
self.assertEqual(len(tensor_dict), 4)
|
|
|
|
self.assertEqual(tensor_dict[x1.ref()], "x2")
|
|
self.assertEqual(tensor_dict[x2.ref()], "x2")
|
|
self.assertEqual(tensor_dict[y.ref()], "y")
|
|
self.assertEqual(tensor_dict[z.ref()], "z")
|
|
self.assertEqual(tensor_dict[w.ref()], "w")
|
|
|
|
def testTensorRefStrong(self):
|
|
x = constant_op.constant(1.)
|
|
x_ref = x.ref()
|
|
del x
|
|
self.assertIsNotNone(x_ref.deref())
|
|
|
|
def testVariableRefStrong(self):
|
|
x = variables.Variable(1.)
|
|
x_ref = x.ref()
|
|
del x
|
|
self.assertIsNotNone(x_ref.deref())
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testBitwiseAndNumeric(self):
|
|
x = constant_op.constant([0, 1, 3])
|
|
y = constant_op.constant([1, 1, 1])
|
|
|
|
z = x & y
|
|
|
|
self.assertAllEqual(z, [0, 1, 1])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testBitwiseAndBool(self):
|
|
x = constant_op.constant([False, False, True, True])
|
|
y = constant_op.constant([False, True, False, True])
|
|
|
|
z = x & y
|
|
|
|
self.assertAllEqual(z, [False, False, False, True])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testBitwiseAndErrors(self):
|
|
x_int = constant_op.constant(0)
|
|
x_bool = constant_op.constant(True)
|
|
|
|
if context.executing_eagerly(): # :(
|
|
expected_errtype = errors.InvalidArgumentError
|
|
else:
|
|
expected_errtype = TypeError
|
|
|
|
with self.assertRaises(expected_errtype):
|
|
_ = x_int & x_bool
|
|
with self.assertRaises(expected_errtype):
|
|
_ = x_int & constant_op.constant("a")
|
|
|
|
with self.assertRaises(expected_errtype):
|
|
_ = x_bool & x_int
|
|
with self.assertRaises(expected_errtype):
|
|
_ = x_bool & constant_op.constant("a")
|
|
|
|
with self.assertRaises(expected_errtype):
|
|
_ = constant_op.constant("a") & constant_op.constant("b")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testBitwiseOrNumeric(self):
|
|
x = constant_op.constant([0, 1, 2])
|
|
y = constant_op.constant([1, 1, 1])
|
|
|
|
z = x | y
|
|
|
|
self.assertAllEqual(z, [1, 1, 3])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testBitwiseOrBool(self):
|
|
x = constant_op.constant([False, False, True, True])
|
|
y = constant_op.constant([False, True, False, True])
|
|
|
|
z = x | y
|
|
|
|
self.assertAllEqual(z, [False, True, True, True])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testBitwiseOrErrors(self):
|
|
x_int = constant_op.constant(0)
|
|
x_bool = constant_op.constant(True)
|
|
|
|
if context.executing_eagerly(): # :(
|
|
expected_errtype = errors.InvalidArgumentError
|
|
else:
|
|
expected_errtype = TypeError
|
|
|
|
with self.assertRaises(expected_errtype):
|
|
_ = x_int | x_bool
|
|
with self.assertRaises(expected_errtype):
|
|
_ = x_int | constant_op.constant("a")
|
|
|
|
with self.assertRaises(expected_errtype):
|
|
_ = x_bool | x_int
|
|
with self.assertRaises(expected_errtype):
|
|
_ = x_bool | constant_op.constant("a")
|
|
|
|
with self.assertRaises(expected_errtype):
|
|
_ = constant_op.constant("a") | constant_op.constant("b")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testBitwiseXorNumeric(self):
|
|
x = constant_op.constant([0, 1, 3])
|
|
y = constant_op.constant([1, 1, 1])
|
|
|
|
z = x ^ y
|
|
|
|
self.assertAllEqual(z, [1, 0, 2])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testBitwiseXorBool(self):
|
|
x = constant_op.constant([False, False, True, True])
|
|
y = constant_op.constant([False, True, False, True])
|
|
|
|
z = x ^ y
|
|
|
|
self.assertAllEqual(z, [False, True, True, False])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testBitwiseXorErrors(self):
|
|
x_int = constant_op.constant(0)
|
|
x_bool = constant_op.constant(True)
|
|
|
|
if context.executing_eagerly(): # :(
|
|
expected_errtype = errors.InvalidArgumentError
|
|
else:
|
|
expected_errtype = TypeError
|
|
|
|
with self.assertRaises(expected_errtype):
|
|
_ = x_int ^ x_bool
|
|
with self.assertRaises(expected_errtype):
|
|
_ = x_int ^ constant_op.constant("a")
|
|
|
|
with self.assertRaises(expected_errtype):
|
|
_ = x_bool ^ x_int
|
|
with self.assertRaises(expected_errtype):
|
|
_ = x_bool ^ constant_op.constant("a")
|
|
|
|
with self.assertRaises(expected_errtype):
|
|
_ = constant_op.constant("a") ^ constant_op.constant("b")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testBitwiseNotNumeric(self):
|
|
x = constant_op.constant([0, dtypes.int32.min, 1])
|
|
|
|
y = ~x
|
|
|
|
self.assertAllEqual(y, [-1, dtypes.int32.max, -2])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testBitwiseNotBool(self):
|
|
x = constant_op.constant([False, True])
|
|
|
|
y = ~x
|
|
|
|
self.assertAllEqual(y, [True, False])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testBitwiseNotErrors(self):
|
|
if context.executing_eagerly(): # :(
|
|
expected_errtype = errors.InvalidArgumentError
|
|
else:
|
|
expected_errtype = TypeError
|
|
|
|
with self.assertRaises(expected_errtype):
|
|
_ = ~constant_op.constant("a")
|
|
|
|
|
|
@test_util.run_all_in_graph_and_eager_modes
|
|
class IndexedSlicesTest(test_util.TensorFlowTestCase):
|
|
|
|
def testToTensor(self):
|
|
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
|
|
indices = constant_op.constant([0, 2])
|
|
x = ops.IndexedSlices(values, indices)
|
|
with self.assertRaises(ValueError):
|
|
tensor = ops.convert_to_tensor(x, name="tensor")
|
|
self.assertEqual(tensor_shape.TensorShape(None), x.shape)
|
|
|
|
dense_shape = constant_op.constant([3, 2])
|
|
y = ops.IndexedSlices(values, indices, dense_shape)
|
|
tensor = ops.convert_to_tensor(y, name="tensor")
|
|
self.assertAllEqual(tensor.shape, y.shape)
|
|
self.assertAllEqual(self.evaluate(tensor), [[2, 3], [0, 0], [5, 7]])
|
|
|
|
@test_util.run_gpu_only
|
|
def testEagerCopy(self):
|
|
with context.eager_mode():
|
|
var = variables.Variable([[0.0], [0.0], [0.0], [0.0]], name="tensor")
|
|
with backprop.GradientTape() as tape:
|
|
a = array_ops.gather(array_ops.gather(var, [0, 1]), [0, 1])
|
|
b = array_ops.gather(array_ops.gather(var, [2, 3]), [0, 1])
|
|
r = special_math_ops.einsum("ij,ij->i", a, b)
|
|
g = tape.gradient(r, [var])[0]
|
|
values = g.values if isinstance(g, ops.IndexedSlices) else g
|
|
self.assertAllEqual(values.get_shape(), [4, 1])
|
|
|
|
def testNegation(self):
|
|
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
|
|
indices = constant_op.constant([0, 2])
|
|
x = -ops.IndexedSlices(values, indices)
|
|
self.assertAllEqual(x.values, [[-2, -3], [-5, -7]])
|
|
self.assertAllEqual(x.indices, [0, 2])
|
|
|
|
def testScalarMul(self):
|
|
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
|
|
indices = constant_op.constant([0, 2])
|
|
x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices))
|
|
self.assertAllEqual(x.values, [[-4, -6], [-10, -14]])
|
|
self.assertAllEqual(x.indices, [0, 2])
|
|
|
|
|
|
@test_util.run_all_in_graph_and_eager_modes
|
|
class IndexedSlicesSpecTest(test_util.TensorFlowTestCase,
|
|
parameterized.TestCase):
|
|
|
|
def assertAllTensorsEqual(self, list1, list2):
|
|
self.assertLen(list1, len(list2))
|
|
for (t1, t2) in zip(list1, list2):
|
|
self.assertAllEqual(t1, t2)
|
|
|
|
def testConstruction(self):
|
|
spec1 = indexed_slices.IndexedSlicesSpec()
|
|
self.assertEqual(spec1._shape.rank, None)
|
|
self.assertEqual(spec1._values_dtype, dtypes.float32)
|
|
self.assertEqual(spec1._indices_dtype, dtypes.int64)
|
|
self.assertEqual(spec1._dense_shape_dtype, None)
|
|
self.assertEqual(spec1._indices_shape.as_list(), [None])
|
|
|
|
spec2 = indexed_slices.IndexedSlicesSpec([None, None], dtypes.string,
|
|
dtypes.int32, dtypes.int64, [10])
|
|
self.assertEqual(spec2._shape.as_list(), [None, None])
|
|
self.assertEqual(spec2._values_dtype, dtypes.string)
|
|
self.assertEqual(spec2._indices_dtype, dtypes.int32)
|
|
self.assertEqual(spec2._dense_shape_dtype, dtypes.int64)
|
|
self.assertEqual(spec2._indices_shape.as_list(), [10])
|
|
|
|
def testValueType(self):
|
|
spec1 = indexed_slices.IndexedSlicesSpec()
|
|
self.assertEqual(spec1.value_type, ops.IndexedSlices)
|
|
|
|
@parameterized.parameters([
|
|
(indexed_slices.IndexedSlicesSpec(),
|
|
(tensor_shape.TensorShape(None), dtypes.float32, dtypes.int64, None,
|
|
tensor_shape.TensorShape([None]))),
|
|
(indexed_slices.IndexedSlicesSpec(shape=[5, None, None]),
|
|
(tensor_shape.TensorShape([5, None, None]), dtypes.float32,
|
|
dtypes.int64, None, tensor_shape.TensorShape([None]))),
|
|
(indexed_slices.IndexedSlicesSpec(
|
|
dtype=dtypes.int32, dense_shape_dtype=dtypes.int64),
|
|
(tensor_shape.TensorShape(None), dtypes.int32, dtypes.int64,
|
|
dtypes.int64, tensor_shape.TensorShape([None]))),
|
|
(indexed_slices.IndexedSlicesSpec(indices_shape=[100]),
|
|
(tensor_shape.TensorShape(None), dtypes.float32, dtypes.int64, None,
|
|
tensor_shape.TensorShape([100]))),
|
|
]) # pyformat: disable
|
|
def testSerialize(self, spec, expected):
|
|
serialization = spec._serialize()
|
|
# TensorShape has an unconventional definition of equality, so we can't use
|
|
# assertEqual directly here. But repr() is deterministic and lossless for
|
|
# the expected values, so we can use that instead.
|
|
self.assertEqual(repr(serialization), repr(expected))
|
|
|
|
@parameterized.parameters([
|
|
(indexed_slices.IndexedSlicesSpec(dtype=dtypes.string), (
|
|
tensor_spec.TensorSpec(None, dtypes.string),
|
|
tensor_spec.TensorSpec([None], dtypes.int64),
|
|
)),
|
|
(indexed_slices.IndexedSlicesSpec(
|
|
dtype=dtypes.string, dense_shape_dtype=dtypes.int32), (
|
|
tensor_spec.TensorSpec(None, dtypes.string),
|
|
tensor_spec.TensorSpec([None], dtypes.int64),
|
|
tensor_spec.TensorSpec([None], dtypes.int32),
|
|
)),
|
|
(indexed_slices.IndexedSlicesSpec(
|
|
shape=[5, 10, 15], dense_shape_dtype=dtypes.int32), (
|
|
tensor_spec.TensorSpec([None, 10, 15], dtypes.float32),
|
|
tensor_spec.TensorSpec([None], dtypes.int64),
|
|
tensor_spec.TensorSpec([3], dtypes.int32),
|
|
)),
|
|
(indexed_slices.IndexedSlicesSpec(
|
|
shape=[5, 10, 15], dense_shape_dtype=dtypes.int32,
|
|
indices_shape=[20]), (
|
|
tensor_spec.TensorSpec([20, 10, 15], dtypes.float32),
|
|
tensor_spec.TensorSpec([20], dtypes.int64),
|
|
tensor_spec.TensorSpec([3], dtypes.int32),
|
|
)),
|
|
])
|
|
def testComponentSpecs(self, spec, expected):
|
|
self.assertEqual(spec._component_specs, expected)
|
|
|
|
@parameterized.parameters([
|
|
{
|
|
"spec": indexed_slices.IndexedSlicesSpec(),
|
|
"values": [3.0, 5.0],
|
|
"indices": [5, 10]
|
|
},
|
|
{
|
|
"spec":
|
|
indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32),
|
|
"values": [3.0, 5.0],
|
|
"indices": [5, 10],
|
|
"dense_shape": [100]
|
|
},
|
|
])
|
|
def testToFromComponents(self, spec, indices, values, dense_shape=None):
|
|
x = ops.IndexedSlices(indices, values, dense_shape)
|
|
actual_components = spec._to_components(x)
|
|
if dense_shape is None:
|
|
self.assertAllTensorsEqual(actual_components, [indices, values])
|
|
else:
|
|
self.assertAllTensorsEqual(actual_components,
|
|
[indices, values, dense_shape])
|
|
st_reconstructed = spec._from_components(actual_components)
|
|
self.assertAllEqual(x.indices, st_reconstructed.indices)
|
|
self.assertAllEqual(x.values, st_reconstructed.values)
|
|
if dense_shape is None:
|
|
self.assertIs(st_reconstructed.dense_shape, None)
|
|
else:
|
|
self.assertAllEqual(x.dense_shape, st_reconstructed.dense_shape)
|
|
|
|
@test_util.run_v1_only("IndexedSlicesValue is deprecated in v2")
|
|
def testFromNumpyComponents(self):
|
|
indices = np.array([3, 8])
|
|
values = np.array([1.0, 9.0])
|
|
dense_shape = np.array([100])
|
|
|
|
spec1 = indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32)
|
|
st1 = spec1._from_components((values, indices, dense_shape))
|
|
self.assertIsInstance(st1, indexed_slices.IndexedSlicesValue)
|
|
self.assertAllEqual(st1.indices, indices)
|
|
self.assertAllEqual(st1.values, values)
|
|
self.assertAllEqual(st1.dense_shape, dense_shape)
|
|
|
|
spec2 = indexed_slices.IndexedSlicesSpec()
|
|
st2 = spec2._from_components((values, indices))
|
|
self.assertIsInstance(st2, indexed_slices.IndexedSlicesValue)
|
|
self.assertAllEqual(st2.indices, indices)
|
|
self.assertAllEqual(st2.values, values)
|
|
self.assertIs(st2.dense_shape, None)
|
|
|
|
|
|
class NodeDefConstructorTest(test_util.TensorFlowTestCase):
|
|
|
|
def testNoArgs(self):
|
|
nodedef = ops._NodeDef("None", "bar")
|
|
self.assertProtoEquals("op: 'None' name: 'bar'", nodedef)
|
|
|
|
|
|
def _apply_op(g, *args, **kwargs):
|
|
op = g.create_op(*args, **kwargs)
|
|
if len(op.outputs) == 1:
|
|
return op.outputs[0]
|
|
else:
|
|
return op.outputs
|
|
|
|
|
|
class OperationTest(test_util.TensorFlowTestCase):
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testNoInputs(self):
|
|
op = test_ops.float_output_string_output(name="myop").a.op
|
|
self.assertEqual(2, len(op.values()))
|
|
self.assertEqual(0, len(op.inputs))
|
|
self.assertEqual("myop", op.name)
|
|
|
|
float_t, label_str_t = op.values()
|
|
self.assertEqual(dtypes.float32, float_t.dtype)
|
|
self.assertEqual(op, float_t.op)
|
|
self.assertEqual(0, float_t._value_index)
|
|
self.assertEqual(0, len(float_t.consumers()))
|
|
self.assertEqual("myop", float_t._as_node_def_input())
|
|
|
|
self.assertEqual(dtypes.string, label_str_t.dtype)
|
|
self.assertEqual(op, label_str_t.op)
|
|
self.assertEqual(1, label_str_t._value_index)
|
|
self.assertEqual(0, len(label_str_t.consumers()))
|
|
self.assertEqual("myop:1", label_str_t._as_node_def_input())
|
|
|
|
self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'",
|
|
op.node_def)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testNoOutputs(self):
|
|
op1 = test_ops.float_output(name="myop1").op
|
|
float_t, = op1.values()
|
|
op2 = test_ops.float_input(float_t, name="myop2")
|
|
self.assertEqual(0, len(op2.values()))
|
|
self.assertEqual(1, len(op2.inputs))
|
|
self.assertIs(float_t, op2.inputs[0])
|
|
|
|
self.assertEqual(1, len(float_t.consumers()))
|
|
self.assertEqual(op2, float_t.consumers()[0])
|
|
|
|
self.assertProtoEquals("op:'FloatOutput' name:'myop1'", op1.node_def)
|
|
self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'",
|
|
op2.node_def)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testInputsAndOutputs(self):
|
|
op1 = test_ops.float_output(name="myop1").op
|
|
self.assertEqual(1, len(op1.values()))
|
|
float1_t, = op1.values()
|
|
|
|
op2 = test_ops.float_output_string_output(name="myop2").a.op
|
|
self.assertEqual(2, len(op2.values()))
|
|
float2_t, label2_str_t = op2.values()
|
|
|
|
# Note that we consume label2_str_t twice here.
|
|
op3 = test_ops.foo2(float1_t, label2_str_t, label2_str_t, name="myop3").d.op
|
|
self.assertEqual(2, len(op3.values()))
|
|
|
|
self.assertEqual(1, len(float1_t.consumers()))
|
|
self.assertEqual(op3, float1_t.consumers()[0])
|
|
|
|
self.assertEqual(0, len(float2_t.consumers()))
|
|
|
|
self.assertEqual(2, len(label2_str_t.consumers()))
|
|
self.assertEqual(op3, label2_str_t.consumers()[0])
|
|
self.assertEqual(op3, label2_str_t.consumers()[1])
|
|
|
|
self.assertProtoEquals("""
|
|
op:'Foo2' name:'myop3'
|
|
input:'myop1' input:'myop2:1' input:'myop2:1'
|
|
""", op3.node_def)
|
|
|
|
def testDeviceObject(self):
|
|
op = ops.Operation(ops._NodeDef("None", "myop"), ops.Graph(), [], [])
|
|
op._set_device("/job:goo/device:GPU:0")
|
|
self.assertProtoEquals(
|
|
"op:'None' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def)
|
|
op = ops.Operation(ops._NodeDef("None", "op2"), ops.Graph(), [], [])
|
|
op._set_device(
|
|
pydev.DeviceSpec(
|
|
job="muu", device_type="CPU", device_index=0))
|
|
self.assertProtoEquals(
|
|
"op:'None' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def)
|
|
|
|
def testReferenceInput(self):
|
|
g = ops.Graph()
|
|
op1 = ops.Operation(
|
|
ops._NodeDef("RefOutputFloatOutput", "op1"), g, [],
|
|
[dtypes.float32_ref, dtypes.float32])
|
|
self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
|
|
self.assertEqual([], list(op1.inputs))
|
|
ref_t, nonref_t = op1.values()
|
|
# NOTE(mrry): Must specify input_types to preserve ref-typed input.
|
|
op2 = ops.Operation(
|
|
ops._NodeDef("RefInputFloatInput", "op2"),
|
|
g, [ref_t, nonref_t], [],
|
|
input_types=[dtypes.float32_ref, dtypes.float32])
|
|
self.assertProtoEquals(
|
|
"op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
|
|
op2.node_def)
|
|
self.assertEqual([ref_t, nonref_t], list(op2.inputs))
|
|
op3 = ops.Operation(
|
|
ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], [])
|
|
self.assertProtoEquals(
|
|
"op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
|
|
op3.node_def)
|
|
|
|
def testInvalidNames(self):
|
|
g = ops.Graph()
|
|
with self.assertRaises(ValueError):
|
|
ops.Operation(ops._NodeDef("op", ""), g)
|
|
with self.assertRaises(ValueError):
|
|
ops.Operation(ops._NodeDef("op", "_invalid"), g)
|
|
with self.assertRaises(ValueError):
|
|
ops.Operation(ops._NodeDef("op", "-invalid"), g)
|
|
with self.assertRaises(ValueError):
|
|
ops.Operation(ops._NodeDef("op", "/invalid"), g)
|
|
with self.assertRaises(ValueError):
|
|
ops.Operation(ops._NodeDef("op", "invalid:0"), g)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testNoShapeFunction(self):
|
|
op = test_ops.a()
|
|
self.assertEqual(tensor_shape.unknown_shape(), op.get_shape())
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testConvertToTensorNestedArray(self):
|
|
values = [[2], [3], [5], [7]]
|
|
tensor = ops.convert_to_tensor(values)
|
|
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
|
|
self.assertAllEqual(values, self.evaluate(tensor))
|
|
|
|
def testShapeTuple(self):
|
|
with self.cached_session():
|
|
c = constant_op.constant(1)
|
|
self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access
|
|
|
|
def testConvertToTensorEager(self):
|
|
with context.eager_mode():
|
|
t = constant_op.constant(1)
|
|
self.assertTrue(isinstance(t, ops.EagerTensor))
|
|
converted = ops.convert_to_tensor(t)
|
|
self.assertTrue(isinstance(converted, ops.EagerTensor))
|
|
converted = ops.convert_to_tensor(1)
|
|
self.assertTrue(isinstance(converted, ops.EagerTensor))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testConvertToTensorNestedTuple(self):
|
|
values = ((2,), (3,), (5,), (7,))
|
|
tensor = ops.convert_to_tensor(values)
|
|
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
|
|
self.assertAllEqual(values, self.evaluate(ops.convert_to_tensor(values)))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testConvertToTensorNestedTensors(self):
|
|
values = ((2,), (3,), (5,), (7,))
|
|
tensor = ops.convert_to_tensor(
|
|
[constant_op.constant(row) for row in values])
|
|
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
|
|
self.assertAllEqual(values, self.evaluate(tensor))
|
|
tensor = ops.convert_to_tensor(
|
|
[[constant_op.constant(v) for v in row] for row in values])
|
|
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
|
|
self.assertAllEqual(values, self.evaluate(tensor))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testConvertToTensorNestedMix(self):
|
|
values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7]))
|
|
tensor = ops.convert_to_tensor(values)
|
|
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
|
|
self.assertAllEqual(((2,), (3,), (5,), (7,)), self.evaluate(tensor))
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testConvertToTensorPreferred(self):
|
|
values = [2, 3, 5, 7]
|
|
tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32)
|
|
self.assertEqual(dtypes.float32, tensor.dtype)
|
|
|
|
# Convert empty tensor to anything.
|
|
values = []
|
|
tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
|
|
self.assertEqual(dtypes.int64, tensor.dtype)
|
|
|
|
# The preferred dtype is a type error and will convert to
|
|
# float32 instead.
|
|
values = [1.23]
|
|
tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
|
|
self.assertEqual(dtypes.float32, tensor.dtype)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testConvertToInvalidTensorType(self):
|
|
with self.assertRaises(TypeError):
|
|
# Forcing an invalid dtype should fail with a type error.
|
|
values = [1.23]
|
|
ops.convert_to_tensor(values, dtype=dtypes.int64)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testConvertToLongLongTensorType(self):
|
|
tensor = ops.convert_to_tensor(
|
|
# Get a numpy array of dtype NPY_LONGLONG
|
|
np.prod(constant_op.constant([1])._shape_tuple()),
|
|
dtype=dtypes.int64)
|
|
self.assertEqual(dtypes.int64, tensor.dtype)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testConvertToTensorFromInvalidTensor(self):
|
|
tensor = constant_op.constant(42.0, dtype=dtypes.float32)
|
|
with self.assertRaises(ValueError):
|
|
ops.convert_to_tensor(tensor, dtype=dtypes.int32)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testConvertToTensorProtocol(self):
|
|
class TensorCompatible:
|
|
|
|
def __tf_tensor__(self, dtype=None, name=None):
|
|
return constant_op.constant((1, 2, 3), dtype=dtype, name=name)
|
|
|
|
tc = TensorCompatible()
|
|
|
|
tensor = ops.convert_to_tensor(tc, dtype=dtypes.int32)
|
|
self.assertEqual(tensor.dtype, dtypes.int32)
|
|
self.assertAllEqual((1, 2, 3), self.evaluate(tensor))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testNoConvert(self):
|
|
# Operation cannot be converted to Tensor.
|
|
op = control_flow_ops.no_op()
|
|
with self.assertRaisesRegex(TypeError,
|
|
"can't convert Operation '.+' to Tensor"):
|
|
ops.convert_to_tensor(op)
|
|
|
|
def testStr(self):
|
|
node_def = ops._NodeDef("None", "op1")
|
|
op = ops.Operation(node_def, ops.Graph(), [], [dtypes.float32])
|
|
self.assertEqual(str(node_def), str(op))
|
|
|
|
def testRepr(self):
|
|
op = ops.Operation(
|
|
ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32])
|
|
self.assertEqual("<tf.Operation 'op1' type=None>", repr(op))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testGetAttr(self):
|
|
op = test_ops.default_attrs()
|
|
self.assertEqual(op.get_attr("string_val"), b"abc")
|
|
self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""])
|
|
self.assertEqual(op.get_attr("int_val"), 123)
|
|
self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3])
|
|
self.assertEqual(op.get_attr("float_val"), 10.0)
|
|
self.assertEqual(op.get_attr("float_list_val"), [10.0])
|
|
self.assertEqual(op.get_attr("bool_val"), True)
|
|
self.assertEqual(op.get_attr("bool_list_val"), [True, False])
|
|
self.assertEqual(op.get_attr("shape_val"),
|
|
tensor_shape.as_shape([2, 1]).as_proto())
|
|
self.assertEqual(op.get_attr("shape_list_val"),
|
|
[tensor_shape.as_shape([]).as_proto(),
|
|
tensor_shape.as_shape([1]).as_proto()])
|
|
self.assertEqual(op.get_attr("tensor_val"),
|
|
tensor_util.make_tensor_proto(1, dtypes.int32))
|
|
self.assertEqual(op.get_attr("tensor_list_val"),
|
|
[tensor_util.make_tensor_proto(1, dtypes.int32)])
|
|
|
|
type_val = op.get_attr("type_val")
|
|
# First check that type_val is a DType, because the assertEqual will work
|
|
# no matter what since DType overrides __eq__
|
|
self.assertIsInstance(type_val, dtypes.DType)
|
|
self.assertEqual(type_val, dtypes.int32)
|
|
|
|
type_list_val = op.get_attr("type_list_val")
|
|
self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val))
|
|
self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32])
|
|
|
|
@function.Defun(dtypes.float32, func_name="MyFunc")
|
|
def func(x):
|
|
return x
|
|
|
|
op = test_ops.func_attr(func)
|
|
self.assertEqual(op.get_attr("f"),
|
|
attr_value_pb2.NameAttrList(name="MyFunc"))
|
|
|
|
# Try fetching missing attr
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Operation 'FuncAttr' has no attr named 'FakeAttr'."):
|
|
op.get_attr("FakeAttr")
|
|
|
|
# TODO(b/65162920): remove this test when users who are directly mutating the
|
|
# node_def have been updated to proper usage.
|
|
@test_util.run_deprecated_v1
|
|
def testSetAttr(self):
|
|
op = test_ops.int_attr().op
|
|
op._set_attr("foo", attr_value_pb2.AttrValue(i=2))
|
|
# TODO(skyewm): add node_def check
|
|
self.assertEqual(op.get_attr("foo"), 2)
|
|
|
|
# TODO(nolivia): test all error cases
|
|
def testAddControlInput(self):
|
|
with ops.Graph().as_default():
|
|
x = constant_op.constant(1).op
|
|
y = constant_op.constant(2).op
|
|
z = constant_op.constant(3).op
|
|
z._add_control_input(x) # pylint: disable=protected-access
|
|
self.assertEqual(z.control_inputs, [x])
|
|
z._add_control_input(x) # pylint: disable=protected-access
|
|
self.assertEqual(z.control_inputs, [x])
|
|
z._add_control_inputs([x, y, y]) # pylint: disable=protected-access
|
|
self.assertEqual(z.control_inputs, [x, y])
|
|
self.assertEqual(x._control_outputs, [z])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testRemoveAllControlInputs(self):
|
|
a = constant_op.constant(1)
|
|
with ops.control_dependencies([a]):
|
|
b = constant_op.constant(2)
|
|
c = constant_op.constant(3)
|
|
d = constant_op.constant(4)
|
|
e = constant_op.constant(5)
|
|
with ops.control_dependencies([a, c]):
|
|
f = d + e
|
|
|
|
self.assertEqual(a.op.control_inputs, [])
|
|
self.assertEqual(b.op.control_inputs, [a.op])
|
|
self.assertEqual(f.op.control_inputs, [a.op, c.op])
|
|
|
|
a.op._remove_all_control_inputs() # pylint: disable=protected-access
|
|
self.assertEqual(a.op.control_inputs, [])
|
|
|
|
b.op._remove_all_control_inputs() # pylint: disable=protected-access
|
|
self.assertEqual(b.op.control_inputs, [])
|
|
|
|
f.op._remove_all_control_inputs() # pylint: disable=protected-access
|
|
self.assertEqual(f.op.control_inputs, [])
|
|
self.assertEqual(list(f.op.inputs), [d, e])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testControlInputCycle(self):
|
|
graph = ops.Graph()
|
|
with graph.as_default():
|
|
z = constant_op.constant(0)
|
|
x = constant_op.constant(1)
|
|
y = constant_op.constant(2)
|
|
y.op._add_control_input(z.op) # pylint: disable=protected-access
|
|
y.op._add_control_input(x.op) # pylint: disable=protected-access
|
|
x.op._add_control_input(y.op) # pylint: disable=protected-access
|
|
with self.session(graph=graph) as sess:
|
|
with self.assertRaisesRegex(
|
|
errors.InvalidArgumentError,
|
|
"Graph is invalid, contains a cycle with 2 nodes"):
|
|
self.evaluate(x)
|
|
|
|
def testUpdateInput(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
x = constant_op.constant(1)
|
|
y = constant_op.constant(2)
|
|
z = x + y
|
|
|
|
z.op._update_input(0, y) # pylint: disable=protected-access
|
|
self.assertEqual(list(z.op.inputs), [y, y])
|
|
self.assertEqual(x.consumers(), [])
|
|
self.assertEqual(y.consumers(), [z.op, z.op])
|
|
with session.Session(graph=g) as sess:
|
|
self.assertEqual(self.evaluate(z), 4)
|
|
|
|
z.op._update_input(0, x) # pylint: disable=protected-access
|
|
self.assertEqual(list(z.op.inputs), [x, y])
|
|
self.assertEqual(x.consumers(), [z.op])
|
|
self.assertEqual(y.consumers(), [z.op])
|
|
with session.Session(graph=g) as sess:
|
|
self.assertEqual(self.evaluate(z), 3)
|
|
|
|
z.op._update_input(1, y) # pylint: disable=protected-access
|
|
self.assertEqual(list(z.op.inputs), [x, y])
|
|
self.assertEqual(x.consumers(), [z.op])
|
|
self.assertEqual(y.consumers(), [z.op])
|
|
with session.Session(graph=g) as sess:
|
|
self.assertEqual(self.evaluate(z), 3)
|
|
|
|
def testUpdateInputGraphError(self):
|
|
g_0 = ops.Graph()
|
|
g_1 = ops.Graph()
|
|
with g_0.as_default():
|
|
x = constant_op.constant(1)
|
|
with g_1.as_default():
|
|
y = constant_op.constant(2)
|
|
z = y * 2
|
|
with self.assertRaisesRegex(ValueError, "must be from the same graph"):
|
|
z.op._update_input(0, x) # pylint: disable=protected-access
|
|
|
|
def testUpdateInputTypeError(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
w = constant_op.constant(0)
|
|
x = constant_op.constant("")
|
|
y = constant_op.constant(1)
|
|
z = y + w
|
|
z.op._update_input(0, x) # pylint: disable=protected-access
|
|
with session.Session(graph=g) as sess:
|
|
with self.assertRaisesRegex(
|
|
errors.InvalidArgumentError,
|
|
"Input 0 of node add was passed string from Const_1:0 incompatible "
|
|
"with expected int32"):
|
|
self.evaluate(z)
|
|
|
|
def testUpdateInputShapeError(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
w = constant_op.constant(2, shape=[3, 1])
|
|
x = constant_op.constant(0, shape=[3, 1])
|
|
y = constant_op.constant(1, shape=[2, 2])
|
|
z = w + x
|
|
with self.assertRaisesRegex(
|
|
errors.InvalidArgumentError,
|
|
r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"):
|
|
z.op._update_input(0, y) # pylint: disable=protected-access
|
|
|
|
def testUpdateInputOutOfRange(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
x = constant_op.constant(1)
|
|
with self.assertRaisesRegex(
|
|
errors.OutOfRangeError,
|
|
r"Cannot update edge. Input index \[1\] is greater than the number of "
|
|
r"total inputs \[0\]."):
|
|
x.op._update_input(1, x) # pylint: disable=protected-access
|
|
|
|
@test_util.enable_control_flow_v2
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testAddWhileInput(self):
|
|
|
|
@eager_function.defun
|
|
def test():
|
|
output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1,
|
|
[1])
|
|
while_op = output.op
|
|
self.assertEqual(while_op.type, "StatelessWhile")
|
|
orig_num_inputs = len(while_op.inputs)
|
|
|
|
# Make sure we can handle the while op having a control input.
|
|
while_op._add_control_input(constant_op.constant(0).op)
|
|
|
|
new_input1 = constant_op.constant(1.0)
|
|
new_input2 = constant_op.constant(True)
|
|
|
|
# Clear output shapes to bypass shape checking.
|
|
while_op._set_shape_list_attr("output_shapes", [])
|
|
while_op._set_type_list_attr("T", [t.dtype for t in while_op.inputs] +
|
|
[new_input1.dtype, new_input2.dtype])
|
|
|
|
while_op._add_while_inputs([new_input1, new_input2])
|
|
# Can't add an edge beyond what's specified by "T"
|
|
with self.assertRaises(errors.OutOfRangeError):
|
|
while_op._add_while_inputs([new_input2])
|
|
self.assertEqual(len(while_op.inputs), orig_num_inputs + 2) # pylint: disable=g-deprecated-assert
|
|
|
|
test()
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testOpDef(self):
|
|
x = constant_op.constant(0)
|
|
y = constant_op.constant(1)
|
|
z = x + y
|
|
|
|
self.assertEqual(x.op.op_def.name, "Const")
|
|
self.assertEqual(len(x.op.op_def.input_arg), 0)
|
|
self.assertEqual(len(x.op.op_def.output_arg), 1)
|
|
|
|
self.assertRegex(z.op.op_def.name, "Add(V2)?")
|
|
self.assertEqual(len(z.op.op_def.input_arg), 2)
|
|
self.assertEqual(len(z.op.op_def.output_arg), 1)
|
|
|
|
def testInputFromDifferentGraphError(self):
|
|
g_0 = ops.Graph()
|
|
g_1 = ops.Graph()
|
|
with g_0.as_default():
|
|
x = constant_op.constant(1)
|
|
with g_1.as_default():
|
|
y = constant_op.constant(2)
|
|
with self.assertRaisesRegex(ValueError, "must be from the same graph"):
|
|
y * x # pylint: disable=pointless-statement
|
|
|
|
def testInputsAreImmutable(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
x = test_ops.int_output()
|
|
op = test_ops.int_input_int_output(x, name="myop").op
|
|
with self.assertRaisesRegex(AttributeError,
|
|
"'tuple' object has no attribute 'append'"):
|
|
op.inputs.append(None)
|
|
|
|
|
|
class CreateOpTest(test_util.TensorFlowTestCase):
|
|
|
|
def testNodeDefArgs(self):
|
|
g = ops.Graph()
|
|
op1 = g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
|
|
with g.device("/device:GPU:0"):
|
|
op2 = g.create_op(
|
|
"FloatOutputStringOutput", [], [dtypes.float32, dtypes.string], None,
|
|
name="myop2")
|
|
op3 = g.create_op(
|
|
"Foo3",
|
|
[list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]],
|
|
[dtypes.float32, dtypes.int32],
|
|
None,
|
|
name="myop3")
|
|
self.assertDeviceEqual(None, op1.device)
|
|
self.assertDeviceEqual("/device:GPU:0", op2.device)
|
|
self.assertDeviceEqual(None, op3.device)
|
|
self.assertProtoEquals("name:'myop1' op:'FloatOutput'", op1.node_def)
|
|
self.assertProtoEquals(
|
|
"name:'myop2' op:'FloatOutputStringOutput' device:'/device:GPU:0'",
|
|
op2.node_def)
|
|
self.assertProtoEquals(
|
|
"name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo3'",
|
|
op3.node_def)
|
|
|
|
def testReferenceInput(self):
|
|
g = ops.Graph()
|
|
op1 = g.create_op(
|
|
"RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
|
|
name="op1")
|
|
self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
|
|
ref_t, nonref_t = op1.values()
|
|
# NOTE(mrry): Must specify input_types to preserve ref-typed input.
|
|
op2 = g.create_op(
|
|
"RefInputFloatInput", [ref_t, nonref_t], [],
|
|
input_types=[dtypes.float32_ref, dtypes.float32],
|
|
name="op2")
|
|
self.assertProtoEquals(
|
|
"op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
|
|
op2.node_def)
|
|
op3 = g.create_op("TwoFloatInputs", [ref_t, nonref_t], [], name="op3")
|
|
self.assertProtoEquals(
|
|
"op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
|
|
op3.node_def)
|
|
|
|
def testFinalized(self):
|
|
g = ops.Graph()
|
|
g.finalize()
|
|
with self.assertRaises(RuntimeError):
|
|
g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
|
|
|
|
# Test unfinalize.
|
|
g._unsafe_unfinalize()
|
|
g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
|
|
|
|
|
|
# NOTE(skyewm): these cases test the private Graph._create_op_from_tf_operation
|
|
# method. Arguably we should only test the public APIs that depend on this
|
|
# method. However, this logic is complex and tricky, and it can be difficult to
|
|
# ascertain if we have adequate coverage (e.g. a graph may run successfully if
|
|
# the control flow context isn't set properly, but a more complicated use case
|
|
# that might not be obvious to test will fail). Thus we instead explicitly test
|
|
# the low-level behavior.
|
|
class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase):
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testBasic(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
x = test_ops.int_output()
|
|
c_op = ops._create_c_op(
|
|
g, ops._NodeDef("IntInputIntOutput", "myop"), [x], [])
|
|
op = g._create_op_from_tf_operation(c_op)
|
|
|
|
self.assertEqual(op.name, "myop")
|
|
self.assertEqual(op.type, "IntInputIntOutput")
|
|
self.assertEqual(len(op.outputs), 1)
|
|
self.assertEqual(op.outputs[0].shape, tensor_shape.unknown_shape())
|
|
self.assertEqual(list(op.inputs), [x])
|
|
self.assertEqual(op.control_inputs, [])
|
|
self.assertEqual(op.graph, g)
|
|
self.assertEqual(x.consumers(), [op])
|
|
self.assertIsNotNone(op.traceback)
|
|
self.assertEqual(g.get_operation_by_name("myop"), op)
|
|
self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0])
|
|
|
|
def testShape(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
|
|
c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], [])
|
|
op = g._create_op_from_tf_operation(c_op)
|
|
|
|
self.assertEqual(op.name, "myop")
|
|
self.assertEqual(op.type, "Identity")
|
|
self.assertEqual(len(op.outputs), 1)
|
|
self.assertEqual(op.outputs[0].shape, tensor_shape.TensorShape([2, 3]))
|
|
|
|
def testUniqueName(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], [])
|
|
c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], [])
|
|
op = g._create_op_from_tf_operation(c_op)
|
|
op2 = g._create_op_from_tf_operation(c_op2)
|
|
|
|
# Create ops with same names as op1 and op2. We expect the new names to be
|
|
# uniquified.
|
|
op3 = test_ops.int_output(name="myop").op
|
|
op4 = test_ops.int_output(name="myop_1").op
|
|
|
|
self.assertEqual(op.name, "myop")
|
|
self.assertEqual(op2.name, "myop_1")
|
|
self.assertEqual(op3.name, "myop_2")
|
|
self.assertEqual(op4.name, "myop_1_1")
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testCond(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
x = test_ops.int_output()
|
|
|
|
def true_fn():
|
|
ops._create_c_op(ops.get_default_graph(),
|
|
ops._NodeDef("IntInput", "cond/myop"), [x], [])
|
|
new_ops = g._add_new_tf_operations()
|
|
self.assertEqual(len(new_ops), 1)
|
|
return x
|
|
|
|
control_flow_ops.cond(x < 10, true_fn, lambda: x)
|
|
|
|
op = g.get_operation_by_name("cond/myop")
|
|
self.assertIsNotNone(op)
|
|
self.assertEqual(op.name, "cond/myop")
|
|
self.assertEqual(op.type, "IntInput")
|
|
self.assertEqual(op.outputs, [])
|
|
op_input = op.inputs[0].op
|
|
self.assertEqual(op_input.type, "Switch")
|
|
self.assertEqual(op_input.inputs[0], x)
|
|
self.assertEqual(op.graph, g)
|
|
# pylint: disable=protected-access
|
|
self.assertIsNotNone(op._get_control_flow_context())
|
|
self.assertEqual(op._get_control_flow_context().name,
|
|
"cond/cond_text")
|
|
# pylint: enable=protected-access
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testWhileLoop(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
x = test_ops.int_output()
|
|
|
|
def body(i):
|
|
ops._create_c_op(ops.get_default_graph(),
|
|
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
|
|
new_ops = g._add_new_tf_operations()
|
|
self.assertEqual(len(new_ops), 1)
|
|
return i
|
|
|
|
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
|
|
|
|
op = g.get_operation_by_name("myloop/myop")
|
|
self.assertIsNotNone(op)
|
|
self.assertEqual(op.name, "myloop/myop")
|
|
self.assertEqual(op.type, "IntInput")
|
|
self.assertEqual(op.outputs, [])
|
|
op_input = op.inputs[0].op
|
|
self.assertEqual(op_input.type, "Enter")
|
|
self.assertEqual(list(op_input.inputs), [x])
|
|
self.assertEqual(op.graph, g)
|
|
# pylint: disable=protected-access
|
|
self.assertIsNotNone(op._get_control_flow_context())
|
|
self.assertEqual(op._get_control_flow_context().name,
|
|
"myloop/while_context")
|
|
# pylint: enable=protected-access
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testWhileLoopWithInternalControlDep(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
x = test_ops.int_output()
|
|
|
|
def body(i):
|
|
c = constant_op.constant(1.0, name="c")
|
|
ops._create_c_op(ops.get_default_graph(),
|
|
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
|
|
with ops.control_dependencies([c]):
|
|
new_ops = g._add_new_tf_operations()
|
|
self.assertEqual(len(new_ops), 1)
|
|
return i
|
|
|
|
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
|
|
|
|
op = g.get_operation_by_name("myloop/myop")
|
|
self.assertIsNotNone(op)
|
|
c = g.get_operation_by_name("myloop/c")
|
|
self.assertIsNotNone(c)
|
|
# Internal control dep is preserved
|
|
self.assertEqual(op.control_inputs, [c])
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testWhileLoopWithExternalControlDep(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
x = test_ops.int_output()
|
|
c = constant_op.constant(1.0)
|
|
|
|
def body(i):
|
|
ops._create_c_op(ops.get_default_graph(),
|
|
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
|
|
with ops.control_dependencies([c]):
|
|
new_ops = g._add_new_tf_operations()
|
|
self.assertEqual(len(new_ops), 1)
|
|
return i
|
|
|
|
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
|
|
|
|
op = g.get_operation_by_name("myloop/myop")
|
|
self.assertIsNotNone(op)
|
|
# External control dep is removed and replaced with internal control dep
|
|
self.assertNotEqual(op.control_inputs[0], c.op)
|
|
self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
|
|
|
|
|
|
class ApplyOpTest(test_util.TensorFlowTestCase):
|
|
|
|
def testNodeDefArgs(self):
|
|
g = ops.Graph()
|
|
t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
|
|
with g.device("/device:GPU:0"):
|
|
t2 = _apply_op(
|
|
g, "TwoIntOutputs", [], [dtypes.int32, dtypes.int32], name="myop2")
|
|
t3 = _apply_op(
|
|
g,
|
|
"Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32],
|
|
name="myop3")
|
|
self.assertTrue(isinstance(t1, ops.Tensor))
|
|
self.assertTrue(isinstance(t2, list))
|
|
self.assertTrue(isinstance(t3, list))
|
|
self.assertTrue(isinstance(t3[0], ops.Tensor))
|
|
self.assertEqual("myop1", t1._as_node_def_input())
|
|
self.assertEqual("myop2", t2[0]._as_node_def_input())
|
|
self.assertEqual("myop2:1", t2[1]._as_node_def_input())
|
|
self.assertEqual("myop3", t3[0]._as_node_def_input())
|
|
# Validate that we got the right ops as well
|
|
self.assertProtoEquals("name:'myop1' op:'FloatOutput'", t1.op.node_def)
|
|
self.assertProtoEquals(
|
|
"name:'myop2' op:'TwoIntOutputs' device:'/device:GPU:0'",
|
|
t2[0].op.node_def)
|
|
self.assertProtoEquals(
|
|
"name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo1'",
|
|
t3[0].op.node_def)
|
|
|
|
def testReferenceInput(self):
|
|
g = ops.Graph()
|
|
ref_t, nonref_t = _apply_op(
|
|
g, "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
|
|
name="op1")
|
|
self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'",
|
|
ref_t.op.node_def)
|
|
# NOTE(mrry): Must specify input_types to preserve ref-typed input.
|
|
out_2 = _apply_op(
|
|
g,
|
|
"RefInputFloatInputIntOutput", [ref_t, nonref_t], [dtypes.int32],
|
|
input_types=[dtypes.float32_ref, dtypes.float32],
|
|
name="op2")
|
|
self.assertProtoEquals(
|
|
"op:'RefInputFloatInputIntOutput' name:'op2' input:'op1' input:'op1:1'",
|
|
out_2.op.node_def)
|
|
out_3 = _apply_op(
|
|
g, "TwoFloatInputsIntOutput", [ref_t, nonref_t], [dtypes.int32],
|
|
name="op3")
|
|
self.assertProtoEquals(
|
|
"op:'TwoFloatInputsIntOutput' name:'op3' input:'op1' input:'op1:1'",
|
|
out_3.op.node_def)
|
|
|
|
|
|
class NameStackTest(test_util.TensorFlowTestCase):
|
|
|
|
def testBasics(self):
|
|
g = ops.Graph()
|
|
self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
|
|
self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
|
|
self.assertEqual("foo", g.unique_name("foo"))
|
|
self.assertEqual("foo_1", g.unique_name("foo", mark_as_used=False))
|
|
self.assertEqual("foo_1", g.unique_name("foo"))
|
|
self.assertEqual("foo_2", g.unique_name("foo", mark_as_used=False))
|
|
self.assertEqual("foo_2", g.unique_name("foo"))
|
|
self.assertEqual("foo_1_1", g.unique_name("foo_1", mark_as_used=False))
|
|
self.assertEqual("foo_1_1", g.unique_name("foo_1"))
|
|
self.assertEqual("foo_1_2", g.unique_name("foo_1", mark_as_used=False))
|
|
self.assertEqual("foo_1_2", g.unique_name("foo_1"))
|
|
self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2", mark_as_used=False))
|
|
self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2"))
|
|
with g.name_scope("bar"):
|
|
self.assertEqual("bar/foo", g.unique_name("foo", mark_as_used=False))
|
|
self.assertEqual("bar/foo", g.unique_name("foo"))
|
|
self.assertEqual("bar/foo_1", g.unique_name("foo", mark_as_used=False))
|
|
self.assertEqual("bar/foo_1", g.unique_name("foo"))
|
|
with g.name_scope(None):
|
|
self.assertEqual("foo_3", g.unique_name("foo", mark_as_used=False))
|
|
self.assertEqual("foo_3", g.unique_name("foo"))
|
|
with g.name_scope("baz"):
|
|
self.assertEqual(
|
|
"bar/baz/foo", g.unique_name(
|
|
"foo", mark_as_used=False))
|
|
self.assertEqual("bar/baz/foo", g.unique_name("foo"))
|
|
self.assertEqual(
|
|
"bar/baz/foo_1", g.unique_name(
|
|
"foo", mark_as_used=False))
|
|
self.assertEqual("bar/baz/foo_1", g.unique_name("foo"))
|
|
with g.name_scope("baz"):
|
|
self.assertEqual(
|
|
"bar/baz_1/foo", g.unique_name(
|
|
"foo", mark_as_used=False))
|
|
self.assertEqual("bar/baz_1/foo", g.unique_name("foo"))
|
|
self.assertEqual(
|
|
"bar/baz_1/foo_1", g.unique_name(
|
|
"foo", mark_as_used=False))
|
|
self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo"))
|
|
with g.name_scope("quux"):
|
|
self.assertEqual("quux/foo", g.unique_name("foo", mark_as_used=False))
|
|
self.assertEqual("quux/foo", g.unique_name("foo"))
|
|
with g.name_scope("bar"):
|
|
with g.name_scope("baz"):
|
|
self.assertEqual(
|
|
"bar_1/baz/foo", g.unique_name(
|
|
"foo", mark_as_used=False))
|
|
self.assertEqual("bar_1/baz/foo", g.unique_name("foo"))
|
|
self.assertEqual("foo_4", g.unique_name("foo", mark_as_used=False))
|
|
self.assertEqual("foo_4", g.unique_name("foo"))
|
|
self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False))
|
|
self.assertEqual("bar_2", g.unique_name("bar"))
|
|
|
|
def testBackslashAndDashRegex(self):
|
|
# GitHub issue 39019, all should pass
|
|
g = ops.Graph()
|
|
with g.name_scope("n_CatCntc-campaign\\c_campaign"):
|
|
pass
|
|
with g.name_scope("foo"):
|
|
with g.name_scope("n_CatCntc-campaign\\c_campaign"):
|
|
pass
|
|
with g.name_scope("n_CatCntc-campaign\\c_campaign"):
|
|
with g.name_scope("foo"):
|
|
pass
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testNameAndVariableScope(self):
|
|
with self.cached_session() as sess:
|
|
with sess.graph.name_scope("l0"):
|
|
with variable_scope.variable_scope("l1"):
|
|
with sess.graph.name_scope("l1") as scope:
|
|
self.assertEqual("l0/l1/l1/", scope)
|
|
self.assertEqual(
|
|
"l0/l1/l1/foo",
|
|
sess.graph.unique_name(
|
|
"foo", mark_as_used=False))
|
|
self.assertEqual("l0/l1/l1/foo", sess.graph.unique_name("foo"))
|
|
with sess.graph.name_scope("l2") as scope:
|
|
self.assertEqual("l0/l1/l2/", scope)
|
|
self.assertEqual(
|
|
"l0/l1/l2/foo",
|
|
sess.graph.unique_name(
|
|
"foo", mark_as_used=False))
|
|
self.assertEqual("l0/l1/l2/foo", sess.graph.unique_name("foo"))
|
|
|
|
def testOutOfOrderUniqueName(self):
|
|
g = ops.Graph()
|
|
self.assertEqual("foo_2", g.unique_name("foo_2"))
|
|
self.assertEqual("foo", g.unique_name("foo"))
|
|
self.assertEqual("foo_1", g.unique_name("foo"))
|
|
self.assertEqual("foo_3", g.unique_name("foo"))
|
|
|
|
def testUniqueNameCaseInsensitivity(self):
|
|
g = ops.Graph()
|
|
self.assertEqual("foo", g.unique_name("foo"))
|
|
self.assertEqual("Foo_1", g.unique_name("Foo"))
|
|
with g.name_scope("bar"):
|
|
self.assertEqual("bar/foo", g.unique_name("foo"))
|
|
with g.name_scope("Bar"):
|
|
self.assertEqual("Bar_1/foo", g.unique_name("foo"))
|
|
|
|
def testInvalidNameRaisesError(self):
|
|
g = ops.Graph()
|
|
with g.name_scope(""): # Should not raise
|
|
pass
|
|
with g.name_scope("foo/"): # Should not raise
|
|
with g.name_scope("_bar"): # Should not raise
|
|
pass
|
|
with self.assertRaises(ValueError):
|
|
with g.name_scope("foo:0"):
|
|
pass
|
|
with self.assertRaises(ValueError):
|
|
with g.name_scope("_bar"):
|
|
pass
|
|
|
|
|
|
class NameTest(test_util.TensorFlowTestCase):
|
|
|
|
def testGenerateName(self):
|
|
g = ops.Graph()
|
|
op0 = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
|
|
self.assertEqual("TwoFloatOutputs", op0.name)
|
|
self.assertEqual("TwoFloatOutputs:0", op0.outputs[0].name)
|
|
self.assertEqual("TwoFloatOutputs:1", op0.outputs[1].name)
|
|
|
|
op1 = g.create_op("FloatOutput", [], [dtypes.float32])
|
|
self.assertEqual("FloatOutput", op1.name)
|
|
self.assertEqual("FloatOutput:0", op1.outputs[0].name)
|
|
|
|
op2 = g.create_op("FloatOutput", [], [dtypes.float32])
|
|
self.assertEqual("FloatOutput_1", op2.name)
|
|
self.assertEqual("FloatOutput_1:0", op2.outputs[0].name)
|
|
|
|
op3 = g.create_op("FloatOutput", [], [dtypes.float32], name="my_op")
|
|
self.assertEqual("my_op", op3.name)
|
|
self.assertEqual("my_op:0", op3.outputs[0].name)
|
|
|
|
def testNameScope(self):
|
|
g = ops.Graph()
|
|
|
|
with g.name_scope("foo") as foo:
|
|
self.assertEqual("foo/", foo)
|
|
with g.name_scope("foo2") as foo2:
|
|
self.assertEqual("foo/foo2/", foo2)
|
|
with g.name_scope(None) as empty1:
|
|
self.assertEqual("", empty1)
|
|
with g.name_scope("foo3") as foo3:
|
|
self.assertEqual("foo3/", foo3)
|
|
with g.name_scope("") as empty2:
|
|
self.assertEqual("", empty2)
|
|
|
|
self.assertEqual("FloatOutput",
|
|
g.create_op("FloatOutput", [], [dtypes.float32]).name)
|
|
with g.name_scope("bar") as scope:
|
|
self.assertEqual("bar/FloatOutput",
|
|
g.create_op("FloatOutput", [], [dtypes.float32]).name)
|
|
self.assertEqual("bar/FloatOutput_1",
|
|
g.create_op("FloatOutput", [], [dtypes.float32]).name)
|
|
# If you use the value from "with .. as", that values is used as-is.
|
|
self.assertEqual(
|
|
"bar", g.create_op(
|
|
"FloatOutput", [], [dtypes.float32], name=scope).name)
|
|
with g.name_scope("baz") as scope:
|
|
with g.name_scope("quux"):
|
|
self.assertEqual("baz/quux/FloatOutput",
|
|
g.create_op("FloatOutput", [], [dtypes.float32]).name)
|
|
# If you use the value from the enclosing "with .. as", nothing is pushed.
|
|
with g.name_scope(scope):
|
|
self.assertEqual("baz/FloatOutput",
|
|
g.create_op("FloatOutput", [], [dtypes.float32]).name)
|
|
self.assertEqual(
|
|
"baz", g.create_op(
|
|
"FloatOutput", [], [dtypes.float32], name=scope).name)
|
|
self.assertEqual(
|
|
"trailing",
|
|
g.create_op(
|
|
"FloatOutput", [], [dtypes.float32], name="trailing/").name)
|
|
with g.name_scope("bar"):
|
|
self.assertEqual("bar_1/FloatOutput",
|
|
g.create_op("FloatOutput", [], [dtypes.float32]).name)
|
|
with g.name_scope("bar/"):
|
|
self.assertEqual("bar/FloatOutput_2",
|
|
g.create_op("FloatOutput", [], [dtypes.float32]).name)
|
|
|
|
|
|
class DeviceTest(test_util.TensorFlowTestCase):
|
|
|
|
def testNoDevice(self):
|
|
g = ops.Graph()
|
|
op = g.create_op("FloatOutput", [], [dtypes.float32])
|
|
self.assertDeviceEqual(None, op.device)
|
|
gd = g.as_graph_def()
|
|
self.assertProtoEqualsVersion("""
|
|
node { name: "FloatOutput" op: "FloatOutput" }
|
|
""", gd)
|
|
|
|
def testEagerBackingDevice(self):
|
|
with context.eager_mode():
|
|
with ops.device("/device:CPU:0"):
|
|
t = constant_op.constant(1.0)
|
|
self.assertRegex(t.device, "/device:CPU:0")
|
|
self.assertRegex(t.backing_device, "/device:CPU:0")
|
|
|
|
def testDevicePartialString(self):
|
|
g = ops.Graph()
|
|
with g.device("/job:worker/replica:2"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
gd = g.as_graph_def()
|
|
self.assertProtoEqualsVersion("""
|
|
node { name: "FloatOutput" op: "FloatOutput"
|
|
device: "/job:worker/replica:2" }
|
|
""", gd)
|
|
|
|
def testDeviceFull(self):
|
|
g = ops.Graph()
|
|
with g.device(
|
|
pydev.DeviceSpec(
|
|
job="worker", replica=2, task=0, device_type="CPU",
|
|
device_index=3)):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
gd = g.as_graph_def()
|
|
self.assertProtoEqualsVersion("""
|
|
node { name: "FloatOutput" op: "FloatOutput"
|
|
device: "/job:worker/replica:2/task:0/device:CPU:3" }
|
|
""", gd)
|
|
|
|
def testNesting(self):
|
|
g = ops.Graph()
|
|
with g.device("/job:worker/replica:2"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device("/job:worker/replica:3/task:0"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
gd = g.as_graph_def()
|
|
self.assertProtoEqualsVersion("""
|
|
node { name: "FloatOutput" op: "FloatOutput"
|
|
device: "/job:worker/replica:2" }
|
|
node { name: "FloatOutput_1" op: "FloatOutput"
|
|
device: "/job:worker/replica:3/task:0" }
|
|
node { name: "FloatOutput_2" op: "FloatOutput"
|
|
device: "/job:worker/replica:2" }
|
|
""", gd)
|
|
|
|
def testNestingString(self):
|
|
g = ops.Graph()
|
|
with g.device("/job:worker/replica:2"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device("/job:worker/replica:3/task:0"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
gd = g.as_graph_def()
|
|
self.assertProtoEqualsVersion("""
|
|
node { name: "FloatOutput" op: "FloatOutput"
|
|
device: "/job:worker/replica:2" }
|
|
node { name: "FloatOutput_1" op: "FloatOutput"
|
|
device: "/job:worker/replica:3/task:0" }
|
|
node { name: "FloatOutput_2" op: "FloatOutput"
|
|
device: "/job:worker/replica:2" }
|
|
""", gd)
|
|
|
|
def testNestingOverrideGpuCpu(self):
|
|
g = ops.Graph()
|
|
with g.device("/job:worker/replica:2/device:CPU:1"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device("/job:worker/replica:2/device:GPU:2"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
gd = g.as_graph_def()
|
|
self.assertProtoEqualsVersion("""
|
|
node { name: "FloatOutput" op: "FloatOutput"
|
|
device: "/job:worker/replica:2/device:CPU:1" }
|
|
node { name: "FloatOutput_1" op: "FloatOutput"
|
|
device: "/job:worker/replica:2/device:GPU:2" }
|
|
node { name: "FloatOutput_2" op: "FloatOutput"
|
|
device: "/job:worker/replica:2/device:CPU:1" }
|
|
""", gd)
|
|
|
|
def testNestingWithMergeDeviceFunction(self):
|
|
g = ops.Graph()
|
|
|
|
with g.device(pydev.merge_device("/device:GPU:0")):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device(pydev.merge_device("/job:worker")):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device(pydev.merge_device("/device:CPU:0")):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device(pydev.merge_device("/job:ps")):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device(pydev.merge_device(None)):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
|
|
gd = g.as_graph_def()
|
|
self.assertProtoEqualsVersion("""
|
|
node { name: "FloatOutput" op: "FloatOutput"
|
|
device: "/device:GPU:0" }
|
|
node { name: "FloatOutput_1" op: "FloatOutput"
|
|
device: "/job:worker/device:GPU:0" }
|
|
node { name: "FloatOutput_2" op: "FloatOutput"
|
|
device: "/job:worker/device:CPU:0" }
|
|
node { name: "FloatOutput_3" op: "FloatOutput"
|
|
device: "/job:ps/device:CPU:0" }
|
|
node { name: "FloatOutput_4" op: "FloatOutput"
|
|
device: "/job:ps/device:CPU:0" }
|
|
""", gd)
|
|
|
|
def testNestingWithDeviceStrings(self):
|
|
g = ops.Graph()
|
|
|
|
with g.device("/device:GPU:0"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device("/job:worker"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device("/device:CPU:0"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device("/job:ps"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device(""):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
|
|
gd = g.as_graph_def()
|
|
self.assertProtoEqualsVersion("""
|
|
node { name: "FloatOutput" op: "FloatOutput"
|
|
device: "/device:GPU:0" }
|
|
node { name: "FloatOutput_1" op: "FloatOutput"
|
|
device: "/job:worker/device:GPU:0" }
|
|
node { name: "FloatOutput_2" op: "FloatOutput"
|
|
device: "/job:worker/device:CPU:0" }
|
|
node { name: "FloatOutput_3" op: "FloatOutput"
|
|
device: "/job:ps/device:CPU:0" }
|
|
node { name: "FloatOutput_4" op: "FloatOutput"
|
|
device: "/job:ps/device:CPU:0" }
|
|
""", gd)
|
|
|
|
def testNestingWithDeviceStringWildcard(self):
|
|
g = ops.Graph()
|
|
|
|
with g.device("/device:GPU:7"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device("/device:GPU:*"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
|
|
with g.device("/device:CPU:*"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device("/device:CPU:5"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
|
|
gd = g.as_graph_def()
|
|
self.assertProtoEqualsVersion("""
|
|
node { name: "FloatOutput" op: "FloatOutput"
|
|
device: "/device:GPU:7" }
|
|
node { name: "FloatOutput_1" op: "FloatOutput"
|
|
device: "/device:GPU:7" }
|
|
node { name: "FloatOutput_2" op: "FloatOutput"
|
|
device: "/device:CPU:*" }
|
|
node { name: "FloatOutput_3" op: "FloatOutput"
|
|
device: "/device:CPU:5" }
|
|
""", gd)
|
|
|
|
def testNestingErrorGraph(self):
|
|
g = ops.Graph()
|
|
scope = g.device("/device:GPU:8")
|
|
scope.__enter__()
|
|
with g.device("/device:GPU:9"):
|
|
with self.assertRaises(RuntimeError):
|
|
scope.__exit__(None, None, None)
|
|
|
|
def testNestingErrorEager(self):
|
|
with context.eager_mode():
|
|
scope = ops.device("/device:CPU:0")
|
|
scope.__enter__()
|
|
with ops.device(None):
|
|
with self.assertRaises(RuntimeError):
|
|
scope.__exit__(None, None, None)
|
|
|
|
def testNoneClearsDefault(self):
|
|
g = ops.Graph()
|
|
with g.device("/job:worker/replica:2/device:CPU:1"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device(None):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
gd = g.as_graph_def()
|
|
self.assertProtoEqualsVersion("""
|
|
node { name: "FloatOutput" op: "FloatOutput"
|
|
device: "/job:worker/replica:2/device:CPU:1" }
|
|
node { name: "FloatOutput_1" op: "FloatOutput" }
|
|
node { name: "FloatOutput_2" op: "FloatOutput"
|
|
device: "/job:worker/replica:2/device:CPU:1" }
|
|
""", gd)
|
|
|
|
def testNoneIgnoresOuterDeviceFunction(self):
|
|
g = ops.Graph()
|
|
with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device(None):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
gd = g.as_graph_def()
|
|
self.assertProtoEqualsVersion("""
|
|
node { name: "FloatOutput" op: "FloatOutput"
|
|
device: "/job:worker/replica:2/device:CPU:1" }
|
|
node { name: "FloatOutput_1" op: "FloatOutput" }
|
|
node { name: "FloatOutput_2" op: "FloatOutput"
|
|
device: "/job:worker/replica:2/device:CPU:1" }
|
|
""", gd)
|
|
|
|
def _overwritingDeviceFunction(self, unused_op):
|
|
# This device function unconditionally overwrites the device of ops.
|
|
#
|
|
# NOTE(mrry): Writing device functions like this is not
|
|
# recommended. Instead, in most cases you should use
|
|
# `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the
|
|
# argument to `tf.device()` and the device component will be merged in.
|
|
return "/job:overwrite"
|
|
|
|
def testOverwritingBehavior(self):
|
|
g = ops.Graph()
|
|
with g.device(self._overwritingDeviceFunction):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device("/job:ps"): # Will be overwritten.
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device(pydev.merge_device("/job:ps")): # Will be overwritten.
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device(None): # Disables overwriting device function
|
|
with g.device("/job:ps"):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
with g.device(None): # Disables overwriting device function
|
|
with g.device(pydev.merge_device("/job:ps")):
|
|
g.create_op("FloatOutput", [], [dtypes.float32])
|
|
gd = g.as_graph_def()
|
|
self.assertProtoEqualsVersion("""
|
|
node { name: "FloatOutput" op: "FloatOutput"
|
|
device: "/job:overwrite" }
|
|
node { name: "FloatOutput_1" op: "FloatOutput"
|
|
device: "/job:overwrite" }
|
|
node { name: "FloatOutput_2" op: "FloatOutput"
|
|
device: "/job:overwrite" }
|
|
node { name: "FloatOutput_3" op: "FloatOutput"
|
|
device: "/job:ps" }
|
|
node { name: "FloatOutput_4" op: "FloatOutput"
|
|
device: "/job:ps" }
|
|
""", gd)
|
|
|
|
|
|
class MultithreadedGraphStateTest(test_util.TensorFlowTestCase):
|
|
|
|
class TestThread(threading.Thread):
|
|
|
|
def __init__(self, graph, replica_id):
|
|
super(MultithreadedGraphStateTest.TestThread, self).__init__()
|
|
self._graph = graph
|
|
self._replica_id = replica_id
|
|
# This thread sets this event when it mutated the graph. The caller can
|
|
# wait for that.
|
|
self.has_mutated_graph = threading.Event()
|
|
# This thread waits for when it should continue. The caller can set this
|
|
# event.
|
|
self.should_continue = threading.Event()
|
|
|
|
def run(self):
|
|
# Mutate a graph's stack, then set `has_mutated_graph`, then wait for
|
|
# `should_continue`, then add an op to the graph affected by the graph's
|
|
# stack.
|
|
raise NotImplementedError("must be implemented in descendants")
|
|
|
|
def testDeviceFunctionStack(self):
|
|
|
|
class DeviceSettingThread(self.TestThread):
|
|
|
|
def run(self):
|
|
with g.device("/job:worker/replica:{}".format(self._replica_id)):
|
|
self.has_mutated_graph.set()
|
|
self.should_continue.wait()
|
|
self.should_continue.clear()
|
|
g.create_op(
|
|
"FloatOutput", [], [dtypes.float32],
|
|
name="FloatOutput_{}".format(self._replica_id))
|
|
|
|
g = ops.Graph()
|
|
# If `switch_to_thread` isn't called, then device placement of the ops
|
|
# below is not deterministic.
|
|
g.switch_to_thread_local()
|
|
threads = [DeviceSettingThread(g, i) for i in range(3)]
|
|
for t in threads:
|
|
t.start()
|
|
t.has_mutated_graph.wait()
|
|
t.has_mutated_graph.clear()
|
|
for t in threads:
|
|
t.should_continue.set()
|
|
t.join()
|
|
|
|
gd = g.as_graph_def()
|
|
self.assertProtoEqualsVersion("""
|
|
node { name: "FloatOutput_0" op: "FloatOutput"
|
|
device: "/job:worker/replica:0" }
|
|
node { name: "FloatOutput_1" op: "FloatOutput"
|
|
device: "/job:worker/replica:1" }
|
|
node { name: "FloatOutput_2" op: "FloatOutput"
|
|
device: "/job:worker/replica:2" }
|
|
""", gd)
|
|
|
|
def testColocateWith(self):
|
|
|
|
class ColocatingThread(self.TestThread):
|
|
|
|
def __init__(self, graph, replica_id, op_to_colocate_with):
|
|
super(ColocatingThread, self).__init__(graph, replica_id)
|
|
self._op_to_colocate_with = op_to_colocate_with
|
|
|
|
def run(self):
|
|
with g.colocate_with(self._op_to_colocate_with):
|
|
self.has_mutated_graph.set()
|
|
self.should_continue.wait()
|
|
self.should_continue.clear()
|
|
g.create_op(
|
|
"FloatOutput", [], [dtypes.float32],
|
|
name="FloatOutput_{}".format(self._replica_id))
|
|
|
|
g = ops.Graph()
|
|
ops_to_colocate_with = []
|
|
for i in range(3):
|
|
with g.device("/job:worker/replica:{}".format(i)):
|
|
ops_to_colocate_with.append(
|
|
g.create_op(
|
|
"FloatOutput", [], [dtypes.float32],
|
|
name="ColocateWithMe_{}".format(i)))
|
|
|
|
# If `switch_to_thread` isn't called, then `device` and `attr` values for
|
|
# the ops below are not deterministic.
|
|
g.switch_to_thread_local()
|
|
threads = [
|
|
ColocatingThread(g, i, ops_to_colocate_with[i]) for i in range(3)
|
|
]
|
|
for t in threads:
|
|
t.start()
|
|
t.has_mutated_graph.wait()
|
|
t.has_mutated_graph.clear()
|
|
for t in threads:
|
|
t.should_continue.set()
|
|
t.join()
|
|
|
|
gd = g.as_graph_def()
|
|
self.assertProtoEqualsVersion("""
|
|
node { name: "ColocateWithMe_0" op: "FloatOutput"
|
|
device: "/job:worker/replica:0" }
|
|
node { name: "ColocateWithMe_1" op: "FloatOutput"
|
|
device: "/job:worker/replica:1" }
|
|
node { name: "ColocateWithMe_2" op: "FloatOutput"
|
|
device: "/job:worker/replica:2" }
|
|
node { name: "FloatOutput_0" op: "FloatOutput"
|
|
device: "/job:worker/replica:0"
|
|
attr { key: "_class"
|
|
value { list {
|
|
s: "loc:@ColocateWithMe_0"}}}}
|
|
node { name: "FloatOutput_1" op: "FloatOutput"
|
|
device: "/job:worker/replica:1"
|
|
attr { key: "_class"
|
|
value { list {
|
|
s: "loc:@ColocateWithMe_1"}}}}
|
|
node { name: "FloatOutput_2" op: "FloatOutput"
|
|
device: "/job:worker/replica:2"
|
|
attr { key: "_class"
|
|
value { list {
|
|
s: "loc:@ColocateWithMe_2"}}}}
|
|
""", gd)
|
|
|
|
def testControlDependencies(self):
|
|
|
|
class DependingThread(self.TestThread):
|
|
|
|
def __init__(self, graph, replica_id, dependency_op):
|
|
super(DependingThread, self).__init__(graph, replica_id)
|
|
self._dependency_op = dependency_op
|
|
|
|
def run(self):
|
|
with g.control_dependencies([self._dependency_op]):
|
|
self.has_mutated_graph.set()
|
|
self.should_continue.wait()
|
|
self.should_continue.clear()
|
|
g.create_op(
|
|
"FloatOutput", [], [dtypes.float32],
|
|
name="FloatOutput_{}".format(self._replica_id))
|
|
|
|
g = ops.Graph()
|
|
dependency_ops = []
|
|
for i in range(3):
|
|
dependency_ops.append(
|
|
g.create_op(
|
|
"FloatOutput", [], [dtypes.float32],
|
|
name="ColocateWithMe_{}".format(i)))
|
|
|
|
# If `switch_to_thread` isn't called, then `input` values for the ops below
|
|
# are not deterministic.
|
|
g.switch_to_thread_local()
|
|
threads = [DependingThread(g, i, dependency_ops[i]) for i in range(3)]
|
|
for t in threads:
|
|
t.start()
|
|
t.has_mutated_graph.wait()
|
|
t.has_mutated_graph.clear()
|
|
for t in threads:
|
|
t.should_continue.set()
|
|
t.join()
|
|
|
|
gd = g.as_graph_def()
|
|
self.assertProtoEqualsVersion("""
|
|
node { name: "ColocateWithMe_0" op: "FloatOutput" }
|
|
node { name: "ColocateWithMe_1" op: "FloatOutput" }
|
|
node { name: "ColocateWithMe_2" op: "FloatOutput" }
|
|
node { name: "FloatOutput_0" op: "FloatOutput"
|
|
input: "^ColocateWithMe_0" }
|
|
node { name: "FloatOutput_1" op: "FloatOutput"
|
|
input: "^ColocateWithMe_1" }
|
|
node { name: "FloatOutput_2" op: "FloatOutput"
|
|
input: "^ColocateWithMe_2" }
|
|
""", gd)
|
|
|
|
def testNameStack(self):
|
|
|
|
class NameSettingThread(self.TestThread):
|
|
|
|
def run(self):
|
|
with g.name_scope("foo"):
|
|
op1 = g.create_op("FloatOutput", [], [dtypes.float32])
|
|
self.has_mutated_graph.set()
|
|
self.should_continue.wait()
|
|
self.should_continue.clear()
|
|
op2 = g.create_op("FloatOutput", [], [dtypes.float32])
|
|
self.result = (op1, op2)
|
|
|
|
g = ops.Graph()
|
|
threads = [NameSettingThread(g, i) for i in range(3)]
|
|
for t in threads:
|
|
t.start()
|
|
t.has_mutated_graph.wait()
|
|
t.has_mutated_graph.clear()
|
|
|
|
for t in threads:
|
|
t.should_continue.set()
|
|
t.join()
|
|
|
|
suffixes = ["", "_1", "_2"]
|
|
for t, s in zip(threads, suffixes):
|
|
self.assertEqual("foo" + s + "/FloatOutput", t.result[0].name)
|
|
self.assertEqual("foo" + s + "/FloatOutput_1", t.result[1].name)
|
|
|
|
|
|
class ObjectWithName(object):
|
|
|
|
def __init__(self, name):
|
|
self._name = name
|
|
|
|
@property
|
|
def name(self):
|
|
return self._name
|
|
|
|
|
|
class CollectionTest(test_util.TensorFlowTestCase):
|
|
|
|
def test_get_collections(self):
|
|
g = ops.Graph()
|
|
self.assertSequenceEqual(g.collections, [])
|
|
g.add_to_collection("key", 12)
|
|
g.add_to_collection("key", 15)
|
|
self.assertSequenceEqual(g.collections, ["key"])
|
|
g.add_to_collection("other", "foo")
|
|
self.assertSequenceEqual(sorted(g.collections), ["key", "other"])
|
|
self.assertSequenceEqual(
|
|
sorted(g.get_all_collection_keys()), ["key", "other"])
|
|
|
|
def test_add_to_collection(self):
|
|
g = ops.Graph()
|
|
g.add_to_collection("key", 12)
|
|
g.add_to_collection("other", "foo")
|
|
g.add_to_collection("key", 34)
|
|
|
|
# Note that only blank1 is returned.
|
|
g.add_to_collection("blah", 27)
|
|
blank1 = ObjectWithName("prefix/foo")
|
|
g.add_to_collection("blah", blank1)
|
|
blank2 = ObjectWithName("junk/foo")
|
|
g.add_to_collection("blah", blank2)
|
|
|
|
self.assertEqual([12, 34], g.get_collection("key"))
|
|
self.assertEqual([], g.get_collection("nothing"))
|
|
self.assertEqual([27, blank1, blank2], g.get_collection("blah"))
|
|
self.assertEqual([blank1], g.get_collection("blah", "prefix"))
|
|
self.assertEqual([blank1], g.get_collection("blah", ".*x"))
|
|
|
|
# Make sure that get_collection() returns a first-level
|
|
# copy of the collection, while get_collection_ref() returns
|
|
# the original list.
|
|
other_collection_snapshot = g.get_collection("other")
|
|
other_collection_ref = g.get_collection_ref("other")
|
|
self.assertEqual(["foo"], other_collection_snapshot)
|
|
self.assertEqual(["foo"], other_collection_ref)
|
|
g.add_to_collection("other", "bar")
|
|
self.assertEqual(["foo"], other_collection_snapshot)
|
|
self.assertEqual(["foo", "bar"], other_collection_ref)
|
|
self.assertEqual(["foo", "bar"], g.get_collection("other"))
|
|
self.assertTrue(other_collection_ref is g.get_collection_ref("other"))
|
|
|
|
# Verify that getting an empty collection ref returns a modifiable list.
|
|
empty_coll_ref = g.get_collection_ref("empty")
|
|
self.assertEqual([], empty_coll_ref)
|
|
empty_coll = g.get_collection("empty")
|
|
self.assertEqual([], empty_coll)
|
|
self.assertFalse(empty_coll is empty_coll_ref)
|
|
empty_coll_ref2 = g.get_collection_ref("empty")
|
|
self.assertTrue(empty_coll_ref2 is empty_coll_ref)
|
|
# Add to the collection.
|
|
empty_coll_ref.append("something")
|
|
self.assertEqual(["something"], empty_coll_ref)
|
|
self.assertEqual(["something"], empty_coll_ref2)
|
|
self.assertEqual([], empty_coll)
|
|
self.assertEqual(["something"], g.get_collection("empty"))
|
|
empty_coll_ref3 = g.get_collection_ref("empty")
|
|
self.assertTrue(empty_coll_ref3 is empty_coll_ref)
|
|
|
|
def test_add_to_collections_uniquify(self):
|
|
g = ops.Graph()
|
|
g.add_to_collections([1, 2, 1], "key")
|
|
# Make sure "key" is not added twice
|
|
self.assertEqual(["key"], g.get_collection(1))
|
|
|
|
def test_add_to_collections_from_list(self):
|
|
g = ops.Graph()
|
|
g.add_to_collections(["abc", "123"], "key")
|
|
self.assertEqual(["key"], g.get_collection("abc"))
|
|
self.assertEqual(["key"], g.get_collection("123"))
|
|
|
|
def test_add_to_collections_from_tuple(self):
|
|
g = ops.Graph()
|
|
g.add_to_collections(("abc", "123"), "key")
|
|
self.assertEqual(["key"], g.get_collection("abc"))
|
|
self.assertEqual(["key"], g.get_collection("123"))
|
|
|
|
def test_add_to_collections_from_generator(self):
|
|
g = ops.Graph()
|
|
|
|
def generator():
|
|
yield "abc"
|
|
yield "123"
|
|
|
|
g.add_to_collections(generator(), "key")
|
|
self.assertEqual(["key"], g.get_collection("abc"))
|
|
self.assertEqual(["key"], g.get_collection("123"))
|
|
|
|
def test_add_to_collections_from_set(self):
|
|
g = ops.Graph()
|
|
g.add_to_collections(set(["abc", "123"]), "key")
|
|
self.assertEqual(["key"], g.get_collection("abc"))
|
|
self.assertEqual(["key"], g.get_collection("123"))
|
|
|
|
def test_add_to_collections_from_string(self):
|
|
g = ops.Graph()
|
|
g.add_to_collections("abc", "key")
|
|
self.assertEqual(["key"], g.get_collection("abc"))
|
|
|
|
def test_default_graph(self):
|
|
with ops.Graph().as_default():
|
|
ops.add_to_collection("key", 90)
|
|
ops.add_to_collection("key", 100)
|
|
# Collections are ordered.
|
|
self.assertEqual([90, 100], ops.get_collection("key"))
|
|
|
|
def test_defun(self):
|
|
with context.eager_mode():
|
|
|
|
@eager_function.defun
|
|
def defun():
|
|
ops.add_to_collection("int", 1)
|
|
ops.add_to_collection("tensor", constant_op.constant(2))
|
|
|
|
@eager_function.defun
|
|
def inner_defun():
|
|
self.assertEqual(ops.get_collection("int"), [1])
|
|
three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0]
|
|
ops.add_to_collection("int", 2)
|
|
self.assertEqual(ops.get_collection("int"), [1, 2])
|
|
ops.add_to_collection("foo", "bar")
|
|
self.assertEqual(ops.get_collection("foo"), ["bar"])
|
|
return three
|
|
|
|
self.assertEqual(ops.get_collection("int"), [1])
|
|
three = inner_defun()
|
|
self.assertEqual(ops.get_collection("int"), [1])
|
|
self.assertEqual(ops.get_collection("foo"), [])
|
|
return three
|
|
|
|
three = defun()
|
|
self.assertEqual(three.numpy(), 3)
|
|
|
|
|
|
ops.NotDifferentiable("FloatOutput")
|
|
|
|
|
|
@ops.RegisterGradient("CopyOp")
|
|
def _CopyGrad(op, x_grad): # pylint: disable=invalid-name
|
|
_ = op
|
|
return x_grad
|
|
|
|
|
|
@ops.RegisterGradient("copy_override")
|
|
def _CopyOverrideGrad(op, x_grad): # pylint: disable=invalid-name
|
|
_ = op
|
|
return x_grad
|
|
|
|
|
|
class RegistrationTest(test_util.TensorFlowTestCase):
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testRegisterGradients(self):
|
|
x = test_ops.float_output()
|
|
y = test_ops.copy_op(x)
|
|
fn = ops.get_gradient_function(y.op)
|
|
self.assertEqual(_CopyGrad, fn)
|
|
|
|
def testOverrideGradients(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
x = test_ops.float_output()
|
|
with g.gradient_override_map({"CopyOp": "copy_override"}):
|
|
y = test_ops.copy_op(x)
|
|
fn = ops.get_gradient_function(y.op)
|
|
self.assertEqual(_CopyOverrideGrad, fn)
|
|
|
|
def testNonExistentOverride(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
x = test_ops.float_output()
|
|
with g.gradient_override_map({"CopyOp": "unknown_override"}):
|
|
y = test_ops.copy_op(x)
|
|
with self.assertRaisesRegex(LookupError, "unknown_override"):
|
|
ops.get_gradient_function(y.op)
|
|
|
|
|
|
class ComparisonTest(test_util.TensorFlowTestCase):
|
|
|
|
def testMembershipAllowed(self):
|
|
g = ops.Graph()
|
|
t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
|
|
t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2")
|
|
self.assertTrue(isinstance(t1, ops.Tensor))
|
|
self.assertTrue(isinstance(t2, ops.Tensor))
|
|
self.assertTrue(t1 in [t1])
|
|
self.assertTrue(t1 not in [t2])
|
|
|
|
|
|
class ControlDependenciesTest(test_util.TensorFlowTestCase):
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testBasic(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
# Creating unregistered ops with _apply_op() doesn't work with the C API
|
|
# TODO(skyewm): address this more consistently. Possible solutions are
|
|
# to use registered ops in all tests, create a way to register ops in
|
|
# Python tests, or conditionally disable the op registration check in
|
|
# the C API.
|
|
a = constant_op.constant(1.0)
|
|
b = constant_op.constant(1.0)
|
|
with g.control_dependencies([a]):
|
|
c = constant_op.constant(1.0)
|
|
d = array_ops.identity(b)
|
|
e = array_ops.identity(c)
|
|
|
|
self.assertEqual(c.op.control_inputs, [a.op])
|
|
self.assertEqual(d.op.control_inputs, [a.op])
|
|
# e should be dominated by c.
|
|
self.assertEqual(e.op.control_inputs, [])
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testEager(self):
|
|
def future():
|
|
future.calls += 1
|
|
return constant_op.constant(2.0)
|
|
future.calls = 0
|
|
|
|
if context.executing_eagerly():
|
|
a = constant_op.constant(1.0)
|
|
b = future
|
|
with ops.control_dependencies([a, b]):
|
|
c = constant_op.constant(3.0)
|
|
self.assertEqual(future.calls, 1)
|
|
else:
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
a = constant_op.constant(1.0)
|
|
b = future()
|
|
with g.control_dependencies([a, b]):
|
|
c = constant_op.constant(3.0)
|
|
self.assertEqual(c.op.control_inputs, [a.op, b.op])
|
|
self.assertEqual(future.calls, 1)
|
|
|
|
def testBasicWithConversion(self):
|
|
g = ops.Graph()
|
|
a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
|
|
class ConvertibleObj(object):
|
|
|
|
def _as_graph_element(self):
|
|
return a
|
|
|
|
with g.control_dependencies([ConvertibleObj()]):
|
|
c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
|
|
self.assertEqual(c.op.control_inputs, [a.op])
|
|
|
|
def testNested(self):
|
|
g = ops.Graph()
|
|
a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
|
|
with g.control_dependencies([a_1, a_2, a_3, a_4]):
|
|
b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
|
|
with g.control_dependencies([a_1]):
|
|
with g.control_dependencies([a_2]):
|
|
with g.control_dependencies([a_3]):
|
|
with g.control_dependencies([a_4]):
|
|
b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
|
|
self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op],
|
|
b_1.op.control_inputs)
|
|
self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs)
|
|
|
|
def testClear(self):
|
|
g = ops.Graph()
|
|
a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
|
|
with g.control_dependencies([a_1]):
|
|
with g.control_dependencies([a_2]):
|
|
with g.control_dependencies(None):
|
|
with g.control_dependencies([a_3]):
|
|
with g.control_dependencies([a_4]):
|
|
# deps [a_3, a_4]
|
|
b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
# deps = [a_3]
|
|
b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
# deps back to None
|
|
b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
# deps back to [a_1, a_2]
|
|
b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
# deps back to [a_1]
|
|
b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
with g.control_dependencies(None):
|
|
# deps are None again
|
|
b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
|
|
self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
|
|
self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
|
|
self.assertItemsEqual([], b_none.op.control_inputs)
|
|
self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
|
|
self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
|
|
self.assertItemsEqual([], b_none2.op.control_inputs)
|
|
|
|
def testComplex(self):
|
|
g = ops.Graph()
|
|
|
|
# Usage pattern:
|
|
# * Nodes a_i are constants defined at the outermost scope, and are used
|
|
# as control inputs for the ith nested scope.
|
|
# * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
|
|
# * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
|
|
# * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
|
|
# * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
|
|
|
|
a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
|
|
with g.control_dependencies([a_1]):
|
|
b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
|
|
[dtypes.float32])
|
|
c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
|
|
[dtypes.float32])
|
|
d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1],
|
|
[dtypes.float32])
|
|
e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
with g.control_dependencies([a_2]):
|
|
b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
|
|
[dtypes.float32])
|
|
c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
|
|
[dtypes.float32])
|
|
d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2],
|
|
[dtypes.float32])
|
|
e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1],
|
|
[dtypes.float32])
|
|
with g.control_dependencies([a_3]):
|
|
b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
|
|
[dtypes.float32])
|
|
c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
|
|
[dtypes.float32])
|
|
d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3],
|
|
[dtypes.float32])
|
|
e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2],
|
|
[dtypes.float32])
|
|
with g.control_dependencies([a_4]):
|
|
b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
|
|
[dtypes.float32])
|
|
c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
|
|
[dtypes.float32])
|
|
d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4],
|
|
[dtypes.float32])
|
|
e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3],
|
|
[dtypes.float32])
|
|
|
|
self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
|
|
self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs)
|
|
self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs)
|
|
self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs)
|
|
|
|
self.assertItemsEqual([], c_1.op.control_inputs)
|
|
self.assertItemsEqual([a_2.op], c_2.op.control_inputs)
|
|
self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs)
|
|
self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs)
|
|
|
|
self.assertItemsEqual([], d_1.op.control_inputs)
|
|
self.assertItemsEqual([], d_2.op.control_inputs)
|
|
self.assertItemsEqual([], d_3.op.control_inputs)
|
|
self.assertItemsEqual([], d_4.op.control_inputs)
|
|
|
|
self.assertItemsEqual([a_1.op], e_1.op.control_inputs)
|
|
self.assertItemsEqual([a_2.op], e_2.op.control_inputs)
|
|
self.assertItemsEqual([a_3.op], e_3.op.control_inputs)
|
|
self.assertItemsEqual([a_4.op], e_4.op.control_inputs)
|
|
|
|
def testRepeatedDependency(self):
|
|
g = ops.Graph()
|
|
a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
|
|
a_0, a_1 = a.outputs
|
|
with g.control_dependencies([a_0]):
|
|
b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
with g.control_dependencies([a_1]):
|
|
c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
|
|
self.assertEqual(b.op.control_inputs, [a])
|
|
self.assertEqual(c.op.control_inputs, [a])
|
|
|
|
def testNoControlDependencyWithDataDependency(self):
|
|
g = ops.Graph()
|
|
a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
with g.control_dependencies([a]):
|
|
b = _apply_op(g, "Identity", [a], [dtypes.float32])
|
|
|
|
self.assertEqual(b.op.control_inputs, [])
|
|
|
|
|
|
class OpScopeTest(test_util.TensorFlowTestCase):
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testNames(self):
|
|
with ops.name_scope("foo", skip_on_eager=False) as foo:
|
|
self.assertEqual("foo/", foo)
|
|
with ops.name_scope("foo2", skip_on_eager=False) as foo2:
|
|
self.assertEqual("foo/foo2/", foo2)
|
|
with ops.name_scope(None, skip_on_eager=False) as empty1:
|
|
self.assertEqual("", empty1)
|
|
with ops.name_scope("foo3", skip_on_eager=False) as foo3:
|
|
self.assertEqual("foo3/", foo3)
|
|
with ops.name_scope("", skip_on_eager=False) as empty2:
|
|
self.assertEqual("", empty2)
|
|
with ops.name_scope("foo/", skip_on_eager=False) as outer_foo:
|
|
self.assertEqual("foo/", outer_foo)
|
|
with ops.name_scope("", skip_on_eager=False) as empty3:
|
|
self.assertEqual("", empty3)
|
|
with ops.name_scope("foo4", skip_on_eager=False) as foo4:
|
|
self.assertEqual("foo/foo4/", foo4)
|
|
with ops.name_scope("foo5//", skip_on_eager=False) as foo5:
|
|
self.assertEqual("foo5//", foo5)
|
|
with ops.name_scope("foo6", skip_on_eager=False) as foo6:
|
|
self.assertEqual("foo5//foo6/", foo6)
|
|
with ops.name_scope("/", skip_on_eager=False) as foo7:
|
|
self.assertEqual("/", foo7)
|
|
with ops.name_scope("//", skip_on_eager=False) as foo8:
|
|
self.assertEqual("//", foo8)
|
|
with ops.name_scope("a//b/c", skip_on_eager=False) as foo9:
|
|
self.assertEqual("foo/a//b/c/", foo9)
|
|
with ops.name_scope("a//b/c", skip_on_eager=False) as foo10:
|
|
self.assertEqual("a//b/c/", foo10)
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testEagerDefaultScopeName(self):
|
|
with ops.name_scope(None, "default", skip_on_eager=False) as scope:
|
|
self.assertEqual(scope, "default/")
|
|
with ops.name_scope(None, "default2", skip_on_eager=False) as scope2:
|
|
self.assertEqual(scope2, "default/default2/")
|
|
|
|
@test_util.run_in_graph_and_eager_modes
|
|
def testNameScopeV2IsReEntrant(self):
|
|
foo = ops.name_scope_v2("foo")
|
|
bar = ops.name_scope_v2("bar")
|
|
with foo as scope_name:
|
|
self.assertEqual("foo/", scope_name)
|
|
with foo as scope_name:
|
|
self.assertEqual("foo/foo/", scope_name)
|
|
with bar as scope_name:
|
|
self.assertEqual("foo/bar/", scope_name)
|
|
with foo as scope_name:
|
|
self.assertEqual("foo/bar/foo/", scope_name)
|
|
with bar as scope_name:
|
|
self.assertEqual("bar/", scope_name)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testNoScopeName(self):
|
|
g0 = ops.Graph()
|
|
values = [
|
|
g0.create_op("A", [], [dtypes.float32]),
|
|
g0.create_op("B", [], [dtypes.float32])
|
|
]
|
|
with self.assertRaises(ValueError):
|
|
with ops.name_scope(None, values=values):
|
|
pass
|
|
with self.assertRaises(ValueError):
|
|
with ops.name_scope(None, None, values):
|
|
pass
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testEmptyScopeName(self):
|
|
g0 = ops.Graph()
|
|
a = g0.create_op("A", [], [dtypes.float32])
|
|
b = g0.create_op("B", [], [dtypes.float32])
|
|
with ops.name_scope("", values=[a, b]) as scope:
|
|
self.assertEqual("", scope)
|
|
self.assertEqual(g0, ops.get_default_graph())
|
|
with ops.name_scope("", "my_default_scope", [a, b]) as scope:
|
|
self.assertEqual("", scope)
|
|
self.assertEqual(g0, ops.get_default_graph())
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testDefaultScopeName(self):
|
|
g0 = ops.Graph()
|
|
a = g0.create_op("A", [], [dtypes.float32])
|
|
b = g0.create_op("B", [], [dtypes.float32])
|
|
scope_name = "my_scope"
|
|
default_scope_name = "my_default_scope"
|
|
with ops.name_scope(scope_name, default_scope_name, [a, b]) as scope:
|
|
self.assertEqual("%s/" % scope_name, scope)
|
|
self.assertEqual(g0, ops.get_default_graph())
|
|
with ops.name_scope(None, default_scope_name, [a, b]) as scope:
|
|
self.assertEqual("%s/" % default_scope_name, scope)
|
|
self.assertEqual(g0, ops.get_default_graph())
|
|
with self.assertRaises(TypeError):
|
|
with ops.name_scope(scope_name, [a, b]):
|
|
pass
|
|
|
|
def _testGraphElements(self, graph_elements):
|
|
scope_name = "my_scope"
|
|
with ops.name_scope(scope_name, values=graph_elements) as scope:
|
|
self.assertEqual("%s/" % scope_name, scope)
|
|
self.assertEqual(graph_elements[0].graph, ops.get_default_graph())
|
|
g1 = ops.Graph()
|
|
a = g1.create_op("A", [], [dtypes.float32])
|
|
with self.assertRaises(ValueError):
|
|
with ops.name_scope(scope_name, values=graph_elements + [a]):
|
|
pass
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testTensor(self):
|
|
g0 = ops.Graph()
|
|
a = g0.create_op("A", [], [dtypes.float32])
|
|
b = g0.create_op("B", [], [dtypes.float32])
|
|
self._testGraphElements([a, b])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testSparseTensor(self):
|
|
g0 = ops.Graph()
|
|
a = g0.create_op("A", [], [dtypes.float32])
|
|
b = g0.create_op("B", [], [dtypes.float32])
|
|
sparse = sparse_tensor.SparseTensor(
|
|
_apply_op(g0, "Int64Output", [], [dtypes.int64]),
|
|
_apply_op(g0, "FloatOutput", [], [dtypes.float32]),
|
|
_apply_op(g0, "Int64Output", [], [dtypes.int64]))
|
|
self._testGraphElements([a, sparse, b])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testVariable(self):
|
|
g0 = ops.Graph()
|
|
with g0.as_default():
|
|
variable = variables.Variable([1.0])
|
|
a = g0.create_op("A", [], [dtypes.float32])
|
|
b = g0.create_op("B", [], [dtypes.float32])
|
|
self._testGraphElements([a, variable, b])
|
|
|
|
|
|
class InitScopeTest(test_util.TensorFlowTestCase):
|
|
|
|
def testClearsControlDependencies(self):
|
|
g = ops.Graph()
|
|
a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
|
|
with g.as_default():
|
|
with g.control_dependencies([a_1]):
|
|
with g.control_dependencies([a_2]):
|
|
with ops.init_scope():
|
|
with g.control_dependencies([a_3]):
|
|
with g.control_dependencies([a_4]):
|
|
# deps [a_3, a_4]
|
|
b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
# deps = [a_3]
|
|
b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
# deps back to None
|
|
b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
# deps back to [a_1, a_2]
|
|
b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
# deps back to [a_1]
|
|
b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
with ops.init_scope():
|
|
# deps are None again
|
|
b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
|
|
self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
|
|
self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
|
|
self.assertItemsEqual([], b_none.op.control_inputs)
|
|
self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
|
|
self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
|
|
self.assertItemsEqual([], b_none2.op.control_inputs)
|
|
|
|
def testLiftsOpsFromFunctions(self):
|
|
g0 = ops.Graph()
|
|
g1 = ops.Graph()
|
|
g1._building_function = True # pylint: disable=protected-access
|
|
g2 = ops.Graph()
|
|
g2._building_function = True # pylint: disable=protected-access
|
|
|
|
with g0.as_default():
|
|
with g1.as_default():
|
|
with g2.as_default():
|
|
with ops.init_scope():
|
|
_ = constant_op.constant(1.0)
|
|
|
|
self.assertEqual(len(g2.get_operations()), 0)
|
|
self.assertEqual(len(g1.get_operations()), 0)
|
|
self.assertEqual(len(g0.get_operations()), 1)
|
|
|
|
def testPreservesDevices(self):
|
|
g0 = ops.Graph()
|
|
with g0.as_default(), ops.device("CPU:0"):
|
|
g1 = ops.Graph()
|
|
g1._building_function = True # pylint: disable=protected-access
|
|
with g1.as_default():
|
|
with ops.device("GPU:0"):
|
|
with ops.init_scope():
|
|
# init_scope should preserve device set under `g1`.
|
|
on_gpu = constant_op.constant(1.0)
|
|
self.assertEqual(on_gpu.device, "/device:GPU:0")
|
|
still_on_gpu = constant_op.constant(1.0)
|
|
self.assertEqual(still_on_gpu.device, "/device:GPU:0")
|
|
blank = constant_op.constant(1.0)
|
|
self.assertEqual(blank.device, "")
|
|
with ops.init_scope():
|
|
now_on_cpu = constant_op.constant(1.0)
|
|
self.assertEqual(now_on_cpu.device, "/device:CPU:0")
|
|
on_cpu = constant_op.constant(1.0)
|
|
self.assertEqual(on_cpu.device, "/device:CPU:0")
|
|
|
|
def testComposes(self):
|
|
g0 = ops.Graph()
|
|
g1 = ops.Graph()
|
|
g1._building_function = True # pylint: disable=protected-access
|
|
g2 = ops.Graph()
|
|
g2._building_function = True # pylint: disable=protected-access
|
|
g3 = ops.Graph()
|
|
g3._building_function = False # pylint: disable=protected-access
|
|
|
|
with g0.as_default():
|
|
with g1.as_default():
|
|
with ops.init_scope():
|
|
# This op should be lifted into g0.
|
|
_ = constant_op.constant(1.0)
|
|
self.assertIs(g0, ops.get_default_graph())
|
|
self.assertEqual(len(g2.get_operations()), 0)
|
|
self.assertEqual(len(g1.get_operations()), 0)
|
|
self.assertEqual(len(g0.get_operations()), 1)
|
|
with g2.as_default():
|
|
with ops.init_scope():
|
|
# This op should be lifted into g0.
|
|
_ = constant_op.constant(1.0)
|
|
self.assertIs(g0, ops.get_default_graph())
|
|
with g3.as_default():
|
|
with ops.init_scope():
|
|
# This op should be lifted into g3, because g3 is not building a
|
|
# function.
|
|
_ = constant_op.constant(1.0)
|
|
self.assertIs(g3, ops.get_default_graph())
|
|
|
|
self.assertEqual(len(g3.get_operations()), 1)
|
|
self.assertEqual(len(g2.get_operations()), 0)
|
|
self.assertEqual(len(g1.get_operations()), 0)
|
|
self.assertEqual(len(g0.get_operations()), 2)
|
|
|
|
def testEscapesToEagerContext(self):
|
|
g = ops.Graph()
|
|
g._building_function = True # pylint: disable=protected-access
|
|
with context.eager_mode():
|
|
with context.graph_mode():
|
|
with g.as_default():
|
|
with ops.init_scope():
|
|
# Because g is building a function, init_scope should
|
|
# escape out to the eager context.
|
|
self.assertTrue(context.executing_eagerly())
|
|
# g should be reinstated as the default graph, and the
|
|
# graph context should be re-entered.
|
|
self.assertIs(g, ops.get_default_graph())
|
|
self.assertFalse(context.executing_eagerly())
|
|
|
|
def testStaysInEagerWhenOnlyEagerContextActive(self):
|
|
with context.eager_mode():
|
|
with ops.init_scope():
|
|
self.assertTrue(context.eager_mode())
|
|
self.assertTrue(context.eager_mode())
|
|
|
|
def testEscapesDefunWhenInEagerMode(self):
|
|
|
|
def function_with_variables():
|
|
with ops.init_scope():
|
|
self.v = resource_variable_ops.ResourceVariable(3)
|
|
return self.v.assign_add(1)
|
|
|
|
with context.eager_mode():
|
|
# Each invocation of function_with_variables recreates a variable.
|
|
self.assertEqual(4, int(function_with_variables()))
|
|
self.assertEqual(4, int(function_with_variables()))
|
|
|
|
compiled = eager_function.defun(function_with_variables)
|
|
# The init_scope in function_with_variables lifts the variable out
|
|
# of the graph function constructed by defun; hence,
|
|
# compiled now appears to be stateful.
|
|
self.assertEqual(4, int(compiled()))
|
|
self.assertEqual(5, int(compiled()))
|
|
|
|
def testEscapesDefunWhenInGraphMode(self):
|
|
def function_with_variables(name):
|
|
with ops.init_scope():
|
|
_ = variable_scope.get_variable(name, shape=(1,))
|
|
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
with self.cached_session():
|
|
# First ensure that graphs that are not building functions are
|
|
# not escaped.
|
|
function_with_variables("foo")
|
|
with self.assertRaisesRegex(ValueError,
|
|
r"Variable foo already exists.*"):
|
|
# This will fail because reuse is not set to True.
|
|
function_with_variables("foo")
|
|
|
|
compiled = eager_function.defun(function_with_variables)
|
|
compiled("bar")
|
|
self.assertEqual(
|
|
len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
|
|
|
|
# The second call to `compiled` should not create variables: the
|
|
# init_scope has lifted the variable creation code out of the defun.
|
|
compiled("bar")
|
|
self.assertEqual(
|
|
len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
|
|
|
|
def testEscapesNestedDefun(self):
|
|
|
|
def inner_function():
|
|
with ops.init_scope():
|
|
self.v = resource_variable_ops.ResourceVariable(1)
|
|
return self.v.assign_add(2)
|
|
|
|
def outer_function(inner=None):
|
|
with ops.init_scope():
|
|
self.v0 = resource_variable_ops.ResourceVariable(0)
|
|
return self.v0.assign_add(1) + inner()
|
|
|
|
with context.eager_mode():
|
|
# Each invocation of outer_function recreates variables.
|
|
self.assertEqual(4, int(outer_function(inner=inner_function)))
|
|
self.assertEqual(4, int(outer_function(inner=inner_function)))
|
|
|
|
compiled_inner = eager_function.defun(inner_function)
|
|
compiled_outer = eager_function.defun(outer_function)
|
|
# The init_scope lifts variables out of the graph functions
|
|
# constructed by defun; hence, compiled_outer should now appear to be
|
|
# stateful.
|
|
self.assertEqual(4, int(compiled_outer(inner=compiled_inner)))
|
|
self.assertEqual(7, int(compiled_outer(inner=compiled_inner)))
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self):
|
|
with context.graph_mode():
|
|
ops.reset_default_graph()
|
|
# This doesn't push anything onto the graph stack, but it does
|
|
# set the stack's global graph.
|
|
global_graph = ops.get_default_graph()
|
|
fn_graph = ops.Graph()
|
|
|
|
# pylint: disable=protected-access
|
|
fn_graph._building_function = True
|
|
self.assertEqual(len(ops._default_graph_stack.stack), 0)
|
|
with fn_graph.as_default():
|
|
self.assertEqual(len(ops._default_graph_stack.stack), 1)
|
|
with ops.init_scope():
|
|
self.assertGreater(len(ops._default_graph_stack.stack), 1)
|
|
dummy = constant_op.constant(1.0)
|
|
self.assertEqual(len(ops._default_graph_stack.stack), 1)
|
|
# Note that the global graph is _not_ on the graph stack.
|
|
self.assertEqual(len(ops._default_graph_stack.stack), 0)
|
|
# Ensure that `dummy` was added to the global graph.
|
|
self.assertEqual(global_graph, dummy.graph)
|
|
# pylint: enable=protected-access
|
|
|
|
def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self):
|
|
with context.graph_mode():
|
|
# pylint: disable=protected-access
|
|
self.assertEqual(len(ops._default_graph_stack.stack), 0)
|
|
with ops.init_scope():
|
|
self.assertGreater(len(ops._default_graph_stack.stack), 0)
|
|
self.assertEqual(len(ops._default_graph_stack.stack), 0)
|
|
# pylint: enable=protected-access
|
|
|
|
def testPreservesNameScopeInGraphConstruction(self):
|
|
with ops.Graph().as_default():
|
|
function_graph = ops.Graph()
|
|
with function_graph.as_default():
|
|
with ops.name_scope("inner", skip_on_eager=False), ops.init_scope():
|
|
self.assertEqual(ops.get_name_scope(), "inner")
|
|
self.assertEqual(ops.get_name_scope(), "")
|
|
|
|
def testEnteringGraphFromEagerIsSticky(self):
|
|
with context.eager_mode():
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
with ops.init_scope():
|
|
self.assertFalse(context.executing_eagerly())
|
|
self.assertEqual(g, ops.get_default_graph())
|
|
|
|
def testMixGraphEager(self):
|
|
with context.eager_mode():
|
|
c = constant_op.constant(1.0)
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"Attempting to capture an EagerTensor"):
|
|
math_ops.add(c, c)
|
|
c2 = constant_op.constant(2.0)
|
|
with self.assertRaisesRegex(TypeError, "Graph tensors"):
|
|
math_ops.add(c2, c2)
|
|
|
|
def testPreservesNameScopeInEagerExecution(self):
|
|
with context.eager_mode():
|
|
def foo():
|
|
with ops.name_scope("inner", skip_on_eager=False), ops.init_scope():
|
|
if context.executing_eagerly():
|
|
# A trailing slash is always appended when eager execution is
|
|
# enabled.
|
|
self.assertEqual(context.context().scope_name, "inner/")
|
|
else:
|
|
self.assertEqual(ops.get_name_scope(), "inner")
|
|
|
|
foo()
|
|
self.assertEqual(ops.get_name_scope(), "")
|
|
foo_compiled = eager_function.defun(foo)
|
|
foo_compiled()
|
|
self.assertEqual(ops.get_name_scope(), "")
|
|
|
|
def testExecutingEagerlyOutsideFunctions(self):
|
|
|
|
@def_function.function
|
|
def f():
|
|
return ops.executing_eagerly_outside_functions()
|
|
|
|
with context.graph_mode():
|
|
self.assertFalse(ops.executing_eagerly_outside_functions())
|
|
with session.Session():
|
|
# Need self.evaluate for these as the return type of functions is
|
|
# tensors.
|
|
self.assertFalse(self.evaluate(f()))
|
|
|
|
with context.eager_mode():
|
|
self.assertTrue(ops.executing_eagerly_outside_functions())
|
|
self.assertTrue(f())
|
|
|
|
with ops.Graph().as_default():
|
|
self.assertFalse(ops.executing_eagerly_outside_functions())
|
|
with session.Session():
|
|
self.assertFalse(self.evaluate(f()))
|
|
|
|
|
|
class GraphTest(test_util.TensorFlowTestCase):
|
|
|
|
def setUp(self):
|
|
ops.reset_default_graph()
|
|
|
|
def _AssertDefault(self, expected):
|
|
self.assertIs(expected, ops.get_default_graph())
|
|
|
|
def testResetDefaultGraphNesting(self):
|
|
g0 = ops.Graph()
|
|
with self.assertRaises(AssertionError):
|
|
with g0.as_default():
|
|
ops.reset_default_graph()
|
|
|
|
def testGraphContextManagerCancelsEager(self):
|
|
with context.eager_mode():
|
|
with ops.Graph().as_default():
|
|
self.assertFalse(context.executing_eagerly())
|
|
|
|
def testGraphContextManager(self):
|
|
g0 = ops.Graph()
|
|
with g0.as_default() as g1:
|
|
self.assertIs(g0, g1)
|
|
|
|
def testDefaultGraph(self):
|
|
orig = ops.get_default_graph()
|
|
self.assertFalse(ops.has_default_graph())
|
|
self._AssertDefault(orig)
|
|
g0 = ops.Graph()
|
|
self.assertFalse(ops.has_default_graph())
|
|
self._AssertDefault(orig)
|
|
context_manager_0 = g0.as_default()
|
|
self.assertFalse(ops.has_default_graph())
|
|
self._AssertDefault(orig)
|
|
with context_manager_0 as g0:
|
|
self._AssertDefault(g0)
|
|
with ops.Graph().as_default() as g1:
|
|
self.assertTrue(ops.has_default_graph())
|
|
self._AssertDefault(g1)
|
|
self._AssertDefault(g0)
|
|
self._AssertDefault(orig)
|
|
self.assertFalse(ops.has_default_graph())
|
|
|
|
def testPreventFeeding(self):
|
|
g = ops.Graph()
|
|
a = constant_op.constant(2.0)
|
|
self.assertTrue(g.is_feedable(a))
|
|
g.prevent_feeding(a)
|
|
self.assertFalse(g.is_feedable(a))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testPreventFetching(self):
|
|
g = ops.Graph()
|
|
a = constant_op.constant(2.0)
|
|
self.assertTrue(g.is_fetchable(a))
|
|
g.prevent_fetching(a.op)
|
|
self.assertFalse(g.is_fetchable(a))
|
|
|
|
def testAsGraphElementConversions(self):
|
|
|
|
class ConvertibleObj(object):
|
|
|
|
def _as_graph_element(self):
|
|
return "FloatOutput:0"
|
|
|
|
class NonConvertibleObj(object):
|
|
|
|
pass
|
|
|
|
g = ops.Graph()
|
|
a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
|
|
self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
|
|
with self.assertRaises(TypeError):
|
|
g.as_graph_element(NonConvertibleObj())
|
|
|
|
# Regression test against creating custom __del__ functions in classes
|
|
# involved in cyclic references, e.g. Graph and Operation. (Python won't gc
|
|
# cycles that require calling a __del__ method, because the __del__ method can
|
|
# theoretically increase the object's refcount to "save" it from gc, and any
|
|
# already-deleted objects in the cycle would have be to restored.)
|
|
def testGarbageCollected(self):
|
|
# Create a graph we can delete and a weak reference to monitor if it's gc'd
|
|
g = ops.Graph()
|
|
g_ref = weakref.ref(g)
|
|
# Create some ops
|
|
with g.as_default():
|
|
a = constant_op.constant(2.0)
|
|
b = constant_op.constant(3.0)
|
|
c = math_ops.add(a, b)
|
|
# Create a session we can delete
|
|
with session.Session(graph=g) as sess:
|
|
self.evaluate(c)
|
|
# Delete all references and trigger gc
|
|
del g
|
|
del a
|
|
del b
|
|
del c
|
|
del sess
|
|
gc.collect()
|
|
self.assertIsNone(g_ref())
|
|
|
|
def testRunnableAfterInvalidShape(self):
|
|
with ops.Graph().as_default():
|
|
with self.assertRaises(ValueError):
|
|
math_ops.add([1, 2], [1, 2, 3])
|
|
a = constant_op.constant(1)
|
|
with session.Session() as sess:
|
|
self.evaluate(a)
|
|
|
|
def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
|
|
g = ops.Graph()
|
|
with g.as_default():
|
|
with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
|
|
with self.assertRaises(ValueError):
|
|
test_ops.kernel_label_required(1)
|
|
a = constant_op.constant(1)
|
|
with session.Session() as sess:
|
|
self.evaluate(a)
|
|
|
|
|
|
class AttrScopeTest(test_util.TensorFlowTestCase):
|
|
|
|
def _get_test_attrs(self):
|
|
x = control_flow_ops.no_op()
|
|
try:
|
|
a = compat.as_text(x.get_attr("_A"))
|
|
except ValueError:
|
|
a = None
|
|
try:
|
|
b = compat.as_text(x.get_attr("_B"))
|
|
except ValueError:
|
|
b = None
|
|
return (a, b)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testNoLabel(self):
|
|
with self.cached_session():
|
|
self.assertAllEqual((None, None), self._get_test_attrs())
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testLabelMap(self):
|
|
with self.cached_session() as sess:
|
|
a1 = self._get_test_attrs()
|
|
with sess.graph._attr_scope({
|
|
"_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo"))
|
|
}):
|
|
a2 = self._get_test_attrs()
|
|
with sess.graph._attr_scope({
|
|
"_A": None,
|
|
"_B": attr_value_pb2.AttrValue(s=compat.as_bytes("bar"))
|
|
}):
|
|
a3 = self._get_test_attrs()
|
|
with sess.graph._attr_scope({
|
|
"_A": attr_value_pb2.AttrValue(s=compat.as_bytes("baz"))
|
|
}):
|
|
a4 = self._get_test_attrs()
|
|
a5 = self._get_test_attrs()
|
|
a6 = self._get_test_attrs()
|
|
a7 = self._get_test_attrs()
|
|
|
|
self.assertAllEqual((None, None), a1)
|
|
self.assertAllEqual(("foo", None), a2)
|
|
self.assertAllEqual((None, "bar"), a3)
|
|
self.assertAllEqual(("baz", "bar"), a4)
|
|
self.assertAllEqual((None, "bar"), a5)
|
|
self.assertAllEqual(("foo", None), a6)
|
|
self.assertAllEqual((None, None), a7)
|
|
|
|
|
|
class KernelLabelTest(test_util.TensorFlowTestCase):
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testNoLabel(self):
|
|
with self.cached_session():
|
|
self.assertAllEqual(b"My label is: default",
|
|
test_ops.kernel_label().eval())
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testLabelMap(self):
|
|
with self.cached_session() as sess:
|
|
default_1 = test_ops.kernel_label()
|
|
# pylint: disable=protected-access
|
|
with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
|
|
overload_1_1 = test_ops.kernel_label()
|
|
with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}):
|
|
overload_2 = test_ops.kernel_label()
|
|
with sess.graph._kernel_label_map({"KernelLabel": ""}):
|
|
default_2 = test_ops.kernel_label()
|
|
overload_1_2 = test_ops.kernel_label()
|
|
# pylint: enable=protected-access
|
|
default_3 = test_ops.kernel_label()
|
|
|
|
self.assertAllEqual(b"My label is: default", self.evaluate(default_1))
|
|
self.assertAllEqual(b"My label is: default", self.evaluate(default_2))
|
|
self.assertAllEqual(b"My label is: default", self.evaluate(default_3))
|
|
self.assertAllEqual(b"My label is: overload_1",
|
|
self.evaluate(overload_1_1))
|
|
self.assertAllEqual(b"My label is: overload_1",
|
|
self.evaluate(overload_1_2))
|
|
self.assertAllEqual(b"My label is: overload_2", self.evaluate(overload_2))
|
|
|
|
|
|
class AsGraphDefTest(test_util.TensorFlowTestCase):
|
|
|
|
def testGraphDefVersion(self):
|
|
"""Test that the graphdef version is plumbed through to kernels."""
|
|
with ops.Graph().as_default() as g:
|
|
version = g.graph_def_versions.producer
|
|
with self.session(graph=g):
|
|
v = test_ops.graph_def_version().eval()
|
|
self.assertEqual(version, v)
|
|
|
|
def testAddShapes(self):
|
|
with ops.Graph().as_default() as g:
|
|
t1, t2, t3, t4, t5 = _apply_op(g, "FiveFloatOutputs", [],
|
|
[dtypes.float32] * 5)
|
|
t1.set_shape(None)
|
|
t2.set_shape([])
|
|
t3.set_shape([None])
|
|
t4.set_shape([43, 37])
|
|
t5.set_shape([43, None])
|
|
|
|
b = constant_op.constant(1.0) # pylint: disable=unused-variable
|
|
|
|
gd = g.as_graph_def(add_shapes=True)
|
|
self.assertProtoEqualsVersion("""
|
|
node { name: "FiveFloatOutputs" op: "FiveFloatOutputs"
|
|
attr {
|
|
key: "_output_shapes"
|
|
value {
|
|
list {
|
|
shape { unknown_rank: true }
|
|
shape { }
|
|
shape { dim { size: -1 } }
|
|
shape { dim { size: 43 } dim { size: 37 } }
|
|
shape { dim { size: 43 } dim { size: -1 } }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
node { name: "Const" op: "Const"
|
|
attr {
|
|
key: "_output_shapes"
|
|
value {
|
|
list {
|
|
shape { }
|
|
}
|
|
}
|
|
}
|
|
attr {
|
|
key: "dtype"
|
|
value { type: DT_FLOAT }
|
|
}
|
|
attr {
|
|
key: "value"
|
|
value {
|
|
tensor {
|
|
dtype: DT_FLOAT
|
|
tensor_shape { }
|
|
float_val: 1.0 } } } }
|
|
""", gd)
|
|
|
|
|
|
@ops.RegisterStatistics("a", "flops")
|
|
def _calc_a_forward_flops(unused_graph, unused_node):
|
|
return ops.OpStats("flops", 20)
|
|
|
|
|
|
class StatisticsTest(test_util.TensorFlowTestCase):
|
|
|
|
def testRegisteredNode(self):
|
|
graph = ops.Graph()
|
|
node = ops._NodeDef("a", "an_a")
|
|
flops = ops.get_stats_for_node_def(graph, node, "flops")
|
|
self.assertEqual(20, flops.value)
|
|
missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat")
|
|
self.assertEqual(None, missing_stat.value)
|
|
|
|
def testUnregisteredNode(self):
|
|
graph = ops.Graph()
|
|
node = ops._NodeDef("b", "a_b")
|
|
weight_params = ops.get_stats_for_node_def(graph, node, "weight_params")
|
|
self.assertEqual(None, weight_params.value)
|
|
|
|
def testAccumulateStatistics(self):
|
|
flops_total = ops.OpStats("flops")
|
|
self.assertEqual(None, flops_total.value)
|
|
second_flops = ops.OpStats("flops", 3)
|
|
flops_total += second_flops
|
|
self.assertEqual(3, flops_total.value)
|
|
|
|
|
|
class DeviceStackTest(test_util.TensorFlowTestCase):
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testBasicDeviceAssignmentMetadata(self):
|
|
|
|
def device_func(unused_op):
|
|
return "/cpu:*"
|
|
|
|
const_zero = constant_op.constant([0.0], name="zero")
|
|
with ops.device("/cpu"):
|
|
const_one = constant_op.constant([1.0], name="one")
|
|
with ops.device("/cpu:0"):
|
|
const_two = constant_op.constant([2.0], name="two")
|
|
with ops.device(device_func):
|
|
const_three = constant_op.constant(3.0, name="three")
|
|
|
|
self.assertEqual(0, len(const_zero.op._device_assignments))
|
|
|
|
one_list = const_one.op._device_assignments
|
|
self.assertEqual(1, len(one_list))
|
|
self.assertEqual("/cpu", one_list[0].obj)
|
|
self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename))
|
|
|
|
two_list = const_two.op._device_assignments
|
|
self.assertEqual(2, len(two_list))
|
|
devices = [t.obj for t in two_list]
|
|
self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices))
|
|
|
|
three_list = const_three.op._device_assignments
|
|
self.assertEqual(1, len(three_list))
|
|
func_description = three_list[0].obj
|
|
expected_regex = r"device_func<.*ops_test.py, [0-9]+"
|
|
self.assertRegex(func_description, expected_regex)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self):
|
|
|
|
with ops.device("/cpu"):
|
|
const_one = constant_op.constant([1.0], name="one")
|
|
with ops.get_default_graph().device("/cpu"):
|
|
const_two = constant_op.constant([2.0], name="two")
|
|
|
|
one_metadata = const_one.op._device_assignments[0]
|
|
two_metadata = const_two.op._device_assignments[0]
|
|
|
|
# Verify both types of device assignment return the right stack info.
|
|
self.assertRegex("ops_test.py", os.path.basename(one_metadata.filename))
|
|
self.assertEqual(one_metadata.filename, two_metadata.filename)
|
|
self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno)
|
|
|
|
|
|
class ColocationGroupTest(test_util.TensorFlowTestCase):
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testBasic(self):
|
|
a = constant_op.constant([2.0], name="a")
|
|
with ops.colocate_with(a.op):
|
|
b = constant_op.constant(3.0)
|
|
c = constant_op.constant(4.0)
|
|
self.assertEqual([b"loc:@a"], a.op.colocation_groups())
|
|
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
|
with self.assertRaises(ValueError):
|
|
c.op.get_attr("_class")
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testBasicColocationMetadata(self):
|
|
const_two = constant_op.constant([2.0], name="two")
|
|
with ops.colocate_with(const_two.op):
|
|
const_three = constant_op.constant(3.0, name="three")
|
|
locations_dict = const_three.op._colocation_dict
|
|
self.assertIn("two", locations_dict)
|
|
metadata = locations_dict["two"]
|
|
self.assertIsNone(metadata.obj)
|
|
# Check that this test's filename is recorded as the file containing the
|
|
# colocation statement.
|
|
self.assertEqual("ops_test.py", os.path.basename(metadata.filename))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testColocationDeviceInteraction(self):
|
|
with ops.device("/cpu:0"):
|
|
with ops.device("/device:GPU:0"):
|
|
a = constant_op.constant([2.0], name="a")
|
|
with ops.colocate_with(a.op):
|
|
# 'b' is created in the scope of /cpu:0, but it is
|
|
# colocated with 'a', which is on '/device:GPU:0'. colocate_with
|
|
# overrides devices because it is a stronger constraint.
|
|
b = constant_op.constant(3.0)
|
|
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
|
self.assertEqual(a.op.device, b.op.device)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testColocationCanonicalization(self):
|
|
with ops.device("/device:GPU:0"):
|
|
_ = constant_op.constant(2.0)
|
|
with ops.device(lambda op: "/device:GPU:0"):
|
|
b = constant_op.constant(3.0)
|
|
with ops.get_default_graph().colocate_with(b):
|
|
with ops.device("/device:GPU:0"):
|
|
c = constant_op.constant(4.0)
|
|
|
|
# A's device will be /device:GPU:0
|
|
# B's device will be /device:GPU:0
|
|
# C's device will be /device:GPU:0 because it
|
|
# inherits B's device name, after canonicalizing the names.
|
|
self.assertEqual(b.op.device, c.op.device)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testLocationOverrides(self):
|
|
with ops.device("/cpu:0"):
|
|
with ops.device("/device:GPU:0"):
|
|
a = constant_op.constant([2.0], name="a")
|
|
# Note that this colocation is "redundant", since we are
|
|
# within the scope of "/device:GPU:0". However, we would like to
|
|
# preserve in the GraphDef that these two ops should be
|
|
# colocated in a portable way.
|
|
with ops.colocate_with(a.op):
|
|
b = constant_op.constant(3.0)
|
|
c = constant_op.constant(4.0)
|
|
d = constant_op.constant(5.0)
|
|
|
|
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
|
self.assertEqual("/device:GPU:0", a.op.device)
|
|
self.assertEqual(a.op.device, b.op.device)
|
|
|
|
# Test that device function stack is restored.
|
|
self.assertEqual("/device:GPU:0", c.op.device)
|
|
self.assertEqual("/device:CPU:0", d.op.device)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testNestedColocateWith(self):
|
|
a = constant_op.constant([2.0], name="a")
|
|
with ops.colocate_with(a.op):
|
|
b = constant_op.constant(3.0)
|
|
with ops.colocate_with(b.op):
|
|
c = constant_op.constant(4.0)
|
|
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
|
self.assertEqual([b"loc:@a"], c.op.colocation_groups())
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testMultiColocationGroups(self):
|
|
a = constant_op.constant([2.0], name="a")
|
|
b = constant_op.constant(3.0, name="b")
|
|
with ops.colocate_with(a.op):
|
|
with ops.colocate_with(b.op):
|
|
c = constant_op.constant(4.0)
|
|
self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups()))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testColocationIgnoreStack(self):
|
|
a = constant_op.constant([2.0], name="a")
|
|
b = constant_op.constant(3.0, name="b")
|
|
with ops.colocate_with(a.op):
|
|
with ops.colocate_with(b.op, ignore_existing=True):
|
|
c = constant_op.constant(4.0)
|
|
self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups()))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testColocateWithReset(self):
|
|
a = constant_op.constant([2.0], name="a")
|
|
with ops.colocate_with(a.op):
|
|
b = constant_op.constant(3.0, name="b")
|
|
with ops.colocate_with(None, ignore_existing=True):
|
|
c = constant_op.constant(4.0, name="c")
|
|
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
|
self.assertEqual([b"loc:@c"], c.op.colocation_groups())
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testColocateWithInitialNoneThenNested(self):
|
|
a = constant_op.constant([2.0], name="a")
|
|
with ops.colocate_with(a.op):
|
|
with ops.colocate_with(None, ignore_existing=True):
|
|
b = constant_op.constant(3.0, name="b")
|
|
with ops.colocate_with(b.op):
|
|
c = constant_op.constant(4.0, name="c")
|
|
self.assertEqual([b"loc:@b"], b.op.colocation_groups())
|
|
self.assertEqual([b"loc:@b"], c.op.colocation_groups())
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testColocateVariables(self):
|
|
a = variables.Variable([2.0], name="a")
|
|
with ops.colocate_with(a.op):
|
|
b = variables.Variable([3.0], name="b")
|
|
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testColocateResourceVariablesInFunction(self):
|
|
with ops.device("/device:CPU:0"):
|
|
a = resource_variable_ops.ResourceVariable(1.0)
|
|
|
|
@def_function.function
|
|
def f():
|
|
with ops.colocate_with(a):
|
|
b = array_ops.ones([], name="output")
|
|
self.assertEqual("/device:CPU:0", b.op.device)
|
|
f()
|
|
|
|
def testColocateWithVariableInFunction(self):
|
|
v = variables.Variable(1.)
|
|
|
|
@def_function.function
|
|
def f():
|
|
with ops.colocate_with(v):
|
|
return array_ops.ones([], name="output")
|
|
|
|
f()
|
|
graph_def = f.get_concrete_function().graph.as_graph_def()
|
|
wrap_function.function_from_graph_def(graph_def, [], ["output"])
|
|
|
|
|
|
class DeprecatedTest(test_util.TensorFlowTestCase):
|
|
|
|
def testSuccess(self):
|
|
with ops.Graph().as_default() as g:
|
|
test_util.set_producer_version(g, 7)
|
|
old = test_ops.old()
|
|
with self.session(graph=g):
|
|
old.run()
|
|
|
|
def _error(self):
|
|
return ((r"Op Old is not available in GraphDef version %d\. "
|
|
r"It has been removed in version 8\. For reasons\.") %
|
|
versions.GRAPH_DEF_VERSION)
|
|
|
|
def testGraphConstructionFail(self):
|
|
with ops.Graph().as_default():
|
|
with self.assertRaisesRegex(NotImplementedError, self._error()):
|
|
test_ops.old()
|
|
|
|
|
|
class NameScopeTest(test_util.TensorFlowTestCase):
|
|
|
|
def testStripAndPrependScope(self):
|
|
strs = [
|
|
"hidden1/hidden1/weights", # Same prefix. Should strip.
|
|
"hidden1///hidden1/weights", # Extra "/". Should strip.
|
|
"^hidden1/hidden1/weights", # Same prefix. Should strip.
|
|
"loc:@hidden1/hidden1/weights", # Same prefix. Should strip.
|
|
"hhidden1/hidden1/weights", # Different prefix. Should keep.
|
|
"hidden1"
|
|
] # Not a prefix. Should keep.
|
|
expected_striped = [
|
|
"hidden1/weights", "hidden1/weights", "^hidden1/weights",
|
|
"loc:@hidden1/weights", "hhidden1/hidden1/weights", "hidden1"
|
|
]
|
|
expected_prepended = [
|
|
"hidden2/hidden1/weights", "hidden2/hidden1/weights",
|
|
"^hidden2/hidden1/weights", "loc:@hidden2/hidden1/weights",
|
|
"hidden2/hhidden1/hidden1/weights", "hidden2/hidden1"
|
|
]
|
|
name_scope_to_strip = "hidden1"
|
|
name_scope_to_add = "hidden2"
|
|
for es, ep, s in zip(expected_striped, expected_prepended, strs):
|
|
striped = ops.strip_name_scope(s, name_scope_to_strip)
|
|
self.assertEqual(es, striped)
|
|
self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add))
|
|
|
|
def testGetNameScope(self):
|
|
with ops.Graph().as_default() as g:
|
|
with ops.name_scope("scope1"):
|
|
with ops.name_scope("scope2"):
|
|
with ops.name_scope("scope3"):
|
|
self.assertEqual("scope1/scope2/scope3", g.get_name_scope())
|
|
self.assertEqual("scope1/scope2", g.get_name_scope())
|
|
self.assertEqual("scope1", g.get_name_scope())
|
|
self.assertEqual("", g.get_name_scope())
|
|
|
|
def testTwoGraphs(self):
|
|
|
|
def f():
|
|
g1 = ops.Graph()
|
|
g2 = ops.Graph()
|
|
with g1.as_default():
|
|
with g2.as_default():
|
|
with ops.name_scope("_"):
|
|
pass
|
|
|
|
self.assertRaisesRegex(ValueError, "'_' is not a valid scope name", f)
|
|
|
|
|
|
class EnableEagerExecutionTest(test_util.TensorFlowTestCase):
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
def testBadArgumentsToEnableEagerExecution(self):
|
|
with self.assertRaisesRegex(TypeError, "config must be a tf.ConfigProto"):
|
|
ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT)
|
|
with self.assertRaisesRegex(ValueError, "device_policy must be one of"):
|
|
c = config_pb2.ConfigProto()
|
|
ops.enable_eager_execution(c, c)
|
|
with self.assertRaisesRegex(ValueError, "execution_mode must be one of"):
|
|
c = config_pb2.ConfigProto()
|
|
ops.enable_eager_execution(c, execution_mode=c)
|
|
|
|
|
|
class _TupleTensor(composite_tensor.CompositeTensor):
|
|
"""`Tensor`-like `tuple`-like for custom `Tensor` conversion masquerading."""
|
|
|
|
def __init__(self, components):
|
|
super(_TupleTensor, self).__init__()
|
|
self._components = tuple(ops.convert_to_tensor(c) for c in components)
|
|
|
|
@property
|
|
def _type_spec(self):
|
|
return _TupleTensorSpec(type_spec.from_value(c) for c in self._components)
|
|
|
|
def __getitem__(self, key):
|
|
return self._components[key]
|
|
|
|
def __len__(self):
|
|
return len(self._components)
|
|
|
|
def __iter__(self):
|
|
return iter(self._components)
|
|
|
|
|
|
class _TupleTensorSpec(type_spec.TypeSpec):
|
|
|
|
def __init__(self, specs):
|
|
self._specs = specs
|
|
|
|
value_type = property(lambda self: _TupleTensor)
|
|
_component_specs = property(lambda self: self._specs)
|
|
|
|
def _to_components(self, value):
|
|
return value._components
|
|
|
|
def _from_components(self, components):
|
|
return _TupleTensor(*components)
|
|
|
|
def _serialize(self):
|
|
return (self._specs,)
|
|
|
|
|
|
class _MyTuple(object):
|
|
"""Pretend user-side class for `ConvertToCompositeTensorTest ."""
|
|
|
|
def __init__(self, components):
|
|
super(_MyTuple, self).__init__()
|
|
self._components = tuple(components)
|
|
|
|
def __getitem__(self, key):
|
|
return self._components[key]
|
|
|
|
def __len__(self):
|
|
return len(self._components)
|
|
|
|
def __iter__(self):
|
|
return iter(self._components)
|
|
|
|
|
|
ops.register_tensor_conversion_function(
|
|
_MyTuple, conversion_func=lambda x, *_, **__: _TupleTensor(x))
|
|
|
|
|
|
class CustomConvertToCompositeTensorTest(test_util.TensorFlowTestCase):
|
|
|
|
@test_util.disable_tfrt("TODO(kkb): This makes Kokoro tests fail.")
|
|
def testCompositeTensorConversion(self):
|
|
"""Tests that a user can register a CompositeTensor converter."""
|
|
x = _MyTuple((1, [2., 3.], [[4, 5], [6, 7]]))
|
|
y = ops.convert_to_tensor_or_composite(x)
|
|
self.assertFalse(tensor_util.is_tensor(y))
|
|
self.assertIsInstance(y, _TupleTensor)
|
|
self.assertLen(y, len(x))
|
|
for x_, y_ in zip(x, y):
|
|
self.assertIsInstance(y_, ops.Tensor)
|
|
self.assertTrue(tensor_util.is_tensor(y_))
|
|
self.assertAllEqual(x_, tensor_util.constant_value(y_))
|
|
|
|
|
|
@test_util.disable_tfrt("Packing EagerTensors is not supported yet.")
|
|
class PackEagerTensorTest(test_util.TensorFlowTestCase):
|
|
|
|
def setUp(self):
|
|
super(PackEagerTensorTest, self).setUp()
|
|
context._reset_context()
|
|
cpus = config.list_physical_devices("CPU")
|
|
# Set 2 virtual CPUs
|
|
config.set_logical_device_configuration(cpus[0], [
|
|
context.LogicalDeviceConfiguration(),
|
|
context.LogicalDeviceConfiguration(),
|
|
])
|
|
|
|
def testPack(self):
|
|
with context.eager_mode():
|
|
with ops.device("CPU:0"):
|
|
var0 = resource_variable_ops.ResourceVariable(1.0)
|
|
c0 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
|
|
with ops.device("CPU:1"):
|
|
var1 = resource_variable_ops.ResourceVariable(2.0)
|
|
var2 = resource_variable_ops.ResourceVariable([3.0])
|
|
c1 = constant_op.constant([9.0])
|
|
|
|
packed_var0 = ops.pack_eager_tensors([var0.handle, var1.handle])
|
|
self.assertTrue(packed_var0.is_packed)
|
|
self.assertEqual(packed_var0.dtype, var0.handle.dtype)
|
|
self.assertEqual(packed_var0.shape, var0.handle.shape)
|
|
self.assertEqual(packed_var0._handle_data, var0.handle._handle_data)
|
|
self.assertIn("COMPOSITE:0", packed_var0.device)
|
|
self.assertIn("COMPOSITE:0", packed_var0.backing_device)
|
|
with self.assertRaises(errors.InvalidArgumentError):
|
|
packed_var0.numpy()
|
|
|
|
# Different dtypes
|
|
with self.assertRaises(ValueError):
|
|
ops.pack_eager_tensors([var0.handle, c1])
|
|
|
|
# Different shapes
|
|
with self.assertRaises(ValueError):
|
|
ops.pack_eager_tensors([c0, c1])
|
|
|
|
# Different handle data
|
|
with self.assertRaises(ValueError):
|
|
ops.pack_eager_tensors([var0.handle, var2.handle])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
googletest.main()
|