STT-tensorflow/tensorflow/python/framework/ops_test.py
Chuanhao Zhuge 4a05ea9a74 Enable passing TFRT python tests.
PiperOrigin-RevId: 337418589
Change-Id: I2267e24d39aa367df75176108df95ddd81d7b968
2020-10-15 17:52:05 -07:00

3620 lines
132 KiB
Python

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.framework.ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import gc
import numpy as np
import os
import threading
import weakref
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as eager_function
from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import config
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import type_spec
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
import tensorflow.python.ops.gradients # pylint: disable=unused-import
from tensorflow.python.platform import googletest
from tensorflow.python.util import compat
class ResourceTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def testBuildGraph(self):
with self.cached_session():
pt = test_ops.stub_resource_handle_op(container="a", shared_name="b")
test_ops.resource_create_op(pt).run()
@test_util.run_deprecated_v1
def testInitialize(self):
with self.cached_session():
handle = test_ops.stub_resource_handle_op(container="a", shared_name="b")
resources.register_resource(
handle=handle,
create_op=test_ops.resource_create_op(handle),
is_initialized_op=test_ops.resource_initialized_op(handle))
self.assertEqual(
len(
resources.report_uninitialized_resources(
resources.shared_resources()).eval()), 1)
resources.initialize_resources(resources.shared_resources()).run()
self.assertEqual(
len(
resources.report_uninitialized_resources(
resources.shared_resources()).eval()), 0)
class TensorAndShapeTest(test_util.TensorFlowTestCase):
def testShape(self):
op = ops.Operation(
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
t = op.outputs[0]
self.assertEqual(tensor_shape.unknown_shape(), t.get_shape())
t.set_shape([1, 2, 3])
self.assertEqual([1, 2, 3], t.get_shape())
def testIterable(self):
if not context.executing_eagerly():
self.skipTest("Eager-mode test")
op = ops.Operation(
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
t = op.outputs[0]
with self.assertRaisesRegex(TypeError, "Cannot iterate"):
iter(t)
def testIterableGraph(self):
if context.executing_eagerly():
self.skipTest("Graph-mode test")
op = ops.Operation(
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
t = op.outputs[0]
with self.assertRaisesRegex(TypeError, "iterating.*not allowed in Graph"):
next(iter(t))
with self.assertRaisesRegex(TypeError, "iterating.*AutoGraph did convert"):
with ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED):
next(iter(t))
with self.assertRaisesRegex(TypeError, "iterating.*AutoGraph is disabled"):
with ag_ctx.ControlStatusCtx(ag_ctx.Status.DISABLED):
next(iter(t))
def testImplicitBool(self):
op = ops.Operation(
ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.bool])
t = op.outputs[0]
with self.assertRaisesRegex(TypeError,
"using.*as a.*bool.*not allowed in Graph"):
bool(t)
with self.assertRaisesRegex(TypeError,
"using.*as a.*bool.*AutoGraph did convert"):
with ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED):
bool(t)
with self.assertRaisesRegex(TypeError,
"using.*as a.*bool.*AutoGraph is disabled"):
with ag_ctx.ControlStatusCtx(ag_ctx.Status.DISABLED):
bool(t)
def testAddShape(self):
with self.cached_session():
a = array_ops.zeros([2, 3])
b = array_ops.ones([1, 3])
c = a + b
self.assertEqual([2, 3], c.shape)
@test_util.run_deprecated_v1
def testUnknownDim(self):
with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
c = a + b
self.assertEqual([2, None, 3], c.shape.as_list())
@test_util.run_deprecated_v1
def testUnknownShape(self):
with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
b = array_ops.ones([1, 3])
c = a + b
self.assertEqual(tensor_shape.unknown_shape(), c.shape)
@test_util.run_deprecated_v1
def testScalarShape(self):
with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
b = array_ops.ones([])
c = a + b
self.assertEqual(tensor_shape.TensorShape([]), c.shape)
@test_util.run_deprecated_v1
def testShapeFunctionError(self):
with self.cached_session():
a = array_ops.ones([1, 2, 3])
b = array_ops.ones([4, 5, 6])
with self.assertRaisesRegex(
ValueError, r"Dimensions must be equal, but are 2 and 5 for .*add"
r".*Add(V2)?.* with input shapes: \[1,2,3\], \[4,5,6\]."):
_ = a + b
def testNumpyArray(self):
with ops.Graph().as_default():
x = array_ops.ones((3, 4), name="test_ones")
with self.assertRaisesRegex(NotImplementedError,
r"Cannot convert a symbolic.+test_ones"):
np.array(x)
with self.assertRaisesRegex(TypeError, "not well defined.+test_ones"):
len(x)
# EagerTensors should still behave as numpy arrays.
with context.eager_mode():
x = array_ops.ones((3, 4))
self.assertAllEqual(x, np.ones((3, 4)))
self.assertAllEqual(np.array(x), np.ones((3, 4)))
self.assertEqual(len(x), 3)
def testRef(self):
x1 = constant_op.constant(3)
x2 = x1
y = constant_op.constant(3)
z = constant_op.constant([6, 10])
w = variables.Variable(5)
self.assertEqual(x1.ref(), x1.ref())
self.assertEqual(x2.ref(), x2.ref())
self.assertEqual(x1.ref(), x2.ref())
self.assertEqual(y.ref(), y.ref())
self.assertEqual(z.ref(), z.ref())
self.assertEqual(w.ref(), w.ref())
self.assertNotEqual(x1.ref(), y.ref())
self.assertNotEqual(x1.ref(), z.ref())
self.assertNotEqual(x1.ref(), w.ref())
self.assertNotEqual(y.ref(), z.ref())
self.assertNotEqual(y.ref(), w.ref())
self.assertNotEqual(z.ref(), w.ref())
def testRefDeref(self):
x1 = constant_op.constant(3)
x2 = x1
y = constant_op.constant(3)
z = constant_op.constant([6, 10])
w = variables.Variable(5)
self.assertIs(x1, x1.ref().deref())
self.assertIs(x2, x2.ref().deref())
self.assertIs(x1, x2.ref().deref())
self.assertIs(x2, x1.ref().deref())
self.assertIs(y, y.ref().deref())
self.assertIs(z, z.ref().deref())
self.assertIsNot(x1, y.ref().deref())
self.assertIsNot(x1, z.ref().deref())
self.assertIsNot(x1, w.ref().deref())
self.assertIsNot(y, z.ref().deref())
self.assertIsNot(y, w.ref().deref())
self.assertIsNot(z, w.ref().deref())
def testRefInSet(self):
x1 = constant_op.constant(3)
x2 = x1
y = constant_op.constant(3)
z = constant_op.constant([6, 10])
w = variables.Variable(5)
self.assertEqual(x1.ref(), x2.ref())
tensor_set = {
x1.ref(),
x2.ref(),
y.ref(),
z.ref(),
w.ref(),
}
self.assertEqual(len(tensor_set), 4)
self.assertIn(x1.ref(), tensor_set)
self.assertIn(x2.ref(), tensor_set)
self.assertIn(y.ref(), tensor_set)
self.assertIn(z.ref(), tensor_set)
self.assertIn(w.ref(), tensor_set)
def testRefInDict(self):
x1 = constant_op.constant(3)
x2 = x1
y = constant_op.constant(3)
z = constant_op.constant([6, 10])
w = variables.Variable(5)
self.assertEqual(x1.ref(), x2.ref())
tensor_dict = {
x1.ref(): "x1",
y.ref(): "y",
z.ref(): "z",
w.ref(): "w",
}
self.assertEqual(len(tensor_dict), 4)
# Overwriting x1
tensor_dict[x2.ref()] = "x2"
self.assertEqual(len(tensor_dict), 4)
self.assertEqual(tensor_dict[x1.ref()], "x2")
self.assertEqual(tensor_dict[x2.ref()], "x2")
self.assertEqual(tensor_dict[y.ref()], "y")
self.assertEqual(tensor_dict[z.ref()], "z")
self.assertEqual(tensor_dict[w.ref()], "w")
def testTensorRefStrong(self):
x = constant_op.constant(1.)
x_ref = x.ref()
del x
self.assertIsNotNone(x_ref.deref())
def testVariableRefStrong(self):
x = variables.Variable(1.)
x_ref = x.ref()
del x
self.assertIsNotNone(x_ref.deref())
@test_util.run_in_graph_and_eager_modes
def testBitwiseAndNumeric(self):
x = constant_op.constant([0, 1, 3])
y = constant_op.constant([1, 1, 1])
z = x & y
self.assertAllEqual(z, [0, 1, 1])
@test_util.run_in_graph_and_eager_modes
def testBitwiseAndBool(self):
x = constant_op.constant([False, False, True, True])
y = constant_op.constant([False, True, False, True])
z = x & y
self.assertAllEqual(z, [False, False, False, True])
@test_util.run_in_graph_and_eager_modes
def testBitwiseAndErrors(self):
x_int = constant_op.constant(0)
x_bool = constant_op.constant(True)
if context.executing_eagerly(): # :(
expected_errtype = errors.InvalidArgumentError
else:
expected_errtype = TypeError
with self.assertRaises(expected_errtype):
_ = x_int & x_bool
with self.assertRaises(expected_errtype):
_ = x_int & constant_op.constant("a")
with self.assertRaises(expected_errtype):
_ = x_bool & x_int
with self.assertRaises(expected_errtype):
_ = x_bool & constant_op.constant("a")
with self.assertRaises(expected_errtype):
_ = constant_op.constant("a") & constant_op.constant("b")
@test_util.run_in_graph_and_eager_modes
def testBitwiseOrNumeric(self):
x = constant_op.constant([0, 1, 2])
y = constant_op.constant([1, 1, 1])
z = x | y
self.assertAllEqual(z, [1, 1, 3])
@test_util.run_in_graph_and_eager_modes
def testBitwiseOrBool(self):
x = constant_op.constant([False, False, True, True])
y = constant_op.constant([False, True, False, True])
z = x | y
self.assertAllEqual(z, [False, True, True, True])
@test_util.run_in_graph_and_eager_modes
def testBitwiseOrErrors(self):
x_int = constant_op.constant(0)
x_bool = constant_op.constant(True)
if context.executing_eagerly(): # :(
expected_errtype = errors.InvalidArgumentError
else:
expected_errtype = TypeError
with self.assertRaises(expected_errtype):
_ = x_int | x_bool
with self.assertRaises(expected_errtype):
_ = x_int | constant_op.constant("a")
with self.assertRaises(expected_errtype):
_ = x_bool | x_int
with self.assertRaises(expected_errtype):
_ = x_bool | constant_op.constant("a")
with self.assertRaises(expected_errtype):
_ = constant_op.constant("a") | constant_op.constant("b")
@test_util.run_in_graph_and_eager_modes
def testBitwiseXorNumeric(self):
x = constant_op.constant([0, 1, 3])
y = constant_op.constant([1, 1, 1])
z = x ^ y
self.assertAllEqual(z, [1, 0, 2])
@test_util.run_in_graph_and_eager_modes
def testBitwiseXorBool(self):
x = constant_op.constant([False, False, True, True])
y = constant_op.constant([False, True, False, True])
z = x ^ y
self.assertAllEqual(z, [False, True, True, False])
@test_util.run_in_graph_and_eager_modes
def testBitwiseXorErrors(self):
x_int = constant_op.constant(0)
x_bool = constant_op.constant(True)
if context.executing_eagerly(): # :(
expected_errtype = errors.InvalidArgumentError
else:
expected_errtype = TypeError
with self.assertRaises(expected_errtype):
_ = x_int ^ x_bool
with self.assertRaises(expected_errtype):
_ = x_int ^ constant_op.constant("a")
with self.assertRaises(expected_errtype):
_ = x_bool ^ x_int
with self.assertRaises(expected_errtype):
_ = x_bool ^ constant_op.constant("a")
with self.assertRaises(expected_errtype):
_ = constant_op.constant("a") ^ constant_op.constant("b")
@test_util.run_in_graph_and_eager_modes
def testBitwiseNotNumeric(self):
x = constant_op.constant([0, dtypes.int32.min, 1])
y = ~x
self.assertAllEqual(y, [-1, dtypes.int32.max, -2])
@test_util.run_in_graph_and_eager_modes
def testBitwiseNotBool(self):
x = constant_op.constant([False, True])
y = ~x
self.assertAllEqual(y, [True, False])
@test_util.run_in_graph_and_eager_modes
def testBitwiseNotErrors(self):
if context.executing_eagerly(): # :(
expected_errtype = errors.InvalidArgumentError
else:
expected_errtype = TypeError
with self.assertRaises(expected_errtype):
_ = ~constant_op.constant("a")
@test_util.run_all_in_graph_and_eager_modes
class IndexedSlicesTest(test_util.TensorFlowTestCase):
def testToTensor(self):
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
x = ops.IndexedSlices(values, indices)
with self.assertRaises(ValueError):
tensor = ops.convert_to_tensor(x, name="tensor")
self.assertEqual(tensor_shape.TensorShape(None), x.shape)
dense_shape = constant_op.constant([3, 2])
y = ops.IndexedSlices(values, indices, dense_shape)
tensor = ops.convert_to_tensor(y, name="tensor")
self.assertAllEqual(tensor.shape, y.shape)
self.assertAllEqual(self.evaluate(tensor), [[2, 3], [0, 0], [5, 7]])
@test_util.run_gpu_only
def testEagerCopy(self):
with context.eager_mode():
var = variables.Variable([[0.0], [0.0], [0.0], [0.0]], name="tensor")
with backprop.GradientTape() as tape:
a = array_ops.gather(array_ops.gather(var, [0, 1]), [0, 1])
b = array_ops.gather(array_ops.gather(var, [2, 3]), [0, 1])
r = special_math_ops.einsum("ij,ij->i", a, b)
g = tape.gradient(r, [var])[0]
values = g.values if isinstance(g, ops.IndexedSlices) else g
self.assertAllEqual(values.get_shape(), [4, 1])
def testNegation(self):
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
x = -ops.IndexedSlices(values, indices)
self.assertAllEqual(x.values, [[-2, -3], [-5, -7]])
self.assertAllEqual(x.indices, [0, 2])
def testScalarMul(self):
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices))
self.assertAllEqual(x.values, [[-4, -6], [-10, -14]])
self.assertAllEqual(x.indices, [0, 2])
@test_util.run_all_in_graph_and_eager_modes
class IndexedSlicesSpecTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
def assertAllTensorsEqual(self, list1, list2):
self.assertLen(list1, len(list2))
for (t1, t2) in zip(list1, list2):
self.assertAllEqual(t1, t2)
def testConstruction(self):
spec1 = indexed_slices.IndexedSlicesSpec()
self.assertEqual(spec1._shape.rank, None)
self.assertEqual(spec1._values_dtype, dtypes.float32)
self.assertEqual(spec1._indices_dtype, dtypes.int64)
self.assertEqual(spec1._dense_shape_dtype, None)
self.assertEqual(spec1._indices_shape.as_list(), [None])
spec2 = indexed_slices.IndexedSlicesSpec([None, None], dtypes.string,
dtypes.int32, dtypes.int64, [10])
self.assertEqual(spec2._shape.as_list(), [None, None])
self.assertEqual(spec2._values_dtype, dtypes.string)
self.assertEqual(spec2._indices_dtype, dtypes.int32)
self.assertEqual(spec2._dense_shape_dtype, dtypes.int64)
self.assertEqual(spec2._indices_shape.as_list(), [10])
def testValueType(self):
spec1 = indexed_slices.IndexedSlicesSpec()
self.assertEqual(spec1.value_type, ops.IndexedSlices)
@parameterized.parameters([
(indexed_slices.IndexedSlicesSpec(),
(tensor_shape.TensorShape(None), dtypes.float32, dtypes.int64, None,
tensor_shape.TensorShape([None]))),
(indexed_slices.IndexedSlicesSpec(shape=[5, None, None]),
(tensor_shape.TensorShape([5, None, None]), dtypes.float32,
dtypes.int64, None, tensor_shape.TensorShape([None]))),
(indexed_slices.IndexedSlicesSpec(
dtype=dtypes.int32, dense_shape_dtype=dtypes.int64),
(tensor_shape.TensorShape(None), dtypes.int32, dtypes.int64,
dtypes.int64, tensor_shape.TensorShape([None]))),
(indexed_slices.IndexedSlicesSpec(indices_shape=[100]),
(tensor_shape.TensorShape(None), dtypes.float32, dtypes.int64, None,
tensor_shape.TensorShape([100]))),
]) # pyformat: disable
def testSerialize(self, spec, expected):
serialization = spec._serialize()
# TensorShape has an unconventional definition of equality, so we can't use
# assertEqual directly here. But repr() is deterministic and lossless for
# the expected values, so we can use that instead.
self.assertEqual(repr(serialization), repr(expected))
@parameterized.parameters([
(indexed_slices.IndexedSlicesSpec(dtype=dtypes.string), (
tensor_spec.TensorSpec(None, dtypes.string),
tensor_spec.TensorSpec([None], dtypes.int64),
)),
(indexed_slices.IndexedSlicesSpec(
dtype=dtypes.string, dense_shape_dtype=dtypes.int32), (
tensor_spec.TensorSpec(None, dtypes.string),
tensor_spec.TensorSpec([None], dtypes.int64),
tensor_spec.TensorSpec([None], dtypes.int32),
)),
(indexed_slices.IndexedSlicesSpec(
shape=[5, 10, 15], dense_shape_dtype=dtypes.int32), (
tensor_spec.TensorSpec([None, 10, 15], dtypes.float32),
tensor_spec.TensorSpec([None], dtypes.int64),
tensor_spec.TensorSpec([3], dtypes.int32),
)),
(indexed_slices.IndexedSlicesSpec(
shape=[5, 10, 15], dense_shape_dtype=dtypes.int32,
indices_shape=[20]), (
tensor_spec.TensorSpec([20, 10, 15], dtypes.float32),
tensor_spec.TensorSpec([20], dtypes.int64),
tensor_spec.TensorSpec([3], dtypes.int32),
)),
])
def testComponentSpecs(self, spec, expected):
self.assertEqual(spec._component_specs, expected)
@parameterized.parameters([
{
"spec": indexed_slices.IndexedSlicesSpec(),
"values": [3.0, 5.0],
"indices": [5, 10]
},
{
"spec":
indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32),
"values": [3.0, 5.0],
"indices": [5, 10],
"dense_shape": [100]
},
])
def testToFromComponents(self, spec, indices, values, dense_shape=None):
x = ops.IndexedSlices(indices, values, dense_shape)
actual_components = spec._to_components(x)
if dense_shape is None:
self.assertAllTensorsEqual(actual_components, [indices, values])
else:
self.assertAllTensorsEqual(actual_components,
[indices, values, dense_shape])
st_reconstructed = spec._from_components(actual_components)
self.assertAllEqual(x.indices, st_reconstructed.indices)
self.assertAllEqual(x.values, st_reconstructed.values)
if dense_shape is None:
self.assertIs(st_reconstructed.dense_shape, None)
else:
self.assertAllEqual(x.dense_shape, st_reconstructed.dense_shape)
@test_util.run_v1_only("IndexedSlicesValue is deprecated in v2")
def testFromNumpyComponents(self):
indices = np.array([3, 8])
values = np.array([1.0, 9.0])
dense_shape = np.array([100])
spec1 = indexed_slices.IndexedSlicesSpec(dense_shape_dtype=dtypes.int32)
st1 = spec1._from_components((values, indices, dense_shape))
self.assertIsInstance(st1, indexed_slices.IndexedSlicesValue)
self.assertAllEqual(st1.indices, indices)
self.assertAllEqual(st1.values, values)
self.assertAllEqual(st1.dense_shape, dense_shape)
spec2 = indexed_slices.IndexedSlicesSpec()
st2 = spec2._from_components((values, indices))
self.assertIsInstance(st2, indexed_slices.IndexedSlicesValue)
self.assertAllEqual(st2.indices, indices)
self.assertAllEqual(st2.values, values)
self.assertIs(st2.dense_shape, None)
class NodeDefConstructorTest(test_util.TensorFlowTestCase):
def testNoArgs(self):
nodedef = ops._NodeDef("None", "bar")
self.assertProtoEquals("op: 'None' name: 'bar'", nodedef)
def _apply_op(g, *args, **kwargs):
op = g.create_op(*args, **kwargs)
if len(op.outputs) == 1:
return op.outputs[0]
else:
return op.outputs
class OperationTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def testNoInputs(self):
op = test_ops.float_output_string_output(name="myop").a.op
self.assertEqual(2, len(op.values()))
self.assertEqual(0, len(op.inputs))
self.assertEqual("myop", op.name)
float_t, label_str_t = op.values()
self.assertEqual(dtypes.float32, float_t.dtype)
self.assertEqual(op, float_t.op)
self.assertEqual(0, float_t._value_index)
self.assertEqual(0, len(float_t.consumers()))
self.assertEqual("myop", float_t._as_node_def_input())
self.assertEqual(dtypes.string, label_str_t.dtype)
self.assertEqual(op, label_str_t.op)
self.assertEqual(1, label_str_t._value_index)
self.assertEqual(0, len(label_str_t.consumers()))
self.assertEqual("myop:1", label_str_t._as_node_def_input())
self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'",
op.node_def)
@test_util.run_deprecated_v1
def testNoOutputs(self):
op1 = test_ops.float_output(name="myop1").op
float_t, = op1.values()
op2 = test_ops.float_input(float_t, name="myop2")
self.assertEqual(0, len(op2.values()))
self.assertEqual(1, len(op2.inputs))
self.assertIs(float_t, op2.inputs[0])
self.assertEqual(1, len(float_t.consumers()))
self.assertEqual(op2, float_t.consumers()[0])
self.assertProtoEquals("op:'FloatOutput' name:'myop1'", op1.node_def)
self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'",
op2.node_def)
@test_util.run_deprecated_v1
def testInputsAndOutputs(self):
op1 = test_ops.float_output(name="myop1").op
self.assertEqual(1, len(op1.values()))
float1_t, = op1.values()
op2 = test_ops.float_output_string_output(name="myop2").a.op
self.assertEqual(2, len(op2.values()))
float2_t, label2_str_t = op2.values()
# Note that we consume label2_str_t twice here.
op3 = test_ops.foo2(float1_t, label2_str_t, label2_str_t, name="myop3").d.op
self.assertEqual(2, len(op3.values()))
self.assertEqual(1, len(float1_t.consumers()))
self.assertEqual(op3, float1_t.consumers()[0])
self.assertEqual(0, len(float2_t.consumers()))
self.assertEqual(2, len(label2_str_t.consumers()))
self.assertEqual(op3, label2_str_t.consumers()[0])
self.assertEqual(op3, label2_str_t.consumers()[1])
self.assertProtoEquals("""
op:'Foo2' name:'myop3'
input:'myop1' input:'myop2:1' input:'myop2:1'
""", op3.node_def)
def testDeviceObject(self):
op = ops.Operation(ops._NodeDef("None", "myop"), ops.Graph(), [], [])
op._set_device("/job:goo/device:GPU:0")
self.assertProtoEquals(
"op:'None' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def)
op = ops.Operation(ops._NodeDef("None", "op2"), ops.Graph(), [], [])
op._set_device(
pydev.DeviceSpec(
job="muu", device_type="CPU", device_index=0))
self.assertProtoEquals(
"op:'None' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def)
def testReferenceInput(self):
g = ops.Graph()
op1 = ops.Operation(
ops._NodeDef("RefOutputFloatOutput", "op1"), g, [],
[dtypes.float32_ref, dtypes.float32])
self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
self.assertEqual([], list(op1.inputs))
ref_t, nonref_t = op1.values()
# NOTE(mrry): Must specify input_types to preserve ref-typed input.
op2 = ops.Operation(
ops._NodeDef("RefInputFloatInput", "op2"),
g, [ref_t, nonref_t], [],
input_types=[dtypes.float32_ref, dtypes.float32])
self.assertProtoEquals(
"op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
op2.node_def)
self.assertEqual([ref_t, nonref_t], list(op2.inputs))
op3 = ops.Operation(
ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], [])
self.assertProtoEquals(
"op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
op3.node_def)
def testInvalidNames(self):
g = ops.Graph()
with self.assertRaises(ValueError):
ops.Operation(ops._NodeDef("op", ""), g)
with self.assertRaises(ValueError):
ops.Operation(ops._NodeDef("op", "_invalid"), g)
with self.assertRaises(ValueError):
ops.Operation(ops._NodeDef("op", "-invalid"), g)
with self.assertRaises(ValueError):
ops.Operation(ops._NodeDef("op", "/invalid"), g)
with self.assertRaises(ValueError):
ops.Operation(ops._NodeDef("op", "invalid:0"), g)
@test_util.run_deprecated_v1
def testNoShapeFunction(self):
op = test_ops.a()
self.assertEqual(tensor_shape.unknown_shape(), op.get_shape())
@test_util.run_in_graph_and_eager_modes
def testConvertToTensorNestedArray(self):
values = [[2], [3], [5], [7]]
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(values, self.evaluate(tensor))
def testShapeTuple(self):
with self.cached_session():
c = constant_op.constant(1)
self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access
def testConvertToTensorEager(self):
with context.eager_mode():
t = constant_op.constant(1)
self.assertTrue(isinstance(t, ops.EagerTensor))
converted = ops.convert_to_tensor(t)
self.assertTrue(isinstance(converted, ops.EagerTensor))
converted = ops.convert_to_tensor(1)
self.assertTrue(isinstance(converted, ops.EagerTensor))
@test_util.run_in_graph_and_eager_modes
def testConvertToTensorNestedTuple(self):
values = ((2,), (3,), (5,), (7,))
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(values, self.evaluate(ops.convert_to_tensor(values)))
@test_util.run_in_graph_and_eager_modes
def testConvertToTensorNestedTensors(self):
values = ((2,), (3,), (5,), (7,))
tensor = ops.convert_to_tensor(
[constant_op.constant(row) for row in values])
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(values, self.evaluate(tensor))
tensor = ops.convert_to_tensor(
[[constant_op.constant(v) for v in row] for row in values])
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(values, self.evaluate(tensor))
@test_util.run_in_graph_and_eager_modes
def testConvertToTensorNestedMix(self):
values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7]))
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(((2,), (3,), (5,), (7,)), self.evaluate(tensor))
@test_util.run_in_graph_and_eager_modes
def testConvertToTensorPreferred(self):
values = [2, 3, 5, 7]
tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32)
self.assertEqual(dtypes.float32, tensor.dtype)
# Convert empty tensor to anything.
values = []
tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
self.assertEqual(dtypes.int64, tensor.dtype)
# The preferred dtype is a type error and will convert to
# float32 instead.
values = [1.23]
tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
self.assertEqual(dtypes.float32, tensor.dtype)
@test_util.run_in_graph_and_eager_modes
def testConvertToInvalidTensorType(self):
with self.assertRaises(TypeError):
# Forcing an invalid dtype should fail with a type error.
values = [1.23]
ops.convert_to_tensor(values, dtype=dtypes.int64)
@test_util.run_in_graph_and_eager_modes
def testConvertToLongLongTensorType(self):
tensor = ops.convert_to_tensor(
# Get a numpy array of dtype NPY_LONGLONG
np.prod(constant_op.constant([1])._shape_tuple()),
dtype=dtypes.int64)
self.assertEqual(dtypes.int64, tensor.dtype)
@test_util.run_in_graph_and_eager_modes
def testConvertToTensorFromInvalidTensor(self):
tensor = constant_op.constant(42.0, dtype=dtypes.float32)
with self.assertRaises(ValueError):
ops.convert_to_tensor(tensor, dtype=dtypes.int32)
@test_util.run_in_graph_and_eager_modes
def testConvertToTensorProtocol(self):
class TensorCompatible:
def __tf_tensor__(self, dtype=None, name=None):
return constant_op.constant((1, 2, 3), dtype=dtype, name=name)
tc = TensorCompatible()
tensor = ops.convert_to_tensor(tc, dtype=dtypes.int32)
self.assertEqual(tensor.dtype, dtypes.int32)
self.assertAllEqual((1, 2, 3), self.evaluate(tensor))
@test_util.run_deprecated_v1
def testNoConvert(self):
# Operation cannot be converted to Tensor.
op = control_flow_ops.no_op()
with self.assertRaisesRegex(TypeError,
"can't convert Operation '.+' to Tensor"):
ops.convert_to_tensor(op)
def testStr(self):
node_def = ops._NodeDef("None", "op1")
op = ops.Operation(node_def, ops.Graph(), [], [dtypes.float32])
self.assertEqual(str(node_def), str(op))
def testRepr(self):
op = ops.Operation(
ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32])
self.assertEqual("<tf.Operation 'op1' type=None>", repr(op))
@test_util.run_deprecated_v1
def testGetAttr(self):
op = test_ops.default_attrs()
self.assertEqual(op.get_attr("string_val"), b"abc")
self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""])
self.assertEqual(op.get_attr("int_val"), 123)
self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3])
self.assertEqual(op.get_attr("float_val"), 10.0)
self.assertEqual(op.get_attr("float_list_val"), [10.0])
self.assertEqual(op.get_attr("bool_val"), True)
self.assertEqual(op.get_attr("bool_list_val"), [True, False])
self.assertEqual(op.get_attr("shape_val"),
tensor_shape.as_shape([2, 1]).as_proto())
self.assertEqual(op.get_attr("shape_list_val"),
[tensor_shape.as_shape([]).as_proto(),
tensor_shape.as_shape([1]).as_proto()])
self.assertEqual(op.get_attr("tensor_val"),
tensor_util.make_tensor_proto(1, dtypes.int32))
self.assertEqual(op.get_attr("tensor_list_val"),
[tensor_util.make_tensor_proto(1, dtypes.int32)])
type_val = op.get_attr("type_val")
# First check that type_val is a DType, because the assertEqual will work
# no matter what since DType overrides __eq__
self.assertIsInstance(type_val, dtypes.DType)
self.assertEqual(type_val, dtypes.int32)
type_list_val = op.get_attr("type_list_val")
self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val))
self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32])
@function.Defun(dtypes.float32, func_name="MyFunc")
def func(x):
return x
op = test_ops.func_attr(func)
self.assertEqual(op.get_attr("f"),
attr_value_pb2.NameAttrList(name="MyFunc"))
# Try fetching missing attr
with self.assertRaisesRegex(
ValueError, "Operation 'FuncAttr' has no attr named 'FakeAttr'."):
op.get_attr("FakeAttr")
# TODO(b/65162920): remove this test when users who are directly mutating the
# node_def have been updated to proper usage.
@test_util.run_deprecated_v1
def testSetAttr(self):
op = test_ops.int_attr().op
op._set_attr("foo", attr_value_pb2.AttrValue(i=2))
# TODO(skyewm): add node_def check
self.assertEqual(op.get_attr("foo"), 2)
# TODO(nolivia): test all error cases
def testAddControlInput(self):
with ops.Graph().as_default():
x = constant_op.constant(1).op
y = constant_op.constant(2).op
z = constant_op.constant(3).op
z._add_control_input(x) # pylint: disable=protected-access
self.assertEqual(z.control_inputs, [x])
z._add_control_input(x) # pylint: disable=protected-access
self.assertEqual(z.control_inputs, [x])
z._add_control_inputs([x, y, y]) # pylint: disable=protected-access
self.assertEqual(z.control_inputs, [x, y])
self.assertEqual(x._control_outputs, [z])
@test_util.run_deprecated_v1
def testRemoveAllControlInputs(self):
a = constant_op.constant(1)
with ops.control_dependencies([a]):
b = constant_op.constant(2)
c = constant_op.constant(3)
d = constant_op.constant(4)
e = constant_op.constant(5)
with ops.control_dependencies([a, c]):
f = d + e
self.assertEqual(a.op.control_inputs, [])
self.assertEqual(b.op.control_inputs, [a.op])
self.assertEqual(f.op.control_inputs, [a.op, c.op])
a.op._remove_all_control_inputs() # pylint: disable=protected-access
self.assertEqual(a.op.control_inputs, [])
b.op._remove_all_control_inputs() # pylint: disable=protected-access
self.assertEqual(b.op.control_inputs, [])
f.op._remove_all_control_inputs() # pylint: disable=protected-access
self.assertEqual(f.op.control_inputs, [])
self.assertEqual(list(f.op.inputs), [d, e])
@test_util.run_deprecated_v1
def testControlInputCycle(self):
graph = ops.Graph()
with graph.as_default():
z = constant_op.constant(0)
x = constant_op.constant(1)
y = constant_op.constant(2)
y.op._add_control_input(z.op) # pylint: disable=protected-access
y.op._add_control_input(x.op) # pylint: disable=protected-access
x.op._add_control_input(y.op) # pylint: disable=protected-access
with self.session(graph=graph) as sess:
with self.assertRaisesRegex(
errors.InvalidArgumentError,
"Graph is invalid, contains a cycle with 2 nodes"):
self.evaluate(x)
def testUpdateInput(self):
g = ops.Graph()
with g.as_default():
x = constant_op.constant(1)
y = constant_op.constant(2)
z = x + y
z.op._update_input(0, y) # pylint: disable=protected-access
self.assertEqual(list(z.op.inputs), [y, y])
self.assertEqual(x.consumers(), [])
self.assertEqual(y.consumers(), [z.op, z.op])
with session.Session(graph=g) as sess:
self.assertEqual(self.evaluate(z), 4)
z.op._update_input(0, x) # pylint: disable=protected-access
self.assertEqual(list(z.op.inputs), [x, y])
self.assertEqual(x.consumers(), [z.op])
self.assertEqual(y.consumers(), [z.op])
with session.Session(graph=g) as sess:
self.assertEqual(self.evaluate(z), 3)
z.op._update_input(1, y) # pylint: disable=protected-access
self.assertEqual(list(z.op.inputs), [x, y])
self.assertEqual(x.consumers(), [z.op])
self.assertEqual(y.consumers(), [z.op])
with session.Session(graph=g) as sess:
self.assertEqual(self.evaluate(z), 3)
def testUpdateInputGraphError(self):
g_0 = ops.Graph()
g_1 = ops.Graph()
with g_0.as_default():
x = constant_op.constant(1)
with g_1.as_default():
y = constant_op.constant(2)
z = y * 2
with self.assertRaisesRegex(ValueError, "must be from the same graph"):
z.op._update_input(0, x) # pylint: disable=protected-access
def testUpdateInputTypeError(self):
g = ops.Graph()
with g.as_default():
w = constant_op.constant(0)
x = constant_op.constant("")
y = constant_op.constant(1)
z = y + w
z.op._update_input(0, x) # pylint: disable=protected-access
with session.Session(graph=g) as sess:
with self.assertRaisesRegex(
errors.InvalidArgumentError,
"Input 0 of node add was passed string from Const_1:0 incompatible "
"with expected int32"):
self.evaluate(z)
def testUpdateInputShapeError(self):
g = ops.Graph()
with g.as_default():
w = constant_op.constant(2, shape=[3, 1])
x = constant_op.constant(0, shape=[3, 1])
y = constant_op.constant(1, shape=[2, 2])
z = w + x
with self.assertRaisesRegex(
errors.InvalidArgumentError,
r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"):
z.op._update_input(0, y) # pylint: disable=protected-access
def testUpdateInputOutOfRange(self):
g = ops.Graph()
with g.as_default():
x = constant_op.constant(1)
with self.assertRaisesRegex(
errors.OutOfRangeError,
r"Cannot update edge. Input index \[1\] is greater than the number of "
r"total inputs \[0\]."):
x.op._update_input(1, x) # pylint: disable=protected-access
@test_util.enable_control_flow_v2
@test_util.run_v1_only("b/120545219")
def testAddWhileInput(self):
@eager_function.defun
def test():
output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1,
[1])
while_op = output.op
self.assertEqual(while_op.type, "StatelessWhile")
orig_num_inputs = len(while_op.inputs)
# Make sure we can handle the while op having a control input.
while_op._add_control_input(constant_op.constant(0).op)
new_input1 = constant_op.constant(1.0)
new_input2 = constant_op.constant(True)
# Clear output shapes to bypass shape checking.
while_op._set_shape_list_attr("output_shapes", [])
while_op._set_type_list_attr("T", [t.dtype for t in while_op.inputs] +
[new_input1.dtype, new_input2.dtype])
while_op._add_while_inputs([new_input1, new_input2])
# Can't add an edge beyond what's specified by "T"
with self.assertRaises(errors.OutOfRangeError):
while_op._add_while_inputs([new_input2])
self.assertEqual(len(while_op.inputs), orig_num_inputs + 2) # pylint: disable=g-deprecated-assert
test()
@test_util.run_deprecated_v1
def testOpDef(self):
x = constant_op.constant(0)
y = constant_op.constant(1)
z = x + y
self.assertEqual(x.op.op_def.name, "Const")
self.assertEqual(len(x.op.op_def.input_arg), 0)
self.assertEqual(len(x.op.op_def.output_arg), 1)
self.assertRegex(z.op.op_def.name, "Add(V2)?")
self.assertEqual(len(z.op.op_def.input_arg), 2)
self.assertEqual(len(z.op.op_def.output_arg), 1)
def testInputFromDifferentGraphError(self):
g_0 = ops.Graph()
g_1 = ops.Graph()
with g_0.as_default():
x = constant_op.constant(1)
with g_1.as_default():
y = constant_op.constant(2)
with self.assertRaisesRegex(ValueError, "must be from the same graph"):
y * x # pylint: disable=pointless-statement
def testInputsAreImmutable(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
op = test_ops.int_input_int_output(x, name="myop").op
with self.assertRaisesRegex(AttributeError,
"'tuple' object has no attribute 'append'"):
op.inputs.append(None)
class CreateOpTest(test_util.TensorFlowTestCase):
def testNodeDefArgs(self):
g = ops.Graph()
op1 = g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
with g.device("/device:GPU:0"):
op2 = g.create_op(
"FloatOutputStringOutput", [], [dtypes.float32, dtypes.string], None,
name="myop2")
op3 = g.create_op(
"Foo3",
[list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]],
[dtypes.float32, dtypes.int32],
None,
name="myop3")
self.assertDeviceEqual(None, op1.device)
self.assertDeviceEqual("/device:GPU:0", op2.device)
self.assertDeviceEqual(None, op3.device)
self.assertProtoEquals("name:'myop1' op:'FloatOutput'", op1.node_def)
self.assertProtoEquals(
"name:'myop2' op:'FloatOutputStringOutput' device:'/device:GPU:0'",
op2.node_def)
self.assertProtoEquals(
"name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo3'",
op3.node_def)
def testReferenceInput(self):
g = ops.Graph()
op1 = g.create_op(
"RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
name="op1")
self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
ref_t, nonref_t = op1.values()
# NOTE(mrry): Must specify input_types to preserve ref-typed input.
op2 = g.create_op(
"RefInputFloatInput", [ref_t, nonref_t], [],
input_types=[dtypes.float32_ref, dtypes.float32],
name="op2")
self.assertProtoEquals(
"op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
op2.node_def)
op3 = g.create_op("TwoFloatInputs", [ref_t, nonref_t], [], name="op3")
self.assertProtoEquals(
"op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
op3.node_def)
def testFinalized(self):
g = ops.Graph()
g.finalize()
with self.assertRaises(RuntimeError):
g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
# Test unfinalize.
g._unsafe_unfinalize()
g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
# NOTE(skyewm): these cases test the private Graph._create_op_from_tf_operation
# method. Arguably we should only test the public APIs that depend on this
# method. However, this logic is complex and tricky, and it can be difficult to
# ascertain if we have adequate coverage (e.g. a graph may run successfully if
# the control flow context isn't set properly, but a more complicated use case
# that might not be obvious to test will fail). Thus we instead explicitly test
# the low-level behavior.
class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def testBasic(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
c_op = ops._create_c_op(
g, ops._NodeDef("IntInputIntOutput", "myop"), [x], [])
op = g._create_op_from_tf_operation(c_op)
self.assertEqual(op.name, "myop")
self.assertEqual(op.type, "IntInputIntOutput")
self.assertEqual(len(op.outputs), 1)
self.assertEqual(op.outputs[0].shape, tensor_shape.unknown_shape())
self.assertEqual(list(op.inputs), [x])
self.assertEqual(op.control_inputs, [])
self.assertEqual(op.graph, g)
self.assertEqual(x.consumers(), [op])
self.assertIsNotNone(op.traceback)
self.assertEqual(g.get_operation_by_name("myop"), op)
self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0])
def testShape(self):
g = ops.Graph()
with g.as_default():
x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], [])
op = g._create_op_from_tf_operation(c_op)
self.assertEqual(op.name, "myop")
self.assertEqual(op.type, "Identity")
self.assertEqual(len(op.outputs), 1)
self.assertEqual(op.outputs[0].shape, tensor_shape.TensorShape([2, 3]))
def testUniqueName(self):
g = ops.Graph()
with g.as_default():
c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], [])
c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], [])
op = g._create_op_from_tf_operation(c_op)
op2 = g._create_op_from_tf_operation(c_op2)
# Create ops with same names as op1 and op2. We expect the new names to be
# uniquified.
op3 = test_ops.int_output(name="myop").op
op4 = test_ops.int_output(name="myop_1").op
self.assertEqual(op.name, "myop")
self.assertEqual(op2.name, "myop_1")
self.assertEqual(op3.name, "myop_2")
self.assertEqual(op4.name, "myop_1_1")
@test_util.run_v1_only("b/120545219")
def testCond(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def true_fn():
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "cond/myop"), [x], [])
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return x
control_flow_ops.cond(x < 10, true_fn, lambda: x)
op = g.get_operation_by_name("cond/myop")
self.assertIsNotNone(op)
self.assertEqual(op.name, "cond/myop")
self.assertEqual(op.type, "IntInput")
self.assertEqual(op.outputs, [])
op_input = op.inputs[0].op
self.assertEqual(op_input.type, "Switch")
self.assertEqual(op_input.inputs[0], x)
self.assertEqual(op.graph, g)
# pylint: disable=protected-access
self.assertIsNotNone(op._get_control_flow_context())
self.assertEqual(op._get_control_flow_context().name,
"cond/cond_text")
# pylint: enable=protected-access
@test_util.run_v1_only("b/120545219")
def testWhileLoop(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def body(i):
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
self.assertEqual(op.name, "myloop/myop")
self.assertEqual(op.type, "IntInput")
self.assertEqual(op.outputs, [])
op_input = op.inputs[0].op
self.assertEqual(op_input.type, "Enter")
self.assertEqual(list(op_input.inputs), [x])
self.assertEqual(op.graph, g)
# pylint: disable=protected-access
self.assertIsNotNone(op._get_control_flow_context())
self.assertEqual(op._get_control_flow_context().name,
"myloop/while_context")
# pylint: enable=protected-access
@test_util.run_v1_only("b/120545219")
def testWhileLoopWithInternalControlDep(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def body(i):
c = constant_op.constant(1.0, name="c")
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
with ops.control_dependencies([c]):
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
c = g.get_operation_by_name("myloop/c")
self.assertIsNotNone(c)
# Internal control dep is preserved
self.assertEqual(op.control_inputs, [c])
@test_util.run_v1_only("b/120545219")
def testWhileLoopWithExternalControlDep(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
c = constant_op.constant(1.0)
def body(i):
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
with ops.control_dependencies([c]):
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
# External control dep is removed and replaced with internal control dep
self.assertNotEqual(op.control_inputs[0], c.op)
self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
class ApplyOpTest(test_util.TensorFlowTestCase):
def testNodeDefArgs(self):
g = ops.Graph()
t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
with g.device("/device:GPU:0"):
t2 = _apply_op(
g, "TwoIntOutputs", [], [dtypes.int32, dtypes.int32], name="myop2")
t3 = _apply_op(
g,
"Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32],
name="myop3")
self.assertTrue(isinstance(t1, ops.Tensor))
self.assertTrue(isinstance(t2, list))
self.assertTrue(isinstance(t3, list))
self.assertTrue(isinstance(t3[0], ops.Tensor))
self.assertEqual("myop1", t1._as_node_def_input())
self.assertEqual("myop2", t2[0]._as_node_def_input())
self.assertEqual("myop2:1", t2[1]._as_node_def_input())
self.assertEqual("myop3", t3[0]._as_node_def_input())
# Validate that we got the right ops as well
self.assertProtoEquals("name:'myop1' op:'FloatOutput'", t1.op.node_def)
self.assertProtoEquals(
"name:'myop2' op:'TwoIntOutputs' device:'/device:GPU:0'",
t2[0].op.node_def)
self.assertProtoEquals(
"name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo1'",
t3[0].op.node_def)
def testReferenceInput(self):
g = ops.Graph()
ref_t, nonref_t = _apply_op(
g, "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
name="op1")
self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'",
ref_t.op.node_def)
# NOTE(mrry): Must specify input_types to preserve ref-typed input.
out_2 = _apply_op(
g,
"RefInputFloatInputIntOutput", [ref_t, nonref_t], [dtypes.int32],
input_types=[dtypes.float32_ref, dtypes.float32],
name="op2")
self.assertProtoEquals(
"op:'RefInputFloatInputIntOutput' name:'op2' input:'op1' input:'op1:1'",
out_2.op.node_def)
out_3 = _apply_op(
g, "TwoFloatInputsIntOutput", [ref_t, nonref_t], [dtypes.int32],
name="op3")
self.assertProtoEquals(
"op:'TwoFloatInputsIntOutput' name:'op3' input:'op1' input:'op1:1'",
out_3.op.node_def)
class NameStackTest(test_util.TensorFlowTestCase):
def testBasics(self):
g = ops.Graph()
self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
self.assertEqual("foo", g.unique_name("foo"))
self.assertEqual("foo_1", g.unique_name("foo", mark_as_used=False))
self.assertEqual("foo_1", g.unique_name("foo"))
self.assertEqual("foo_2", g.unique_name("foo", mark_as_used=False))
self.assertEqual("foo_2", g.unique_name("foo"))
self.assertEqual("foo_1_1", g.unique_name("foo_1", mark_as_used=False))
self.assertEqual("foo_1_1", g.unique_name("foo_1"))
self.assertEqual("foo_1_2", g.unique_name("foo_1", mark_as_used=False))
self.assertEqual("foo_1_2", g.unique_name("foo_1"))
self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2", mark_as_used=False))
self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2"))
with g.name_scope("bar"):
self.assertEqual("bar/foo", g.unique_name("foo", mark_as_used=False))
self.assertEqual("bar/foo", g.unique_name("foo"))
self.assertEqual("bar/foo_1", g.unique_name("foo", mark_as_used=False))
self.assertEqual("bar/foo_1", g.unique_name("foo"))
with g.name_scope(None):
self.assertEqual("foo_3", g.unique_name("foo", mark_as_used=False))
self.assertEqual("foo_3", g.unique_name("foo"))
with g.name_scope("baz"):
self.assertEqual(
"bar/baz/foo", g.unique_name(
"foo", mark_as_used=False))
self.assertEqual("bar/baz/foo", g.unique_name("foo"))
self.assertEqual(
"bar/baz/foo_1", g.unique_name(
"foo", mark_as_used=False))
self.assertEqual("bar/baz/foo_1", g.unique_name("foo"))
with g.name_scope("baz"):
self.assertEqual(
"bar/baz_1/foo", g.unique_name(
"foo", mark_as_used=False))
self.assertEqual("bar/baz_1/foo", g.unique_name("foo"))
self.assertEqual(
"bar/baz_1/foo_1", g.unique_name(
"foo", mark_as_used=False))
self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo"))
with g.name_scope("quux"):
self.assertEqual("quux/foo", g.unique_name("foo", mark_as_used=False))
self.assertEqual("quux/foo", g.unique_name("foo"))
with g.name_scope("bar"):
with g.name_scope("baz"):
self.assertEqual(
"bar_1/baz/foo", g.unique_name(
"foo", mark_as_used=False))
self.assertEqual("bar_1/baz/foo", g.unique_name("foo"))
self.assertEqual("foo_4", g.unique_name("foo", mark_as_used=False))
self.assertEqual("foo_4", g.unique_name("foo"))
self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False))
self.assertEqual("bar_2", g.unique_name("bar"))
def testBackslashAndDashRegex(self):
# GitHub issue 39019, all should pass
g = ops.Graph()
with g.name_scope("n_CatCntc-campaign\\c_campaign"):
pass
with g.name_scope("foo"):
with g.name_scope("n_CatCntc-campaign\\c_campaign"):
pass
with g.name_scope("n_CatCntc-campaign\\c_campaign"):
with g.name_scope("foo"):
pass
@test_util.run_deprecated_v1
def testNameAndVariableScope(self):
with self.cached_session() as sess:
with sess.graph.name_scope("l0"):
with variable_scope.variable_scope("l1"):
with sess.graph.name_scope("l1") as scope:
self.assertEqual("l0/l1/l1/", scope)
self.assertEqual(
"l0/l1/l1/foo",
sess.graph.unique_name(
"foo", mark_as_used=False))
self.assertEqual("l0/l1/l1/foo", sess.graph.unique_name("foo"))
with sess.graph.name_scope("l2") as scope:
self.assertEqual("l0/l1/l2/", scope)
self.assertEqual(
"l0/l1/l2/foo",
sess.graph.unique_name(
"foo", mark_as_used=False))
self.assertEqual("l0/l1/l2/foo", sess.graph.unique_name("foo"))
def testOutOfOrderUniqueName(self):
g = ops.Graph()
self.assertEqual("foo_2", g.unique_name("foo_2"))
self.assertEqual("foo", g.unique_name("foo"))
self.assertEqual("foo_1", g.unique_name("foo"))
self.assertEqual("foo_3", g.unique_name("foo"))
def testUniqueNameCaseInsensitivity(self):
g = ops.Graph()
self.assertEqual("foo", g.unique_name("foo"))
self.assertEqual("Foo_1", g.unique_name("Foo"))
with g.name_scope("bar"):
self.assertEqual("bar/foo", g.unique_name("foo"))
with g.name_scope("Bar"):
self.assertEqual("Bar_1/foo", g.unique_name("foo"))
def testInvalidNameRaisesError(self):
g = ops.Graph()
with g.name_scope(""): # Should not raise
pass
with g.name_scope("foo/"): # Should not raise
with g.name_scope("_bar"): # Should not raise
pass
with self.assertRaises(ValueError):
with g.name_scope("foo:0"):
pass
with self.assertRaises(ValueError):
with g.name_scope("_bar"):
pass
class NameTest(test_util.TensorFlowTestCase):
def testGenerateName(self):
g = ops.Graph()
op0 = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
self.assertEqual("TwoFloatOutputs", op0.name)
self.assertEqual("TwoFloatOutputs:0", op0.outputs[0].name)
self.assertEqual("TwoFloatOutputs:1", op0.outputs[1].name)
op1 = g.create_op("FloatOutput", [], [dtypes.float32])
self.assertEqual("FloatOutput", op1.name)
self.assertEqual("FloatOutput:0", op1.outputs[0].name)
op2 = g.create_op("FloatOutput", [], [dtypes.float32])
self.assertEqual("FloatOutput_1", op2.name)
self.assertEqual("FloatOutput_1:0", op2.outputs[0].name)
op3 = g.create_op("FloatOutput", [], [dtypes.float32], name="my_op")
self.assertEqual("my_op", op3.name)
self.assertEqual("my_op:0", op3.outputs[0].name)
def testNameScope(self):
g = ops.Graph()
with g.name_scope("foo") as foo:
self.assertEqual("foo/", foo)
with g.name_scope("foo2") as foo2:
self.assertEqual("foo/foo2/", foo2)
with g.name_scope(None) as empty1:
self.assertEqual("", empty1)
with g.name_scope("foo3") as foo3:
self.assertEqual("foo3/", foo3)
with g.name_scope("") as empty2:
self.assertEqual("", empty2)
self.assertEqual("FloatOutput",
g.create_op("FloatOutput", [], [dtypes.float32]).name)
with g.name_scope("bar") as scope:
self.assertEqual("bar/FloatOutput",
g.create_op("FloatOutput", [], [dtypes.float32]).name)
self.assertEqual("bar/FloatOutput_1",
g.create_op("FloatOutput", [], [dtypes.float32]).name)
# If you use the value from "with .. as", that values is used as-is.
self.assertEqual(
"bar", g.create_op(
"FloatOutput", [], [dtypes.float32], name=scope).name)
with g.name_scope("baz") as scope:
with g.name_scope("quux"):
self.assertEqual("baz/quux/FloatOutput",
g.create_op("FloatOutput", [], [dtypes.float32]).name)
# If you use the value from the enclosing "with .. as", nothing is pushed.
with g.name_scope(scope):
self.assertEqual("baz/FloatOutput",
g.create_op("FloatOutput", [], [dtypes.float32]).name)
self.assertEqual(
"baz", g.create_op(
"FloatOutput", [], [dtypes.float32], name=scope).name)
self.assertEqual(
"trailing",
g.create_op(
"FloatOutput", [], [dtypes.float32], name="trailing/").name)
with g.name_scope("bar"):
self.assertEqual("bar_1/FloatOutput",
g.create_op("FloatOutput", [], [dtypes.float32]).name)
with g.name_scope("bar/"):
self.assertEqual("bar/FloatOutput_2",
g.create_op("FloatOutput", [], [dtypes.float32]).name)
class DeviceTest(test_util.TensorFlowTestCase):
def testNoDevice(self):
g = ops.Graph()
op = g.create_op("FloatOutput", [], [dtypes.float32])
self.assertDeviceEqual(None, op.device)
gd = g.as_graph_def()
self.assertProtoEqualsVersion("""
node { name: "FloatOutput" op: "FloatOutput" }
""", gd)
def testEagerBackingDevice(self):
with context.eager_mode():
with ops.device("/device:CPU:0"):
t = constant_op.constant(1.0)
self.assertRegex(t.device, "/device:CPU:0")
self.assertRegex(t.backing_device, "/device:CPU:0")
def testDevicePartialString(self):
g = ops.Graph()
with g.device("/job:worker/replica:2"):
g.create_op("FloatOutput", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEqualsVersion("""
node { name: "FloatOutput" op: "FloatOutput"
device: "/job:worker/replica:2" }
""", gd)
def testDeviceFull(self):
g = ops.Graph()
with g.device(
pydev.DeviceSpec(
job="worker", replica=2, task=0, device_type="CPU",
device_index=3)):
g.create_op("FloatOutput", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEqualsVersion("""
node { name: "FloatOutput" op: "FloatOutput"
device: "/job:worker/replica:2/task:0/device:CPU:3" }
""", gd)
def testNesting(self):
g = ops.Graph()
with g.device("/job:worker/replica:2"):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device("/job:worker/replica:3/task:0"):
g.create_op("FloatOutput", [], [dtypes.float32])
g.create_op("FloatOutput", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEqualsVersion("""
node { name: "FloatOutput" op: "FloatOutput"
device: "/job:worker/replica:2" }
node { name: "FloatOutput_1" op: "FloatOutput"
device: "/job:worker/replica:3/task:0" }
node { name: "FloatOutput_2" op: "FloatOutput"
device: "/job:worker/replica:2" }
""", gd)
def testNestingString(self):
g = ops.Graph()
with g.device("/job:worker/replica:2"):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device("/job:worker/replica:3/task:0"):
g.create_op("FloatOutput", [], [dtypes.float32])
g.create_op("FloatOutput", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEqualsVersion("""
node { name: "FloatOutput" op: "FloatOutput"
device: "/job:worker/replica:2" }
node { name: "FloatOutput_1" op: "FloatOutput"
device: "/job:worker/replica:3/task:0" }
node { name: "FloatOutput_2" op: "FloatOutput"
device: "/job:worker/replica:2" }
""", gd)
def testNestingOverrideGpuCpu(self):
g = ops.Graph()
with g.device("/job:worker/replica:2/device:CPU:1"):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device("/job:worker/replica:2/device:GPU:2"):
g.create_op("FloatOutput", [], [dtypes.float32])
g.create_op("FloatOutput", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEqualsVersion("""
node { name: "FloatOutput" op: "FloatOutput"
device: "/job:worker/replica:2/device:CPU:1" }
node { name: "FloatOutput_1" op: "FloatOutput"
device: "/job:worker/replica:2/device:GPU:2" }
node { name: "FloatOutput_2" op: "FloatOutput"
device: "/job:worker/replica:2/device:CPU:1" }
""", gd)
def testNestingWithMergeDeviceFunction(self):
g = ops.Graph()
with g.device(pydev.merge_device("/device:GPU:0")):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device(pydev.merge_device("/job:worker")):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device(pydev.merge_device("/device:CPU:0")):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device(pydev.merge_device("/job:ps")):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device(pydev.merge_device(None)):
g.create_op("FloatOutput", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEqualsVersion("""
node { name: "FloatOutput" op: "FloatOutput"
device: "/device:GPU:0" }
node { name: "FloatOutput_1" op: "FloatOutput"
device: "/job:worker/device:GPU:0" }
node { name: "FloatOutput_2" op: "FloatOutput"
device: "/job:worker/device:CPU:0" }
node { name: "FloatOutput_3" op: "FloatOutput"
device: "/job:ps/device:CPU:0" }
node { name: "FloatOutput_4" op: "FloatOutput"
device: "/job:ps/device:CPU:0" }
""", gd)
def testNestingWithDeviceStrings(self):
g = ops.Graph()
with g.device("/device:GPU:0"):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device("/job:worker"):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device("/device:CPU:0"):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device("/job:ps"):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device(""):
g.create_op("FloatOutput", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEqualsVersion("""
node { name: "FloatOutput" op: "FloatOutput"
device: "/device:GPU:0" }
node { name: "FloatOutput_1" op: "FloatOutput"
device: "/job:worker/device:GPU:0" }
node { name: "FloatOutput_2" op: "FloatOutput"
device: "/job:worker/device:CPU:0" }
node { name: "FloatOutput_3" op: "FloatOutput"
device: "/job:ps/device:CPU:0" }
node { name: "FloatOutput_4" op: "FloatOutput"
device: "/job:ps/device:CPU:0" }
""", gd)
def testNestingWithDeviceStringWildcard(self):
g = ops.Graph()
with g.device("/device:GPU:7"):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device("/device:GPU:*"):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device("/device:CPU:*"):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device("/device:CPU:5"):
g.create_op("FloatOutput", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEqualsVersion("""
node { name: "FloatOutput" op: "FloatOutput"
device: "/device:GPU:7" }
node { name: "FloatOutput_1" op: "FloatOutput"
device: "/device:GPU:7" }
node { name: "FloatOutput_2" op: "FloatOutput"
device: "/device:CPU:*" }
node { name: "FloatOutput_3" op: "FloatOutput"
device: "/device:CPU:5" }
""", gd)
def testNestingErrorGraph(self):
g = ops.Graph()
scope = g.device("/device:GPU:8")
scope.__enter__()
with g.device("/device:GPU:9"):
with self.assertRaises(RuntimeError):
scope.__exit__(None, None, None)
def testNestingErrorEager(self):
with context.eager_mode():
scope = ops.device("/device:CPU:0")
scope.__enter__()
with ops.device(None):
with self.assertRaises(RuntimeError):
scope.__exit__(None, None, None)
def testNoneClearsDefault(self):
g = ops.Graph()
with g.device("/job:worker/replica:2/device:CPU:1"):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device(None):
g.create_op("FloatOutput", [], [dtypes.float32])
g.create_op("FloatOutput", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEqualsVersion("""
node { name: "FloatOutput" op: "FloatOutput"
device: "/job:worker/replica:2/device:CPU:1" }
node { name: "FloatOutput_1" op: "FloatOutput" }
node { name: "FloatOutput_2" op: "FloatOutput"
device: "/job:worker/replica:2/device:CPU:1" }
""", gd)
def testNoneIgnoresOuterDeviceFunction(self):
g = ops.Graph()
with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device(None):
g.create_op("FloatOutput", [], [dtypes.float32])
g.create_op("FloatOutput", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEqualsVersion("""
node { name: "FloatOutput" op: "FloatOutput"
device: "/job:worker/replica:2/device:CPU:1" }
node { name: "FloatOutput_1" op: "FloatOutput" }
node { name: "FloatOutput_2" op: "FloatOutput"
device: "/job:worker/replica:2/device:CPU:1" }
""", gd)
def _overwritingDeviceFunction(self, unused_op):
# This device function unconditionally overwrites the device of ops.
#
# NOTE(mrry): Writing device functions like this is not
# recommended. Instead, in most cases you should use
# `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the
# argument to `tf.device()` and the device component will be merged in.
return "/job:overwrite"
def testOverwritingBehavior(self):
g = ops.Graph()
with g.device(self._overwritingDeviceFunction):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device("/job:ps"): # Will be overwritten.
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device(pydev.merge_device("/job:ps")): # Will be overwritten.
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device(None): # Disables overwriting device function
with g.device("/job:ps"):
g.create_op("FloatOutput", [], [dtypes.float32])
with g.device(None): # Disables overwriting device function
with g.device(pydev.merge_device("/job:ps")):
g.create_op("FloatOutput", [], [dtypes.float32])
gd = g.as_graph_def()
self.assertProtoEqualsVersion("""
node { name: "FloatOutput" op: "FloatOutput"
device: "/job:overwrite" }
node { name: "FloatOutput_1" op: "FloatOutput"
device: "/job:overwrite" }
node { name: "FloatOutput_2" op: "FloatOutput"
device: "/job:overwrite" }
node { name: "FloatOutput_3" op: "FloatOutput"
device: "/job:ps" }
node { name: "FloatOutput_4" op: "FloatOutput"
device: "/job:ps" }
""", gd)
class MultithreadedGraphStateTest(test_util.TensorFlowTestCase):
class TestThread(threading.Thread):
def __init__(self, graph, replica_id):
super(MultithreadedGraphStateTest.TestThread, self).__init__()
self._graph = graph
self._replica_id = replica_id
# This thread sets this event when it mutated the graph. The caller can
# wait for that.
self.has_mutated_graph = threading.Event()
# This thread waits for when it should continue. The caller can set this
# event.
self.should_continue = threading.Event()
def run(self):
# Mutate a graph's stack, then set `has_mutated_graph`, then wait for
# `should_continue`, then add an op to the graph affected by the graph's
# stack.
raise NotImplementedError("must be implemented in descendants")
def testDeviceFunctionStack(self):
class DeviceSettingThread(self.TestThread):
def run(self):
with g.device("/job:worker/replica:{}".format(self._replica_id)):
self.has_mutated_graph.set()
self.should_continue.wait()
self.should_continue.clear()
g.create_op(
"FloatOutput", [], [dtypes.float32],
name="FloatOutput_{}".format(self._replica_id))
g = ops.Graph()
# If `switch_to_thread` isn't called, then device placement of the ops
# below is not deterministic.
g.switch_to_thread_local()
threads = [DeviceSettingThread(g, i) for i in range(3)]
for t in threads:
t.start()
t.has_mutated_graph.wait()
t.has_mutated_graph.clear()
for t in threads:
t.should_continue.set()
t.join()
gd = g.as_graph_def()
self.assertProtoEqualsVersion("""
node { name: "FloatOutput_0" op: "FloatOutput"
device: "/job:worker/replica:0" }
node { name: "FloatOutput_1" op: "FloatOutput"
device: "/job:worker/replica:1" }
node { name: "FloatOutput_2" op: "FloatOutput"
device: "/job:worker/replica:2" }
""", gd)
def testColocateWith(self):
class ColocatingThread(self.TestThread):
def __init__(self, graph, replica_id, op_to_colocate_with):
super(ColocatingThread, self).__init__(graph, replica_id)
self._op_to_colocate_with = op_to_colocate_with
def run(self):
with g.colocate_with(self._op_to_colocate_with):
self.has_mutated_graph.set()
self.should_continue.wait()
self.should_continue.clear()
g.create_op(
"FloatOutput", [], [dtypes.float32],
name="FloatOutput_{}".format(self._replica_id))
g = ops.Graph()
ops_to_colocate_with = []
for i in range(3):
with g.device("/job:worker/replica:{}".format(i)):
ops_to_colocate_with.append(
g.create_op(
"FloatOutput", [], [dtypes.float32],
name="ColocateWithMe_{}".format(i)))
# If `switch_to_thread` isn't called, then `device` and `attr` values for
# the ops below are not deterministic.
g.switch_to_thread_local()
threads = [
ColocatingThread(g, i, ops_to_colocate_with[i]) for i in range(3)
]
for t in threads:
t.start()
t.has_mutated_graph.wait()
t.has_mutated_graph.clear()
for t in threads:
t.should_continue.set()
t.join()
gd = g.as_graph_def()
self.assertProtoEqualsVersion("""
node { name: "ColocateWithMe_0" op: "FloatOutput"
device: "/job:worker/replica:0" }
node { name: "ColocateWithMe_1" op: "FloatOutput"
device: "/job:worker/replica:1" }
node { name: "ColocateWithMe_2" op: "FloatOutput"
device: "/job:worker/replica:2" }
node { name: "FloatOutput_0" op: "FloatOutput"
device: "/job:worker/replica:0"
attr { key: "_class"
value { list {
s: "loc:@ColocateWithMe_0"}}}}
node { name: "FloatOutput_1" op: "FloatOutput"
device: "/job:worker/replica:1"
attr { key: "_class"
value { list {
s: "loc:@ColocateWithMe_1"}}}}
node { name: "FloatOutput_2" op: "FloatOutput"
device: "/job:worker/replica:2"
attr { key: "_class"
value { list {
s: "loc:@ColocateWithMe_2"}}}}
""", gd)
def testControlDependencies(self):
class DependingThread(self.TestThread):
def __init__(self, graph, replica_id, dependency_op):
super(DependingThread, self).__init__(graph, replica_id)
self._dependency_op = dependency_op
def run(self):
with g.control_dependencies([self._dependency_op]):
self.has_mutated_graph.set()
self.should_continue.wait()
self.should_continue.clear()
g.create_op(
"FloatOutput", [], [dtypes.float32],
name="FloatOutput_{}".format(self._replica_id))
g = ops.Graph()
dependency_ops = []
for i in range(3):
dependency_ops.append(
g.create_op(
"FloatOutput", [], [dtypes.float32],
name="ColocateWithMe_{}".format(i)))
# If `switch_to_thread` isn't called, then `input` values for the ops below
# are not deterministic.
g.switch_to_thread_local()
threads = [DependingThread(g, i, dependency_ops[i]) for i in range(3)]
for t in threads:
t.start()
t.has_mutated_graph.wait()
t.has_mutated_graph.clear()
for t in threads:
t.should_continue.set()
t.join()
gd = g.as_graph_def()
self.assertProtoEqualsVersion("""
node { name: "ColocateWithMe_0" op: "FloatOutput" }
node { name: "ColocateWithMe_1" op: "FloatOutput" }
node { name: "ColocateWithMe_2" op: "FloatOutput" }
node { name: "FloatOutput_0" op: "FloatOutput"
input: "^ColocateWithMe_0" }
node { name: "FloatOutput_1" op: "FloatOutput"
input: "^ColocateWithMe_1" }
node { name: "FloatOutput_2" op: "FloatOutput"
input: "^ColocateWithMe_2" }
""", gd)
def testNameStack(self):
class NameSettingThread(self.TestThread):
def run(self):
with g.name_scope("foo"):
op1 = g.create_op("FloatOutput", [], [dtypes.float32])
self.has_mutated_graph.set()
self.should_continue.wait()
self.should_continue.clear()
op2 = g.create_op("FloatOutput", [], [dtypes.float32])
self.result = (op1, op2)
g = ops.Graph()
threads = [NameSettingThread(g, i) for i in range(3)]
for t in threads:
t.start()
t.has_mutated_graph.wait()
t.has_mutated_graph.clear()
for t in threads:
t.should_continue.set()
t.join()
suffixes = ["", "_1", "_2"]
for t, s in zip(threads, suffixes):
self.assertEqual("foo" + s + "/FloatOutput", t.result[0].name)
self.assertEqual("foo" + s + "/FloatOutput_1", t.result[1].name)
class ObjectWithName(object):
def __init__(self, name):
self._name = name
@property
def name(self):
return self._name
class CollectionTest(test_util.TensorFlowTestCase):
def test_get_collections(self):
g = ops.Graph()
self.assertSequenceEqual(g.collections, [])
g.add_to_collection("key", 12)
g.add_to_collection("key", 15)
self.assertSequenceEqual(g.collections, ["key"])
g.add_to_collection("other", "foo")
self.assertSequenceEqual(sorted(g.collections), ["key", "other"])
self.assertSequenceEqual(
sorted(g.get_all_collection_keys()), ["key", "other"])
def test_add_to_collection(self):
g = ops.Graph()
g.add_to_collection("key", 12)
g.add_to_collection("other", "foo")
g.add_to_collection("key", 34)
# Note that only blank1 is returned.
g.add_to_collection("blah", 27)
blank1 = ObjectWithName("prefix/foo")
g.add_to_collection("blah", blank1)
blank2 = ObjectWithName("junk/foo")
g.add_to_collection("blah", blank2)
self.assertEqual([12, 34], g.get_collection("key"))
self.assertEqual([], g.get_collection("nothing"))
self.assertEqual([27, blank1, blank2], g.get_collection("blah"))
self.assertEqual([blank1], g.get_collection("blah", "prefix"))
self.assertEqual([blank1], g.get_collection("blah", ".*x"))
# Make sure that get_collection() returns a first-level
# copy of the collection, while get_collection_ref() returns
# the original list.
other_collection_snapshot = g.get_collection("other")
other_collection_ref = g.get_collection_ref("other")
self.assertEqual(["foo"], other_collection_snapshot)
self.assertEqual(["foo"], other_collection_ref)
g.add_to_collection("other", "bar")
self.assertEqual(["foo"], other_collection_snapshot)
self.assertEqual(["foo", "bar"], other_collection_ref)
self.assertEqual(["foo", "bar"], g.get_collection("other"))
self.assertTrue(other_collection_ref is g.get_collection_ref("other"))
# Verify that getting an empty collection ref returns a modifiable list.
empty_coll_ref = g.get_collection_ref("empty")
self.assertEqual([], empty_coll_ref)
empty_coll = g.get_collection("empty")
self.assertEqual([], empty_coll)
self.assertFalse(empty_coll is empty_coll_ref)
empty_coll_ref2 = g.get_collection_ref("empty")
self.assertTrue(empty_coll_ref2 is empty_coll_ref)
# Add to the collection.
empty_coll_ref.append("something")
self.assertEqual(["something"], empty_coll_ref)
self.assertEqual(["something"], empty_coll_ref2)
self.assertEqual([], empty_coll)
self.assertEqual(["something"], g.get_collection("empty"))
empty_coll_ref3 = g.get_collection_ref("empty")
self.assertTrue(empty_coll_ref3 is empty_coll_ref)
def test_add_to_collections_uniquify(self):
g = ops.Graph()
g.add_to_collections([1, 2, 1], "key")
# Make sure "key" is not added twice
self.assertEqual(["key"], g.get_collection(1))
def test_add_to_collections_from_list(self):
g = ops.Graph()
g.add_to_collections(["abc", "123"], "key")
self.assertEqual(["key"], g.get_collection("abc"))
self.assertEqual(["key"], g.get_collection("123"))
def test_add_to_collections_from_tuple(self):
g = ops.Graph()
g.add_to_collections(("abc", "123"), "key")
self.assertEqual(["key"], g.get_collection("abc"))
self.assertEqual(["key"], g.get_collection("123"))
def test_add_to_collections_from_generator(self):
g = ops.Graph()
def generator():
yield "abc"
yield "123"
g.add_to_collections(generator(), "key")
self.assertEqual(["key"], g.get_collection("abc"))
self.assertEqual(["key"], g.get_collection("123"))
def test_add_to_collections_from_set(self):
g = ops.Graph()
g.add_to_collections(set(["abc", "123"]), "key")
self.assertEqual(["key"], g.get_collection("abc"))
self.assertEqual(["key"], g.get_collection("123"))
def test_add_to_collections_from_string(self):
g = ops.Graph()
g.add_to_collections("abc", "key")
self.assertEqual(["key"], g.get_collection("abc"))
def test_default_graph(self):
with ops.Graph().as_default():
ops.add_to_collection("key", 90)
ops.add_to_collection("key", 100)
# Collections are ordered.
self.assertEqual([90, 100], ops.get_collection("key"))
def test_defun(self):
with context.eager_mode():
@eager_function.defun
def defun():
ops.add_to_collection("int", 1)
ops.add_to_collection("tensor", constant_op.constant(2))
@eager_function.defun
def inner_defun():
self.assertEqual(ops.get_collection("int"), [1])
three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0]
ops.add_to_collection("int", 2)
self.assertEqual(ops.get_collection("int"), [1, 2])
ops.add_to_collection("foo", "bar")
self.assertEqual(ops.get_collection("foo"), ["bar"])
return three
self.assertEqual(ops.get_collection("int"), [1])
three = inner_defun()
self.assertEqual(ops.get_collection("int"), [1])
self.assertEqual(ops.get_collection("foo"), [])
return three
three = defun()
self.assertEqual(three.numpy(), 3)
ops.NotDifferentiable("FloatOutput")
@ops.RegisterGradient("CopyOp")
def _CopyGrad(op, x_grad): # pylint: disable=invalid-name
_ = op
return x_grad
@ops.RegisterGradient("copy_override")
def _CopyOverrideGrad(op, x_grad): # pylint: disable=invalid-name
_ = op
return x_grad
class RegistrationTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def testRegisterGradients(self):
x = test_ops.float_output()
y = test_ops.copy_op(x)
fn = ops.get_gradient_function(y.op)
self.assertEqual(_CopyGrad, fn)
def testOverrideGradients(self):
g = ops.Graph()
with g.as_default():
x = test_ops.float_output()
with g.gradient_override_map({"CopyOp": "copy_override"}):
y = test_ops.copy_op(x)
fn = ops.get_gradient_function(y.op)
self.assertEqual(_CopyOverrideGrad, fn)
def testNonExistentOverride(self):
g = ops.Graph()
with g.as_default():
x = test_ops.float_output()
with g.gradient_override_map({"CopyOp": "unknown_override"}):
y = test_ops.copy_op(x)
with self.assertRaisesRegex(LookupError, "unknown_override"):
ops.get_gradient_function(y.op)
class ComparisonTest(test_util.TensorFlowTestCase):
def testMembershipAllowed(self):
g = ops.Graph()
t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2")
self.assertTrue(isinstance(t1, ops.Tensor))
self.assertTrue(isinstance(t2, ops.Tensor))
self.assertTrue(t1 in [t1])
self.assertTrue(t1 not in [t2])
class ControlDependenciesTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def testBasic(self):
g = ops.Graph()
with g.as_default():
# Creating unregistered ops with _apply_op() doesn't work with the C API
# TODO(skyewm): address this more consistently. Possible solutions are
# to use registered ops in all tests, create a way to register ops in
# Python tests, or conditionally disable the op registration check in
# the C API.
a = constant_op.constant(1.0)
b = constant_op.constant(1.0)
with g.control_dependencies([a]):
c = constant_op.constant(1.0)
d = array_ops.identity(b)
e = array_ops.identity(c)
self.assertEqual(c.op.control_inputs, [a.op])
self.assertEqual(d.op.control_inputs, [a.op])
# e should be dominated by c.
self.assertEqual(e.op.control_inputs, [])
@test_util.run_in_graph_and_eager_modes
def testEager(self):
def future():
future.calls += 1
return constant_op.constant(2.0)
future.calls = 0
if context.executing_eagerly():
a = constant_op.constant(1.0)
b = future
with ops.control_dependencies([a, b]):
c = constant_op.constant(3.0)
self.assertEqual(future.calls, 1)
else:
g = ops.Graph()
with g.as_default():
a = constant_op.constant(1.0)
b = future()
with g.control_dependencies([a, b]):
c = constant_op.constant(3.0)
self.assertEqual(c.op.control_inputs, [a.op, b.op])
self.assertEqual(future.calls, 1)
def testBasicWithConversion(self):
g = ops.Graph()
a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
class ConvertibleObj(object):
def _as_graph_element(self):
return a
with g.control_dependencies([ConvertibleObj()]):
c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
self.assertEqual(c.op.control_inputs, [a.op])
def testNested(self):
g = ops.Graph()
a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.control_dependencies([a_1, a_2, a_3, a_4]):
b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.control_dependencies([a_1]):
with g.control_dependencies([a_2]):
with g.control_dependencies([a_3]):
with g.control_dependencies([a_4]):
b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op],
b_1.op.control_inputs)
self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs)
def testClear(self):
g = ops.Graph()
a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.control_dependencies([a_1]):
with g.control_dependencies([a_2]):
with g.control_dependencies(None):
with g.control_dependencies([a_3]):
with g.control_dependencies([a_4]):
# deps [a_3, a_4]
b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
# deps = [a_3]
b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
# deps back to None
b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
# deps back to [a_1, a_2]
b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
# deps back to [a_1]
b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.control_dependencies(None):
# deps are None again
b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
self.assertItemsEqual([], b_none.op.control_inputs)
self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
self.assertItemsEqual([], b_none2.op.control_inputs)
def testComplex(self):
g = ops.Graph()
# Usage pattern:
# * Nodes a_i are constants defined at the outermost scope, and are used
# as control inputs for the ith nested scope.
# * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
# * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
# * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
# * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.control_dependencies([a_1]):
b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
[dtypes.float32])
c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
[dtypes.float32])
d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1],
[dtypes.float32])
e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.control_dependencies([a_2]):
b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
[dtypes.float32])
c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
[dtypes.float32])
d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2],
[dtypes.float32])
e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1],
[dtypes.float32])
with g.control_dependencies([a_3]):
b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
[dtypes.float32])
c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
[dtypes.float32])
d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3],
[dtypes.float32])
e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2],
[dtypes.float32])
with g.control_dependencies([a_4]):
b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
[dtypes.float32])
c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
[dtypes.float32])
d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4],
[dtypes.float32])
e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3],
[dtypes.float32])
self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs)
self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs)
self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs)
self.assertItemsEqual([], c_1.op.control_inputs)
self.assertItemsEqual([a_2.op], c_2.op.control_inputs)
self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs)
self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs)
self.assertItemsEqual([], d_1.op.control_inputs)
self.assertItemsEqual([], d_2.op.control_inputs)
self.assertItemsEqual([], d_3.op.control_inputs)
self.assertItemsEqual([], d_4.op.control_inputs)
self.assertItemsEqual([a_1.op], e_1.op.control_inputs)
self.assertItemsEqual([a_2.op], e_2.op.control_inputs)
self.assertItemsEqual([a_3.op], e_3.op.control_inputs)
self.assertItemsEqual([a_4.op], e_4.op.control_inputs)
def testRepeatedDependency(self):
g = ops.Graph()
a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
a_0, a_1 = a.outputs
with g.control_dependencies([a_0]):
b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.control_dependencies([a_1]):
c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
self.assertEqual(b.op.control_inputs, [a])
self.assertEqual(c.op.control_inputs, [a])
def testNoControlDependencyWithDataDependency(self):
g = ops.Graph()
a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.control_dependencies([a]):
b = _apply_op(g, "Identity", [a], [dtypes.float32])
self.assertEqual(b.op.control_inputs, [])
class OpScopeTest(test_util.TensorFlowTestCase):
@test_util.run_in_graph_and_eager_modes
def testNames(self):
with ops.name_scope("foo", skip_on_eager=False) as foo:
self.assertEqual("foo/", foo)
with ops.name_scope("foo2", skip_on_eager=False) as foo2:
self.assertEqual("foo/foo2/", foo2)
with ops.name_scope(None, skip_on_eager=False) as empty1:
self.assertEqual("", empty1)
with ops.name_scope("foo3", skip_on_eager=False) as foo3:
self.assertEqual("foo3/", foo3)
with ops.name_scope("", skip_on_eager=False) as empty2:
self.assertEqual("", empty2)
with ops.name_scope("foo/", skip_on_eager=False) as outer_foo:
self.assertEqual("foo/", outer_foo)
with ops.name_scope("", skip_on_eager=False) as empty3:
self.assertEqual("", empty3)
with ops.name_scope("foo4", skip_on_eager=False) as foo4:
self.assertEqual("foo/foo4/", foo4)
with ops.name_scope("foo5//", skip_on_eager=False) as foo5:
self.assertEqual("foo5//", foo5)
with ops.name_scope("foo6", skip_on_eager=False) as foo6:
self.assertEqual("foo5//foo6/", foo6)
with ops.name_scope("/", skip_on_eager=False) as foo7:
self.assertEqual("/", foo7)
with ops.name_scope("//", skip_on_eager=False) as foo8:
self.assertEqual("//", foo8)
with ops.name_scope("a//b/c", skip_on_eager=False) as foo9:
self.assertEqual("foo/a//b/c/", foo9)
with ops.name_scope("a//b/c", skip_on_eager=False) as foo10:
self.assertEqual("a//b/c/", foo10)
@test_util.run_in_graph_and_eager_modes
def testEagerDefaultScopeName(self):
with ops.name_scope(None, "default", skip_on_eager=False) as scope:
self.assertEqual(scope, "default/")
with ops.name_scope(None, "default2", skip_on_eager=False) as scope2:
self.assertEqual(scope2, "default/default2/")
@test_util.run_in_graph_and_eager_modes
def testNameScopeV2IsReEntrant(self):
foo = ops.name_scope_v2("foo")
bar = ops.name_scope_v2("bar")
with foo as scope_name:
self.assertEqual("foo/", scope_name)
with foo as scope_name:
self.assertEqual("foo/foo/", scope_name)
with bar as scope_name:
self.assertEqual("foo/bar/", scope_name)
with foo as scope_name:
self.assertEqual("foo/bar/foo/", scope_name)
with bar as scope_name:
self.assertEqual("bar/", scope_name)
@test_util.run_deprecated_v1
def testNoScopeName(self):
g0 = ops.Graph()
values = [
g0.create_op("A", [], [dtypes.float32]),
g0.create_op("B", [], [dtypes.float32])
]
with self.assertRaises(ValueError):
with ops.name_scope(None, values=values):
pass
with self.assertRaises(ValueError):
with ops.name_scope(None, None, values):
pass
@test_util.run_deprecated_v1
def testEmptyScopeName(self):
g0 = ops.Graph()
a = g0.create_op("A", [], [dtypes.float32])
b = g0.create_op("B", [], [dtypes.float32])
with ops.name_scope("", values=[a, b]) as scope:
self.assertEqual("", scope)
self.assertEqual(g0, ops.get_default_graph())
with ops.name_scope("", "my_default_scope", [a, b]) as scope:
self.assertEqual("", scope)
self.assertEqual(g0, ops.get_default_graph())
@test_util.run_deprecated_v1
def testDefaultScopeName(self):
g0 = ops.Graph()
a = g0.create_op("A", [], [dtypes.float32])
b = g0.create_op("B", [], [dtypes.float32])
scope_name = "my_scope"
default_scope_name = "my_default_scope"
with ops.name_scope(scope_name, default_scope_name, [a, b]) as scope:
self.assertEqual("%s/" % scope_name, scope)
self.assertEqual(g0, ops.get_default_graph())
with ops.name_scope(None, default_scope_name, [a, b]) as scope:
self.assertEqual("%s/" % default_scope_name, scope)
self.assertEqual(g0, ops.get_default_graph())
with self.assertRaises(TypeError):
with ops.name_scope(scope_name, [a, b]):
pass
def _testGraphElements(self, graph_elements):
scope_name = "my_scope"
with ops.name_scope(scope_name, values=graph_elements) as scope:
self.assertEqual("%s/" % scope_name, scope)
self.assertEqual(graph_elements[0].graph, ops.get_default_graph())
g1 = ops.Graph()
a = g1.create_op("A", [], [dtypes.float32])
with self.assertRaises(ValueError):
with ops.name_scope(scope_name, values=graph_elements + [a]):
pass
@test_util.run_deprecated_v1
def testTensor(self):
g0 = ops.Graph()
a = g0.create_op("A", [], [dtypes.float32])
b = g0.create_op("B", [], [dtypes.float32])
self._testGraphElements([a, b])
@test_util.run_deprecated_v1
def testSparseTensor(self):
g0 = ops.Graph()
a = g0.create_op("A", [], [dtypes.float32])
b = g0.create_op("B", [], [dtypes.float32])
sparse = sparse_tensor.SparseTensor(
_apply_op(g0, "Int64Output", [], [dtypes.int64]),
_apply_op(g0, "FloatOutput", [], [dtypes.float32]),
_apply_op(g0, "Int64Output", [], [dtypes.int64]))
self._testGraphElements([a, sparse, b])
@test_util.run_deprecated_v1
def testVariable(self):
g0 = ops.Graph()
with g0.as_default():
variable = variables.Variable([1.0])
a = g0.create_op("A", [], [dtypes.float32])
b = g0.create_op("B", [], [dtypes.float32])
self._testGraphElements([a, variable, b])
class InitScopeTest(test_util.TensorFlowTestCase):
def testClearsControlDependencies(self):
g = ops.Graph()
a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.as_default():
with g.control_dependencies([a_1]):
with g.control_dependencies([a_2]):
with ops.init_scope():
with g.control_dependencies([a_3]):
with g.control_dependencies([a_4]):
# deps [a_3, a_4]
b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
# deps = [a_3]
b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
# deps back to None
b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
# deps back to [a_1, a_2]
b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
# deps back to [a_1]
b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with ops.init_scope():
# deps are None again
b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
self.assertItemsEqual([], b_none.op.control_inputs)
self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
self.assertItemsEqual([], b_none2.op.control_inputs)
def testLiftsOpsFromFunctions(self):
g0 = ops.Graph()
g1 = ops.Graph()
g1._building_function = True # pylint: disable=protected-access
g2 = ops.Graph()
g2._building_function = True # pylint: disable=protected-access
with g0.as_default():
with g1.as_default():
with g2.as_default():
with ops.init_scope():
_ = constant_op.constant(1.0)
self.assertEqual(len(g2.get_operations()), 0)
self.assertEqual(len(g1.get_operations()), 0)
self.assertEqual(len(g0.get_operations()), 1)
def testPreservesDevices(self):
g0 = ops.Graph()
with g0.as_default(), ops.device("CPU:0"):
g1 = ops.Graph()
g1._building_function = True # pylint: disable=protected-access
with g1.as_default():
with ops.device("GPU:0"):
with ops.init_scope():
# init_scope should preserve device set under `g1`.
on_gpu = constant_op.constant(1.0)
self.assertEqual(on_gpu.device, "/device:GPU:0")
still_on_gpu = constant_op.constant(1.0)
self.assertEqual(still_on_gpu.device, "/device:GPU:0")
blank = constant_op.constant(1.0)
self.assertEqual(blank.device, "")
with ops.init_scope():
now_on_cpu = constant_op.constant(1.0)
self.assertEqual(now_on_cpu.device, "/device:CPU:0")
on_cpu = constant_op.constant(1.0)
self.assertEqual(on_cpu.device, "/device:CPU:0")
def testComposes(self):
g0 = ops.Graph()
g1 = ops.Graph()
g1._building_function = True # pylint: disable=protected-access
g2 = ops.Graph()
g2._building_function = True # pylint: disable=protected-access
g3 = ops.Graph()
g3._building_function = False # pylint: disable=protected-access
with g0.as_default():
with g1.as_default():
with ops.init_scope():
# This op should be lifted into g0.
_ = constant_op.constant(1.0)
self.assertIs(g0, ops.get_default_graph())
self.assertEqual(len(g2.get_operations()), 0)
self.assertEqual(len(g1.get_operations()), 0)
self.assertEqual(len(g0.get_operations()), 1)
with g2.as_default():
with ops.init_scope():
# This op should be lifted into g0.
_ = constant_op.constant(1.0)
self.assertIs(g0, ops.get_default_graph())
with g3.as_default():
with ops.init_scope():
# This op should be lifted into g3, because g3 is not building a
# function.
_ = constant_op.constant(1.0)
self.assertIs(g3, ops.get_default_graph())
self.assertEqual(len(g3.get_operations()), 1)
self.assertEqual(len(g2.get_operations()), 0)
self.assertEqual(len(g1.get_operations()), 0)
self.assertEqual(len(g0.get_operations()), 2)
def testEscapesToEagerContext(self):
g = ops.Graph()
g._building_function = True # pylint: disable=protected-access
with context.eager_mode():
with context.graph_mode():
with g.as_default():
with ops.init_scope():
# Because g is building a function, init_scope should
# escape out to the eager context.
self.assertTrue(context.executing_eagerly())
# g should be reinstated as the default graph, and the
# graph context should be re-entered.
self.assertIs(g, ops.get_default_graph())
self.assertFalse(context.executing_eagerly())
def testStaysInEagerWhenOnlyEagerContextActive(self):
with context.eager_mode():
with ops.init_scope():
self.assertTrue(context.eager_mode())
self.assertTrue(context.eager_mode())
def testEscapesDefunWhenInEagerMode(self):
def function_with_variables():
with ops.init_scope():
self.v = resource_variable_ops.ResourceVariable(3)
return self.v.assign_add(1)
with context.eager_mode():
# Each invocation of function_with_variables recreates a variable.
self.assertEqual(4, int(function_with_variables()))
self.assertEqual(4, int(function_with_variables()))
compiled = eager_function.defun(function_with_variables)
# The init_scope in function_with_variables lifts the variable out
# of the graph function constructed by defun; hence,
# compiled now appears to be stateful.
self.assertEqual(4, int(compiled()))
self.assertEqual(5, int(compiled()))
def testEscapesDefunWhenInGraphMode(self):
def function_with_variables(name):
with ops.init_scope():
_ = variable_scope.get_variable(name, shape=(1,))
g = ops.Graph()
with g.as_default():
with self.cached_session():
# First ensure that graphs that are not building functions are
# not escaped.
function_with_variables("foo")
with self.assertRaisesRegex(ValueError,
r"Variable foo already exists.*"):
# This will fail because reuse is not set to True.
function_with_variables("foo")
compiled = eager_function.defun(function_with_variables)
compiled("bar")
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
# The second call to `compiled` should not create variables: the
# init_scope has lifted the variable creation code out of the defun.
compiled("bar")
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
def testEscapesNestedDefun(self):
def inner_function():
with ops.init_scope():
self.v = resource_variable_ops.ResourceVariable(1)
return self.v.assign_add(2)
def outer_function(inner=None):
with ops.init_scope():
self.v0 = resource_variable_ops.ResourceVariable(0)
return self.v0.assign_add(1) + inner()
with context.eager_mode():
# Each invocation of outer_function recreates variables.
self.assertEqual(4, int(outer_function(inner=inner_function)))
self.assertEqual(4, int(outer_function(inner=inner_function)))
compiled_inner = eager_function.defun(inner_function)
compiled_outer = eager_function.defun(outer_function)
# The init_scope lifts variables out of the graph functions
# constructed by defun; hence, compiled_outer should now appear to be
# stateful.
self.assertEqual(4, int(compiled_outer(inner=compiled_inner)))
self.assertEqual(7, int(compiled_outer(inner=compiled_inner)))
@test_util.run_v1_only("b/120545219")
def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self):
with context.graph_mode():
ops.reset_default_graph()
# This doesn't push anything onto the graph stack, but it does
# set the stack's global graph.
global_graph = ops.get_default_graph()
fn_graph = ops.Graph()
# pylint: disable=protected-access
fn_graph._building_function = True
self.assertEqual(len(ops._default_graph_stack.stack), 0)
with fn_graph.as_default():
self.assertEqual(len(ops._default_graph_stack.stack), 1)
with ops.init_scope():
self.assertGreater(len(ops._default_graph_stack.stack), 1)
dummy = constant_op.constant(1.0)
self.assertEqual(len(ops._default_graph_stack.stack), 1)
# Note that the global graph is _not_ on the graph stack.
self.assertEqual(len(ops._default_graph_stack.stack), 0)
# Ensure that `dummy` was added to the global graph.
self.assertEqual(global_graph, dummy.graph)
# pylint: enable=protected-access
def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self):
with context.graph_mode():
# pylint: disable=protected-access
self.assertEqual(len(ops._default_graph_stack.stack), 0)
with ops.init_scope():
self.assertGreater(len(ops._default_graph_stack.stack), 0)
self.assertEqual(len(ops._default_graph_stack.stack), 0)
# pylint: enable=protected-access
def testPreservesNameScopeInGraphConstruction(self):
with ops.Graph().as_default():
function_graph = ops.Graph()
with function_graph.as_default():
with ops.name_scope("inner", skip_on_eager=False), ops.init_scope():
self.assertEqual(ops.get_name_scope(), "inner")
self.assertEqual(ops.get_name_scope(), "")
def testEnteringGraphFromEagerIsSticky(self):
with context.eager_mode():
g = ops.Graph()
with g.as_default():
with ops.init_scope():
self.assertFalse(context.executing_eagerly())
self.assertEqual(g, ops.get_default_graph())
def testMixGraphEager(self):
with context.eager_mode():
c = constant_op.constant(1.0)
with ops.Graph().as_default():
with self.assertRaisesRegex(RuntimeError,
"Attempting to capture an EagerTensor"):
math_ops.add(c, c)
c2 = constant_op.constant(2.0)
with self.assertRaisesRegex(TypeError, "Graph tensors"):
math_ops.add(c2, c2)
def testPreservesNameScopeInEagerExecution(self):
with context.eager_mode():
def foo():
with ops.name_scope("inner", skip_on_eager=False), ops.init_scope():
if context.executing_eagerly():
# A trailing slash is always appended when eager execution is
# enabled.
self.assertEqual(context.context().scope_name, "inner/")
else:
self.assertEqual(ops.get_name_scope(), "inner")
foo()
self.assertEqual(ops.get_name_scope(), "")
foo_compiled = eager_function.defun(foo)
foo_compiled()
self.assertEqual(ops.get_name_scope(), "")
def testExecutingEagerlyOutsideFunctions(self):
@def_function.function
def f():
return ops.executing_eagerly_outside_functions()
with context.graph_mode():
self.assertFalse(ops.executing_eagerly_outside_functions())
with session.Session():
# Need self.evaluate for these as the return type of functions is
# tensors.
self.assertFalse(self.evaluate(f()))
with context.eager_mode():
self.assertTrue(ops.executing_eagerly_outside_functions())
self.assertTrue(f())
with ops.Graph().as_default():
self.assertFalse(ops.executing_eagerly_outside_functions())
with session.Session():
self.assertFalse(self.evaluate(f()))
class GraphTest(test_util.TensorFlowTestCase):
def setUp(self):
ops.reset_default_graph()
def _AssertDefault(self, expected):
self.assertIs(expected, ops.get_default_graph())
def testResetDefaultGraphNesting(self):
g0 = ops.Graph()
with self.assertRaises(AssertionError):
with g0.as_default():
ops.reset_default_graph()
def testGraphContextManagerCancelsEager(self):
with context.eager_mode():
with ops.Graph().as_default():
self.assertFalse(context.executing_eagerly())
def testGraphContextManager(self):
g0 = ops.Graph()
with g0.as_default() as g1:
self.assertIs(g0, g1)
def testDefaultGraph(self):
orig = ops.get_default_graph()
self.assertFalse(ops.has_default_graph())
self._AssertDefault(orig)
g0 = ops.Graph()
self.assertFalse(ops.has_default_graph())
self._AssertDefault(orig)
context_manager_0 = g0.as_default()
self.assertFalse(ops.has_default_graph())
self._AssertDefault(orig)
with context_manager_0 as g0:
self._AssertDefault(g0)
with ops.Graph().as_default() as g1:
self.assertTrue(ops.has_default_graph())
self._AssertDefault(g1)
self._AssertDefault(g0)
self._AssertDefault(orig)
self.assertFalse(ops.has_default_graph())
def testPreventFeeding(self):
g = ops.Graph()
a = constant_op.constant(2.0)
self.assertTrue(g.is_feedable(a))
g.prevent_feeding(a)
self.assertFalse(g.is_feedable(a))
@test_util.run_deprecated_v1
def testPreventFetching(self):
g = ops.Graph()
a = constant_op.constant(2.0)
self.assertTrue(g.is_fetchable(a))
g.prevent_fetching(a.op)
self.assertFalse(g.is_fetchable(a))
def testAsGraphElementConversions(self):
class ConvertibleObj(object):
def _as_graph_element(self):
return "FloatOutput:0"
class NonConvertibleObj(object):
pass
g = ops.Graph()
a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
with self.assertRaises(TypeError):
g.as_graph_element(NonConvertibleObj())
# Regression test against creating custom __del__ functions in classes
# involved in cyclic references, e.g. Graph and Operation. (Python won't gc
# cycles that require calling a __del__ method, because the __del__ method can
# theoretically increase the object's refcount to "save" it from gc, and any
# already-deleted objects in the cycle would have be to restored.)
def testGarbageCollected(self):
# Create a graph we can delete and a weak reference to monitor if it's gc'd
g = ops.Graph()
g_ref = weakref.ref(g)
# Create some ops
with g.as_default():
a = constant_op.constant(2.0)
b = constant_op.constant(3.0)
c = math_ops.add(a, b)
# Create a session we can delete
with session.Session(graph=g) as sess:
self.evaluate(c)
# Delete all references and trigger gc
del g
del a
del b
del c
del sess
gc.collect()
self.assertIsNone(g_ref())
def testRunnableAfterInvalidShape(self):
with ops.Graph().as_default():
with self.assertRaises(ValueError):
math_ops.add([1, 2], [1, 2, 3])
a = constant_op.constant(1)
with session.Session() as sess:
self.evaluate(a)
def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
g = ops.Graph()
with g.as_default():
with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
with self.assertRaises(ValueError):
test_ops.kernel_label_required(1)
a = constant_op.constant(1)
with session.Session() as sess:
self.evaluate(a)
class AttrScopeTest(test_util.TensorFlowTestCase):
def _get_test_attrs(self):
x = control_flow_ops.no_op()
try:
a = compat.as_text(x.get_attr("_A"))
except ValueError:
a = None
try:
b = compat.as_text(x.get_attr("_B"))
except ValueError:
b = None
return (a, b)
@test_util.run_deprecated_v1
def testNoLabel(self):
with self.cached_session():
self.assertAllEqual((None, None), self._get_test_attrs())
@test_util.run_deprecated_v1
def testLabelMap(self):
with self.cached_session() as sess:
a1 = self._get_test_attrs()
with sess.graph._attr_scope({
"_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo"))
}):
a2 = self._get_test_attrs()
with sess.graph._attr_scope({
"_A": None,
"_B": attr_value_pb2.AttrValue(s=compat.as_bytes("bar"))
}):
a3 = self._get_test_attrs()
with sess.graph._attr_scope({
"_A": attr_value_pb2.AttrValue(s=compat.as_bytes("baz"))
}):
a4 = self._get_test_attrs()
a5 = self._get_test_attrs()
a6 = self._get_test_attrs()
a7 = self._get_test_attrs()
self.assertAllEqual((None, None), a1)
self.assertAllEqual(("foo", None), a2)
self.assertAllEqual((None, "bar"), a3)
self.assertAllEqual(("baz", "bar"), a4)
self.assertAllEqual((None, "bar"), a5)
self.assertAllEqual(("foo", None), a6)
self.assertAllEqual((None, None), a7)
class KernelLabelTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def testNoLabel(self):
with self.cached_session():
self.assertAllEqual(b"My label is: default",
test_ops.kernel_label().eval())
@test_util.run_deprecated_v1
def testLabelMap(self):
with self.cached_session() as sess:
default_1 = test_ops.kernel_label()
# pylint: disable=protected-access
with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
overload_1_1 = test_ops.kernel_label()
with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}):
overload_2 = test_ops.kernel_label()
with sess.graph._kernel_label_map({"KernelLabel": ""}):
default_2 = test_ops.kernel_label()
overload_1_2 = test_ops.kernel_label()
# pylint: enable=protected-access
default_3 = test_ops.kernel_label()
self.assertAllEqual(b"My label is: default", self.evaluate(default_1))
self.assertAllEqual(b"My label is: default", self.evaluate(default_2))
self.assertAllEqual(b"My label is: default", self.evaluate(default_3))
self.assertAllEqual(b"My label is: overload_1",
self.evaluate(overload_1_1))
self.assertAllEqual(b"My label is: overload_1",
self.evaluate(overload_1_2))
self.assertAllEqual(b"My label is: overload_2", self.evaluate(overload_2))
class AsGraphDefTest(test_util.TensorFlowTestCase):
def testGraphDefVersion(self):
"""Test that the graphdef version is plumbed through to kernels."""
with ops.Graph().as_default() as g:
version = g.graph_def_versions.producer
with self.session(graph=g):
v = test_ops.graph_def_version().eval()
self.assertEqual(version, v)
def testAddShapes(self):
with ops.Graph().as_default() as g:
t1, t2, t3, t4, t5 = _apply_op(g, "FiveFloatOutputs", [],
[dtypes.float32] * 5)
t1.set_shape(None)
t2.set_shape([])
t3.set_shape([None])
t4.set_shape([43, 37])
t5.set_shape([43, None])
b = constant_op.constant(1.0) # pylint: disable=unused-variable
gd = g.as_graph_def(add_shapes=True)
self.assertProtoEqualsVersion("""
node { name: "FiveFloatOutputs" op: "FiveFloatOutputs"
attr {
key: "_output_shapes"
value {
list {
shape { unknown_rank: true }
shape { }
shape { dim { size: -1 } }
shape { dim { size: 43 } dim { size: 37 } }
shape { dim { size: 43 } dim { size: -1 } }
}
}
}
}
node { name: "Const" op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape { }
}
}
}
attr {
key: "dtype"
value { type: DT_FLOAT }
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape { }
float_val: 1.0 } } } }
""", gd)
@ops.RegisterStatistics("a", "flops")
def _calc_a_forward_flops(unused_graph, unused_node):
return ops.OpStats("flops", 20)
class StatisticsTest(test_util.TensorFlowTestCase):
def testRegisteredNode(self):
graph = ops.Graph()
node = ops._NodeDef("a", "an_a")
flops = ops.get_stats_for_node_def(graph, node, "flops")
self.assertEqual(20, flops.value)
missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat")
self.assertEqual(None, missing_stat.value)
def testUnregisteredNode(self):
graph = ops.Graph()
node = ops._NodeDef("b", "a_b")
weight_params = ops.get_stats_for_node_def(graph, node, "weight_params")
self.assertEqual(None, weight_params.value)
def testAccumulateStatistics(self):
flops_total = ops.OpStats("flops")
self.assertEqual(None, flops_total.value)
second_flops = ops.OpStats("flops", 3)
flops_total += second_flops
self.assertEqual(3, flops_total.value)
class DeviceStackTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def testBasicDeviceAssignmentMetadata(self):
def device_func(unused_op):
return "/cpu:*"
const_zero = constant_op.constant([0.0], name="zero")
with ops.device("/cpu"):
const_one = constant_op.constant([1.0], name="one")
with ops.device("/cpu:0"):
const_two = constant_op.constant([2.0], name="two")
with ops.device(device_func):
const_three = constant_op.constant(3.0, name="three")
self.assertEqual(0, len(const_zero.op._device_assignments))
one_list = const_one.op._device_assignments
self.assertEqual(1, len(one_list))
self.assertEqual("/cpu", one_list[0].obj)
self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename))
two_list = const_two.op._device_assignments
self.assertEqual(2, len(two_list))
devices = [t.obj for t in two_list]
self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices))
three_list = const_three.op._device_assignments
self.assertEqual(1, len(three_list))
func_description = three_list[0].obj
expected_regex = r"device_func<.*ops_test.py, [0-9]+"
self.assertRegex(func_description, expected_regex)
@test_util.run_deprecated_v1
def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self):
with ops.device("/cpu"):
const_one = constant_op.constant([1.0], name="one")
with ops.get_default_graph().device("/cpu"):
const_two = constant_op.constant([2.0], name="two")
one_metadata = const_one.op._device_assignments[0]
two_metadata = const_two.op._device_assignments[0]
# Verify both types of device assignment return the right stack info.
self.assertRegex("ops_test.py", os.path.basename(one_metadata.filename))
self.assertEqual(one_metadata.filename, two_metadata.filename)
self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno)
class ColocationGroupTest(test_util.TensorFlowTestCase):
@test_util.run_deprecated_v1
def testBasic(self):
a = constant_op.constant([2.0], name="a")
with ops.colocate_with(a.op):
b = constant_op.constant(3.0)
c = constant_op.constant(4.0)
self.assertEqual([b"loc:@a"], a.op.colocation_groups())
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
with self.assertRaises(ValueError):
c.op.get_attr("_class")
@test_util.run_deprecated_v1
def testBasicColocationMetadata(self):
const_two = constant_op.constant([2.0], name="two")
with ops.colocate_with(const_two.op):
const_three = constant_op.constant(3.0, name="three")
locations_dict = const_three.op._colocation_dict
self.assertIn("two", locations_dict)
metadata = locations_dict["two"]
self.assertIsNone(metadata.obj)
# Check that this test's filename is recorded as the file containing the
# colocation statement.
self.assertEqual("ops_test.py", os.path.basename(metadata.filename))
@test_util.run_deprecated_v1
def testColocationDeviceInteraction(self):
with ops.device("/cpu:0"):
with ops.device("/device:GPU:0"):
a = constant_op.constant([2.0], name="a")
with ops.colocate_with(a.op):
# 'b' is created in the scope of /cpu:0, but it is
# colocated with 'a', which is on '/device:GPU:0'. colocate_with
# overrides devices because it is a stronger constraint.
b = constant_op.constant(3.0)
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
self.assertEqual(a.op.device, b.op.device)
@test_util.run_deprecated_v1
def testColocationCanonicalization(self):
with ops.device("/device:GPU:0"):
_ = constant_op.constant(2.0)
with ops.device(lambda op: "/device:GPU:0"):
b = constant_op.constant(3.0)
with ops.get_default_graph().colocate_with(b):
with ops.device("/device:GPU:0"):
c = constant_op.constant(4.0)
# A's device will be /device:GPU:0
# B's device will be /device:GPU:0
# C's device will be /device:GPU:0 because it
# inherits B's device name, after canonicalizing the names.
self.assertEqual(b.op.device, c.op.device)
@test_util.run_deprecated_v1
def testLocationOverrides(self):
with ops.device("/cpu:0"):
with ops.device("/device:GPU:0"):
a = constant_op.constant([2.0], name="a")
# Note that this colocation is "redundant", since we are
# within the scope of "/device:GPU:0". However, we would like to
# preserve in the GraphDef that these two ops should be
# colocated in a portable way.
with ops.colocate_with(a.op):
b = constant_op.constant(3.0)
c = constant_op.constant(4.0)
d = constant_op.constant(5.0)
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
self.assertEqual("/device:GPU:0", a.op.device)
self.assertEqual(a.op.device, b.op.device)
# Test that device function stack is restored.
self.assertEqual("/device:GPU:0", c.op.device)
self.assertEqual("/device:CPU:0", d.op.device)
@test_util.run_deprecated_v1
def testNestedColocateWith(self):
a = constant_op.constant([2.0], name="a")
with ops.colocate_with(a.op):
b = constant_op.constant(3.0)
with ops.colocate_with(b.op):
c = constant_op.constant(4.0)
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
self.assertEqual([b"loc:@a"], c.op.colocation_groups())
@test_util.run_deprecated_v1
def testMultiColocationGroups(self):
a = constant_op.constant([2.0], name="a")
b = constant_op.constant(3.0, name="b")
with ops.colocate_with(a.op):
with ops.colocate_with(b.op):
c = constant_op.constant(4.0)
self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups()))
@test_util.run_deprecated_v1
def testColocationIgnoreStack(self):
a = constant_op.constant([2.0], name="a")
b = constant_op.constant(3.0, name="b")
with ops.colocate_with(a.op):
with ops.colocate_with(b.op, ignore_existing=True):
c = constant_op.constant(4.0)
self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups()))
@test_util.run_deprecated_v1
def testColocateWithReset(self):
a = constant_op.constant([2.0], name="a")
with ops.colocate_with(a.op):
b = constant_op.constant(3.0, name="b")
with ops.colocate_with(None, ignore_existing=True):
c = constant_op.constant(4.0, name="c")
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
self.assertEqual([b"loc:@c"], c.op.colocation_groups())
@test_util.run_deprecated_v1
def testColocateWithInitialNoneThenNested(self):
a = constant_op.constant([2.0], name="a")
with ops.colocate_with(a.op):
with ops.colocate_with(None, ignore_existing=True):
b = constant_op.constant(3.0, name="b")
with ops.colocate_with(b.op):
c = constant_op.constant(4.0, name="c")
self.assertEqual([b"loc:@b"], b.op.colocation_groups())
self.assertEqual([b"loc:@b"], c.op.colocation_groups())
@test_util.run_deprecated_v1
def testColocateVariables(self):
a = variables.Variable([2.0], name="a")
with ops.colocate_with(a.op):
b = variables.Variable([3.0], name="b")
self.assertEqual([b"loc:@a"], b.op.colocation_groups())
@test_util.run_deprecated_v1
def testColocateResourceVariablesInFunction(self):
with ops.device("/device:CPU:0"):
a = resource_variable_ops.ResourceVariable(1.0)
@def_function.function
def f():
with ops.colocate_with(a):
b = array_ops.ones([], name="output")
self.assertEqual("/device:CPU:0", b.op.device)
f()
def testColocateWithVariableInFunction(self):
v = variables.Variable(1.)
@def_function.function
def f():
with ops.colocate_with(v):
return array_ops.ones([], name="output")
f()
graph_def = f.get_concrete_function().graph.as_graph_def()
wrap_function.function_from_graph_def(graph_def, [], ["output"])
class DeprecatedTest(test_util.TensorFlowTestCase):
def testSuccess(self):
with ops.Graph().as_default() as g:
test_util.set_producer_version(g, 7)
old = test_ops.old()
with self.session(graph=g):
old.run()
def _error(self):
return ((r"Op Old is not available in GraphDef version %d\. "
r"It has been removed in version 8\. For reasons\.") %
versions.GRAPH_DEF_VERSION)
def testGraphConstructionFail(self):
with ops.Graph().as_default():
with self.assertRaisesRegex(NotImplementedError, self._error()):
test_ops.old()
class NameScopeTest(test_util.TensorFlowTestCase):
def testStripAndPrependScope(self):
strs = [
"hidden1/hidden1/weights", # Same prefix. Should strip.
"hidden1///hidden1/weights", # Extra "/". Should strip.
"^hidden1/hidden1/weights", # Same prefix. Should strip.
"loc:@hidden1/hidden1/weights", # Same prefix. Should strip.
"hhidden1/hidden1/weights", # Different prefix. Should keep.
"hidden1"
] # Not a prefix. Should keep.
expected_striped = [
"hidden1/weights", "hidden1/weights", "^hidden1/weights",
"loc:@hidden1/weights", "hhidden1/hidden1/weights", "hidden1"
]
expected_prepended = [
"hidden2/hidden1/weights", "hidden2/hidden1/weights",
"^hidden2/hidden1/weights", "loc:@hidden2/hidden1/weights",
"hidden2/hhidden1/hidden1/weights", "hidden2/hidden1"
]
name_scope_to_strip = "hidden1"
name_scope_to_add = "hidden2"
for es, ep, s in zip(expected_striped, expected_prepended, strs):
striped = ops.strip_name_scope(s, name_scope_to_strip)
self.assertEqual(es, striped)
self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add))
def testGetNameScope(self):
with ops.Graph().as_default() as g:
with ops.name_scope("scope1"):
with ops.name_scope("scope2"):
with ops.name_scope("scope3"):
self.assertEqual("scope1/scope2/scope3", g.get_name_scope())
self.assertEqual("scope1/scope2", g.get_name_scope())
self.assertEqual("scope1", g.get_name_scope())
self.assertEqual("", g.get_name_scope())
def testTwoGraphs(self):
def f():
g1 = ops.Graph()
g2 = ops.Graph()
with g1.as_default():
with g2.as_default():
with ops.name_scope("_"):
pass
self.assertRaisesRegex(ValueError, "'_' is not a valid scope name", f)
class EnableEagerExecutionTest(test_util.TensorFlowTestCase):
@test_util.run_v1_only("b/120545219")
def testBadArgumentsToEnableEagerExecution(self):
with self.assertRaisesRegex(TypeError, "config must be a tf.ConfigProto"):
ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT)
with self.assertRaisesRegex(ValueError, "device_policy must be one of"):
c = config_pb2.ConfigProto()
ops.enable_eager_execution(c, c)
with self.assertRaisesRegex(ValueError, "execution_mode must be one of"):
c = config_pb2.ConfigProto()
ops.enable_eager_execution(c, execution_mode=c)
class _TupleTensor(composite_tensor.CompositeTensor):
"""`Tensor`-like `tuple`-like for custom `Tensor` conversion masquerading."""
def __init__(self, components):
super(_TupleTensor, self).__init__()
self._components = tuple(ops.convert_to_tensor(c) for c in components)
@property
def _type_spec(self):
return _TupleTensorSpec(type_spec.from_value(c) for c in self._components)
def __getitem__(self, key):
return self._components[key]
def __len__(self):
return len(self._components)
def __iter__(self):
return iter(self._components)
class _TupleTensorSpec(type_spec.TypeSpec):
def __init__(self, specs):
self._specs = specs
value_type = property(lambda self: _TupleTensor)
_component_specs = property(lambda self: self._specs)
def _to_components(self, value):
return value._components
def _from_components(self, components):
return _TupleTensor(*components)
def _serialize(self):
return (self._specs,)
class _MyTuple(object):
"""Pretend user-side class for `ConvertToCompositeTensorTest ."""
def __init__(self, components):
super(_MyTuple, self).__init__()
self._components = tuple(components)
def __getitem__(self, key):
return self._components[key]
def __len__(self):
return len(self._components)
def __iter__(self):
return iter(self._components)
ops.register_tensor_conversion_function(
_MyTuple, conversion_func=lambda x, *_, **__: _TupleTensor(x))
class CustomConvertToCompositeTensorTest(test_util.TensorFlowTestCase):
@test_util.disable_tfrt("TODO(kkb): This makes Kokoro tests fail.")
def testCompositeTensorConversion(self):
"""Tests that a user can register a CompositeTensor converter."""
x = _MyTuple((1, [2., 3.], [[4, 5], [6, 7]]))
y = ops.convert_to_tensor_or_composite(x)
self.assertFalse(tensor_util.is_tensor(y))
self.assertIsInstance(y, _TupleTensor)
self.assertLen(y, len(x))
for x_, y_ in zip(x, y):
self.assertIsInstance(y_, ops.Tensor)
self.assertTrue(tensor_util.is_tensor(y_))
self.assertAllEqual(x_, tensor_util.constant_value(y_))
@test_util.disable_tfrt("Packing EagerTensors is not supported yet.")
class PackEagerTensorTest(test_util.TensorFlowTestCase):
def setUp(self):
super(PackEagerTensorTest, self).setUp()
context._reset_context()
cpus = config.list_physical_devices("CPU")
# Set 2 virtual CPUs
config.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
])
def testPack(self):
with context.eager_mode():
with ops.device("CPU:0"):
var0 = resource_variable_ops.ResourceVariable(1.0)
c0 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
with ops.device("CPU:1"):
var1 = resource_variable_ops.ResourceVariable(2.0)
var2 = resource_variable_ops.ResourceVariable([3.0])
c1 = constant_op.constant([9.0])
packed_var0 = ops.pack_eager_tensors([var0.handle, var1.handle])
self.assertTrue(packed_var0.is_packed)
self.assertEqual(packed_var0.dtype, var0.handle.dtype)
self.assertEqual(packed_var0.shape, var0.handle.shape)
self.assertEqual(packed_var0._handle_data, var0.handle._handle_data)
self.assertIn("COMPOSITE:0", packed_var0.device)
self.assertIn("COMPOSITE:0", packed_var0.backing_device)
with self.assertRaises(errors.InvalidArgumentError):
packed_var0.numpy()
# Different dtypes
with self.assertRaises(ValueError):
ops.pack_eager_tensors([var0.handle, c1])
# Different shapes
with self.assertRaises(ValueError):
ops.pack_eager_tensors([c0, c1])
# Different handle data
with self.assertRaises(ValueError):
ops.pack_eager_tensors([var0.handle, var2.handle])
if __name__ == "__main__":
googletest.main()